1 /**
2     Contains common maths operations.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.ops.math;
7 
8 import std.algorithm;
9 import std.conv;
10 import std.functional;
11 import std.range;
12 
13 import dopt.core.ops;
14 import dopt.core.types;
15 
16 package
17 {
18     void initialize()
19     {
20         void registerPointwiseBinary(string opName)
21         {
22             bool verifier(Operation op)
23             {
24                 return op.deps.length == 2 && op.deps[0].outputType == op.deps[1].outputType;
25             }
26 
27             TensorType judge(Operation op)
28             {
29                 return TensorType(op.deps[0].outputType);
30             }
31 
32             registerOperation(opName, OpDef(&verifier, &judge));
33         }
34 
35         void registerPointwiseUnary(string opName)
36         {
37             bool verifier(Operation op)
38             {
39                 return true;
40             }
41 
42             TensorType judge(Operation op)
43             {
44                 return TensorType(op.deps[0].outputType);
45             }
46 
47             registerOperation(opName, OpDef(&verifier, &judge));
48         }
49 
50         static foreach(opName; AliasSeq!(arith, comp, binfunc))
51         {
52             mixin("registerPointwiseBinary(\"" ~ opName ~ "\");");
53         }
54 
55         static foreach(opName; unfunc)
56         {
57             mixin("registerPointwiseUnary(\"" ~ opName ~ "\");");
58         }
59 
60         registerOperation("matmul", OpDef(toDelegate(&verifyMatmul), toDelegate(&judgeMatmul)));
61         registerOperation("sum", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum)));
62         registerOperation("argmin", OpDef(toDelegate(&verifyArgmin), toDelegate(&judgeArgmin)));
63 
64         //maxElement and sum are both reduction operations
65         registerOperation("maxElement", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum)));
66     }
67 }
68 
69 private
70 {
71     import std.meta : AliasSeq;
72     enum arith = AliasSeq!("add", "sub", "mul", "div");
73     enum comp = AliasSeq!("lt", "lte", "gt", "gte", "eq", "neq");
74     enum binfunc = AliasSeq!("max", "min", "pow");
75     enum unfunc = AliasSeq!("neg", "abs", "sgn", "exp", "log", "sqrt", "sin", "cos", "tan", "asin", "acos",
76         "atan", "atan2", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh");
77 
78     string createAllCtors()
79     {
80         string createOpCtor(string opName, size_t numDeps)
81         {
82             auto params = iota(0, numDeps)
83                         .map!(x => "Operation p" ~ x.to!string)
84                         .joiner(", ")
85                         .to!string();
86 
87             auto args = iota(0, numDeps)
88                     .map!(x => "p" ~ x.to!string)
89                     .joiner(", ")
90                     .to!string;
91 
92             return "
93                     Operation " ~ opName ~ "(" ~ params ~ ", string mod = __MODULE__, size_t line = __LINE__)
94                     {
95                         return createOperation(\"" ~ opName ~ "\", [" ~ args ~ "], null, mod, line);
96                     }
97                 ";
98         }
99 
100         import std.array : appender;
101 
102         auto strBuilder = appender!string();
103 
104         /*string binctors = chain(arith, comp, binfunc)
105                          .map!(x => createOpCtor(x, 2))
106                          .joiner("\n")
107                          .to!string;
108 
109         auto unctors = unfunc
110                       .map!(x => createOpCtor(x, 1))
111                       .joiner("\n")
112                       .to!string;
113 
114         return binctors ~ unctors;*/
115 
116         static foreach(f; AliasSeq!(arith, comp, binfunc))
117         {
118             strBuilder.put(createOpCtor(f, 2));
119         }
120 
121         static foreach(f; unfunc)
122         {
123             strBuilder.put(createOpCtor(f, 1));
124         }
125 
126         return strBuilder.data;
127     }
128 
129     bool verifyMatmul(Operation op)
130     {
131         return op.deps.length == 2
132             && op.deps[0].outputType.rank == 2
133             && op.deps[1].outputType.rank == 2
134             && op.deps[0].outputType.elementType == op.deps[1].outputType.elementType
135             && op.deps[0].outputType.shape[1] == op.deps[1].outputType.shape[0];
136     }
137 
138     TensorType judgeMatmul(Operation op)
139     {
140         return TensorType(op.deps[0].outputType.elementType,
141             [op.deps[0].outputType.shape[0], op.deps[1].outputType.shape[1]]);
142     }
143 
144     bool verifySum(Operation op)
145     {
146         if(op.deps.length != 1)
147         {
148             return false;
149         }
150 
151         if(("axes" in op.attributes) is null || op.attributes["axes"].peek!(size_t[]) is null)
152         {
153             return false;
154         }
155 
156         auto axes = op.attributes["axes"].get!(size_t[]);
157 
158         return axes.all!(x => x < op.deps[0].rank) &&
159                axes.map!(x => size_t(x)).array().sort().uniq().count() == axes.length;
160     }
161 
162     TensorType judgeSum(Operation op)
163     {
164         auto t = op.deps[0].outputType;
165         auto axes = op.attributes["axes"].get!(size_t[]);
166 
167         auto newShape = t
168                        .shape
169                        .zip(iota(0, t.shape.length))
170                        .filter!(x => !axes.canFind(x[1]))
171                        .map!(x => x[0])
172                        .array();
173 
174         return TensorType(t.elementType, newShape);
175     }
176 
177     bool verifyArgmin(Operation op)
178     {
179         return op.deps.length == 1
180             && ("axis" in op.attributes)
181             && (op.attributes["axis"].peek!size_t !is null)
182             && (op.attributes["axis"].get!size_t < op.deps[0].rank);
183     }
184 
185     TensorType judgeArgmin(Operation op)
186     {
187         auto shape = op.deps[0].shape.dup;
188         shape[op.attributes["axis"].get!size_t] = 1;
189 
190         return TensorType(DataType.int32, shape);
191     }
192 }
193 
194 mixin(createAllCtors());
195 
196 /**
197     Computes the matrix multiplication between two rank-2 tensors.
198 
199     Params:
200         lhs = The tensor on the left-hand side of the operation.
201         rhs = The tensor on the right-hand side of the operation.
202 
203     Returns:
204         The resulting operation.
205 */
206 Operation matmul(Operation lhs, Operation rhs, string mod = __MODULE__, size_t line = __LINE__)
207 {
208     return createOperation("matmul", [lhs, rhs], null, mod, line);
209 }
210 
211 ///
212 unittest
213 {
214     import dopt.core : evaluate;
215     
216     auto a = float32([2, 1], [
217         1.0f,
218         2.0f
219     ]);
220     
221     auto b = float32([1, 2], [
222         3.0f, 4.0f
223     ]);
224 
225     auto c = matmul(a, b);
226 
227     assert(c.evaluate().as!float == [
228         3.0f, 4.0f,
229         6.0f, 8.0f
230     ]);
231 }
232 
233 /**
234     Computes a sum reduction along the specified axes.
235 
236     Params:
237         op = The input to the reduction.
238         axes = The axes the reduction should be performed along.
239 
240     Returns:
241         The resulting operation.
242 */
243 Operation sum(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__)
244 {
245     import std.variant : Variant;
246 
247     if(op.rank == 0)
248     {
249         return op.reshape(op.shape);
250     }
251 
252     if(axes.length == 0)
253     {
254         axes = iota(0, op.rank).array();
255     }
256     
257     return createOperation("sum", [op], ["axes": Variant(axes)], mod, line);
258 }
259 
260 ///
261 unittest
262 {
263     import dopt.core : evaluate;
264 
265     auto s1 = float32([2], [0.5, 1.5]).sum();
266     auto s2 = float32([2, 2], [0, 1, 0, 5]).sum();
267     auto s3 = float32([2, 2], [0, 1, 0, 5]).sum([0]);
268     auto s4 = float32([2, 2], [0, 1, 0, 5]).sum([1]);
269 
270     assert(s1.evaluate().as!float == [2.0f]);
271     assert(s2.evaluate().as!float == [6.0f]);
272     assert(s3.evaluate().as!float == [0.0f, 6.0f]);
273     assert(s4.evaluate().as!float == [1.0f, 5.0f]);
274 }
275 
276 /**
277     Performs an argmin over the specified dimension.
278 
279     Params:
280         input = The operation to perform argmin on.
281         axis = The diension the argmin should be performed over.
282     
283     Returns:
284         The new argmin operation.
285 */
286 Operation argmin(Operation input, size_t axis, string mod = __MODULE__, size_t line = __LINE__)
287 {
288     import std.variant : Variant;
289 
290     return createOperation("argmin", [input], ["axis": Variant(axis)], mod, line);
291 }
292 
293 ///
294 unittest
295 {
296     import dopt.core : evaluate;
297 
298     auto a = float32([5], [4.0f, 2.0f, 6.0f, 1.0f, 2.0f]).argmin(0);
299 
300     auto b = float32([2, 3], [
301         5.0f, 1.0f, 3.0f,
302         6.0f, 7.0f, 2.0f
303     ]).argmin(1);
304 
305     assert(a.evaluate().as!int == [3]);
306     assert(b.evaluate().as!int == [1, 2]);
307 }
308 
309 /**
310     Computes a max reduction along the specified axes.
311 
312     Params:
313         op = The input to the reduction.
314         axes = The axes the reduction should be performed along.
315 
316     Returns:
317         The resulting operation.
318 */
319 Operation maxElement(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__)
320 {
321     import std.variant : Variant;
322 
323     if(op.rank == 0)
324     {
325         return op.reshape(op.shape);
326     }
327 
328     if(axes.length == 0)
329     {
330         axes = iota(0, op.rank).array();
331     }
332     
333     return createOperation("maxElement", [op], ["axes": Variant(axes)], mod, line);
334 }
335 
336 ///
337 unittest
338 {
339     import dopt.core : evaluate;
340 
341     auto a = float32([2, 2],[
342         1.0f, 4.0f,
343         3.0f, 6.0f
344     ]);
345 
346     assert(a.maxElement.evaluate().as!float == [6.0f]); //Max value in the entire tensor
347     assert(a.maxElement([0]).evaluate().as!float == [3.0f, 6.0f]); //Max of each column
348     assert(a.maxElement([1]).evaluate().as!float == [4.0f, 6.0f]); //Max of each row
349 }