1 /// TF Session module. 2 module tfd.session; 3 4 import tfd.c_api; 5 import tfd.graph : Operation; 6 import tfd.tensor : Tensor; 7 import tfd.testing : assertStatus; 8 9 10 /// Wrapper class for TF_Session. 11 struct Session 12 { 13 import tfd.tensor : Tensor, TensorOwner; 14 import tfd.graph : Operation; 15 16 /// Raw session data. 17 TF_Session* base; 18 alias base this; 19 20 /// Status 21 TF_Status* status; 22 23 /// Not copyable 24 @disable this(this); 25 26 /// Constructs a new session. 27 @nogc nothrow @trusted 28 this(scope TF_Graph* graph, bool useXLA = false) 29 { 30 // TODO(karita): support XLA 31 assert(!useXLA, "XLA is not supported yet."); 32 this.status = TF_NewStatus(); 33 TF_SessionOptions* opts = TF_NewSessionOptions(); 34 scope (exit) TF_DeleteSessionOptions(opts); 35 // TF_EnableXLACompilation(opts, useXLA); 36 this.base = TF_NewSession(graph, opts, this.status); 37 assertStatus(this.status); 38 } 39 40 @nogc nothrow @trusted 41 ~this() 42 { 43 this.close(); 44 TF_DeleteStatus(this.status); 45 } 46 47 /// Closes and deletes input/output values explicitly. 48 @nogc nothrow @trusted 49 void close() 50 { 51 if (base !is null) 52 { 53 TF_CloseSession(base, this.status); 54 assertStatus(this.status); 55 TF_DeleteSession(base, this.status); 56 assertStatus(this.status); 57 base = null; 58 } 59 } 60 61 /// Runs session to evaluate outputs by given inputs. 62 @nogc nothrow @trusted 63 void run(Operation[] inputs, Tensor[] inputValues, 64 Operation[] outputs, Tensor[] outputValues, 65 Operation[] targets = []) 66 { 67 import std.container.array : Array; 68 import std.range : empty; 69 70 import mir.rc.slim_ptr : createSlimRC; 71 72 assert(inputs.length == inputValues.length); 73 assert(outputs.length == outputValues.length); 74 75 Array!TF_Output baseInputs; 76 baseInputs.reserve(inputs.length); 77 foreach (x; inputs) 78 { 79 baseInputs ~= TF_Output(x.base); 80 } 81 82 Array!TF_Output baseOutputs; 83 baseInputs.reserve(outputs.length); 84 foreach (x; outputs) 85 { 86 baseOutputs ~= TF_Output(x.base); 87 } 88 89 Array!(TF_Tensor*) baseInputValues; 90 baseInputValues.reserve(inputValues.length); 91 foreach (x; inputValues) 92 { 93 baseInputValues ~= x.base; 94 } 95 96 Array!(TF_Tensor*) baseOutputValues; 97 baseOutputValues.length = outputValues.length; 98 99 Array!(TF_Operation*) baseTargets; 100 baseInputs.reserve(targets.length); 101 foreach (x; targets) 102 { 103 baseTargets ~= x.base; 104 } 105 106 TF_SessionRun( 107 this.base, null, 108 inputs.empty ? null : &baseInputs[0], &baseInputValues[0], cast(int) inputs.length, 109 outputs.empty ? null : &baseOutputs[0], &baseOutputValues[0], cast(int) outputs.length, 110 targets.empty ? null : &baseTargets[0], cast(int) targets.length, 111 null, this.status); 112 assertStatus(this.status); 113 114 foreach (i; 0 .. outputs.length) 115 { 116 outputValues[i] = createSlimRC!TensorOwner(baseOutputValues[i]); 117 } 118 } 119 120 /// Runs in python-like usage. 121 nothrow @trusted 122 Tensor[N] run(size_t N)(Operation[N] outputs, Tensor[Operation] inputs) 123 { 124 Tensor[N] ret; 125 this.run(inputs.keys, inputs.values, outputs[], ret[]); 126 return ret; 127 } 128 129 } 130 131 132 /// nothrow, nogc, and safe usage 133 version (tfd_test) 134 @nogc nothrow @safe 135 unittest 136 { 137 import std.typecons : tuple; 138 import tfd.tensor : tensor, Tensor; 139 import tfd.graph : newGraph, Operation; 140 141 with (newGraph) 142 { 143 Operation x = placeholder!int("x"); 144 Operation two = constant(2); 145 Operation add = x + two; 146 147 Operation[1] inops; 148 inops[0] = x; 149 Tensor[1] inputs; 150 inputs[0] = 3.tensor; 151 Operation[1] outops; 152 outops[0] = add; 153 Tensor[1] outputs; 154 session.run(inops, inputs, outops, outputs); 155 assert(outputs[0].scalar!int == 5); 156 157 write("tmp.pb"); 158 } 159 with (newGraph) 160 { 161 read("tmp.pb"); 162 // auto x = operationByName("x"); 163 // auto add = operationByName("add"); 164 } 165 } 166 167 /// TODO(karita): more interesting example. e.g., logistic regression. 168 version (tfd_test) 169 unittest 170 { 171 import tfd; 172 173 /// scalar add 174 with (newGraph) 175 { 176 Operation x = placeholder!int("x"); 177 Operation two = constant(2); 178 Operation add = x + two; 179 180 Tensor addVal = session.run([add], [x: 3.tensor])[0]; 181 assert(addVal.scalar!int == 5); 182 } 183 184 /// tensor add 185 with (newGraph) 186 { 187 import mir.ndslice : as, iota; 188 189 auto i = iota(2, 3, 4).as!float; 190 191 Operation x = placeholder!float("x", 2, 3, 4); 192 Operation two = constant(i); 193 Operation add = x + two; 194 195 Tensor addVal = session.run([add], [x: i.tensor])[0]; 196 assert(addVal.sliced!(float, 3) == i * 2); 197 } 198 }