1 // -*- c-basic-style: google, c-basic-offset: 2 -*-
2 
3 /** D wrapper for Protobuf-C + dpp created objects.
4     https://github.com/protobuf-c/protobuf-c
5     https://github.com/atilaneves/dpp
6 */
7 module tfd.protobuf;
8 
9 
10 struct ProtobufCMessageDescriptor;
11 
12 struct ProtobufCMessageUnknownField;
13 
14 struct ProtobufCMessage
15 {
16   const(ProtobufCMessageDescriptor)* descriptor;
17   uint n_unknown_fields;
18   ProtobufCMessageUnknownField* unknown_fields;
19 }
20 
21 bool isMessage(T)() {
22   static if (__traits(compiles, T.init.base))
23     // name-based comparison to avoid module name mangling
24     return typeof(T.init.base).stringof == "ProtobufCMessage";
25   else
26     return false;
27 }
28 
29 mixin template OpDispatchMixin(alias base)
30 {
31   auto opDispatch(string name)()
32   {
33     import std.string : fromStringz;
34     import std.traits : hasMember, isPointer, PointerTarget;
35 
36     auto child = __traits(getMember, base, name);
37     alias Child = typeof(child);
38     alias Base = typeof(base);
39 
40     // nested message
41     static if (isPointer!Child && isMessage!Child)
42     {
43       alias Ret = Message!(PointerTarget!Child);
44       return child ? Ret(*child) : Ret();
45     }
46     // array
47     else static if (isPointer!Child && hasMember!(Base, "n_" ~ name))
48     {
49       auto n = __traits(getMember, base, "n_" ~ name);
50       alias Elem = typeof(*Child.init);
51 
52       // nested message array
53       static if (isPointer!Elem && isMessage!Elem)
54       {
55         alias Ret = Message!(PointerTarget!Elem);
56         auto ret = new Ret[n];
57         foreach (i, c; child[0 .. n])
58         {
59           if (c)
60           {
61             ret[i] =  Ret(*c);
62           }
63         }
64         return ret;
65       }
66       // non-nested message array
67       else
68       {
69         return child[0 .. n];
70       }
71     }
72     // string
73     else static if (__traits(compiles, child.fromStringz))
74     {
75       return child.fromStringz;
76     }
77     else
78     {
79       return child;
80     }
81   }
82 }
83 
84 struct Message(Base)
85 {
86   Base base;
87   alias base this;
88 
89   mixin OpDispatchMixin!base;
90 
91   string toString()
92   {
93       import tensorflow.op_def_pb : Tensorflow__TensorProto, Tensorflow__AttrValue;
94     static if (is(Base == Tensorflow__TensorProto) || is(Base == Tensorflow__AttrValue)) {
95       return Base.stringof;
96     }
97     else return toStringImpl!(Base)(this.base);
98   }
99 }
100 
101 string toStringImpl(Base)(Base base) if (isMessage!Base)
102 {
103   import std.conv : text;
104   import std.string : startsWith;
105   import std.traits : isPointer, isSomeString, PointerTarget;
106 
107   auto ret = Base.stringof ~ " {";
108   static foreach (name; __traits(allMembers, Base))
109   {
110     // static foreach reuse the scope
111     {
112       // alias BaseChild = typeof(__traits(getMember, Base.init, name));
113       static if (
114           name == "base" ||
115           name[0] == '_'
116           //is(BaseChild == union)
117                  )
118       {
119         // TODO(karita): support union value
120       }
121       else
122       {
123         auto child = Message!Base(base).opDispatch!name;
124         alias Child = typeof(child);
125 
126         ret ~= name ~ ": ";
127         scope (exit) ret ~= ", ";
128         static if (is(Child : Message!T, T))
129         {
130           ret ~= child.toString;
131         }
132         else
133         {
134           static if (isSomeString!Child)
135           {
136             ret ~= "\"";
137             scope (exit) ret ~= "\"";
138           }
139           ret ~= child.text;
140         }
141       }
142     }
143   }
144   // trim last ", "
145   return (ret[$-1 .. $] ==  " " ? ret[0 .. $-2] : ret) ~ "}";
146 }
147 
148 
149 /// protobuf-c message wrapper
150 unittest
151 {
152   import core.stdc.config : c_ulong;
153 
154   // protobuf-c + dpp generates structs like these
155   struct Child
156   {
157     ProtobufCMessage base;
158     char* name;
159   }
160 
161   struct Sample
162   {
163     ProtobufCMessage base;
164     c_ulong n_farray;
165     float* farray;
166     char* name;
167     Child* child;
168     c_ulong n_children;
169     Child** children;
170   }
171 
172   // static tests
173   static assert(isMessage!Child);
174   static assert(isMessage!Sample);
175   static assert(!isMessage!float);
176 
177   // example objects
178   Child child = { name: cast(char*) "child" };
179   Sample base = { n_farray: 2,
180                   farray: [0.1f, 0.2f],
181                   name: cast(char*) "foo\0",
182                   child: &child,
183                   n_children: 2,
184                   children: [&child, &child]
185   };
186 
187   // D friendly wrapper
188   auto s = Message!Sample(base);
189 
190   import std.stdio;
191   writeln(Message!Sample());
192   writeln(Message!Sample(s));
193 
194   // basic type access
195   assert(s.n_farray == 2);
196   // string conversion
197   assert(s.name == "foo");
198   // array conversion
199   assert(s.farray == [0.1f, 0.2f]);
200   // nested message conversion
201   assert(s.child.name == "child");
202   // pretty print for debugging
203   assert(s.toString ==
204          `Sample {n_farray: 2, farray: [0.1, 0.2], name: "foo", `
205          ~ `child: Child {name: "child"}, n_children: 2, `
206          ~ `children: [Child {name: "child"}, Child {name: "child"}]}`);
207 }
208 
209 /// test with tf proto
210 unittest
211 {
212   import tensorflow.op_def_pb;
213   import std.stdio;
214 
215   writeln(Message!Tensorflow__OpDef().toString);
216   static assert(isMessage!Tensorflow__OpDef);
217   static assert(isMessage!Tensorflow__OpDef__ArgDef);
218   static assert(isMessage!Tensorflow__OpDef__AttrDef);
219 }
220 
221 
222 /// Prints all TF ops using toString (debugging)
223 void printAllTFOps()
224 {
225     import tensorflow.c_api;
226     import tensorflow.op_def_pb;
227 
228     alias OpDef = Message!Tensorflow__OpDef;
229 
230     import std.stdio;
231 
232     auto buf = TF_GetAllOpList();
233     scope (exit) TF_DeleteBuffer(buf);
234     auto opList = tensorflow__op_list__unpack(null, buf.length, cast(const(ubyte)*) buf.data);
235     assert(opList, "unpack failed");
236 
237     auto ops = opList.op[0 .. opList.n_op];
238     foreach (i, rawOp; ops)
239     {
240         auto op = OpDef(*rawOp);
241         writefln("\n%s\n", op);
242     }
243 }
244