1 /// TF Session module. 2 module tfd.session; 3 4 import tfd.c_api; 5 import tfd.tensor : Tensor; 6 import tfd.testing : assertStatus; 7 8 9 /// Wrapper class for TF_Session. 10 struct Session 11 { 12 import tfd.tensor : Tensor, TensorOwner; 13 import tfd.op : Operation; 14 15 /// Raw session data. 16 TF_Session* base; 17 alias base this; 18 19 /// Status 20 TF_Status* status; 21 22 /// Not copyable 23 @disable this(this); 24 25 /// Constructs a new session. 26 @nogc nothrow @trusted 27 this(scope TF_Graph* graph, bool useXLA = false) 28 { 29 // TODO(karita): support XLA 30 assert(!useXLA, "XLA is not supported yet."); 31 this.status = TF_NewStatus(); 32 TF_SessionOptions* opts = TF_NewSessionOptions(); 33 scope (exit) TF_DeleteSessionOptions(opts); 34 // TF_EnableXLACompilation(opts, useXLA); 35 this.base = TF_NewSession(graph, opts, this.status); 36 assertStatus(this.status); 37 } 38 39 @nogc nothrow @trusted 40 ~this() 41 { 42 this.close(); 43 TF_DeleteStatus(this.status); 44 } 45 46 /// Closes and deletes input/output values explicitly. 47 @nogc nothrow @trusted 48 void close() 49 { 50 if (base !is null) 51 { 52 TF_CloseSession(base, this.status); 53 assertStatus(this.status); 54 TF_DeleteSession(base, this.status); 55 assertStatus(this.status); 56 base = null; 57 } 58 } 59 60 /// Runs session to evaluate outputs by given inputs. 61 @nogc nothrow @trusted 62 void run(Operation[] inputs, Tensor[] inputValues, 63 Operation[] outputs, Tensor[] outputValues, 64 Operation[] targets = []) 65 { 66 import std.container.array : Array; 67 import std.range : empty; 68 69 import mir.rc.slim_ptr : createSlimRC; 70 71 assert(inputs.length == inputValues.length); 72 assert(outputs.length == outputValues.length); 73 74 Array!TF_Output baseInputs; 75 baseInputs.reserve(inputs.length); 76 foreach (x; inputs) 77 { 78 baseInputs ~= x.base; 79 } 80 81 Array!TF_Output baseOutputs; 82 baseInputs.reserve(outputs.length); 83 foreach (x; outputs) 84 { 85 baseOutputs ~= x.base; 86 } 87 88 Array!(TF_Tensor*) baseInputValues; 89 baseInputValues.reserve(inputValues.length); 90 foreach (x; inputValues) 91 { 92 baseInputValues ~= x.base; 93 } 94 95 Array!(TF_Tensor*) baseOutputValues; 96 baseOutputValues.length = outputValues.length; 97 98 Array!(TF_Operation*) baseTargets; 99 baseInputs.reserve(targets.length); 100 foreach (x; targets) 101 { 102 baseTargets ~= x.base.oper; 103 } 104 105 TF_SessionRun( 106 this.base, null, 107 inputs.empty ? null : &baseInputs[0], &baseInputValues[0], cast(int) inputs.length, 108 outputs.empty ? null : &baseOutputs[0], &baseOutputValues[0], cast(int) outputs.length, 109 targets.empty ? null : &baseTargets[0], cast(int) targets.length, 110 null, this.status); 111 assertStatus(this.status); 112 113 foreach (i; 0 .. outputs.length) 114 { 115 outputValues[i] = createSlimRC!TensorOwner(baseOutputValues[i]); 116 } 117 } 118 119 /// Runs in python-like usage. 120 nothrow @trusted 121 Tensor[N] run(size_t N)(Operation[N] outputs, Tensor[Operation] inputs) 122 { 123 Tensor[N] ret; 124 this.run(inputs.keys, inputs.values, outputs[], ret[]); 125 return ret; 126 } 127 128 } 129 130 131 /// nothrow, nogc, and safe usage 132 version (tfd_test) 133 @nogc nothrow @safe 134 unittest 135 { 136 import std.typecons : tuple; 137 import tfd.tensor : tensor, Tensor; 138 import tfd.graph : newGraph; 139 import tfd.op : Operation; 140 141 with (newGraph) 142 { 143 Operation x = placeholder!int("x"); 144 Operation two = constant(2); 145 Operation add = x + two; 146 147 Operation[1] inops; 148 inops[0] = x; 149 Tensor[1] inputs; 150 inputs[0] = 3.tensor; 151 Operation[1] outops; 152 outops[0] = add; 153 Tensor[1] outputs; 154 session.run(inops, inputs, outops, outputs); 155 assert(outputs[0].scalar!int == 5); 156 157 write("tmp.pb"); 158 } 159 with (newGraph) 160 { 161 read("tmp.pb"); 162 // auto x = operationByName("x"); 163 // auto add = operationByName("add"); 164 } 165 } 166 167 /// TODO(karita): more interesting example. e.g., logistic regression. 168 version (tfd_test) 169 unittest 170 { 171 import tfd; 172 173 /// scalar add 174 with (newGraph) 175 { 176 Operation x = placeholder!int("x"); 177 Operation two = constant(2); 178 Operation add = x + two; 179 180 Tensor addVal = session.run([add], [x: 3.tensor])[0]; 181 assert(addVal.scalar!int == 5); 182 } 183 184 /// tensor add 185 with (newGraph) 186 { 187 import mir.ndslice : as, iota; 188 189 auto i = iota(2, 3, 4).as!float; 190 191 Operation x = placeholder!float("x", 2, 3, 4); 192 Operation two = constant(i); 193 Operation add = x + two; 194 195 Tensor addVal = session.run([add], [x: i.tensor])[0]; 196 assert(addVal.sliced!(float, 3) == i * 2); 197 } 198 }