1 /// TF Session module.
2 module tfd.session;
3 
4 import tfd.c_api;
5 import tfd.graph : Operation;
6 import tfd.tensor : Tensor;
7 import tfd.testing : assertStatus;
8 
9 
10 /// Wrapper class for TF_Session.
11 struct Session
12 {
13   import tfd.tensor : Tensor, TensorOwner;
14   import tfd.graph : Operation;
15 
16   /// Raw session data.
17   TF_Session* base;
18   alias base this;
19 
20   /// Status
21   TF_Status* status;
22 
23   /// Not copyable
24   @disable this(this);
25 
26   /// Constructs a new session.
27   @nogc nothrow @trusted
28   this(scope TF_Graph* graph, bool useXLA = false)
29   {
30     // TODO(karita): support XLA
31     assert(!useXLA, "XLA is not supported yet.");
32     this.status = TF_NewStatus();
33     TF_SessionOptions* opts = TF_NewSessionOptions();
34     scope (exit) TF_DeleteSessionOptions(opts);
35     // TF_EnableXLACompilation(opts, useXLA);
36     this.base = TF_NewSession(graph, opts, this.status);
37     assertStatus(this.status);
38   }
39 
40    @nogc nothrow @trusted
41   ~this()
42   {
43     this.close();
44     TF_DeleteStatus(this.status);
45   }
46 
47   /// Closes and deletes input/output values explicitly.
48   @nogc nothrow @trusted
49   void close()
50   {
51     if (base !is null)
52     {
53       TF_CloseSession(base, this.status);
54       assertStatus(this.status);
55       TF_DeleteSession(base, this.status);
56       assertStatus(this.status);
57       base = null;
58     }
59   }
60 
61   /// Runs session to evaluate outputs by given inputs.
62   @nogc nothrow @trusted
63   void run(Operation[] inputs, Tensor[] inputValues,
64            Operation[] outputs, Tensor[] outputValues,
65            Operation[] targets = [])
66   {
67     import std.container.array : Array;
68     import std.range : empty;
69 
70     import mir.rc.slim_ptr : createSlimRC;
71 
72     assert(inputs.length == inputValues.length);
73     assert(outputs.length == outputValues.length);
74 
75     Array!TF_Output baseInputs;
76     baseInputs.reserve(inputs.length);
77     foreach (x; inputs)
78     {
79       baseInputs ~= TF_Output(x.base);
80     }
81 
82     Array!TF_Output baseOutputs;
83     baseInputs.reserve(outputs.length);
84     foreach (x; outputs)
85     {
86       baseOutputs ~= TF_Output(x.base);
87     }
88 
89     Array!(TF_Tensor*) baseInputValues;
90     baseInputValues.reserve(inputValues.length);
91     foreach (x; inputValues)
92     {
93       baseInputValues ~= x.base;
94     }
95 
96     Array!(TF_Tensor*) baseOutputValues;
97     baseOutputValues.length = outputValues.length;
98 
99     Array!(TF_Operation*) baseTargets;
100     baseInputs.reserve(targets.length);
101     foreach (x; targets)
102     {
103       baseTargets ~= x.base;
104     }
105 
106     TF_SessionRun(
107         this.base, null,
108         inputs.empty ? null : &baseInputs[0], &baseInputValues[0], cast(int) inputs.length,
109         outputs.empty ? null : &baseOutputs[0], &baseOutputValues[0], cast(int) outputs.length,
110         targets.empty ? null : &baseTargets[0], cast(int) targets.length,
111         null, this.status);
112     assertStatus(this.status);
113 
114     foreach (i; 0 .. outputs.length)
115     {
116       outputValues[i] = createSlimRC!TensorOwner(baseOutputValues[i]);
117     }
118   }
119 
120   /// Runs in python-like usage.
121   nothrow @trusted
122   Tensor[N] run(size_t N)(Operation[N] outputs, Tensor[Operation] inputs)
123   {
124     Tensor[N] ret;
125     this.run(inputs.keys, inputs.values, outputs[], ret[]);
126     return ret;
127   }
128 
129 }
130 
131 
132 /// nothrow, nogc, and safe usage
133 version (tfd_test)
134 @nogc nothrow @safe
135 unittest
136 {
137   import std.typecons : tuple;
138   import tfd.tensor : tensor, Tensor;
139   import tfd.graph : newGraph, Operation;
140 
141   with (newGraph)
142   {
143     Operation x = placeholder!int("x");
144     Operation two = constant(2);
145     Operation add = x + two;
146 
147     Operation[1] inops;
148     inops[0] = x;
149     Tensor[1] inputs;
150     inputs[0] = 3.tensor;
151     Operation[1] outops;
152     outops[0] = add;
153     Tensor[1] outputs;
154     session.run(inops, inputs, outops, outputs);
155     assert(outputs[0].scalar!int == 5);
156 
157     write("tmp.pb");
158   }
159   with (newGraph)
160   {
161     read("tmp.pb");
162     // auto x = operationByName("x");
163     // auto add = operationByName("add");
164   }
165 }
166 
167 /// TODO(karita): more interesting example. e.g., logistic regression.
168 version (tfd_test)
169 unittest
170 {
171   import tfd;
172 
173   /// scalar add
174   with (newGraph)
175   {
176     Operation x = placeholder!int("x");
177     Operation two = constant(2);
178     Operation add = x + two;
179 
180     Tensor addVal = session.run([add], [x: 3.tensor])[0];
181     assert(addVal.scalar!int == 5);
182   }
183 
184   /// tensor add
185   with (newGraph)
186   {
187     import mir.ndslice : as, iota;
188 
189     auto i = iota(2, 3, 4).as!float;
190 
191     Operation x = placeholder!float("x", 2, 3, 4);
192     Operation two = constant(i);
193     Operation add = x + two;
194 
195     Tensor addVal = session.run([add], [x: i.tensor])[0];
196     assert(addVal.sliced!(float, 3) == i * 2);
197   }
198 }