1 /// TF_Tensor wrapper. 2 module tfd.tensor; 3 4 import std.traits : isScalarType; 5 import std.typecons : tuple; 6 7 import mir.ndslice.slice : Slice, SliceKind; 8 import mir.rc.slim_ptr : createSlimRC, SlimRCPtr; 9 10 import tfd.c_api; 11 version (Windows) alias size_t = object.size_t; 12 13 import tfd.testing : assertStatus; 14 15 // TODO(karita): support all dtypes in TF 16 17 /// Meta data to store TF/D types. 18 struct TFDPair(_D, TF_DataType _tf) { 19 /// tensorflow type 20 enum tf = _tf; 21 /// D type 22 alias D = _D; 23 } 24 25 import std.meta : AliasSeq; 26 import std.complex : Complex; 27 import std.numeric : CustomFloat; 28 29 /// IEEE 754-2008 half: https://en.wikipedia.org/wiki/Half-precision_floating-point_format 30 alias half = CustomFloat!(10, 5); 31 /// bfloat16: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format 32 alias bfloat16 = CustomFloat!(7, 8); 33 34 alias tfdTypePairList = AliasSeq!( 35 TFDPair!(float, TF_FLOAT), 36 TFDPair!(double, TF_DOUBLE), 37 TFDPair!(int, TF_INT32), 38 TFDPair!(ubyte, TF_UINT8), 39 TFDPair!(short, TF_INT16), 40 TFDPair!(byte, TF_INT8), 41 42 // TODO(karita): specialize this for conversion. 43 TFDPair!(string, TF_STRING), 44 45 // NOTE: these two looks same 46 TFDPair!(Complex!float, TF_COMPLEX64), 47 TFDPair!(Complex!float, TF_COMPLEX), 48 49 TFDPair!(long, TF_INT64), 50 TFDPair!(bool, TF_BOOL), 51 52 // TODO(karita) 53 // TF_QINT8 = 11, 54 // TF_QUINT8 = 12, 55 // TF_QINT32 = 13, 56 57 TFDPair!(bfloat16, TF_BFLOAT16), 58 59 // TODO(karita) 60 // TF_QINT16 = 15, 61 // TF_QUINT16 = 16, 62 63 TFDPair!(ushort, TF_UINT16), 64 TFDPair!(Complex!double, TF_COMPLEX128), 65 TFDPair!(half, TF_HALF), 66 67 // TODO(karita): what is this? 68 // TF_RESOURCE = 20, 69 // TF_VARIANT = 21, 70 71 TFDPair!(uint, TF_UINT32), 72 TFDPair!(ulong, TF_UINT64), 73 ); 74 75 static foreach (type; tfdTypePairList) 76 { 77 enum tfType(T: type.D) = type.tf; 78 } 79 80 81 /// Creates an uninitialized tensor with dtype of T. 82 @trusted 83 TF_Tensor* empty(T, size_t N)(long[N] dims...) 84 { 85 size_t num_values = 1; 86 foreach (d; dims) { 87 num_values *= d; 88 } 89 90 static if (N == 0) 91 { 92 const dimsPtr = null; 93 } 94 else 95 { 96 const dimsPtr = dims.ptr; 97 } 98 return TF_AllocateTensor( 99 tfType!T, dimsPtr, N, T.sizeof * num_values); 100 } 101 102 103 /// Creates a tensor from a given range. 104 TF_Tensor* makeTF_Tensor(size_t N, SourceRange)(long[N] dims, SourceRange source) 105 { 106 import std.algorithm.mutation : copy; 107 import std.range.primitives : ElementType; 108 109 alias T = ElementType!SourceRange; 110 TF_Tensor* t = empty!T(dims); 111 T[] target = (cast(T*) TF_TensorData(t))[0 .. TF_TensorElementCount(t)]; 112 copy(source, target); 113 return t; 114 } 115 116 /// Creates a tensor from a given mir.ndslice.Slice 117 TF_Tensor* makeTF_Tensor(Iterator, size_t N, SliceKind kind)(Slice!(Iterator, N, kind) slice) 118 { 119 import mir.ndslice.topology : flattened; 120 long[N] shape; 121 static foreach (i; 0 .. N) 122 { 123 shape[i] = slice.length!i; 124 } 125 return makeTF_Tensor(shape, slice.flattened); 126 } 127 128 /// 129 version (tfd_test) 130 @nogc nothrow 131 unittest 132 { 133 import mir.ndslice : iota, universal; 134 135 const slice = iota(2, 3); 136 auto tensor = slice.makeTF_Tensor; 137 scope (exit) TF_DeleteTensor(tensor); 138 } 139 140 /// Creates a tensor from a given scalar. 141 TF_Tensor* makeTF_Tensor(T)(T scalar) if (isScalarType!T) 142 { 143 long[0] dims; 144 return makeTF_Tensor(dims, (&scalar)[0 .. 1]); 145 } 146 147 /// 148 version (tfd_test) 149 @nogc nothrow 150 unittest 151 { 152 auto t = makeTF_Tensor(0); 153 scope (exit) TF_DeleteTensor(t); 154 assert(TF_TensorType(t) == TF_INT32); 155 } 156 157 /// Tensor freed by dtor (RAII) with convinient methods. 158 struct TensorOwner 159 { 160 import mir.ndslice.slice : Contiguous; 161 import mir.rc.array : RCArray; 162 163 /// Base pointer. 164 TF_Tensor* base; 165 alias base this; 166 167 // Not copyable. 168 @disable this(this); 169 170 /// Dtor. 171 @nogc nothrow @trusted 172 ~this() 173 { 174 TF_DeleteTensor(this.base); 175 } 176 177 /// Return an array storing data. 178 @trusted 179 inout(T)[] payload(T)() inout 180 { 181 assert(tfType!T == this.dataType); 182 auto tp = cast(inout(T)*) TF_TensorData(this.base); 183 return tp[0 .. this.elementCount]; 184 } 185 186 /// Return the number of elements. 187 @nogc nothrow @trusted 188 long elementCount() const 189 { 190 return TF_TensorElementCount(this.base); 191 } 192 193 /// Return the number of dimentions. 194 @nogc nothrow @trusted 195 int ndim() const 196 { 197 return TF_NumDims(this.base); 198 } 199 200 /// Returns a tensor shape. 201 @nogc nothrow @trusted 202 RCArray!long shape() const 203 { 204 auto ret = RCArray!long(this.ndim); 205 foreach (i; 0 .. this.ndim) 206 { 207 ret[i] = TF_Dim(this.base, i); 208 } 209 return ret; 210 } 211 212 /// Returns data type, i.e. element type enum identifier. 213 @nogc nothrow @trusted 214 TF_DataType dataType() const 215 { 216 return TF_TensorType(this.base); 217 } 218 219 /// Returns a tensor slice as same as a given slice with assertions. 220 Slice!(T*, N, Contiguous) 221 slicedAs(Iterator, size_t N, SliceKind kind, T = typeof(Iterator.init[0]))( 222 Slice!(Iterator, N, kind) slice) 223 { 224 static foreach (i; 0 .. N) 225 { 226 assert(this.shape[i] == slice.length!i); 227 } 228 return cast(typeof(return)) this.sliced!(T, N); 229 } 230 231 /// Return a tensor slice with an element type T with assertions. 232 Slice!(T*, N, Contiguous) sliced(T, size_t N)() 233 { 234 import mir.ndslice.slice : sliced; 235 236 assert(this.ndim == N); 237 238 size_t[N] lengths; 239 static foreach (i; 0 .. N) 240 { 241 lengths[i] = this.shape[i]; 242 } 243 return this.payload!T.sliced(lengths); 244 } 245 246 /// Returns a scalar if valid. 247 ref inout(T) scalar(T)() inout 248 { 249 assert(this.ndim == 0); 250 return this.payload!T[0]; 251 } 252 } 253 254 /// 255 alias Tensor = SlimRCPtr!TensorOwner; 256 257 /// Allocates ref-counted (RC) Tensor. 258 /// TODO(karita): non-allocated (borrowed) version. 259 @trusted 260 Tensor tensor(Args ...)(Args args) 261 { 262 import core.lifetime : forward; 263 return createSlimRC!TensorOwner(makeTF_Tensor(forward!args)); 264 } 265 266 @nogc nothrow @trusted 267 Tensor tensor(TF_Tensor* t) 268 { 269 return createSlimRC!TensorOwner(t); 270 } 271 272 /// Make a scalar RCTensor. 273 version (tfd_test) 274 @nogc nothrow @safe 275 unittest 276 { 277 const t = tensor(123); 278 279 // check content 280 assert(t.payload!int[0] == 123); 281 assert(t.ndim == 0); 282 assert(t.elementCount == 1); 283 assert(t.shape.length == 0); 284 assert(t.dataType == TF_INT32); 285 286 // check RC 287 assert(t._counter == 1); 288 { 289 const t1 = t; 290 assert(t._counter == 2); 291 assert(t1._counter == 2); 292 } 293 assert(t._counter == 1); 294 } 295 296 /// Make an empty multi-dim RCTensor. 297 version (tfd_test) 298 @nogc nothrow @safe 299 unittest 300 { 301 const t = createSlimRC!TensorOwner(empty!double(1, 2, 3)); 302 303 // check content 304 assert(t.ndim == 3); 305 static immutable expectedShape = [1, 2, 3]; 306 assert(t.shape[] == expectedShape); 307 assert(t.dataType == TF_DOUBLE); 308 } 309 310 /// Make a Tensor from iota slice. 311 version (tfd_test) 312 @nogc nothrow @safe 313 unittest 314 { 315 import mir.ndslice : iota, sliced; 316 317 auto s = iota(2, 3); 318 auto t = s.tensor; 319 const st = t.slicedAs(s); 320 assert(t.dataType == TF_INT64); 321 assert(t.shape[] == s.shape); 322 assert(st == s); 323 }