1 // -*- c-basic-style: google, c-basic-offset: 2 -*-
2 
3 /// TF Graph wrapper
4 module tfd.graph;
5 
6 import std.string : fromStringz;
7 
8 import tensorflow.c_api;
9 import tensorflow.op_def_pb;
10 
11 /// Creates a new placeholder in a given graph.
12 @nogc nothrow @trusted
13 TF_Operation* Placeholder(size_t N = 0)(
14     TF_Graph* graph,
15     TF_Status* s,
16     const(char)* name = "feed",
17     TF_DataType dtype = TF_INT32,
18     const long[N] dims = [])
19 {
20   TF_Operation* op;
21   TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
22   TF_SetAttrType(desc, "dtype", dtype);
23   static if (N != 0)
24   {
25     TF_SetAttrShape(desc, "shape", dims.ptr, dims.length);
26   }
27   op = TF_FinishOperation(desc, s);
28   assert(TF_GetCode(s) == TF_OK, TF_Message(s).fromStringz);
29   assert(op);
30   return op;
31 }
32 
33 alias AttrValue = Tensorflow__AttrValue;
34 
35 /// Gets an AttrValue from a given operation.
36 @nogc nothrow @trusted
37 bool GetAttrValue(
38     TF_Operation* oper, const(char)* attr_name,
39     AttrValue* attr_value, TF_Status* s)
40 {
41   TF_Buffer* buffer = TF_NewBuffer();
42   scope (exit) TF_DeleteBuffer(buffer);
43 
44   TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
45   bool ret = TF_GetCode(s) == TF_OK;
46   if (ret)
47   {
48     auto unpacked = tensorflow__attr_value__unpack(
49         null,
50         buffer.length,
51         cast(const(ubyte)*) buffer.data);
52     ret = (unpacked !is null);
53     if (ret) *attr_value = *unpacked;
54   }
55   return ret;
56 }
57 
58 
59 /// Asserts TF_Status and shows message if failed.
60 @nogc nothrow @trusted
61 void assertStatus(TF_Status* s)
62 {
63   assert(TF_GetCode(s) == TF_OK, TF_Message(s).fromStringz);
64 }
65 
66 
67 // TODO(karita): support all dtypes in TF
68 enum dtype(T: int) = TF_INT32;
69 
70 
71 /// Creates a tensor with dtype of T.
72 TF_Tensor* makeTensor(T, size_t num_dims)(
73   const long[num_dims] dims, const(T)* values) 
74 {
75   import core.stdc.string : memcpy;
76 
77   size_t num_values = 1;
78   foreach (d; dims) {
79     num_values *= d;
80   }
81 
82   static if (num_dims == 0)
83   {
84     auto dimsPtr = null;
85   }
86   else
87   {
88     auto dimsPtr = dims.ptr;
89   }
90   TF_Tensor* t = TF_AllocateTensor(
91       dtype!T, dimsPtr, num_dims, T.sizeof * num_values);
92   memcpy(TF_TensorData(t), values, T.sizeof * num_values);
93   return t;
94 }
95 
96 
97 /// Creates a tensor with a given scalar.
98 TF_Tensor* makeTensor(T)(const(T) scalar)
99 {
100   long[0] dims;
101   return makeTensor!(T, 0)(dims, &scalar);
102 }
103 
104 
105 /// Creates a const tensor.
106 @nogc nothrow @trusted
107 TF_Operation* Const(
108     TF_Tensor* t,
109     TF_Graph* graph,
110     TF_Status* s,
111     const(char)* name = "const")
112 {
113   TF_Operation* op;
114   TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
115   TF_SetAttrTensor(desc, "value", t, s);
116   TF_SetAttrType(desc, "dtype", TF_TensorType(t));
117   op = TF_FinishOperation(desc, s);
118   assertStatus(s);
119   assert(op !is null);
120   return op;
121 }
122 
123 
124 /// Creates a scalar const tensor.
125 @nogc nothrow @trusted
126 TF_Operation* ScalarConst(int v, TF_Graph* graph, TF_Status* s,
127                           const(char)* name = "scalar") 
128 {
129   // TODO(karita): free this tensor
130   return Const(makeTensor(v), graph, s, name);
131 }
132 
133 
134 /// Adds two tensors.
135 @nogc nothrow @trusted
136 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
137                   TF_Status* s, const(char)* name = "add") {
138   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
139   TF_Output[2] inputs;
140   inputs[0] = TF_Output(l, 0);
141   inputs[1] = TF_Output(r, 0);
142   TF_AddInputList(desc, inputs.ptr, 2);
143   TF_Operation* op = TF_FinishOperation(desc, s);
144   assertStatus(s);
145   assert(op !is null);
146   return op;
147 }
148 
149 /// ditto
150 TF_Operation* Add(TF_Output l, TF_Output r,
151                   TF_Graph* graph, TF_Status* s,
152                   const(char)* name = "add") {
153   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
154   TF_Output[2] inputs;
155   inputs[0] = l;
156   inputs[1] = r;
157   TF_AddInputList(desc, inputs.ptr, 2);
158   return TF_FinishOperation(desc, s);
159 }
160 
161 /// CAPI Graph test in `tensorflow/c/c_api_test.c`
162 unittest
163 {
164   import std.stdio;
165   writeln("CAPI Graph test");
166 
167   TF_Status* s = TF_NewStatus();
168   TF_Graph* graph = TF_NewGraph();
169 
170   // Make a placeholder operation.
171   TF_Operation* feed = Placeholder(graph, s);
172   assertStatus(s);
173 
174   // Test TF_Operation*() query functions.
175   assert(TF_OperationName(feed).fromStringz == "feed");
176   assert(TF_OperationOpType(feed).fromStringz == "Placeholder");
177   assert(TF_OperationDevice(feed).fromStringz == "");
178   assert(TF_OperationNumOutputs(feed) == 1);
179   assert(TF_OperationOutputType(TF_Output(feed, 0)) == TF_INT32);
180   assert(TF_OperationOutputListLength(feed, "output", s) == 1);
181   assertStatus(s);
182   assert(TF_OperationNumInputs(feed) == 0);
183   assert(TF_OperationOutputNumConsumers(TF_Output(feed, 0)) == 0);
184   assert(TF_OperationNumControlInputs(feed) == 0);
185   assert(TF_OperationNumControlOutputs(feed) == 0);
186 
187   // TODO(karita): implement AttrValue type switching by `value_case`
188   AttrValue attrValue;
189   assert(GetAttrValue(feed, "dtype", &attrValue, s));
190   assert(attrValue.type == TENSORFLOW__DATA_TYPE__DT_INT32);
191 
192   // Test not found errors in TF_Operation*() query functions.
193   assert(TF_OperationOutputListLength(feed, "bogus", s) == -1);
194   assert(TF_GetCode(s) == TF_INVALID_ARGUMENT);
195   assert(!GetAttrValue(feed, "missing", &attrValue, s));
196   assert(TF_Message(s).fromStringz ==
197          "Operation 'feed' has no attr named 'missing'.");
198 
199   // Make a constant oper with the scalar "3".
200   TF_Operation* three = ScalarConst(3, graph, s);
201   assertStatus(s);
202   // Add oper.
203   Add(feed, three, graph, s);
204   assertStatus(s);
205 }
206 
207 struct Session 
208 {
209   import std.container.array : Array;
210   import std.typecons : Tuple;
211 
212   TF_Session* session_;
213   alias session_ this;
214 
215   Array!TF_Output inputs_;
216   Array!(TF_Tensor*) input_values_;
217   Array!TF_Output outputs_;
218   Array!(TF_Tensor*) output_values_;
219   Array!(TF_Operation*) targets_;
220 
221   void DeleteInputValues()
222   {
223     foreach (v; this.input_values_) 
224     {
225       TF_DeleteTensor(v);
226     }
227     this.input_values_.clear();
228   }
229 
230   void ResetOutputValues()
231   {
232     foreach (v; this.output_values_) 
233     {
234       if (v !is null)
235       {
236         TF_DeleteTensor(v);
237       }
238     }
239     output_values_.clear();
240   }
241 
242 public:
243 
244   /// Constructs a new session.
245   this(TF_Graph* graph, TF_Status* s, bool useXLA = false)
246   {
247     TF_SessionOptions* opts = TF_NewSessionOptions();
248     // TODO(karita): support XLA
249     assert(!useXLA, "XLA is not supported yet.");
250     // TF_EnableXLACompilation(opts, useXLA);
251     this.session_ = TF_NewSession(graph, opts, s);
252     TF_DeleteSessionOptions(opts);
253   }
254 
255   ~this()
256   {
257     TF_Status* s = TF_NewStatus();
258     this.CloseAndDelete(s);
259     assertStatus(s);
260     TF_DeleteStatus(s);
261   }
262 
263   /// Closes and deletes input/output values explicitly.
264   void CloseAndDelete(TF_Status* s) 
265   {
266     DeleteInputValues();
267     ResetOutputValues();
268     if (session_ !is null) {
269       TF_CloseSession(session_, s);
270       assertStatus(s);
271       TF_DeleteSession(session_, s);
272       session_ = null;
273     }
274   }
275 
276   /// Sets input values.
277   void SetInputs(TF_Tensor*[TF_Operation*] inputs)
278   {
279     this.DeleteInputValues();
280     this.inputs_.clear();
281     this.inputs_.reserve(inputs.length);
282     this.input_values_.reserve(inputs.length);
283     foreach (op, tensor; inputs)
284     {
285       this.inputs_ ~= TF_Output(cast(TF_Operation*) op, 0);
286       this.input_values_ ~= tensor;
287     }
288   }
289 
290   /// Sets output values.
291   void SetOutputs(size_t N)(TF_Operation*[N] outputs...)
292   {
293     this.ResetOutputValues();
294     this.outputs_.clear();
295     this.outputs_.reserve(N);
296     foreach (op; outputs)
297     {
298       this.outputs_ ~= TF_Output(op, 0);
299     }
300     this.output_values_.length = N;
301   }
302 
303   void Run(TF_Status* s)
304   {
305     assert(this.inputs_.length == input_values_.length,
306           "Call SetInputs() before Run()");
307     this.ResetOutputValues();
308     this.output_values_.length = this.outputs_.length;
309 
310     const inputs_ptr = inputs_.empty ? null : &inputs_[0];
311     auto input_values_ptr =
312         input_values_.empty ? null : &input_values_[0];
313 
314     const TF_Output* outputs_ptr = 
315         outputs_.empty ? null : &outputs_[0];
316     TF_Tensor** output_values_ptr =
317         output_values_.empty ? null : &output_values_[0];
318 
319     const targets_ptr = targets_.empty ? null : &targets_[0];
320 
321     TF_SessionRun(
322         session_, null, 
323         inputs_ptr, input_values_ptr, cast(int) inputs_.length,
324         outputs_ptr, output_values_ptr, cast(int) outputs_.length, 
325         targets_ptr, cast(int) targets_.length,
326         null, s);
327     this.DeleteInputValues();
328   }
329 }
330 
331 /// CAPI Session test in `tensorflow/c/c_api_test.c`
332 unittest
333 {
334   TF_Status* s = TF_NewStatus();
335   TF_Graph* graph = TF_NewGraph();
336 
337   // Make a placeholder operation.
338   TF_Operation* feed = Placeholder(graph, s);
339   assertStatus(s);
340 
341   // Make a constant operation with the scalar "2".
342   TF_Operation* two = ScalarConst(2, graph, s);
343   assertStatus(s);
344 
345   // Add operation.
346   TF_Operation* add = Add(feed, two, graph, s);
347   assertStatus(s);
348 
349   // Create a session for this graph.
350   auto session = Session(graph, s);
351   assertStatus(s);
352 
353   // Run the graph.
354   import std.typecons : tuple;
355   session.SetInputs([feed: makeTensor(3)]);
356   session.SetOutputs(add);
357   session.Run(s);
358   assertStatus(s);
359   TF_Tensor* result = session.output_values_[0];
360   assert(result !is null);
361   assert(TF_TensorType(result) == TF_INT32);
362   int* resultVal = cast(int*) TF_TensorData(result);
363   assert(2 + 3 == *resultVal);
364 }
365