1 module dopt.cpu.math;
2 
3 import dopt.core;
4 import dopt.cpu;
5 
6 import std.algorithm;
7 import std.conv;
8 import std.functional;
9 import std.math;
10 import std.range;
11 
12 package
13 {
14     void initialize()
15     {
16         mixin(generateRegistrations());
17 
18         registerCPUKernel("matmul", new CPUKernelDelegate(toDelegate(&matmulKernel)));
19         registerCPUKernel("sum", new CPUKernelDelegate(toDelegate(&sumKernel)));
20         registerCPUKernel("maxElement", new CPUKernelDelegate(toDelegate(&maxElementKernel)));
21         registerCPUKernel("argmin", new CPUKernelDelegate(toDelegate(&argminKernel)));
22     }
23 
24     mixin(generateKernels());
25 }
26 
27 private
28 {
29     T expCast(T)(T t)
30     {
31         static if(is(T : int))
32         {
33             return cast(int)exp(cast(float)t);
34         }
35         else
36         {
37             return exp(t);
38         }
39     }
40 
41     T sqrtCast(T)(T t)
42     {
43         static if(is(T : int))
44         {
45             return cast(int)sqrt(cast(float)t);
46         }
47         else
48         {
49             return sqrt(t);
50         }
51     }
52 
53     T atan2(T)(T a, T b)
54     {
55         static if(is(T : int))
56         {
57             return cast(int)std.math.atan2(cast(float)a, cast(float)b);
58         }
59         else
60         {
61             return std.math.atan2(a, b);
62         }
63     }
64 
65     T sgn(T)(T t)
66     {
67         return cast(T)((t > 0) - (t < 0));
68     }
69 
70     void matmulKernel(Operation op, const(void[])[] inputs, void[] output)
71     {
72         if(op.outputType.elementType == DataType.float32)
73         {
74             auto ashape = op.deps[0].outputType.shape;
75             auto bshape = op.deps[1].outputType.shape;
76 
77             import cblas;
78 
79             gemm(Order.RowMajor, Transpose.NoTrans, Transpose.NoTrans,
80                 cast(int)ashape[0], cast(int)bshape[1], cast(int)ashape[1], 1.0, cast(float *)inputs[0].ptr,
81                 cast(int)ashape[1], cast(float *)inputs[1].ptr, cast(int)bshape[1], 0,
82                 cast(float *)output.ptr, cast(int)bshape[1]);
83         }
84         else
85         {
86             throw new Exception("Not implemented.");
87         }
88     }
89 
90     void sumKernel(Operation op, const(void[])[] inputs, void[] output)
91     {
92         void run(T)()
93         {
94             import std.algorithm : fold, sort;
95 
96             void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride)
97             {
98                 import std.array : array;
99                 import std.range : iota;
100                 import std.parallelism : parallel;
101 
102                 foreach(o; iota(0, outbuf.length / lowstride).array().parallel)
103                 {
104                     if(lowstride == 1)
105                     {
106                         outbuf[o] = inbuf[o * highstride .. (o + 1) * highstride]
107                                    .fold!((a, b) => a + b)(cast(T)0);
108                     }
109                     else
110                     {
111                         outbuf[o * lowstride .. (o + 1) * lowstride] = 0;
112 
113                         for(size_t i = 0; i < highstride / lowstride; i++)
114                         {
115                             outbuf[o * lowstride .. (o + 1) * lowstride] +=
116                                 inbuf[o * highstride + i * lowstride .. o * highstride + (i + 1) * lowstride];
117                         }
118                     }
119                 }
120             }
121 
122             auto axes = op.attributes["axes"].get!(size_t[]);
123             auto shape = op.deps[0].shape.dup;
124 
125             auto inbuf = cast(const(T)[])inputs[0];
126             T[] outbuf;
127 
128             foreach(axis; axes)
129             {
130                 //auto axis = axes[0];
131                 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis];
132                 size_t lowstride;
133                 
134                 if(axis == shape.length - 1)
135                 {
136                     lowstride = 1;
137                 }
138                 else
139                 {
140                     lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1));
141                 }
142 
143                 size_t highstride = lowstride * shape[axis];
144 
145                 outbuf = new T[newvol];
146                 process(inbuf, outbuf, highstride, lowstride);
147                 inbuf = outbuf;
148 
149                 shape[axis] = 1;
150             }
151 
152             output[] = outbuf[];
153         }
154 
155         switch(op.outputType.elementType)
156         {
157             case DataType.float32:
158                 run!float();
159                 break;
160 
161             case DataType.int32:
162                 run!int();
163                 break;
164 
165             default:
166                 throw new Exception("Not implemented.");
167         }
168     }
169 
170     void maxElementKernel(Operation op, const(void[])[] inputs, void[] output)
171     {
172         void run(T)()
173         {
174             import std.algorithm : fold, max, sort;
175             import std.parallelism : parallel;
176 
177             void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride)
178             {
179                 foreach(o; iota(0, outbuf.length / lowstride).array().parallel)
180                 {
181                     if(lowstride == 1)
182                     {
183                         outbuf[o] = inbuf[o * highstride .. (o + 1) * highstride].fold!((a, b) => max(a, b))(-T.max);
184                     }
185                     else
186                     {
187                         outbuf[o * lowstride .. (o + 1) * lowstride] = -T.max;
188 
189                         for(size_t i = 0; i < highstride / lowstride; i++)
190                         {
191                             for(size_t j = 0; j < lowstride; j++)
192                             {
193                                 outbuf[o * lowstride + j] = max(outbuf[o * lowstride + j],
194                                     inbuf[o * highstride + i * lowstride + j]);
195                             }
196                         }
197                     }
198                 }
199             }
200 
201             auto axes = op.attributes["axes"].get!(size_t[]);
202             auto shape = op.deps[0].shape.dup;
203 
204             auto inbuf = cast(const(T)[])inputs[0];
205             T[] outbuf;
206 
207             foreach(axis; axes)
208             {
209                 //auto axis = axes[0];
210                 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis];
211                 size_t lowstride;
212                 
213                 if(axis == shape.length - 1)
214                 {
215                     lowstride = 1;
216                 }
217                 else
218                 {
219                     lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1));
220                 }
221 
222                 size_t highstride = lowstride * shape[axis];
223 
224                 outbuf = new T[newvol];
225                 process(inbuf, outbuf, highstride, lowstride);
226                 inbuf = outbuf;
227 
228                 shape[axis] = 1;
229             }
230 
231             output[] = outbuf[];
232         }
233 
234         switch(op.outputType.elementType)
235         {
236             case DataType.float32:
237                 run!float();
238                 break;
239 
240             case DataType.int32:
241                 run!int();
242                 break;
243 
244             default:
245                 throw new Exception("Not implemented.");
246         }
247     }
248 
249     void argminKernel(Operation op, const(void[])[] inputs, void[] output)
250     {
251         void run(T)()
252         {
253             auto inbuf = cast(const(T)[])inputs[0];
254             auto outbuf = cast(int[])output;
255 
256             size_t axis = op.attributes["axis"].get!size_t;
257             size_t outer = 1;
258             size_t inner;
259             size_t vol = 1;
260 
261             for(size_t i = 0; i < op.deps[0].rank; i++)
262             {
263                 if(i < axis)
264                 {
265                     outer *= op.deps[0].shape[i];
266                 }
267                 else if(i > axis)
268                 {
269                     vol *= op.deps[0].shape[i];
270                 }
271                 else
272                 {
273                     inner = op.deps[0].shape[i];
274                 }
275             }
276 
277             auto vals = new T[vol];
278 
279             for(size_t o = 0; o < outer; o++)
280             {
281                 vals[] = T.max;
282                 
283                 for(size_t i = 0; i < inner; i++)
284                 {
285                     for(size_t j = 0; j < vol; j++)
286                     {
287                         if(inbuf[o * vol * inner + i * vol + j] < vals[j])
288                         {
289                             vals[j] = inbuf[o * vol * inner + i * vol + j];
290                             outbuf[o * vol + j] = cast(int)i;
291                         }
292                     }
293                 }
294             }
295         }
296 
297         switch(op.deps[0].outputType.elementType)
298         {
299             case DataType.float32:
300                 run!float();
301                 break;
302 
303             case DataType.int32:
304                 run!int();
305                 break;
306 
307             default:
308                 throw new Exception("Not implemented.");
309         }
310     }
311 
312     import std.meta : AliasSeq;
313 
314     enum arith = AliasSeq!("add", "sub", "mul", "div");
315     enum comp = AliasSeq!("lt", "lte", "gt", "gte", "eq", "neq");
316     enum binfunc = AliasSeq!("max", "min", "pow", "atan2");
317     enum unfunc = AliasSeq!("neg", "abs", "sgn", "exp", "log", "sqrt", "sin", "cos", "tan", "asin", "acos",
318         "atan", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh");
319     
320     enum opsymbol = ["add": "+", "sub": "-", "mul": "*", "div": "/", "lt": "<", "lte": "<=",
321         "gt": ">", "gte": ">=", "eq": "==", "neq": "!=", "neg": "-",
322         "exp": "expCast", "sqrt": "sqrtCast"];
323 
324     enum types = ["float": "float32", "int": "int32"];
325     
326     string generateRegistrations()
327     {
328         import std.array : appender;
329 
330         auto strBuilder = appender!string();
331         
332         static foreach(x; AliasSeq!(arith, comp, binfunc, unfunc))
333         {
334             strBuilder.put("registerCPUKernel(\"" ~ x ~ "\", new CPUKernelDelegate(toDelegate(&" ~ x ~ "Kernel)));");
335         }
336 
337         return strBuilder.data;
338     }
339 
340     string generateKernels()
341     {
342         auto kernelStrings = appender!(string);
343 
344         //This is used for generating a kernel for a specific operation and type combination
345         string generateSingleKernel(string op, string dtype, string expr)
346         {
347             return
348                 "void " ~ op ~ "Kernel_" ~ dtype ~ "(Operation op, const(void[])[] inputs, void[] output)
349                 {
350                     auto ins = inputs.map!(x => cast(const(" ~ dtype ~ ")[])x).array();
351                     auto outs = cast(" ~ dtype ~ "[])output;
352 
353                     for(size_t i = 0; i < outs.length; i++)
354                     {
355                         outs[i] = cast(" ~ dtype ~ ")(" ~ expr ~ ");
356                     }
357                 }
358                 ";
359         }
360 
361         //Dispatches the arguments to the correct kernel, as determined by the output type of the operation
362         string generateTypedKernel(string op, string[string] types)
363         {
364             string ret =
365                 "void " ~ op ~ "Kernel(Operation op, const(void[])[] inputs, void[] output)
366                 {
367                     switch(op.outputType.elementType)
368                     {
369                         ";
370 
371             foreach(dtype, vtype; types)
372             {
373                 ret ~= "case DataType." ~ vtype ~ ": " ~ op ~ "Kernel_" ~ dtype ~ "(op, inputs, output); break;\n";
374             }
375 
376             ret ~= "default: throw new Exception(\"Unknown data type\");
377                 }
378             }
379             ";
380 
381             return ret;
382         }
383 
384         string sym;
385         string expr;
386 
387         //Iterate over each type of (binary) operation and generate the kernels
388         static foreach(op; AliasSeq!(arith, comp, binfunc))
389         {
390             sym = opsymbol.get(op, "");
391 
392             if(sym == "")
393             {
394                 expr = op ~ "(ins[0][i], ins[1][i])";
395             }
396             else
397             {
398                 expr = "ins[0][i] " ~ sym ~ "ins[1][i]";
399             }
400 
401             kernelStrings.put(generateTypedKernel(op, types));
402             kernelStrings.put(
403                          types
404                         .keys
405                         .map!(x => generateSingleKernel(op, x, expr))
406                         .joiner()
407                         .to!string
408             );
409         }
410 
411         //Generates kernels for unary operations
412         static foreach(op; unfunc)
413         {
414             sym = opsymbol.get(op, "");
415 
416             if(sym == "")
417             {
418                 expr = op ~ "(cast(float)ins[0][i])";
419             }
420             else
421             {
422                 expr = sym ~ "(ins[0][i])";
423             }
424 
425             kernelStrings.put(generateTypedKernel(op, types));
426             kernelStrings.put(
427                          types
428                         .keys
429                         .map!(x => generateSingleKernel(op, x, expr))
430                         .joiner()
431                         .to!string
432             );
433         }
434 
435         //Return all the source code we've generated so it can be mixed in
436         return kernelStrings.data;
437     }
438 }