1 // -*- c-basic-style: google, c-basic-offset: 2 -*- 2 3 /// TF Graph wrapper 4 module tfd.graph; 5 6 import std.string : fromStringz; 7 8 import tensorflow.c_api; 9 import tensorflow.op_def_pb; 10 11 /// Creates a new placeholder in a given graph. 12 @nogc nothrow @trusted 13 TF_Operation* Placeholder(size_t N = 0)( 14 TF_Graph* graph, 15 TF_Status* s, 16 const(char)* name = "feed", 17 TF_DataType dtype = TF_INT32, 18 const long[N] dims = []) 19 { 20 TF_Operation* op; 21 TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); 22 TF_SetAttrType(desc, "dtype", dtype); 23 static if (N != 0) 24 { 25 TF_SetAttrShape(desc, "shape", dims.ptr, dims.length); 26 } 27 op = TF_FinishOperation(desc, s); 28 assert(TF_GetCode(s) == TF_OK, TF_Message(s).fromStringz); 29 assert(op); 30 return op; 31 } 32 33 alias AttrValue = Tensorflow__AttrValue; 34 35 /// Gets an AttrValue from a given operation. 36 @nogc nothrow @trusted 37 bool GetAttrValue( 38 TF_Operation* oper, const(char)* attr_name, 39 AttrValue* attr_value, TF_Status* s) 40 { 41 TF_Buffer* buffer = TF_NewBuffer(); 42 scope (exit) TF_DeleteBuffer(buffer); 43 44 TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); 45 bool ret = TF_GetCode(s) == TF_OK; 46 if (ret) 47 { 48 auto unpacked = tensorflow__attr_value__unpack( 49 null, 50 buffer.length, 51 cast(const(ubyte)*) buffer.data); 52 ret = (unpacked !is null); 53 if (ret) *attr_value = *unpacked; 54 } 55 return ret; 56 } 57 58 59 /// Asserts TF_Status and shows message if failed. 60 @nogc nothrow @trusted 61 void assertStatus(TF_Status* s) 62 { 63 assert(TF_GetCode(s) == TF_OK, TF_Message(s).fromStringz); 64 } 65 66 67 // TODO(karita): support all dtypes in TF 68 enum dtype(T: int) = TF_INT32; 69 70 71 /// Creates a tensor with dtype of T. 72 TF_Tensor* makeTensor(T, size_t num_dims)( 73 const long[num_dims] dims, const(T)* values) 74 { 75 import core.stdc.string : memcpy; 76 77 size_t num_values = 1; 78 foreach (d; dims) { 79 num_values *= d; 80 } 81 82 static if (num_dims == 0) 83 { 84 auto dimsPtr = null; 85 } 86 else 87 { 88 auto dimsPtr = dims.ptr; 89 } 90 TF_Tensor* t = TF_AllocateTensor( 91 dtype!T, dimsPtr, num_dims, T.sizeof * num_values); 92 memcpy(TF_TensorData(t), values, T.sizeof * num_values); 93 return t; 94 } 95 96 97 /// Creates a tensor with a given scalar. 98 TF_Tensor* makeTensor(T)(const(T) scalar) 99 { 100 long[0] dims; 101 return makeTensor!(T, 0)(dims, &scalar); 102 } 103 104 105 /// Creates a const tensor. 106 @nogc nothrow @trusted 107 TF_Operation* Const( 108 TF_Tensor* t, 109 TF_Graph* graph, 110 TF_Status* s, 111 const(char)* name = "const") 112 { 113 TF_Operation* op; 114 TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); 115 TF_SetAttrTensor(desc, "value", t, s); 116 TF_SetAttrType(desc, "dtype", TF_TensorType(t)); 117 op = TF_FinishOperation(desc, s); 118 assertStatus(s); 119 assert(op !is null); 120 return op; 121 } 122 123 124 /// Creates a scalar const tensor. 125 @nogc nothrow @trusted 126 TF_Operation* ScalarConst(int v, TF_Graph* graph, TF_Status* s, 127 const(char)* name = "scalar") 128 { 129 // TODO(karita): free this tensor 130 return Const(makeTensor(v), graph, s, name); 131 } 132 133 134 /// Adds two tensors. 135 @nogc nothrow @trusted 136 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, 137 TF_Status* s, const(char)* name = "add") { 138 TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); 139 TF_Output[2] inputs; 140 inputs[0] = TF_Output(l, 0); 141 inputs[1] = TF_Output(r, 0); 142 TF_AddInputList(desc, inputs.ptr, 2); 143 TF_Operation* op = TF_FinishOperation(desc, s); 144 assertStatus(s); 145 assert(op !is null); 146 return op; 147 } 148 149 /// ditto 150 TF_Operation* Add(TF_Output l, TF_Output r, 151 TF_Graph* graph, TF_Status* s, 152 const(char)* name = "add") { 153 TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); 154 TF_Output[2] inputs; 155 inputs[0] = l; 156 inputs[1] = r; 157 TF_AddInputList(desc, inputs.ptr, 2); 158 return TF_FinishOperation(desc, s); 159 } 160 161 /// CAPI Graph test in `tensorflow/c/c_api_test.c` 162 unittest 163 { 164 import std.stdio; 165 writeln("CAPI Graph test"); 166 167 TF_Status* s = TF_NewStatus(); 168 TF_Graph* graph = TF_NewGraph(); 169 170 // Make a placeholder operation. 171 TF_Operation* feed = Placeholder(graph, s); 172 assertStatus(s); 173 174 // Test TF_Operation*() query functions. 175 assert(TF_OperationName(feed).fromStringz == "feed"); 176 assert(TF_OperationOpType(feed).fromStringz == "Placeholder"); 177 assert(TF_OperationDevice(feed).fromStringz == ""); 178 assert(TF_OperationNumOutputs(feed) == 1); 179 assert(TF_OperationOutputType(TF_Output(feed, 0)) == TF_INT32); 180 assert(TF_OperationOutputListLength(feed, "output", s) == 1); 181 assertStatus(s); 182 assert(TF_OperationNumInputs(feed) == 0); 183 assert(TF_OperationOutputNumConsumers(TF_Output(feed, 0)) == 0); 184 assert(TF_OperationNumControlInputs(feed) == 0); 185 assert(TF_OperationNumControlOutputs(feed) == 0); 186 187 // TODO(karita): implement AttrValue type switching by `value_case` 188 AttrValue attrValue; 189 assert(GetAttrValue(feed, "dtype", &attrValue, s)); 190 assert(attrValue.type == TENSORFLOW__DATA_TYPE__DT_INT32); 191 192 // Test not found errors in TF_Operation*() query functions. 193 assert(TF_OperationOutputListLength(feed, "bogus", s) == -1); 194 assert(TF_GetCode(s) == TF_INVALID_ARGUMENT); 195 assert(!GetAttrValue(feed, "missing", &attrValue, s)); 196 assert(TF_Message(s).fromStringz == 197 "Operation 'feed' has no attr named 'missing'."); 198 199 // Make a constant oper with the scalar "3". 200 TF_Operation* three = ScalarConst(3, graph, s); 201 assertStatus(s); 202 // Add oper. 203 Add(feed, three, graph, s); 204 assertStatus(s); 205 } 206 207 struct Session 208 { 209 import std.container.array : Array; 210 import std.typecons : Tuple; 211 212 TF_Session* session_; 213 alias session_ this; 214 215 Array!TF_Output inputs_; 216 Array!(TF_Tensor*) input_values_; 217 Array!TF_Output outputs_; 218 Array!(TF_Tensor*) output_values_; 219 Array!(TF_Operation*) targets_; 220 221 void DeleteInputValues() 222 { 223 foreach (v; this.input_values_) 224 { 225 TF_DeleteTensor(v); 226 } 227 this.input_values_.clear(); 228 } 229 230 void ResetOutputValues() 231 { 232 foreach (v; this.output_values_) 233 { 234 if (v !is null) 235 { 236 TF_DeleteTensor(v); 237 } 238 } 239 output_values_.clear(); 240 } 241 242 public: 243 244 /// Constructs a new session. 245 this(TF_Graph* graph, TF_Status* s, bool useXLA = false) 246 { 247 TF_SessionOptions* opts = TF_NewSessionOptions(); 248 // TODO(karita): support XLA 249 assert(!useXLA, "XLA is not supported yet."); 250 // TF_EnableXLACompilation(opts, useXLA); 251 this.session_ = TF_NewSession(graph, opts, s); 252 TF_DeleteSessionOptions(opts); 253 } 254 255 ~this() 256 { 257 TF_Status* s = TF_NewStatus(); 258 this.CloseAndDelete(s); 259 assertStatus(s); 260 TF_DeleteStatus(s); 261 } 262 263 /// Closes and deletes input/output values explicitly. 264 void CloseAndDelete(TF_Status* s) 265 { 266 DeleteInputValues(); 267 ResetOutputValues(); 268 if (session_ !is null) { 269 TF_CloseSession(session_, s); 270 assertStatus(s); 271 TF_DeleteSession(session_, s); 272 session_ = null; 273 } 274 } 275 276 /// Sets input values. 277 void SetInputs(TF_Tensor*[TF_Operation*] inputs) 278 { 279 this.DeleteInputValues(); 280 this.inputs_.clear(); 281 this.inputs_.reserve(inputs.length); 282 this.input_values_.reserve(inputs.length); 283 foreach (op, tensor; inputs) 284 { 285 this.inputs_ ~= TF_Output(cast(TF_Operation*) op, 0); 286 this.input_values_ ~= tensor; 287 } 288 } 289 290 /// Sets output values. 291 void SetOutputs(size_t N)(TF_Operation*[N] outputs...) 292 { 293 this.ResetOutputValues(); 294 this.outputs_.clear(); 295 this.outputs_.reserve(N); 296 foreach (op; outputs) 297 { 298 this.outputs_ ~= TF_Output(op, 0); 299 } 300 this.output_values_.length = N; 301 } 302 303 void Run(TF_Status* s) 304 { 305 assert(this.inputs_.length == input_values_.length, 306 "Call SetInputs() before Run()"); 307 this.ResetOutputValues(); 308 this.output_values_.length = this.outputs_.length; 309 310 const inputs_ptr = inputs_.empty ? null : &inputs_[0]; 311 auto input_values_ptr = 312 input_values_.empty ? null : &input_values_[0]; 313 314 const TF_Output* outputs_ptr = 315 outputs_.empty ? null : &outputs_[0]; 316 TF_Tensor** output_values_ptr = 317 output_values_.empty ? null : &output_values_[0]; 318 319 const targets_ptr = targets_.empty ? null : &targets_[0]; 320 321 TF_SessionRun( 322 session_, null, 323 inputs_ptr, input_values_ptr, cast(int) inputs_.length, 324 outputs_ptr, output_values_ptr, cast(int) outputs_.length, 325 targets_ptr, cast(int) targets_.length, 326 null, s); 327 this.DeleteInputValues(); 328 } 329 } 330 331 /// CAPI Session test in `tensorflow/c/c_api_test.c` 332 unittest 333 { 334 TF_Status* s = TF_NewStatus(); 335 TF_Graph* graph = TF_NewGraph(); 336 337 // Make a placeholder operation. 338 TF_Operation* feed = Placeholder(graph, s); 339 assertStatus(s); 340 341 // Make a constant operation with the scalar "2". 342 TF_Operation* two = ScalarConst(2, graph, s); 343 assertStatus(s); 344 345 // Add operation. 346 TF_Operation* add = Add(feed, two, graph, s); 347 assertStatus(s); 348 349 // Create a session for this graph. 350 auto session = Session(graph, s); 351 assertStatus(s); 352 353 // Run the graph. 354 import std.typecons : tuple; 355 session.SetInputs([feed: makeTensor(3)]); 356 session.SetOutputs(add); 357 session.Run(s); 358 assertStatus(s); 359 TF_Tensor* result = session.output_values_[0]; 360 assert(result !is null); 361 assert(TF_TensorType(result) == TF_INT32); 362 int* resultVal = cast(int*) TF_TensorData(result); 363 assert(2 + 3 == *resultVal); 364 } 365