1 /**
2     This package facilitates the construction of various nodes in the operation graph.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.ops;
7 
8 import std.array;
9 import std.exception;
10 import std.variant;
11 
12 import dopt.core.types;
13 
14 public
15 {
16     import dopt.core.ops.basic;
17     import dopt.core.ops.math;
18     import dopt.core.ops.nnet;
19     import dopt.core.ops.random;
20 }
21 
22 void initialize()
23 {
24     dopt.core.ops.basic.initialize();
25     dopt.core.ops.math.initialize();
26     dopt.core.ops.nnet.initialize();
27     dopt.core.ops.random.initialize();
28 }
29 
30 alias Verifier = bool delegate(Operation);
31 alias Judge = TensorType delegate(Operation);
32 
33 /**
34 Contains methods to perform procedures specific to the type of an operation
35 */
36 struct OpDef
37 {
38     /**
39     A verifier is used to ensure that an Operation object correctly constructed.
40     */
41     Verifier verifier;
42 
43     /**
44     A judge produces a TensorType object that specifies the type of the result of an operation of this type.
45     */
46     Judge judge;
47 }
48 
49 /**
50 A node in the expression graph
51 */
52 class Operation
53 {
54     public
55     {
56         /**
57         Returns a string identifying the type of this operation. This is the same string used when registering the
58         operation with the registerOperation method.
59         */
60         @property string opType()
61         {
62             return mOpType;
63         }
64 
65         /**
66         Returns a TensorType object that specifies the type of tensor obtained by evaluating this operation.
67         */
68         @property TensorType outputType()
69         {
70             return mOutputType;
71         }
72 
73         /**
74         Returns a list of operands for this operation.
75         */
76         @property Operation[] deps()
77         {
78             return mDeps;
79         }
80 
81         /**
82         Returns an associative array that maps strings to operation specific attributes.
83         */
84         @property Variant[string] attributes()
85         {
86             return mAttributes;
87         }
88 
89         /**
90         Convenience method for pointwise operations.
91 
92         Internally, this just calls the appropriate function from dopt.core.ops.math.
93         */
94         Operation opBinary(string op)(Operation rhs, string mod = __MODULE__, size_t line = __LINE__)
95         {
96             if(rhs.rank == 0 && this.rank != 0)
97             {
98                 return this.opBinary!op(rhs.repeat(this.volume, mod, line).reshape(this.shape, mod, line), mod, line);
99             }
100             else if(this.rank == 0 && rhs.rank != 0)
101             {
102                 return this.repeat(rhs.volume, mod, line).reshape(rhs.shape, mod, line).opBinary!op(rhs, mod, line);
103             }
104 
105             static if(op == "+")
106             {
107                 return this.add(rhs, mod, line);
108             }
109             else static if(op == "-")
110             {
111                 return this.sub(rhs, mod, line);
112             }
113             else static if(op == "*")
114             {
115                 return this.mul(rhs, mod, line);
116             }
117             else static if(op == "/")
118             {
119                 return this.div(rhs, mod, line);
120             }
121             else
122             {
123                 static assert(0, "Unknown binary operation '" ~ op ~ "'");
124             }
125         }
126 
127         Operation opBinary(string op)(int i, string mod = __MODULE__, size_t line = __LINE__)
128         {
129             auto bc = int32Constant(i);
130 
131             return opBinary!op(bc, mod, line);
132         }
133 
134         Operation opBinary(string op)(float i, string mod = __MODULE__, size_t line = __LINE__)
135         {
136             auto bc = float32Constant(i);
137 
138             return opBinary!op(bc, mod, line);
139         }
140 
141         Operation opBinaryRight(string op, T)(T t, string mod = __MODULE__, size_t line = __LINE__)
142         {
143             static if(op == "*" || op == "+")
144             {
145                 return opBinary!op(t);
146             }
147             else static if(op == "-" && is(T == float))
148             {
149                 return float32Constant(t) - this;
150             }
151             else static if(op == "/" && is(T == float))
152             {
153                 return float32Constant(t) / this;
154             }
155             else
156             {
157                 static assert(0, "Not implemented.");
158             }
159         }
160 
161         Operation opUnary(string op)()
162         {
163             static if(op == "-")
164             {
165                 return neg(this);
166             }
167             else
168             {
169                 static assert("Unknown unary operation '" ~ op ~ "'");
170             }
171         }
172 
173         override string toString()
174         {
175             import std.algorithm : joiner, map;
176             import std.conv : to;
177 
178             //If it's a variable, we should have some unique identifier
179             if(opType == "variable")
180             {
181                 //This is very ugly. Someone please come up with a better way.
182                 return to!string(cast(void *)this);
183             }
184             else
185             {
186                 return opType ~ "(" ~ deps.map!(x => x.toString).joiner(", ").to!string ~ ")";
187             }
188         }
189 
190         DeviceBuffer value()
191         {
192             return mBuffer;
193         }
194 
195         void setBuffer(DeviceBuffer buf)
196         {
197             mBuffer = buf;
198         }
199 
200         auto shape()
201         {
202             return outputType.shape;
203         }
204 
205         auto elementType()
206         {
207             return outputType.elementType;
208         }
209 
210         auto volume()
211         {
212             return outputType.volume;
213         }
214 
215         auto rank()
216         {
217             return outputType.rank;
218         }
219     }
220 
221     public
222     {
223         string mOpType;
224         string mModule;
225         size_t mLine;
226         Operation[] mDeps;
227         Variant[string] mAttributes;
228         TensorType mOutputType;
229         DeviceBuffer mBuffer;
230 
231         this(string opType, Operation[] deps, Variant[string] attribs, string mod, size_t line)
232         {
233             import std.conv : to;
234             
235             mOpType = opType;
236             mDeps = deps.array;
237             mAttributes = attribs.dup;
238             mModule = mod;
239             mLine = line;
240 
241             enforce(mOpDefs[opType].verifier(this),
242                 "Operation of type \"" ~ opType ~ "\" failed verification. Instantiated at " ~ mod ~ ":" ~
243                 line.to!string);
244 
245             mOutputType = makeJudgement(this);
246         }
247     }
248 }
249 
250 /**
251 Registers an operation definition with the given identifier.
252 */
253 void registerOperation(string name, OpDef def)
254 {
255     enforce((name in mOpDefs) is null, "There is already an operation registered with the name '" ~ name ~ "'");
256 
257     mOpDefs[name] = def;
258 }
259 
260 /**
261 Returns a list of identifiers for operations that have been registered so far.
262 */
263 string[] listAllOperations()
264 {
265     return mOpDefs.keys.dup;
266 }
267 
268 /**
269 Creates an operation of the given type, with the given dependencies and attributes.
270 */
271 Operation createOperation(string opType, Operation[] deps = [], Variant[string] attribs = null,
272     string mod = __MODULE__, size_t line = __LINE__)
273 {
274     import std.conv : to;
275 
276     enforce(opType in mOpDefs,
277         "Cannot create operation because there is no operation definition registered with the name '" ~ opType ~ "'");
278 
279     auto op = new Operation(opType, deps, attribs, mod, line);
280 
281     return op;
282 }
283 
284 Operation[] topologicalSort(Operation[] ops)
285 {
286     Operation[] sortedOps;
287 
288     void toposort(Operation o)
289     {
290         import std.algorithm : canFind;
291 
292         if(sortedOps.canFind(o))
293         {
294             return;
295         }
296 
297         foreach(d; o.deps)
298         {
299             toposort(d);
300         }
301         
302         sortedOps ~= o;
303     }
304 
305     foreach(o; ops)
306     {
307         toposort(o);
308     }
309 
310     return sortedOps;
311 }
312 
313 private
314 {
315     OpDef[string] mOpDefs;
316 
317     TensorType makeJudgement(Operation op)
318     {
319         auto def = op.opType in mOpDefs;
320 
321         enforce(def !is null, "Cannot make judgement for unknown operation '" ~ op.opType() ~ "'");
322 
323         return def.judge(op);
324     }
325 }