1 /// TF_Tensor wrapper.
2 module tfd.tensor;
3 
4 import std.traits : isScalarType;
5 import std.typecons : tuple;
6 
7 import mir.ndslice.slice : Slice, SliceKind;
8 import mir.rc.slim_ptr : createSlimRC, SlimRCPtr;
9 
10 import tfd.c_api;
11 version (Windows) alias size_t = object.size_t;
12 
13 import tfd.testing : assertStatus;
14 
15 // TODO(karita): support all dtypes in TF
16 
17 /// Meta data to store TF/D types.
18 struct TFDPair(_D, TF_DataType _tf) {
19   /// tensorflow type
20   enum tf = _tf;
21   /// D type
22   alias D = _D;
23 }
24 
25 import std.meta : AliasSeq;
26 import std.complex : Complex;
27 import std.numeric : CustomFloat;
28 
29 /// IEEE 754-2008 half: https://en.wikipedia.org/wiki/Half-precision_floating-point_format
30 alias half = CustomFloat!(10, 5);
31 /// bfloat16: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
32 alias bfloat16 = CustomFloat!(7, 8);
33 
34 alias tfdTypePairList = AliasSeq!(
35     TFDPair!(float, TF_FLOAT),
36     TFDPair!(double, TF_DOUBLE),
37     TFDPair!(int, TF_INT32),
38     TFDPair!(ubyte, TF_UINT8),
39     TFDPair!(short, TF_INT16),
40     TFDPair!(byte, TF_INT8),
41 
42     // TODO(karita): specialize this for conversion.
43     TFDPair!(string, TF_STRING),
44 
45     // NOTE: these two looks same
46     TFDPair!(Complex!float, TF_COMPLEX64),
47     TFDPair!(Complex!float, TF_COMPLEX),
48 
49     TFDPair!(long, TF_INT64),
50     TFDPair!(bool, TF_BOOL),
51 
52     // TODO(karita)
53     // TF_QINT8 = 11,
54     // TF_QUINT8 = 12,
55     // TF_QINT32 = 13,
56 
57     TFDPair!(bfloat16, TF_BFLOAT16),
58 
59     // TODO(karita)
60     // TF_QINT16 = 15,
61     // TF_QUINT16 = 16,
62 
63     TFDPair!(ushort, TF_UINT16),
64     TFDPair!(Complex!double, TF_COMPLEX128),
65     TFDPair!(half, TF_HALF),
66 
67     // TODO(karita): what is this?
68     // TF_RESOURCE = 20,
69     // TF_VARIANT = 21,
70 
71     TFDPair!(uint, TF_UINT32),
72     TFDPair!(ulong, TF_UINT64),
73 );
74 
75 static foreach (type; tfdTypePairList)
76 {
77   enum tfType(T: type.D) = type.tf;
78 }
79 
80 
81 /// Creates an uninitialized tensor with dtype of T.
82 @trusted
83 TF_Tensor* empty(T, size_t N)(long[N] dims...)
84 {
85   size_t num_values = 1;
86   foreach (d; dims) {
87     num_values *= d;
88   }
89 
90   static if (N == 0)
91   {
92     const dimsPtr = null;
93   }
94   else
95   {
96     const dimsPtr = dims.ptr;
97   }
98   return TF_AllocateTensor(
99       tfType!T, dimsPtr, N, T.sizeof * num_values);
100 }
101 
102 
103 /// Creates a tensor from a given range.
104 TF_Tensor* makeTF_Tensor(size_t N, SourceRange)(long[N] dims, SourceRange source)
105 {
106   import std.algorithm.mutation : copy;
107   import std.range.primitives : ElementType;
108 
109   alias T = ElementType!SourceRange;
110   TF_Tensor* t = empty!T(dims);
111   T[] target = (cast(T*) TF_TensorData(t))[0 .. TF_TensorElementCount(t)];
112   copy(source, target);
113   return t;
114 }
115 
116 /// Creates a tensor from a given mir.ndslice.Slice
117 TF_Tensor* makeTF_Tensor(Iterator, size_t N, SliceKind kind)(Slice!(Iterator, N, kind) slice)
118 {
119   import mir.ndslice.topology : flattened;
120   long[N] shape;
121   static foreach (i; 0 .. N)
122   {
123     shape[i] = slice.length!i;
124   }
125   return makeTF_Tensor(shape, slice.flattened);
126 }
127 
128 ///
129 version (tfd_test)
130 @nogc nothrow
131 unittest
132 {
133   import mir.ndslice : iota, universal;
134 
135   const slice = iota(2, 3);
136   auto tensor = slice.makeTF_Tensor;
137   scope (exit) TF_DeleteTensor(tensor);
138 }
139 
140 /// Creates a tensor from a given scalar.
141 TF_Tensor* makeTF_Tensor(T)(T scalar) if (isScalarType!T)
142 {
143   long[0] dims;
144   return makeTF_Tensor(dims, (&scalar)[0 .. 1]);
145 }
146 
147 ///
148 version (tfd_test)
149 @nogc nothrow
150 unittest
151 {
152   auto t = makeTF_Tensor(0);
153   scope (exit) TF_DeleteTensor(t);
154   assert(TF_TensorType(t) == TF_INT32);
155 }
156 
157 /// Tensor freed by dtor (RAII) with convinient methods.
158 struct TensorOwner
159 {
160   import mir.ndslice.slice : Contiguous;
161   import mir.rc.array : RCArray;
162 
163   /// Base pointer.
164   TF_Tensor* base;
165   alias base this;
166 
167   // Not copyable.
168   @disable this(this);
169 
170   /// Dtor.
171   @nogc nothrow @trusted
172   ~this()
173   {
174     TF_DeleteTensor(this.base);
175   }
176 
177   /// Return an array storing data.
178   @trusted
179   inout(T)[] payload(T)() inout
180   {
181     assert(tfType!T == this.dataType);
182     auto tp = cast(inout(T)*) TF_TensorData(this.base);
183     return tp[0 .. this.elementCount];
184   }
185 
186   /// Return the number of elements.
187   @nogc nothrow @trusted
188   long elementCount() const
189   {
190     return TF_TensorElementCount(this.base);
191   }
192 
193   /// Return the number of dimentions.
194   @nogc nothrow @trusted
195   int ndim() const
196   {
197     return TF_NumDims(this.base);
198   }
199 
200   /// Returns a tensor shape.
201   @nogc nothrow @trusted
202   RCArray!long shape() const
203   {
204     auto ret = RCArray!long(this.ndim);
205     foreach (i; 0 .. this.ndim)
206     {
207       ret[i] = TF_Dim(this.base, i);
208     }
209     return ret;
210   }
211 
212   /// Returns data type, i.e. element type enum identifier.
213   @nogc nothrow @trusted
214   TF_DataType dataType() const
215   {
216     return TF_TensorType(this.base);
217   }
218 
219   /// Returns a tensor slice as same as a given slice with assertions.
220   Slice!(T*, N, Contiguous)
221   slicedAs(Iterator, size_t N, SliceKind kind, T = typeof(Iterator.init[0]))(
222       Slice!(Iterator, N, kind) slice)
223   {
224     static foreach (i; 0 .. N)
225     {
226       assert(this.shape[i] == slice.length!i);
227     }
228     return cast(typeof(return)) this.sliced!(T, N);
229   }
230 
231   /// Return a tensor slice with an element type T with assertions.
232   Slice!(T*, N, Contiguous) sliced(T, size_t N)()
233   {
234     import mir.ndslice.slice : sliced;
235 
236     assert(this.ndim == N);
237 
238     size_t[N] lengths;
239     static foreach (i; 0 .. N)
240     {
241       lengths[i] = this.shape[i];
242     }
243     return this.payload!T.sliced(lengths);
244   }
245 
246   /// Returns a scalar if valid.
247   ref inout(T) scalar(T)() inout
248   {
249     assert(this.ndim == 0);
250     return this.payload!T[0];
251   }
252 }
253 
254 ///
255 alias Tensor = SlimRCPtr!TensorOwner;
256 
257 /// Allocates ref-counted (RC) Tensor.
258 /// TODO(karita): non-allocated (borrowed) version.
259 @trusted
260 Tensor tensor(Args ...)(Args args)
261 {
262   import core.lifetime : forward;
263   return createSlimRC!TensorOwner(makeTF_Tensor(forward!args));
264 }
265 
266 @nogc nothrow @trusted
267 Tensor tensor(TF_Tensor* t)
268 {
269   return createSlimRC!TensorOwner(t);
270 }
271 
272 /// Make a scalar RCTensor.
273 version (tfd_test)
274 @nogc nothrow @safe
275 unittest
276 {
277   const t = tensor(123);
278 
279   // check content
280   assert(t.payload!int[0] == 123);
281   assert(t.ndim == 0);
282   assert(t.elementCount == 1);
283   assert(t.shape.length == 0);
284   assert(t.dataType == TF_INT32);
285 
286   // check RC
287   assert(t._counter == 1);
288   {
289     const t1 = t;
290     assert(t._counter == 2);
291     assert(t1._counter == 2);
292   }
293   assert(t._counter == 1);
294 }
295 
296 /// Make an empty multi-dim RCTensor.
297 version (tfd_test)
298 @nogc nothrow @safe
299 unittest
300 {
301   const t = createSlimRC!TensorOwner(empty!double(1, 2, 3));
302 
303   // check content
304   assert(t.ndim == 3);
305   static immutable expectedShape = [1, 2, 3];
306   assert(t.shape[] == expectedShape);
307   assert(t.dataType == TF_DOUBLE);
308 }
309 
310 /// Make a Tensor from iota slice.
311 version (tfd_test)
312 @nogc nothrow @safe
313 unittest
314 {
315   import mir.ndslice : iota, sliced;
316 
317   auto s = iota(2, 3);
318   auto t = s.tensor;
319   const st = t.slicedAs(s);
320   assert(t.dataType == TF_INT64);
321   assert(t.shape[] == s.shape);
322   assert(st == s);
323 }