1 /**
2     Contains common maths operations.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.ops.math;
7 
8 import std.algorithm;
9 import std.conv;
10 import std.functional;
11 import std.range;
12 
13 import dopt.core.ops;
14 import dopt.core.types;
15 
16 package
17 {
18     void initialize()
19     {
20         void registerPointwiseBinary(string opName)
21         {
22             bool verifier(Operation op)
23             {
24                 return op.deps.length == 2 && op.deps[0].outputType == op.deps[1].outputType;
25             }
26 
27             TensorType judge(Operation op)
28             {
29                 return TensorType(op.deps[0].outputType);
30             }
31 
32             registerOperation(opName, OpDef(&verifier, &judge));
33         }
34 
35         void registerPointwiseUnary(string opName)
36         {
37             bool verifier(Operation op)
38             {
39                 return true;
40             }
41 
42             TensorType judge(Operation op)
43             {
44                 return TensorType(op.deps[0].outputType);
45             }
46 
47             registerOperation(opName, OpDef(&verifier, &judge));
48         }
49 
50         static foreach(opName; AliasSeq!(arith, comp, binfunc))
51         {
52             mixin("registerPointwiseBinary(\"" ~ opName ~ "\");");
53         }
54 
55         static foreach(opName; unfunc)
56         {
57             mixin("registerPointwiseUnary(\"" ~ opName ~ "\");");
58         }
59 
60         registerOperation("matmul", OpDef(toDelegate(&verifyMatmul), toDelegate(&judgeMatmul)));
61         registerOperation("sum", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum)));
62         registerOperation("argmin", OpDef(toDelegate(&verifyArgmin), toDelegate(&judgeArgmin)));
63 
64         //maxElement and sum are both reduction operations
65         registerOperation("maxElement", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum)));
66     }
67 }
68 
69 private
70 {
71     import std.meta : AliasSeq;
72     enum arith = AliasSeq!("add", "sub", "mul", "div");
73     enum comp = AliasSeq!("lt", "lte", "gt", "gte", "eq", "neq");
74     enum binfunc = AliasSeq!("max", "min", "pow");
75     enum unfunc = AliasSeq!("neg", "abs", "sgn", "exp", "log", "sqrt", "sin", "cos", "tan", "asin", "acos",
76         "atan", "atan2", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh");
77 
78     string createAllCtors()
79     {
80         string createOpCtor(string opName, size_t numDeps)
81         {
82             auto params = iota(0, numDeps)
83                         .map!(x => "Operation p" ~ x.to!string)
84                         .joiner(", ")
85                         .to!string();
86 
87             auto args = iota(0, numDeps)
88                     .map!(x => "p" ~ x.to!string)
89                     .joiner(", ")
90                     .to!string;
91 
92             return "
93                     Operation " ~ opName ~ "(" ~ params ~ ", string mod = __MODULE__, size_t line = __LINE__)
94                     {
95                         return createOperation(\"" ~ opName ~ "\", [" ~ args ~ "], null, mod, line);
96                     }
97                 ";
98         }
99 
100         import std.array : appender;
101 
102         auto strBuilder = appender!string();
103 
104         /*string binctors = chain(arith, comp, binfunc)
105                          .map!(x => createOpCtor(x, 2))
106                          .joiner("\n")
107                          .to!string;
108 
109         auto unctors = unfunc
110                       .map!(x => createOpCtor(x, 1))
111                       .joiner("\n")
112                       .to!string;
113 
114         return binctors ~ unctors;*/
115 
116         static foreach(f; AliasSeq!(arith, comp, binfunc))
117         {
118             strBuilder.put(createOpCtor(f, 2));
119         }
120 
121         static foreach(f; unfunc)
122         {
123             strBuilder.put(createOpCtor(f, 1));
124         }
125 
126         return strBuilder.data;
127     }
128 
129     bool verifyMatmul(Operation op)
130     {
131         return op.deps.length == 2
132             && op.deps[0].outputType.rank == 2
133             && op.deps[1].outputType.rank == 2
134             && op.deps[0].outputType.elementType == op.deps[1].outputType.elementType
135             && op.deps[0].outputType.shape[1] == op.deps[1].outputType.shape[0];
136     }
137 
138     TensorType judgeMatmul(Operation op)
139     {
140         return TensorType(op.deps[0].outputType.elementType,
141             [op.deps[0].outputType.shape[0], op.deps[1].outputType.shape[1]]);
142     }
143 
144     bool verifySum(Operation op)
145     {
146         if(op.deps.length != 1)
147         {
148             return false;
149         }
150 
151         if(("axes" in op.attributes) is null || op.attributes["axes"].peek!(size_t[]) is null)
152         {
153             return false;
154         }
155 
156         auto axes = op.attributes["axes"].get!(size_t[]);
157 
158         return axes.all!(x => x < op.deps[0].rank) &&
159                axes.map!(x => size_t(x)).array().sort().uniq().count() == axes.length;
160     }
161 
162     TensorType judgeSum(Operation op)
163     {
164         auto t = op.deps[0].outputType;
165         auto axes = op.attributes["axes"].get!(size_t[]);
166 
167         auto newShape = t
168                        .shape
169                        .zip(iota(0, t.shape.length))
170                        .filter!(x => !axes.canFind(x[1]))
171                        .map!(x => x[0])
172                        .array();
173 
174         return TensorType(t.elementType, newShape);
175     }
176 
177     bool verifyArgmin(Operation op)
178     {
179         return op.deps.length == 1
180             && ("axis" in op.attributes)
181             && (op.attributes["axis"].peek!size_t !is null)
182             && (op.attributes["axis"].get!size_t < op.deps[0].rank);
183     }
184 
185     TensorType judgeArgmin(Operation op)
186     {
187         auto shape = op.deps[0].shape.dup;
188         shape[op.attributes["axis"].get!size_t] = 1;
189 
190         return TensorType(DataType.int32, shape);
191     }
192 }
193 
194 mixin(createAllCtors());
195 
196 /**
197     Computes the matrix multiplication between two rank-2 tensors.
198 
199     Params:
200         lhs = The tensor on the left-hand side of the operation.
201         rhs = The tensor on the right-hand side of the operation.
202 
203     Returns:
204         The resulting operation.
205 */
206 Operation matmul(Operation lhs, Operation rhs, string mod = __MODULE__, size_t line = __LINE__)
207 {
208     return createOperation("matmul", [lhs, rhs], null, mod, line);
209 }
210 
211 ///
212 unittest
213 {
214     import dopt.core : evaluate;
215     
216     auto a = float32([2, 1], [
217         1.0f,
218         2.0f
219     ]);
220     
221     auto b = float32([1, 2], [
222         3.0f, 4.0f
223     ]);
224 
225     auto c = matmul(a, b);
226 
227     assert(c.evaluate().get!float == [
228         3.0f, 4.0f,
229         6.0f, 8.0f
230     ]);
231 }
232 
233 /**
234     Computes a sum reduction along the specified axes.
235 
236     Params:
237         op = The input to the reduction.
238         axes = The axes the reduction should be performed along.
239 
240     Returns:
241         The resulting operation.
242 */
243 Operation sum(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__)
244 {
245     import std.variant : Variant;
246 
247     if(op.rank == 0)
248     {
249         return op.reshape(op.shape);
250     }
251 
252     if(axes.length == 0)
253     {
254         axes = iota(0, op.rank).array();
255     }
256 
257     //Temporary speed enhancement: use BLAS to do row/col sums of matrices
258     if(op.rank == 2 && axes.length == 1)
259     {
260         import std.array : array;
261         import std.range : repeat;
262 
263         if(axes[0] == 1)
264         {
265             auto ones = float32Constant([op.shape[1], 1], repeat(1.0f, op.shape[1]).array());
266 
267             return op.matmul(ones).reshape([op.shape[0]]);
268         }
269         else if(axes[0] == 0)
270         {
271             auto ones = float32Constant([1, op.shape[0]], repeat(1.0f, op.shape[0]).array());
272 
273             return ones.matmul(op).reshape([op.shape[1]]);
274         }
275         else
276         {
277             throw new Exception("axes[0] must be less than op.rank");
278         }
279     }
280     
281     return createOperation("sum", [op], ["axes": Variant(axes)], mod, line);
282 }
283 
284 ///
285 unittest
286 {
287     import dopt.core : evaluate;
288 
289     auto s1 = float32([2], [0.5, 1.5]).sum();
290     auto s2 = float32([2, 2], [0, 1, 2, 5]).sum();
291     auto s3 = float32([2, 2], [0, 1, 2, 5]).sum([0]);
292     auto s4 = float32([2, 2], [0, 1, 2, 5]).sum([1]);
293 
294     assert(s1.evaluate().get!float == [2.0f]);
295     assert(s2.evaluate().get!float == [8.0f]);
296     assert(s3.evaluate().get!float == [2.0f, 6.0f]);
297     assert(s4.evaluate().get!float == [1.0f, 7.0f]);
298 }
299 
300 /**
301     Performs an argmin over the specified dimension.
302 
303     Params:
304         input = The operation to perform argmin on.
305         axis = The diension the argmin should be performed over.
306     
307     Returns:
308         The new argmin operation.
309 */
310 Operation argmin(Operation input, size_t axis, string mod = __MODULE__, size_t line = __LINE__)
311 {
312     import std.variant : Variant;
313 
314     return createOperation("argmin", [input], ["axis": Variant(axis)], mod, line);
315 }
316 
317 ///
318 unittest
319 {
320     import dopt.core : evaluate;
321 
322     auto a = float32([5], [4.0f, 2.0f, 6.0f, 1.0f, 2.0f]).argmin(0);
323 
324     auto b = float32([2, 3], [
325         5.0f, 1.0f, 3.0f,
326         6.0f, 7.0f, 2.0f
327     ]).argmin(1);
328 
329     assert(a.evaluate().get!int == [3]);
330     assert(b.evaluate().get!int == [1, 2]);
331 }
332 
333 /**
334     Computes a max reduction along the specified axes.
335 
336     Params:
337         op = The input to the reduction.
338         axes = The axes the reduction should be performed along.
339 
340     Returns:
341         The resulting operation.
342 */
343 Operation maxElement(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__)
344 {
345     import std.variant : Variant;
346 
347     if(op.rank == 0)
348     {
349         return op.reshape(op.shape);
350     }
351 
352     if(axes.length == 0)
353     {
354         axes = iota(0, op.rank).array();
355     }
356     
357     return createOperation("maxElement", [op], ["axes": Variant(axes)], mod, line);
358 }
359 
360 ///
361 unittest
362 {
363     import dopt.core : evaluate;
364 
365     auto a = float32([2, 2],[
366         1.0f, 4.0f,
367         3.0f, 6.0f
368     ]);
369 
370     assert(a.maxElement.evaluate().get!float == [6.0f]); //Max value in the entire tensor
371     assert(a.maxElement([0]).evaluate().get!float == [3.0f, 6.0f]); //Max of each column
372     assert(a.maxElement([1]).evaluate().get!float == [4.0f, 6.0f]); //Max of each row
373 }