1 /// TF_Graph wrapper.
2 module tfd.graph;
3 
4 import std..string : fromStringz;
5 
6 import mir.rc.slim_ptr : createSlimRC, SlimRCPtr;
7 
8 import tfd.c_api;
9 import tfd.testing : assertStatus;
10 
11 
12 /// TF_Graph freed by dtor (RAII) with convinient methods.
13 struct GraphOwner
14 {
15   /// Raw pointer.
16   TF_Graph* ptr;
17   /// Status pointer.
18   TF_Status* status;
19   alias ptr this;
20 
21   // Not copyable.
22   @disable this(this);
23 
24   @nogc nothrow @trusted
25   ~this()
26   {
27     TF_DeleteGraph(this.ptr);
28     TF_DeleteStatus(this.status);
29   }
30 
31   /// Loads serialized graph (GraphDef proto).
32   @nogc nothrow @trusted
33   void deserialize(const(void)[] proto)
34   {
35     auto buffer = TF_NewBufferFromString(proto.ptr, proto.length);
36     scope (exit) TF_DeleteBuffer(buffer);
37     auto opts = TF_NewImportGraphDefOptions;
38     TF_GraphImportGraphDef(this.ptr, buffer, opts, this.status);
39     assertStatus(this.status);
40   }
41 
42   /// Returns serialized bytes (GraphDef proto).
43   @nogc nothrow @system
44   TF_Buffer* serialize()
45   {
46     auto buffer = TF_NewBuffer;
47     TF_GraphToGraphDef(this.ptr, buffer, this.status);
48     assertStatus(this.status);
49     return buffer;
50   }
51 
52   // Reads serialized bytes (GraphDef proto) from file.
53   @nogc nothrow @trusted
54   void read(const(char)* fileName, size_t block = 1024)
55   {
56     import core.stdc.stdio : feof, fopen, fread;
57     import core.stdc.stdlib : free, realloc;
58     
59     size_t len;
60     void* buffer = realloc(null, block);
61     assert(buffer, "realloc failed");
62     scope (exit) free(buffer);
63 
64     void* ptr = realloc(null, block);
65     assert(ptr, "realloc failed");
66     scope (exit) free(ptr);
67 
68     auto fp = fopen(fileName, "rb");
69     while (!feof(fp))
70     {
71        auto inc = fread(ptr, 1, block, fp);
72        len += inc;
73        ptr = realloc(ptr, len + block);
74        assert(ptr, "realloc failed");
75     }
76     this.deserialize(ptr[0 .. len]);
77   }
78 
79   /// Writes serialized bytes (GraphDef proto) to a given file.
80   @nogc nothrow @trusted
81   void write(const(char)* fileName)
82   {
83     import core.stdc.stdio : fopen, fwrite;
84 
85     auto buffer = this.serialize();
86     scope (exit) TF_DeleteBuffer(buffer);
87     auto fp = fopen(fileName, "wb");
88     fwrite(buffer.data, 1, buffer.length, fp);
89   }
90 }
91 
92 
93 /// Shared GraphOwner type.
94 struct Graph
95 {
96   import tfd.session : Session;
97   import tfd.tensor : tfType;
98   import tfd.op : Operation;
99 
100   /// Base reference counted pointer.
101   SlimRCPtr!GraphOwner base;
102   alias base this;
103 
104   /// Get an operation by name
105   @nogc nothrow @trusted
106   bool hasOperationByName(const(char)* name)
107   {
108     auto opr = TF_GraphOperationByName(this.ptr, name);
109     return opr !is null;
110   }
111   
112   @nogc nothrow @trusted
113   Operation getOperationByName(const(char)* name)
114   {
115     auto opr = TF_GraphOperationByName(this.ptr, name);
116     assert(opr);
117     return Operation(TF_Output(opr, 0), this);
118   }
119 
120   /// Creates a placeholder in this graph.
121   @trusted
122   Operation placeholder(T, size_t N)(
123       const(char)* name,
124       long[N] dims...) scope return
125   {
126     TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Placeholder", name);
127     TF_SetAttrType(desc, "dtype", tfType!T);
128     static if (N != 0)
129     {
130       TF_SetAttrShape(desc, "shape", dims.ptr, dims.length);
131     }
132     TF_Operation* op = TF_FinishOperation(desc, this.status);
133     assertStatus(this.status);
134     assert(op);
135     return Operation(TF_Output(op, 0), this.base);
136   }
137 
138   /// ditto.
139   Operation placeholder(T, size_t N)(long[N] dims ...)
140   {
141     return placeholder!T("", dims);
142   }
143 
144   /// Creates a constant in this graph.
145   @trusted
146   Operation constant(S)(S x, const(char)* name = "const")
147   {
148     import tfd.tensor : makeTF_Tensor;
149 
150     // TODO(karita) free TF_Tensor when op is freed?
151     auto t = x.makeTF_Tensor;
152     TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Const", name);
153     TF_SetAttrTensor(desc, "value", t, this.status);
154     assertStatus(this.status);
155     TF_SetAttrType(desc, "dtype", TF_TensorType(t));
156     TF_Operation* op = TF_FinishOperation(desc, this.status);
157     assertStatus(this.status);
158     assert(op !is null);
159     return Operation(TF_Output(op, 0), this);
160   }
161 
162   /// Creates a Session in this graph.
163   @nogc nothrow @trusted
164   Session session()
165   {
166     return Session(this.ptr);
167   }
168 }
169 
170 /// Creates a new reference-counted Graph object.
171 @nogc nothrow @trusted
172 Graph newGraph()
173 {
174   import mir.rc.slim_ptr : createSlimRC;
175   return Graph(createSlimRC!GraphOwner(TF_NewGraph(), TF_NewStatus()));
176 }
177 
178 /// Export/import graph.
179 version (tfd_test)
180 nothrow
181 unittest
182 {
183   import tfd.tensor;
184 
185   TF_Buffer* buffer;
186   scope (exit) TF_DeleteBuffer(buffer);
187   {
188     auto graph = newGraph;
189     with (graph)
190     {
191       auto a = placeholder!int("a");
192       assert(TF_GraphOperationByName(graph, "a"));
193       auto b = constant(3, "b");
194       assert(TF_GraphOperationByName(graph, "b"));
195       // TODO(karita): provide name "add", identity?
196       auto add = a + b;
197       assert(TF_GraphOperationByName(graph, "add"));
198     }
199     buffer = graph.serialize;
200     // for coverage
201     graph.write("tmp.bin");
202   }
203   with (newGraph) {
204     // Import from the GraphDef (protobuf)
205     deserialize(buffer.data[0 .. buffer.length]);
206     auto a = getOperationByName("a");
207     auto add = getOperationByName("add");
208     const t = session.run([add], [a: 1.tensor])[0].tensor;
209     assert(t.scalar!int == 1 + 3);
210   }
211 }