1 /// TF_Graph wrapper. 2 module tfd.graph; 3 4 import std.string : fromStringz; 5 6 import mir.rc.slim_ptr : createSlimRC, SlimRCPtr; 7 8 import tfd.c_api; 9 import tfd.testing : assertStatus; 10 11 12 /// TF_Graph freed by dtor (RAII) with convinient methods. 13 struct GraphOwner 14 { 15 /// Raw pointer. 16 TF_Graph* ptr; 17 /// Status pointer. 18 TF_Status* status; 19 alias ptr this; 20 21 // Not copyable. 22 @disable this(this); 23 24 @nogc nothrow @trusted 25 ~this() 26 { 27 TF_DeleteGraph(this.ptr); 28 TF_DeleteStatus(this.status); 29 } 30 31 /// Loads serialized graph (GraphDef proto). 32 @nogc nothrow @trusted 33 void deserialize(const(void)[] proto) 34 { 35 auto buffer = TF_NewBufferFromString(proto.ptr, proto.length); 36 scope (exit) TF_DeleteBuffer(buffer); 37 auto opts = TF_NewImportGraphDefOptions; 38 TF_GraphImportGraphDef(this.ptr, buffer, opts, this.status); 39 assertStatus(this.status); 40 } 41 42 /// Returns serialized bytes (GraphDef proto). 43 @nogc nothrow @system 44 TF_Buffer* serialize() 45 { 46 auto buffer = TF_NewBuffer; 47 TF_GraphToGraphDef(this.ptr, buffer, this.status); 48 assertStatus(this.status); 49 return buffer; 50 } 51 52 // Reads serialized bytes (GraphDef proto) from file. 53 @nogc nothrow @trusted 54 void read(const(char)* fileName, size_t block = 1024) 55 { 56 import core.stdc.stdio : feof, fopen, fread; 57 import core.stdc.stdlib : free, realloc; 58 59 size_t len; 60 void* buffer = realloc(null, block); 61 assert(buffer, "realloc failed"); 62 scope (exit) free(buffer); 63 64 void* ptr = realloc(null, block); 65 assert(ptr, "realloc failed"); 66 scope (exit) free(ptr); 67 68 auto fp = fopen(fileName, "rb"); 69 while (!feof(fp)) 70 { 71 auto inc = fread(ptr, 1, block, fp); 72 len += inc; 73 ptr = realloc(ptr, len + block); 74 assert(ptr, "realloc failed"); 75 } 76 this.deserialize(ptr[0 .. len]); 77 } 78 79 /// Writes serialized bytes (GraphDef proto) to a given file. 80 @nogc nothrow @trusted 81 void write(const(char)* fileName) 82 { 83 import core.stdc.stdio : fopen, fwrite; 84 85 auto buffer = this.serialize(); 86 scope (exit) TF_DeleteBuffer(buffer); 87 auto fp = fopen(fileName, "wb"); 88 fwrite(buffer.data, 1, buffer.length, fp); 89 } 90 } 91 92 /// TF_Operation wrapper used in Graph. 93 struct Operation 94 { 95 /// Raw pointer. 96 TF_Operation* base; 97 /// Graph scope containing this operation. 98 SlimRCPtr!GraphOwner graph; 99 alias base this; 100 101 /// Binary operator for +. 102 @trusted Operation opBinary(string op : "+")(Operation rhs) 103 { 104 assert(this.graph == rhs.graph); 105 scope (exit) assertStatus(this.graph.status); 106 107 TF_OperationDescription* desc = TF_NewOperation(this.graph, "AddN", "add"); 108 TF_Output[2] inputs; 109 inputs[0] = TF_Output(this.base, 0); 110 inputs[1] = TF_Output(rhs.base, 0); 111 TF_AddInputList(desc, inputs.ptr, 2); 112 TF_Operation* op = TF_FinishOperation(desc, this.graph.status); 113 assertStatus(this.graph.status); 114 assert(op !is null); 115 return Operation(op, this.graph); 116 } 117 118 } 119 120 /// Shared GraphOwner type. 121 struct Graph 122 { 123 import tfd.session : Session; 124 import tfd.tensor : tfType; 125 126 /// Base reference counted pointer. 127 SlimRCPtr!GraphOwner base; 128 alias base this; 129 /// Get an operation by name 130 131 @nogc nothrow @trusted 132 bool hasOperationByName(const(char)* name) 133 { 134 auto opr = TF_GraphOperationByName(this.ptr, name); 135 return opr !is null; 136 } 137 138 @nogc nothrow @trusted 139 Operation getOperationByName(const(char)* name) 140 { 141 auto opr = TF_GraphOperationByName(this.ptr, name); 142 assert(opr); 143 return Operation(opr, this); 144 } 145 146 /// Creates a placeholder in this graph. 147 @trusted 148 Operation placeholder(T, size_t N)( 149 const(char)* name, 150 long[N] dims...) scope return 151 { 152 TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Placeholder", name); 153 TF_SetAttrType(desc, "dtype", tfType!T); 154 static if (N != 0) 155 { 156 TF_SetAttrShape(desc, "shape", dims.ptr, dims.length); 157 } 158 TF_Operation* op = TF_FinishOperation(desc, this.status); 159 assertStatus(this.status); 160 assert(op); 161 return Operation(op, this.base); 162 } 163 164 /// ditto. 165 Operation placeholder(T, size_t N)(long[N] dims ...) 166 { 167 return placeholder!T("", dims); 168 } 169 170 /// Creates a constant in this graph. 171 @trusted 172 Operation constant(S)(S x, const(char)* name = "const") 173 { 174 import tfd.tensor : makeTF_Tensor; 175 176 // TODO(karita) free TF_Tensor when op is freed? 177 auto t = x.makeTF_Tensor; 178 TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Const", name); 179 TF_SetAttrTensor(desc, "value", t, this.status); 180 assertStatus(this.status); 181 TF_SetAttrType(desc, "dtype", TF_TensorType(t)); 182 TF_Operation* op = TF_FinishOperation(desc, this.status); 183 assertStatus(this.status); 184 assert(op !is null); 185 return Operation(op, this); 186 } 187 188 /// Creates a Session in this graph. 189 @nogc nothrow @trusted 190 Session session() 191 { 192 return Session(this.ptr); 193 } 194 } 195 196 /// Creates a new reference-counted Graph object. 197 @nogc nothrow @trusted 198 Graph newGraph() 199 { 200 import mir.rc.slim_ptr : createSlimRC; 201 return Graph(createSlimRC!GraphOwner(TF_NewGraph(), TF_NewStatus())); 202 } 203 204 /// Export/import graph. 205 version (tfd_test) 206 nothrow 207 unittest 208 { 209 import tfd.tensor; 210 211 TF_Buffer* buffer; 212 scope (exit) TF_DeleteBuffer(buffer); 213 { 214 auto graph = newGraph; 215 with (graph) 216 { 217 auto a = placeholder!int("a"); 218 assert(TF_GraphOperationByName(graph, "a")); 219 auto b = constant(3, "b"); 220 assert(TF_GraphOperationByName(graph, "b")); 221 // TODO(karita): provide name "add", identity? 222 auto add = a + b; 223 assert(TF_GraphOperationByName(graph, "add")); 224 } 225 buffer = graph.serialize; 226 // for coverage 227 graph.write("tmp.bin"); 228 } 229 with (newGraph) { 230 // Import from the GraphDef (protobuf) 231 deserialize(buffer.data[0 .. buffer.length]); 232 auto a = getOperationByName("a"); 233 auto add = getOperationByName("add"); 234 const t = session.run([add], [a: 1.tensor])[0].tensor; 235 assert(t.scalar!int == 1 + 3); 236 } 237 }