1 /**
2     Contains common maths operations.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.ops.math;
7 
8 import std.algorithm;
9 import std.conv;
10 import std.functional;
11 import std.range;
12 
13 import dopt.core.ops;
14 import dopt.core.types;
15 
16 package
17 {
18     void initialize()
19     {
20         void registerPointwiseBinary(string opName)
21         {
22             bool verifier(Operation op)
23             {
24                 return op.deps.length == 2 && op.deps[0].outputType == op.deps[1].outputType;
25             }
26 
27             TensorType judge(Operation op)
28             {
29                 return TensorType(op.deps[0].outputType);
30             }
31 
32             registerOperation(opName, OpDef(&verifier, &judge));
33         }
34 
35         void registerPointwiseUnary(string opName)
36         {
37             bool verifier(Operation op)
38             {
39                 return true;
40             }
41 
42             TensorType judge(Operation op)
43             {
44                 return TensorType(op.deps[0].outputType);
45             }
46 
47             registerOperation(opName, OpDef(&verifier, &judge));
48         }
49 
50         foreach(opName; chain(arith, comp, binfunc))
51         {
52             registerPointwiseBinary(opName);
53         }
54 
55         foreach(opName; unfunc)
56         {
57             registerPointwiseUnary(opName);
58         }
59 
60         registerOperation("matmul", OpDef(toDelegate(&verifyMatmul), toDelegate(&judgeMatmul)));
61         registerOperation("sum", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum)));
62         registerOperation("argmin", OpDef(toDelegate(&verifyArgmin), toDelegate(&judgeArgmin)));
63 
64         //maxElement and sum are both reduction operations
65         registerOperation("maxElement", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum)));
66     }
67 }
68 
69 private
70 {
71     immutable string[] arith = ["add", "sub", "mul", "div"];
72     immutable string[] comp = ["lt", "lte", "gt", "gte", "eq", "neq"];
73     immutable string[] binfunc = ["max", "min", "pow"];
74     immutable string[] unfunc = ["neg", "abs", "sgn", "exp", "log", "sqrt"];
75 
76     string createAllCtors()
77     {
78         string createOpCtor(string opName, size_t numDeps)
79         {
80             auto params = iota(0, numDeps)
81                         .map!(x => "Operation p" ~ x.to!string)
82                         .joiner(", ")
83                         .to!string();
84 
85             auto args = iota(0, numDeps)
86                     .map!(x => "p" ~ x.to!string)
87                     .joiner(", ")
88                     .to!string;
89 
90             return "
91                     Operation " ~ opName ~ "(" ~ params ~ ", string mod = __MODULE__, size_t line = __LINE__)
92                     {
93                         return createOperation(\"" ~ opName ~ "\", [" ~ args ~ "], null, mod, line);
94                     }
95                 ";
96         }
97 
98         string binctors = chain(arith, comp, binfunc)
99                          .map!(x => createOpCtor(x, 2))
100                          .joiner("\n")
101                          .to!string;
102 
103         auto unctors = unfunc
104                       .map!(x => createOpCtor(x, 1))
105                       .joiner("\n")
106                       .to!string;
107 
108         return binctors ~ unctors;
109     }
110 
111     bool verifyMatmul(Operation op)
112     {
113         return op.deps.length == 2
114             && op.deps[0].outputType.rank == 2
115             && op.deps[1].outputType.rank == 2
116             && op.deps[0].outputType.elementType == op.deps[1].outputType.elementType
117             && op.deps[0].outputType.shape[1] == op.deps[1].outputType.shape[0];
118     }
119 
120     TensorType judgeMatmul(Operation op)
121     {
122         return TensorType(op.deps[0].outputType.elementType,
123             [op.deps[0].outputType.shape[0], op.deps[1].outputType.shape[1]]);
124     }
125 
126     bool verifySum(Operation op)
127     {
128         if(op.deps.length != 1)
129         {
130             return false;
131         }
132 
133         if(("axes" in op.attributes) is null || op.attributes["axes"].peek!(size_t[]) is null)
134         {
135             return false;
136         }
137 
138         auto axes = op.attributes["axes"].get!(size_t[]);
139 
140         return axes.all!(x => x < op.deps[0].rank) &&
141                axes.map!(x => size_t(x)).array().sort().uniq().count() == axes.length;
142     }
143 
144     TensorType judgeSum(Operation op)
145     {
146         auto t = op.deps[0].outputType;
147         auto axes = op.attributes["axes"].get!(size_t[]);
148 
149         auto newShape = t
150                        .shape
151                        .zip(iota(0, t.shape.length))
152                        .filter!(x => !axes.canFind(x[1]))
153                        .map!(x => x[0])
154                        .array();
155 
156         return TensorType(t.elementType, newShape);
157     }
158 
159     bool verifyArgmin(Operation op)
160     {
161         return op.deps.length == 1
162             && ("axis" in op.attributes)
163             && (op.attributes["axis"].peek!size_t !is null)
164             && (op.attributes["axis"].get!size_t < op.deps[0].rank);
165     }
166 
167     TensorType judgeArgmin(Operation op)
168     {
169         auto shape = op.deps[0].shape.dup;
170         shape[op.attributes["axis"].get!size_t] = 1;
171 
172         return TensorType(DataType.int32, shape);
173     }
174 }
175 
176 mixin(createAllCtors());
177 
178 /**
179     Computes the matrix multiplication between two rank-2 tensors.
180 
181     Params:
182         lhs = The tensor on the left-hand side of the operation.
183         rhs = The tensor on the right-hand side of the operation.
184 
185     Returns:
186         The resulting operation.
187 */
188 Operation matmul(Operation lhs, Operation rhs, string mod = __MODULE__, size_t line = __LINE__)
189 {
190     return createOperation("matmul", [lhs, rhs], null, mod, line);
191 }
192 
193 ///
194 unittest
195 {
196     import dopt.core.cpu : evaluate;
197     
198     auto a = float32([2, 1], [
199         1.0f,
200         2.0f
201     ]);
202     
203     auto b = float32([1, 2], [
204         3.0f, 4.0f
205     ]);
206 
207     auto c = matmul(a, b);
208 
209     assert(c.evaluate().as!float == [
210         3.0f, 4.0f,
211         6.0f, 8.0f
212     ]);
213 }
214 
215 /**
216     Computes a sum reduction along the specified axes.
217 
218     Params:
219         op = The input to the reduction.
220         axes = The axes the reduction should be performed along.
221 
222     Returns:
223         The resulting operation.
224 */
225 Operation sum(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__)
226 {
227     import std.variant : Variant;
228 
229     if(op.rank == 0)
230     {
231         return op.reshape(op.shape);
232     }
233 
234     if(axes.length == 0)
235     {
236         axes = iota(0, op.rank).array();
237     }
238     
239     return createOperation("sum", [op], ["axes": Variant(axes)], mod, line);
240 }
241 
242 ///
243 unittest
244 {
245     import dopt.core.cpu : evaluate;
246 
247     auto s1 = float32([2], [0.5, 1.5]).sum();
248     auto s2 = float32([2, 2], [0, 1, 0, 5]).sum();
249     auto s3 = float32([2, 2], [0, 1, 0, 5]).sum([0]);
250     auto s4 = float32([2, 2], [0, 1, 0, 5]).sum([1]);
251 
252     assert(s1.evaluate().as!float == [2.0f]);
253     assert(s2.evaluate().as!float == [6.0f]);
254     assert(s3.evaluate().as!float == [0.0f, 6.0f]);
255     assert(s4.evaluate().as!float == [1.0f, 5.0f]);
256 }
257 
258 /**
259     Performs an argmin over the specified dimension.
260 
261     Params:
262         input = The operation to perform argmin on.
263         axis = The diension the argmin should be performed over.
264     
265     Returns:
266         The new argmin operation.
267 */
268 Operation argmin(Operation input, size_t axis, string mod = __MODULE__, size_t line = __LINE__)
269 {
270     import std.variant : Variant;
271 
272     return createOperation("argmin", [input], ["axis": Variant(axis)], mod, line);
273 }
274 
275 unittest
276 {
277     import dopt.core : evaluate;
278 
279     auto a = float32([5], [4.0f, 2.0f, 6.0f, 1.0f, 2.0f]).argmin(0);
280 
281     auto b = float32([2, 3], [
282         5.0f, 1.0f, 3.0f,
283         6.0f, 7.0f, 2.0f
284     ]).argmin(1);
285 
286     import std.stdio;
287     assert(a.evaluate().as!int == [3]);
288     assert(b.evaluate().as!int == [1, 2]);
289 }
290 
291 /**
292     Computes a max reduction along the specified axes.
293 
294     Params:
295         op = The input to the reduction.
296         axes = The axes the reduction should be performed along.
297 
298     Returns:
299         The resulting operation.
300 */
301 Operation maxElement(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__)
302 {
303     import std.variant : Variant;
304 
305     if(op.rank == 0)
306     {
307         return op.reshape(op.shape);
308     }
309 
310     if(axes.length == 0)
311     {
312         axes = iota(0, op.rank).array();
313     }
314     
315     return createOperation("maxElement", [op], ["axes": Variant(axes)], mod, line);
316 }