1 /// TF_Graph wrapper.
2 module tfd.graph;
3 
4 import std.string : fromStringz;
5 
6 import mir.rc.slim_ptr : createSlimRC, SlimRCPtr;
7 
8 import tfd.c_api;
9 import tfd.testing : assertStatus;
10 
11 
12 /// TF_Graph freed by dtor (RAII) with convinient methods.
13 struct GraphOwner
14 {
15   /// Raw pointer.
16   TF_Graph* ptr;
17   /// Status pointer.
18   TF_Status* status;
19   alias ptr this;
20 
21   // Not copyable.
22   @disable this(this);
23 
24   @nogc nothrow @trusted
25   ~this()
26   {
27     TF_DeleteGraph(this.ptr);
28     TF_DeleteStatus(this.status);
29   }
30 
31   /// Loads serialized graph (GraphDef proto).
32   @nogc nothrow @trusted
33   void deserialize(const(void)[] proto)
34   {
35     auto buffer = TF_NewBufferFromString(proto.ptr, proto.length);
36     scope (exit) TF_DeleteBuffer(buffer);
37     auto opts = TF_NewImportGraphDefOptions;
38     TF_GraphImportGraphDef(this.ptr, buffer, opts, this.status);
39     assertStatus(this.status);
40   }
41 
42   /// Returns serialized bytes (GraphDef proto).
43   @nogc nothrow @system
44   TF_Buffer* serialize()
45   {
46     auto buffer = TF_NewBuffer;
47     TF_GraphToGraphDef(this.ptr, buffer, this.status);
48     assertStatus(this.status);
49     return buffer;
50   }
51 
52   // Reads serialized bytes (GraphDef proto) from file.
53   @nogc nothrow @trusted
54   void read(const(char)* fileName, size_t block = 1024)
55   {
56     import core.stdc.stdio : feof, fopen, fread;
57     import core.stdc.stdlib : free, realloc;
58     
59     size_t len;
60     void* buffer = realloc(null, block);
61     assert(buffer, "realloc failed");
62     scope (exit) free(buffer);
63 
64     void* ptr = realloc(null, block);
65     assert(ptr, "realloc failed");
66     scope (exit) free(ptr);
67 
68     auto fp = fopen(fileName, "rb");
69     while (!feof(fp))
70     {
71        auto inc = fread(ptr, 1, block, fp);
72        len += inc;
73        ptr = realloc(ptr, len + block);
74        assert(ptr, "realloc failed");
75     }
76     this.deserialize(ptr[0 .. len]);
77   }
78 
79   /// Writes serialized bytes (GraphDef proto) to a given file.
80   @nogc nothrow @trusted
81   void write(const(char)* fileName)
82   {
83     import core.stdc.stdio : fopen, fwrite;
84 
85     auto buffer = this.serialize();
86     scope (exit) TF_DeleteBuffer(buffer);
87     auto fp = fopen(fileName, "wb");
88     fwrite(buffer.data, 1, buffer.length, fp);
89   }
90 }
91 
92 /// TF_Operation wrapper used in Graph.
93 struct Operation
94 {
95   /// Raw pointer.
96   TF_Operation* base;
97   /// Graph scope containing this operation.
98   SlimRCPtr!GraphOwner graph;
99   alias base this;
100 
101   /// Binary operator for +.
102   @trusted  Operation opBinary(string op : "+")(Operation rhs)
103   {
104     assert(this.graph == rhs.graph);
105     scope (exit) assertStatus(this.graph.status);
106 
107     TF_OperationDescription* desc = TF_NewOperation(this.graph, "AddN", "add");
108     TF_Output[2] inputs;
109     inputs[0] = TF_Output(this.base, 0);
110     inputs[1] = TF_Output(rhs.base, 0);
111     TF_AddInputList(desc, inputs.ptr, 2);
112     TF_Operation* op = TF_FinishOperation(desc, this.graph.status);
113     assertStatus(this.graph.status);
114     assert(op !is null);
115     return Operation(op, this.graph);
116   }
117 
118 }
119 
120 /// Shared GraphOwner type.
121 struct Graph
122 {
123   import tfd.session : Session;
124   import tfd.tensor : tfType;
125 
126   /// Base reference counted pointer.
127   SlimRCPtr!GraphOwner base;
128   alias base this;
129   /// Get an operation by name
130 
131   @nogc nothrow @trusted
132   bool hasOperationByName(const(char)* name)
133   {
134     auto opr = TF_GraphOperationByName(this.ptr, name);
135     return opr !is null;
136   }
137   
138   @nogc nothrow @trusted
139   Operation getOperationByName(const(char)* name)
140   {
141     auto opr = TF_GraphOperationByName(this.ptr, name);
142     assert(opr);
143     return Operation(opr, this);
144   }
145 
146   /// Creates a placeholder in this graph.
147   @trusted
148   Operation placeholder(T, size_t N)(
149       const(char)* name,
150       long[N] dims...) scope return
151   {
152     TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Placeholder", name);
153     TF_SetAttrType(desc, "dtype", tfType!T);
154     static if (N != 0)
155     {
156       TF_SetAttrShape(desc, "shape", dims.ptr, dims.length);
157     }
158     TF_Operation* op = TF_FinishOperation(desc, this.status);
159     assertStatus(this.status);
160     assert(op);
161     return Operation(op, this.base);
162   }
163 
164   /// ditto.
165   Operation placeholder(T, size_t N)(long[N] dims ...)
166   {
167     return placeholder!T("", dims);
168   }
169 
170   /// Creates a constant in this graph.
171   @trusted
172   Operation constant(S)(S x, const(char)* name = "const")
173   {
174     import tfd.tensor : makeTF_Tensor;
175 
176     // TODO(karita) free TF_Tensor when op is freed?
177     auto t = x.makeTF_Tensor;
178     TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Const", name);
179     TF_SetAttrTensor(desc, "value", t, this.status);
180     assertStatus(this.status);
181     TF_SetAttrType(desc, "dtype", TF_TensorType(t));
182     TF_Operation* op = TF_FinishOperation(desc, this.status);
183     assertStatus(this.status);
184     assert(op !is null);
185     return Operation(op, this);
186   }
187 
188   /// Creates a Session in this graph.
189   @nogc nothrow @trusted
190   Session session()
191   {
192     return Session(this.ptr);
193   }
194 }
195 
196 /// Creates a new reference-counted Graph object.
197 @nogc nothrow @trusted
198 Graph newGraph()
199 {
200   import mir.rc.slim_ptr : createSlimRC;
201   return Graph(createSlimRC!GraphOwner(TF_NewGraph(), TF_NewStatus()));
202 }
203 
204 /// Export/import graph.
205 version (tfd_test)
206 nothrow
207 unittest
208 {
209   import tfd.tensor;
210 
211   TF_Buffer* buffer;
212   scope (exit) TF_DeleteBuffer(buffer);
213   {
214     auto graph = newGraph;
215     with (graph)
216     {
217       auto a = placeholder!int("a");
218       assert(TF_GraphOperationByName(graph, "a"));
219       auto b = constant(3, "b");
220       assert(TF_GraphOperationByName(graph, "b"));
221       // TODO(karita): provide name "add", identity?
222       auto add = a + b;
223       assert(TF_GraphOperationByName(graph, "add"));
224     }
225     buffer = graph.serialize;
226     // for coverage
227     graph.write("tmp.bin");
228   }
229   with (newGraph) {
230     // Import from the GraphDef (protobuf)
231     deserialize(buffer.data[0 .. buffer.length]);
232     auto a = getOperationByName("a");
233     auto add = getOperationByName("add");
234     const t = session.run([add], [a: 1.tensor])[0].tensor;
235     assert(t.scalar!int == 1 + 3);
236   }
237 }