1 /// TF Session module.
2 module tfd.session;
3 
4 import tfd.c_api;
5 import tfd.tensor : Tensor;
6 import tfd.testing : assertStatus;
7 
8 
9 /// Wrapper class for TF_Session.
10 struct Session
11 {
12   import tfd.tensor : Tensor, TensorOwner;
13   import tfd.op : Operation;
14 
15   /// Raw session data.
16   TF_Session* base;
17   alias base this;
18 
19   /// Status
20   TF_Status* status;
21 
22   /// Not copyable
23   @disable this(this);
24 
25   /// Constructs a new session.
26   @nogc nothrow @trusted
27   this(scope TF_Graph* graph, bool useXLA = false)
28   {
29     // TODO(karita): support XLA
30     assert(!useXLA, "XLA is not supported yet.");
31     this.status = TF_NewStatus();
32     TF_SessionOptions* opts = TF_NewSessionOptions();
33     scope (exit) TF_DeleteSessionOptions(opts);
34     // TF_EnableXLACompilation(opts, useXLA);
35     this.base = TF_NewSession(graph, opts, this.status);
36     assertStatus(this.status);
37   }
38 
39    @nogc nothrow @trusted
40   ~this()
41   {
42     this.close();
43     TF_DeleteStatus(this.status);
44   }
45 
46   /// Closes and deletes input/output values explicitly.
47   @nogc nothrow @trusted
48   void close()
49   {
50     if (base !is null)
51     {
52       TF_CloseSession(base, this.status);
53       assertStatus(this.status);
54       TF_DeleteSession(base, this.status);
55       assertStatus(this.status);
56       base = null;
57     }
58   }
59 
60   /// Runs session to evaluate outputs by given inputs.
61   @nogc nothrow @trusted
62   void run(Operation[] inputs, Tensor[] inputValues,
63            Operation[] outputs, Tensor[] outputValues,
64            Operation[] targets = [])
65   {
66     import std.container.array : Array;
67     import std.range : empty;
68 
69     import mir.rc.slim_ptr : createSlimRC;
70 
71     assert(inputs.length == inputValues.length);
72     assert(outputs.length == outputValues.length);
73 
74     Array!TF_Output baseInputs;
75     baseInputs.reserve(inputs.length);
76     foreach (x; inputs)
77     {
78       baseInputs ~= x.base;
79     }
80 
81     Array!TF_Output baseOutputs;
82     baseInputs.reserve(outputs.length);
83     foreach (x; outputs)
84     {
85       baseOutputs ~= x.base;
86     }
87 
88     Array!(TF_Tensor*) baseInputValues;
89     baseInputValues.reserve(inputValues.length);
90     foreach (x; inputValues)
91     {
92       baseInputValues ~= x.base;
93     }
94 
95     Array!(TF_Tensor*) baseOutputValues;
96     baseOutputValues.length = outputValues.length;
97 
98     Array!(TF_Operation*) baseTargets;
99     baseInputs.reserve(targets.length);
100     foreach (x; targets)
101     {
102       baseTargets ~= x.base.oper;
103     }
104 
105     TF_SessionRun(
106         this.base, null,
107         inputs.empty ? null : &baseInputs[0], &baseInputValues[0], cast(int) inputs.length,
108         outputs.empty ? null : &baseOutputs[0], &baseOutputValues[0], cast(int) outputs.length,
109         targets.empty ? null : &baseTargets[0], cast(int) targets.length,
110         null, this.status);
111     assertStatus(this.status);
112 
113     foreach (i; 0 .. outputs.length)
114     {
115       outputValues[i] = createSlimRC!TensorOwner(baseOutputValues[i]);
116     }
117   }
118 
119   /// Runs in python-like usage.
120   nothrow @trusted
121   Tensor[N] run(size_t N)(Operation[N] outputs, Tensor[Operation] inputs)
122   {
123     Tensor[N] ret;
124     this.run(inputs.keys, inputs.values, outputs[], ret[]);
125     return ret;
126   }
127 
128 }
129 
130 
131 /// nothrow, nogc, and safe usage
132 version (tfd_test)
133 @nogc nothrow @safe
134 unittest
135 {
136   import std.typecons : tuple;
137   import tfd.tensor : tensor, Tensor;
138   import tfd.graph : newGraph;
139   import tfd.op : Operation;
140 
141   with (newGraph)
142   {
143     Operation x = placeholder!int("x");
144     Operation two = constant(2);
145     Operation add = x + two;
146 
147     Operation[1] inops;
148     inops[0] = x;
149     Tensor[1] inputs;
150     inputs[0] = 3.tensor;
151     Operation[1] outops;
152     outops[0] = add;
153     Tensor[1] outputs;
154     session.run(inops, inputs, outops, outputs);
155     assert(outputs[0].scalar!int == 5);
156 
157     write("tmp.pb");
158   }
159   with (newGraph)
160   {
161     read("tmp.pb");
162     // auto x = operationByName("x");
163     // auto add = operationByName("add");
164   }
165 }
166 
167 /// TODO(karita): more interesting example. e.g., logistic regression.
168 version (tfd_test)
169 unittest
170 {
171   import tfd;
172 
173   /// scalar add
174   with (newGraph)
175   {
176     Operation x = placeholder!int("x");
177     Operation two = constant(2);
178     Operation add = x + two;
179 
180     Tensor addVal = session.run([add], [x: 3.tensor])[0];
181     assert(addVal.scalar!int == 5);
182   }
183 
184   /// tensor add
185   with (newGraph)
186   {
187     import mir.ndslice : as, iota;
188 
189     auto i = iota(2, 3, 4).as!float;
190 
191     Operation x = placeholder!float("x", 2, 3, 4);
192     Operation two = constant(i);
193     Operation add = x + two;
194 
195     Tensor addVal = session.run([add], [x: i.tensor])[0];
196     assert(addVal.sliced!(float, 3) == i * 2);
197   }
198 }