1 module dopt.cpu.math; 2 3 import dopt.core; 4 import dopt.cpu; 5 6 import std.algorithm; 7 import std.conv; 8 import std.functional; 9 import std.math; 10 import std.range; 11 12 package 13 { 14 void initialize() 15 { 16 mixin(generateRegistrations()); 17 18 registerCPUKernel("matmul", new CPUKernelDelegate(toDelegate(&matmulKernel))); 19 registerCPUKernel("sum", new CPUKernelDelegate(toDelegate(&sumKernel))); 20 registerCPUKernel("maxElement", new CPUKernelDelegate(toDelegate(&maxElementKernel))); 21 registerCPUKernel("argmin", new CPUKernelDelegate(toDelegate(&argminKernel))); 22 } 23 24 mixin(generateKernels()); 25 } 26 27 private 28 { 29 T expCast(T)(T t) 30 { 31 static if(is(T : int)) 32 { 33 return cast(int)exp(cast(float)t); 34 } 35 else 36 { 37 return exp(t); 38 } 39 } 40 41 T sqrtCast(T)(T t) 42 { 43 static if(is(T : int)) 44 { 45 return cast(int)sqrt(cast(float)t); 46 } 47 else 48 { 49 return sqrt(t); 50 } 51 } 52 53 T atan2(T)(T a, T b) 54 { 55 static if(is(T : int)) 56 { 57 return cast(int)std.math.atan2(cast(float)a, cast(float)b); 58 } 59 else 60 { 61 return std.math.atan2(a, b); 62 } 63 } 64 65 T sgn(T)(T t) 66 { 67 return cast(T)((t > 0) - (t < 0)); 68 } 69 70 void matmulKernel(Operation op, const(void[])[] inputs, void[] output) 71 { 72 if(op.outputType.elementType == DataType.float32) 73 { 74 auto ashape = op.deps[0].outputType.shape; 75 auto bshape = op.deps[1].outputType.shape; 76 77 import cblas; 78 79 gemm(Order.RowMajor, Transpose.NoTrans, Transpose.NoTrans, 80 cast(int)ashape[0], cast(int)bshape[1], cast(int)ashape[1], 1.0, cast(float *)inputs[0].ptr, 81 cast(int)ashape[1], cast(float *)inputs[1].ptr, cast(int)bshape[1], 0, 82 cast(float *)output.ptr, cast(int)bshape[1]); 83 } 84 else 85 { 86 throw new Exception("Not implemented."); 87 } 88 } 89 90 void sumKernel(Operation op, const(void[])[] inputs, void[] output) 91 { 92 void run(T)() 93 { 94 import std.algorithm : fold, sort; 95 96 void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride) 97 { 98 import std.array : array; 99 import std.range : iota; 100 import std.parallelism : parallel; 101 102 foreach(o; iota(0, outbuf.length / lowstride).array().parallel) 103 { 104 if(lowstride == 1) 105 { 106 outbuf[o] = inbuf[o * highstride .. (o + 1) * highstride] 107 .fold!((a, b) => a + b)(cast(T)0); 108 } 109 else 110 { 111 outbuf[o * lowstride .. (o + 1) * lowstride] = 0; 112 113 for(size_t i = 0; i < highstride / lowstride; i++) 114 { 115 outbuf[o * lowstride .. (o + 1) * lowstride] += 116 inbuf[o * highstride + i * lowstride .. o * highstride + (i + 1) * lowstride]; 117 } 118 } 119 } 120 } 121 122 auto axes = op.attributes["axes"].get!(size_t[]); 123 auto shape = op.deps[0].shape.dup; 124 125 auto inbuf = cast(const(T)[])inputs[0]; 126 T[] outbuf; 127 128 foreach(axis; axes) 129 { 130 //auto axis = axes[0]; 131 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis]; 132 size_t lowstride; 133 134 if(axis == shape.length - 1) 135 { 136 lowstride = 1; 137 } 138 else 139 { 140 lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1)); 141 } 142 143 size_t highstride = lowstride * shape[axis]; 144 145 outbuf = new T[newvol]; 146 process(inbuf, outbuf, highstride, lowstride); 147 inbuf = outbuf; 148 149 shape[axis] = 1; 150 } 151 152 output[] = outbuf[]; 153 } 154 155 switch(op.outputType.elementType) 156 { 157 case DataType.float32: 158 run!float(); 159 break; 160 161 case DataType.int32: 162 run!int(); 163 break; 164 165 default: 166 throw new Exception("Not implemented."); 167 } 168 } 169 170 void maxElementKernel(Operation op, const(void[])[] inputs, void[] output) 171 { 172 void run(T)() 173 { 174 import std.algorithm : fold, max, sort; 175 import std.parallelism : parallel; 176 177 void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride) 178 { 179 foreach(o; iota(0, outbuf.length / lowstride).array().parallel) 180 { 181 if(lowstride == 1) 182 { 183 outbuf[o] = inbuf[o * highstride .. (o + 1) * highstride].fold!((a, b) => max(a, b))(-T.max); 184 } 185 else 186 { 187 outbuf[o * lowstride .. (o + 1) * lowstride] = -T.max; 188 189 for(size_t i = 0; i < highstride / lowstride; i++) 190 { 191 for(size_t j = 0; j < lowstride; j++) 192 { 193 outbuf[o * lowstride + j] = max(outbuf[o * lowstride + j], 194 inbuf[o * highstride + i * lowstride + j]); 195 } 196 } 197 } 198 } 199 } 200 201 auto axes = op.attributes["axes"].get!(size_t[]); 202 auto shape = op.deps[0].shape.dup; 203 204 auto inbuf = cast(const(T)[])inputs[0]; 205 T[] outbuf; 206 207 foreach(axis; axes) 208 { 209 //auto axis = axes[0]; 210 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis]; 211 size_t lowstride; 212 213 if(axis == shape.length - 1) 214 { 215 lowstride = 1; 216 } 217 else 218 { 219 lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1)); 220 } 221 222 size_t highstride = lowstride * shape[axis]; 223 224 outbuf = new T[newvol]; 225 process(inbuf, outbuf, highstride, lowstride); 226 inbuf = outbuf; 227 228 shape[axis] = 1; 229 } 230 231 output[] = outbuf[]; 232 } 233 234 switch(op.outputType.elementType) 235 { 236 case DataType.float32: 237 run!float(); 238 break; 239 240 case DataType.int32: 241 run!int(); 242 break; 243 244 default: 245 throw new Exception("Not implemented."); 246 } 247 } 248 249 void argminKernel(Operation op, const(void[])[] inputs, void[] output) 250 { 251 void run(T)() 252 { 253 auto inbuf = cast(const(T)[])inputs[0]; 254 auto outbuf = cast(int[])output; 255 256 size_t axis = op.attributes["axis"].get!size_t; 257 size_t outer = 1; 258 size_t inner; 259 size_t vol = 1; 260 261 for(size_t i = 0; i < op.deps[0].rank; i++) 262 { 263 if(i < axis) 264 { 265 outer *= op.deps[0].shape[i]; 266 } 267 else if(i > axis) 268 { 269 vol *= op.deps[0].shape[i]; 270 } 271 else 272 { 273 inner = op.deps[0].shape[i]; 274 } 275 } 276 277 auto vals = new T[vol]; 278 279 for(size_t o = 0; o < outer; o++) 280 { 281 vals[] = T.max; 282 283 for(size_t i = 0; i < inner; i++) 284 { 285 for(size_t j = 0; j < vol; j++) 286 { 287 if(inbuf[o * vol * inner + i * vol + j] < vals[j]) 288 { 289 vals[j] = inbuf[o * vol * inner + i * vol + j]; 290 outbuf[o * vol + j] = cast(int)i; 291 } 292 } 293 } 294 } 295 } 296 297 switch(op.deps[0].outputType.elementType) 298 { 299 case DataType.float32: 300 run!float(); 301 break; 302 303 case DataType.int32: 304 run!int(); 305 break; 306 307 default: 308 throw new Exception("Not implemented."); 309 } 310 } 311 312 import std.meta : AliasSeq; 313 314 enum arith = AliasSeq!("add", "sub", "mul", "div"); 315 enum comp = AliasSeq!("lt", "lte", "gt", "gte", "eq", "neq"); 316 enum binfunc = AliasSeq!("max", "min", "pow", "atan2"); 317 enum unfunc = AliasSeq!("neg", "abs", "sgn", "exp", "log", "sqrt", "sin", "cos", "tan", "asin", "acos", 318 "atan", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh"); 319 320 enum opsymbol = ["add": "+", "sub": "-", "mul": "*", "div": "/", "lt": "<", "lte": "<=", 321 "gt": ">", "gte": ">=", "eq": "==", "neq": "!=", "neg": "-", 322 "exp": "expCast", "sqrt": "sqrtCast"]; 323 324 enum types = ["float": "float32", "int": "int32"]; 325 326 string generateRegistrations() 327 { 328 import std.array : appender; 329 330 auto strBuilder = appender!string(); 331 332 static foreach(x; AliasSeq!(arith, comp, binfunc, unfunc)) 333 { 334 strBuilder.put("registerCPUKernel(\"" ~ x ~ "\", new CPUKernelDelegate(toDelegate(&" ~ x ~ "Kernel)));"); 335 } 336 337 return strBuilder.data; 338 } 339 340 string generateKernels() 341 { 342 auto kernelStrings = appender!(string); 343 344 //This is used for generating a kernel for a specific operation and type combination 345 string generateSingleKernel(string op, string dtype, string expr) 346 { 347 return 348 "void " ~ op ~ "Kernel_" ~ dtype ~ "(Operation op, const(void[])[] inputs, void[] output) 349 { 350 auto ins = inputs.map!(x => cast(const(" ~ dtype ~ ")[])x).array(); 351 auto outs = cast(" ~ dtype ~ "[])output; 352 353 for(size_t i = 0; i < outs.length; i++) 354 { 355 outs[i] = cast(" ~ dtype ~ ")(" ~ expr ~ "); 356 } 357 } 358 "; 359 } 360 361 //Dispatches the arguments to the correct kernel, as determined by the output type of the operation 362 string generateTypedKernel(string op, string[string] types) 363 { 364 string ret = 365 "void " ~ op ~ "Kernel(Operation op, const(void[])[] inputs, void[] output) 366 { 367 switch(op.outputType.elementType) 368 { 369 "; 370 371 foreach(dtype, vtype; types) 372 { 373 ret ~= "case DataType." ~ vtype ~ ": " ~ op ~ "Kernel_" ~ dtype ~ "(op, inputs, output); break;\n"; 374 } 375 376 ret ~= "default: throw new Exception(\"Unknown data type\"); 377 } 378 } 379 "; 380 381 return ret; 382 } 383 384 string sym; 385 string expr; 386 387 //Iterate over each type of (binary) operation and generate the kernels 388 static foreach(op; AliasSeq!(arith, comp, binfunc)) 389 { 390 sym = opsymbol.get(op, ""); 391 392 if(sym == "") 393 { 394 expr = op ~ "(ins[0][i], ins[1][i])"; 395 } 396 else 397 { 398 expr = "ins[0][i] " ~ sym ~ "ins[1][i]"; 399 } 400 401 kernelStrings.put(generateTypedKernel(op, types)); 402 kernelStrings.put( 403 types 404 .keys 405 .map!(x => generateSingleKernel(op, x, expr)) 406 .joiner() 407 .to!string 408 ); 409 } 410 411 //Generates kernels for unary operations 412 static foreach(op; unfunc) 413 { 414 sym = opsymbol.get(op, ""); 415 416 if(sym == "") 417 { 418 expr = op ~ "(cast(float)ins[0][i])"; 419 } 420 else 421 { 422 expr = sym ~ "(ins[0][i])"; 423 } 424 425 kernelStrings.put(generateTypedKernel(op, types)); 426 kernelStrings.put( 427 types 428 .keys 429 .map!(x => generateSingleKernel(op, x, expr)) 430 .joiner() 431 .to!string 432 ); 433 } 434 435 //Return all the source code we've generated so it can be mixed in 436 return kernelStrings.data; 437 } 438 }