1 /// TF_Graph wrapper. 2 module tfd.graph; 3 4 import std..string : fromStringz; 5 6 import mir.rc.slim_ptr : createSlimRC, SlimRCPtr; 7 8 import tfd.c_api; 9 import tfd.testing : assertStatus; 10 11 12 /// TF_Graph freed by dtor (RAII) with convinient methods. 13 struct GraphOwner 14 { 15 /// Raw pointer. 16 TF_Graph* ptr; 17 /// Status pointer. 18 TF_Status* status; 19 alias ptr this; 20 21 // Not copyable. 22 @disable this(this); 23 24 @nogc nothrow @trusted 25 ~this() 26 { 27 TF_DeleteGraph(this.ptr); 28 TF_DeleteStatus(this.status); 29 } 30 31 /// Loads serialized graph (GraphDef proto). 32 @nogc nothrow @trusted 33 void deserialize(const(void)[] proto) 34 { 35 auto buffer = TF_NewBufferFromString(proto.ptr, proto.length); 36 scope (exit) TF_DeleteBuffer(buffer); 37 auto opts = TF_NewImportGraphDefOptions; 38 TF_GraphImportGraphDef(this.ptr, buffer, opts, this.status); 39 assertStatus(this.status); 40 } 41 42 /// Returns serialized bytes (GraphDef proto). 43 @nogc nothrow @system 44 TF_Buffer* serialize() 45 { 46 auto buffer = TF_NewBuffer; 47 TF_GraphToGraphDef(this.ptr, buffer, this.status); 48 assertStatus(this.status); 49 return buffer; 50 } 51 52 // Reads serialized bytes (GraphDef proto) from file. 53 @nogc nothrow @trusted 54 void read(const(char)* fileName, size_t block = 1024) 55 { 56 import core.stdc.stdio : feof, fopen, fread; 57 import core.stdc.stdlib : free, realloc; 58 59 size_t len; 60 void* buffer = realloc(null, block); 61 assert(buffer, "realloc failed"); 62 scope (exit) free(buffer); 63 64 void* ptr = realloc(null, block); 65 assert(ptr, "realloc failed"); 66 scope (exit) free(ptr); 67 68 auto fp = fopen(fileName, "rb"); 69 while (!feof(fp)) 70 { 71 auto inc = fread(ptr, 1, block, fp); 72 len += inc; 73 ptr = realloc(ptr, len + block); 74 assert(ptr, "realloc failed"); 75 } 76 this.deserialize(ptr[0 .. len]); 77 } 78 79 /// Writes serialized bytes (GraphDef proto) to a given file. 80 @nogc nothrow @trusted 81 void write(const(char)* fileName) 82 { 83 import core.stdc.stdio : fopen, fwrite; 84 85 auto buffer = this.serialize(); 86 scope (exit) TF_DeleteBuffer(buffer); 87 auto fp = fopen(fileName, "wb"); 88 fwrite(buffer.data, 1, buffer.length, fp); 89 } 90 } 91 92 93 /// Shared GraphOwner type. 94 struct Graph 95 { 96 import tfd.session : Session; 97 import tfd.tensor : tfType; 98 import tfd.op : Operation; 99 100 /// Base reference counted pointer. 101 SlimRCPtr!GraphOwner base; 102 alias base this; 103 104 /// Get an operation by name 105 @nogc nothrow @trusted 106 bool hasOperationByName(const(char)* name) 107 { 108 auto opr = TF_GraphOperationByName(this.ptr, name); 109 return opr !is null; 110 } 111 112 @nogc nothrow @trusted 113 Operation getOperationByName(const(char)* name) 114 { 115 auto opr = TF_GraphOperationByName(this.ptr, name); 116 assert(opr); 117 return Operation(TF_Output(opr, 0), this); 118 } 119 120 /// Creates a placeholder in this graph. 121 @trusted 122 Operation placeholder(T, size_t N)( 123 const(char)* name, 124 long[N] dims...) scope return 125 { 126 TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Placeholder", name); 127 TF_SetAttrType(desc, "dtype", tfType!T); 128 static if (N != 0) 129 { 130 TF_SetAttrShape(desc, "shape", dims.ptr, dims.length); 131 } 132 TF_Operation* op = TF_FinishOperation(desc, this.status); 133 assertStatus(this.status); 134 assert(op); 135 return Operation(TF_Output(op, 0), this.base); 136 } 137 138 /// ditto. 139 Operation placeholder(T, size_t N)(long[N] dims ...) 140 { 141 return placeholder!T("", dims); 142 } 143 144 /// Creates a constant in this graph. 145 @trusted 146 Operation constant(S)(S x, const(char)* name = "const") 147 { 148 import tfd.tensor : makeTF_Tensor; 149 150 // TODO(karita) free TF_Tensor when op is freed? 151 auto t = x.makeTF_Tensor; 152 TF_OperationDescription* desc = TF_NewOperation(this.ptr, "Const", name); 153 TF_SetAttrTensor(desc, "value", t, this.status); 154 assertStatus(this.status); 155 TF_SetAttrType(desc, "dtype", TF_TensorType(t)); 156 TF_Operation* op = TF_FinishOperation(desc, this.status); 157 assertStatus(this.status); 158 assert(op !is null); 159 return Operation(TF_Output(op, 0), this); 160 } 161 162 /// Creates a Session in this graph. 163 @nogc nothrow @trusted 164 Session session() 165 { 166 return Session(this.ptr); 167 } 168 } 169 170 /// Creates a new reference-counted Graph object. 171 @nogc nothrow @trusted 172 Graph newGraph() 173 { 174 import mir.rc.slim_ptr : createSlimRC; 175 return Graph(createSlimRC!GraphOwner(TF_NewGraph(), TF_NewStatus())); 176 } 177 178 /// Export/import graph. 179 version (tfd_test) 180 nothrow 181 unittest 182 { 183 import tfd.tensor; 184 185 TF_Buffer* buffer; 186 scope (exit) TF_DeleteBuffer(buffer); 187 { 188 auto graph = newGraph; 189 with (graph) 190 { 191 auto a = placeholder!int("a"); 192 assert(TF_GraphOperationByName(graph, "a")); 193 auto b = constant(3, "b"); 194 assert(TF_GraphOperationByName(graph, "b")); 195 // TODO(karita): provide name "add", identity? 196 auto add = a + b; 197 assert(TF_GraphOperationByName(graph, "add")); 198 } 199 buffer = graph.serialize; 200 // for coverage 201 graph.write("tmp.bin"); 202 } 203 with (newGraph) { 204 // Import from the GraphDef (protobuf) 205 deserialize(buffer.data[0 .. buffer.length]); 206 auto a = getOperationByName("a"); 207 auto add = getOperationByName("add"); 208 const t = session.run([add], [a: 1.tensor])[0].tensor; 209 assert(t.scalar!int == 1 + 3); 210 } 211 }