1 /**
2     This package facilitates the construction of various nodes in the operation graph.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.ops;
7 
8 import std.array;
9 import std.exception;
10 import std.variant;
11 
12 import dopt.core.types;
13 
14 public
15 {
16     import dopt.core.ops.basic;
17     import dopt.core.ops.math;
18     import dopt.core.ops.nnet;
19     import dopt.core.ops.random;
20 }
21 
22 void initialize()
23 {
24     dopt.core.ops.basic.initialize();
25     dopt.core.ops.math.initialize();
26     dopt.core.ops.nnet.initialize();
27     dopt.core.ops.random.initialize();
28 }
29 
30 alias Verifier = bool delegate(Operation);
31 alias Judge = TensorType delegate(Operation);
32 
33 /**
34 Contains methods to perform procedures specific to the type of an operation
35 */
36 struct OpDef
37 {
38     /**
39     A verifier is used to ensure that an Operation object correctly constructed.
40     */
41     Verifier verifier;
42 
43     /**
44     A judge produces a TensorType object that specifies the type of the result of an operation of this type.
45     */
46     Judge judge;
47 }
48 
49 /**
50 A node in the expression graph
51 */
52 class Operation
53 {
54     public
55     {
56         /**
57         Returns a string identifying the type of this operation. This is the same string used when registering the
58         operation with the registerOperation method.
59         */
60         @property string opType()
61         {
62             return mOpType;
63         }
64 
65         /**
66         Returns a TensorType object that specifies the type of tensor obtained by evaluating this operation.
67         */
68         @property TensorType outputType()
69         {
70             return mOutputType;
71         }
72 
73         /**
74         Returns a list of operands for this operation.
75         */
76         @property Operation[] deps()
77         {
78             return mDeps;
79         }
80 
81         /**
82         Returns an associative array that maps strings to operation specific attributes.
83         */
84         @property Variant[string] attributes()
85         {
86             return mAttributes;
87         }
88 
89         /**
90         Convenience method for pointwise operations.
91 
92         Internally, this just calls the appropriate function from dopt.core.ops.math.
93         */
94         Operation opBinary(string op)(Operation rhs, string mod = __MODULE__, size_t line = __LINE__)
95         {
96             if(rhs.rank == 0 && this.rank != 0)
97             {
98                 return this.opBinary!op(rhs.repeat(this.volume, mod, line).reshape(this.shape, mod, line), mod, line);
99             }
100             else if(this.rank == 0 && rhs.rank != 0)
101             {
102                 return this.repeat(rhs.volume, mod, line).reshape(rhs.shape, mod, line).opBinary!op(rhs, mod, line);
103             }
104 
105             static if(op == "+")
106             {
107                 return this.add(rhs, mod, line);
108             }
109             else static if(op == "-")
110             {
111                 return this.sub(rhs, mod, line);
112             }
113             else static if(op == "*")
114             {
115                 return this.mul(rhs, mod, line);
116             }
117             else static if(op == "/")
118             {
119                 return this.div(rhs, mod, line);
120             }
121             else
122             {
123                 static assert(0, "Unknown binary operation '" ~ op ~ "'");
124             }
125         }
126 
127         Operation opBinary(string op)(int i, string mod = __MODULE__, size_t line = __LINE__)
128         {
129             auto bc = int32([], [i]);
130 
131             return opBinary!op(bc, mod, line);
132         }
133 
134         Operation opBinary(string op)(float i, string mod = __MODULE__, size_t line = __LINE__)
135         {
136             auto bc = float32([], [i]);
137 
138             return opBinary!op(bc, mod, line);
139         }
140 
141         Operation opBinaryRight(string op, T)(T t, string mod = __MODULE__, size_t line = __LINE__)
142         {
143             static if(op == "*" || op == "+")
144             {
145                 return opBinary!op(t);
146             }
147             else static if(op == "-" && is(T == float))
148             {
149                 return float32([], [t]) - this;
150             }
151             else static if(op == "/" && is(T == float))
152             {
153                 return float32([], [t]) / this;
154             }
155             else
156             {
157                 static assert(0, "Not implemented.");
158             }
159         }
160 
161         Operation opUnary(string op)()
162         {
163             static if(op == "-")
164             {
165                 return neg(this);
166             }
167             else
168             {
169                 static assert("Unknown unary operation '" ~ op ~ "'");
170             }
171         }
172 
173         override string toString()
174         {
175             import std.algorithm : joiner, map;
176             import std.conv : to;
177 
178             //If it's a variable, we should have some unique identifier
179             if(opType == "variable")
180             {
181                 //This is very ugly. Someone please come up with a better way.
182                 return to!string(cast(void *)this);
183             }
184             else
185             {
186                 return opType ~ "(" ~ deps.map!(x => x.toString).joiner(", ").to!string ~ ")";
187             }
188         }
189 
190         Buffer value()
191         {
192             return attributes["default"].get!Buffer;
193         }
194 
195         auto shape()
196         {
197             return outputType.shape;
198         }
199 
200         auto elementType()
201         {
202             return outputType.elementType;
203         }
204 
205         auto volume()
206         {
207             return outputType.volume;
208         }
209 
210         auto rank()
211         {
212             return outputType.rank;
213         }
214     }
215 
216     public
217     {
218         string mOpType;
219         string mModule;
220         size_t mLine;
221         Operation[] mDeps;
222         Variant[string] mAttributes;
223         TensorType mOutputType;
224 
225         this(string opType, Operation[] deps, Variant[string] attribs, string mod, size_t line)
226         {
227             import std.conv : to;
228             
229             mOpType = opType;
230             mDeps = deps.array;
231             mAttributes = attribs.dup;
232             mModule = mod;
233             mLine = line;
234 
235             enforce(mOpDefs[opType].verifier(this),
236                 "Operation of type \"" ~ opType ~ "\" failed verification. Instantiated at " ~ mod ~ ":" ~
237                 line.to!string);
238 
239             mOutputType = makeJudgement(this);
240         }
241     }
242 }
243 
244 /**
245 Registers an operation definition with the given identifier.
246 */
247 void registerOperation(string name, OpDef def)
248 {
249     enforce((name in mOpDefs) is null, "There is already an operation registered with the name '" ~ name ~ "'");
250 
251     mOpDefs[name] = def;
252 }
253 
254 /**
255 Returns a list of identifiers for operations that have been registered so far.
256 */
257 string[] listAllOperations()
258 {
259     return mOpDefs.keys.dup;
260 }
261 
262 /**
263 Creates an operation of the given type, with the given dependencies and attributes.
264 */
265 Operation createOperation(string opType, Operation[] deps = [], Variant[string] attribs = null,
266     string mod = __MODULE__, size_t line = __LINE__)
267 {
268     import std.conv : to;
269 
270     enforce(opType in mOpDefs,
271         "Cannot create operation because there is no operation definition registered with the name '" ~ opType ~ "'");
272 
273     auto op = new Operation(opType, deps, attribs, mod, line);
274 
275     return op;
276 }
277 
278 Operation[] topologicalSort(Operation[] ops)
279 {
280     Operation[] sortedOps;
281 
282     void toposort(Operation o)
283     {
284         import std.algorithm : canFind;
285 
286         if(sortedOps.canFind(o))
287         {
288             return;
289         }
290 
291         foreach(d; o.deps)
292         {
293             toposort(d);
294         }
295         
296         sortedOps ~= o;
297     }
298 
299     foreach(o; ops)
300     {
301         toposort(o);
302     }
303 
304     return sortedOps;
305 }
306 
307 private
308 {
309     OpDef[string] mOpDefs;
310 
311     TensorType makeJudgement(Operation op)
312     {
313         auto def = op.opType in mOpDefs;
314 
315         enforce(def !is null, "Cannot make judgement for unknown operation '" ~ op.opType() ~ "'");
316 
317         return def.judge(op);
318     }
319 }