1 module dopt.core.cpu.math;
2 
3 import dopt.core;
4 
5 import std.algorithm;
6 import std.conv;
7 import std.functional;
8 import std.math;
9 import std.range;
10 
11 package
12 {
13     void initialize()
14     {
15         mixin(generateRegistrations());
16 
17         registerCPUKernel("matmul", new CPUKernelDelegate(toDelegate(&matmulKernel)));
18         registerCPUKernel("sum", new CPUKernelDelegate(toDelegate(&sumKernel)));
19         registerCPUKernel("maxElement", new CPUKernelDelegate(toDelegate(&maxElementKernel)));
20         registerCPUKernel("argmin", new CPUKernelDelegate(toDelegate(&argminKernel)));
21     }
22 
23     mixin(generateKernels());
24 }
25 
26 private
27 {
28     T expCast(T)(T t)
29     {
30         static if(is(T : int))
31         {
32             return cast(int)exp(cast(float)t);
33         }
34         else
35         {
36             return exp(t);
37         }
38     }
39 
40     T sqrtCast(T)(T t)
41     {
42         static if(is(T : int))
43         {
44             return cast(int)sqrt(cast(float)t);
45         }
46         else
47         {
48             return sqrt(t);
49         }
50     }
51 
52     T sgn(T)(T t)
53     {
54         return cast(T)((t > 0) - (t < 0));
55     }
56 
57     void matmulKernel(Operation op, const(Buffer)[] inputs, Buffer output)
58     {
59         if(op.outputType.elementType == DataType.float32)
60         {
61             auto ashape = op.deps[0].outputType.shape;
62             auto bshape = op.deps[1].outputType.shape;
63 
64             import cblas;
65 
66             gemm(Order.RowMajor, Transpose.NoTrans, Transpose.NoTrans,
67                 cast(int)ashape[0], cast(int)bshape[1], cast(int)ashape[1], 1.0, cast(float *)inputs[0].as!float.ptr,
68                 cast(int)ashape[1], cast(float *)inputs[1].as!float.ptr, cast(int)bshape[1], 0,
69                 cast(float *)output.as!float.ptr, cast(int)bshape[1]);
70         }
71         else
72         {
73             throw new Exception("Not implemented.");
74         }
75     }
76 
77     void sumKernel(Operation op, const(Buffer)[] inputs, Buffer output)
78     {
79         void run(T)()
80         {
81             import std.algorithm : fold, sort;
82 
83             void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride)
84             {
85                 import std.array : array;
86                 import std.range : iota;
87                 import std.parallelism : parallel;
88 
89                 //for(size_t o = 0; o < outbuf.length / lowstride; o++)
90                 foreach(o; iota(0, outbuf.length / lowstride).array().parallel)
91                 {
92                     outbuf[o * lowstride .. (o + 1) * lowstride] = 0;
93 
94                     for(size_t i = 0; i < highstride / lowstride; i++)
95                     {
96                         outbuf[o * lowstride .. (o + 1) * lowstride] +=
97                             inbuf[o * highstride + i * lowstride .. o * highstride + (i + 1) * lowstride];
98                     }
99                 }
100             }
101 
102             auto axes = op.attributes["axes"].get!(size_t[]);
103             auto shape = op.deps[0].shape.dup;
104 
105             auto inbuf = inputs[0].as!T;
106             T[] outbuf;
107 
108             foreach(axis; axes)
109             {
110                 //auto axis = axes[0];
111                 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis];
112                 size_t lowstride;
113                 
114                 if(axis == shape.length - 1)
115                 {
116                     lowstride = 1;
117                 }
118                 else
119                 {
120                     lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1));
121                 }
122 
123                 size_t highstride = lowstride * shape[axis];
124 
125                 outbuf = new T[newvol];
126                 process(inbuf, outbuf, highstride, lowstride);
127                 inbuf = outbuf;
128 
129                 shape[axis] = 1;
130             }
131 
132             output.as!T[] = outbuf[];
133         }
134 
135         switch(op.outputType.elementType)
136         {
137             case DataType.float32:
138                 run!float();
139                 break;
140 
141             case DataType.int32:
142                 run!int();
143                 break;
144 
145             default:
146                 throw new Exception("Not implemented.");
147         }
148     }
149 
150     void maxElementKernel(Operation op, const(Buffer)[] inputs, Buffer output)
151     {
152         void run(T)()
153         {
154             import std.algorithm : fold, max, sort;
155 
156             void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride)
157             {
158                 for(size_t o = 0; o < outbuf.length / lowstride; o++)
159                 {
160                     outbuf[o * lowstride .. (o + 1) * lowstride] = -T.max;
161 
162                     for(size_t i = 0; i < highstride / lowstride; i++)
163                     {
164                         for(size_t j = 0; j < lowstride; j++)
165                         {
166                             outbuf[o * lowstride + j] = max(outbuf[o * lowstride + j],
167                                 inbuf[o * highstride + i * lowstride + j]);
168                         }
169                     }
170                 }
171             }
172 
173             auto axes = op.attributes["axes"].get!(size_t[]);
174             auto shape = op.deps[0].shape.dup;
175 
176             auto inbuf = inputs[0].as!T;
177             T[] outbuf;
178 
179             foreach(axis; axes)
180             {
181                 //auto axis = axes[0];
182                 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis];
183                 size_t lowstride;
184                 
185                 if(axis == shape.length - 1)
186                 {
187                     lowstride = 1;
188                 }
189                 else
190                 {
191                     lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1));
192                 }
193 
194                 size_t highstride = lowstride * shape[axis];
195 
196                 outbuf = new T[newvol];
197                 process(inbuf, outbuf, highstride, lowstride);
198                 inbuf = outbuf;
199 
200                 shape[axis] = 1;
201             }
202 
203             output.as!T[] = outbuf[];
204         }
205 
206         switch(op.outputType.elementType)
207         {
208             case DataType.float32:
209                 run!float();
210                 break;
211 
212             case DataType.int32:
213                 run!int();
214                 break;
215 
216             default:
217                 throw new Exception("Not implemented.");
218         }
219     }
220 
221     void argminKernel(Operation op, const(Buffer)[] inputs, Buffer output)
222     {
223         void run(T)()
224         {
225             auto inbuf = inputs[0].as!T;
226             auto outbuf = output.as!int;
227 
228             size_t axis = op.attributes["axis"].get!size_t;
229             size_t outer = 1;
230             size_t inner;
231             size_t vol = 1;
232 
233             for(size_t i = 0; i < op.deps[0].rank; i++)
234             {
235                 if(i < axis)
236                 {
237                     outer *= op.deps[0].shape[i];
238                 }
239                 else if(i > axis)
240                 {
241                     vol *= op.deps[0].shape[i];
242                 }
243                 else
244                 {
245                     inner = op.deps[0].shape[i];
246                 }
247             }
248 
249             auto vals = new T[vol];
250 
251             for(size_t o = 0; o < outer; o++)
252             {
253                 vals[] = T.max;
254                 
255                 for(size_t i = 0; i < inner; i++)
256                 {
257                     for(size_t j = 0; j < vol; j++)
258                     {
259                         if(inbuf[o * vol * inner + i * vol + j] < vals[j])
260                         {
261                             vals[j] = inbuf[o * vol * inner + i * vol + j];
262                             outbuf[o * vol + j] = cast(int)i;
263                         }
264                     }
265                 }
266             }
267         }
268 
269         switch(op.deps[0].outputType.elementType)
270         {
271             case DataType.float32:
272                 run!float();
273                 break;
274 
275             case DataType.int32:
276                 run!int();
277                 break;
278 
279             default:
280                 throw new Exception("Not implemented.");
281         }
282     }
283 
284     immutable string[] arith = ["add", "sub", "mul", "div"];
285     immutable string[] comp = ["lt", "lte", "gt", "gte", "eq", "neq"];
286     immutable string[] binfunc = ["max", "min", "pow"];
287     immutable string[] unfunc = ["neg", "abs", "sgn", "exp", "log", "sqrt"];
288     
289     string generateRegistrations()
290     {
291         return chain(arith, comp, binfunc, unfunc)
292               .map!(x => "registerCPUKernel(\"" ~ x ~ "\", new CPUKernelDelegate(toDelegate(&" ~ x ~ "Kernel)));")
293               .joiner("\n")
294               .to!string;
295     }
296 
297     string generateKernels()
298     {
299         string[string] opsymbol = ["add": "+", "sub": "-", "mul": "*", "div": "/", "lt": "<", "lte": "<=",
300                                          "gt": ">", "gte": ">=", "eq": "==", "neq": "!=", "neg": "-",
301                                          "exp": "expCast", "sqrt": "sqrtCast"];
302 
303         string[string] types = ["float": "float32", "int": "int32"];
304 
305         string[] kernelStrings;
306 
307         //This is used for generating a kernel for a specific operation and type combination
308         string generateSingleKernel(string op, string dtype, string expr)
309         {
310             return
311                 "void " ~ op ~ "Kernel_" ~ dtype ~ "(Operation op, const(Buffer)[] inputs, Buffer output)
312                 {
313                     auto ins = inputs.map!(x => x.as!" ~ dtype ~ ").array();
314                     auto outs = output.as!" ~ dtype ~ ";
315 
316                     for(size_t i = 0; i < outs.length; i++)
317                     {
318                         outs[i] = cast(" ~ dtype ~ ")(" ~ expr ~ ");
319                     }
320                 }
321                 ";
322         }
323 
324         //Dispatches the arguments to the correct kernel, as determined by the output type of the operation
325         string generateTypedKernel(string op, string[string] types)
326         {
327             string ret =
328                 "void " ~ op ~ "Kernel(Operation op, const(Buffer)[] inputs, Buffer output)
329                 {
330                     switch(op.outputType.elementType)
331                     {
332                         ";
333 
334             foreach(dtype, vtype; types)
335             {
336                 ret ~= "case DataType." ~ vtype ~ ": " ~ op ~ "Kernel_" ~ dtype ~ "(op, inputs, output); break;\n";
337             }
338 
339             ret ~= "default: throw new Exception(\"Unknown data type\");
340                 }
341             }
342             ";
343 
344             return ret;
345         }
346 
347         //Iterate over each type of (binary) operation and generate the kernels
348         foreach(op; chain(arith, comp, binfunc))
349         {
350             string sym = opsymbol.get(op, "");
351             string expr;
352 
353             if(sym == "")
354             {
355                 expr = op ~ "(ins[0][i], ins[1][i])";
356             }
357             else
358             {
359                 expr = "ins[0][i] " ~ sym ~ "ins[1][i]";
360             }
361 
362             auto mux = generateTypedKernel(op, types);
363             auto kerns = types
364                         .keys
365                         .map!(x => generateSingleKernel(op, x, expr))
366                         .joiner()
367                         .to!string;
368 
369             kernelStrings ~= mux;
370             kernelStrings ~= kerns;
371         }
372 
373         //Generates kernels for unary operations
374         foreach(op; unfunc)
375         {
376             string sym = opsymbol.get(op, "");
377             string expr;
378 
379             if(sym == "")
380             {
381                 expr = op ~ "(ins[0][i])";
382             }
383             else
384             {
385                 expr = sym ~ "(ins[0][i])";
386             }
387 
388             auto mux = generateTypedKernel(op, types);
389             auto kerns = types
390                         .keys
391                         .map!(x => generateSingleKernel(op, x, expr))
392                         .joiner()
393                         .to!string;
394 
395             kernelStrings ~= mux;
396             kernelStrings ~= kerns;
397         }
398 
399         //Return all the source code we've generated so it can be mixed in
400         return kernelStrings.joiner().to!string;
401     }
402 }