1 module dopt.core.cpu.math; 2 3 import dopt.core; 4 5 import std.algorithm; 6 import std.conv; 7 import std.functional; 8 import std.math; 9 import std.range; 10 11 package 12 { 13 void initialize() 14 { 15 mixin(generateRegistrations()); 16 17 registerCPUKernel("matmul", new CPUKernelDelegate(toDelegate(&matmulKernel))); 18 registerCPUKernel("sum", new CPUKernelDelegate(toDelegate(&sumKernel))); 19 registerCPUKernel("maxElement", new CPUKernelDelegate(toDelegate(&maxElementKernel))); 20 registerCPUKernel("argmin", new CPUKernelDelegate(toDelegate(&argminKernel))); 21 } 22 23 mixin(generateKernels()); 24 } 25 26 private 27 { 28 T expCast(T)(T t) 29 { 30 static if(is(T : int)) 31 { 32 return cast(int)exp(cast(float)t); 33 } 34 else 35 { 36 return exp(t); 37 } 38 } 39 40 T sqrtCast(T)(T t) 41 { 42 static if(is(T : int)) 43 { 44 return cast(int)sqrt(cast(float)t); 45 } 46 else 47 { 48 return sqrt(t); 49 } 50 } 51 52 T sgn(T)(T t) 53 { 54 return cast(T)((t > 0) - (t < 0)); 55 } 56 57 void matmulKernel(Operation op, const(Buffer)[] inputs, Buffer output) 58 { 59 if(op.outputType.elementType == DataType.float32) 60 { 61 auto ashape = op.deps[0].outputType.shape; 62 auto bshape = op.deps[1].outputType.shape; 63 64 import cblas; 65 66 gemm(Order.RowMajor, Transpose.NoTrans, Transpose.NoTrans, 67 cast(int)ashape[0], cast(int)bshape[1], cast(int)ashape[1], 1.0, cast(float *)inputs[0].as!float.ptr, 68 cast(int)ashape[1], cast(float *)inputs[1].as!float.ptr, cast(int)bshape[1], 0, 69 cast(float *)output.as!float.ptr, cast(int)bshape[1]); 70 } 71 else 72 { 73 throw new Exception("Not implemented."); 74 } 75 } 76 77 void sumKernel(Operation op, const(Buffer)[] inputs, Buffer output) 78 { 79 void run(T)() 80 { 81 import std.algorithm : fold, sort; 82 83 void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride) 84 { 85 import std.array : array; 86 import std.range : iota; 87 import std.parallelism : parallel; 88 89 //for(size_t o = 0; o < outbuf.length / lowstride; o++) 90 foreach(o; iota(0, outbuf.length / lowstride).array().parallel) 91 { 92 outbuf[o * lowstride .. (o + 1) * lowstride] = 0; 93 94 for(size_t i = 0; i < highstride / lowstride; i++) 95 { 96 outbuf[o * lowstride .. (o + 1) * lowstride] += 97 inbuf[o * highstride + i * lowstride .. o * highstride + (i + 1) * lowstride]; 98 } 99 } 100 } 101 102 auto axes = op.attributes["axes"].get!(size_t[]); 103 auto shape = op.deps[0].shape.dup; 104 105 auto inbuf = inputs[0].as!T; 106 T[] outbuf; 107 108 foreach(axis; axes) 109 { 110 //auto axis = axes[0]; 111 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis]; 112 size_t lowstride; 113 114 if(axis == shape.length - 1) 115 { 116 lowstride = 1; 117 } 118 else 119 { 120 lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1)); 121 } 122 123 size_t highstride = lowstride * shape[axis]; 124 125 outbuf = new T[newvol]; 126 process(inbuf, outbuf, highstride, lowstride); 127 inbuf = outbuf; 128 129 shape[axis] = 1; 130 } 131 132 output.as!T[] = outbuf[]; 133 } 134 135 switch(op.outputType.elementType) 136 { 137 case DataType.float32: 138 run!float(); 139 break; 140 141 case DataType.int32: 142 run!int(); 143 break; 144 145 default: 146 throw new Exception("Not implemented."); 147 } 148 } 149 150 void maxElementKernel(Operation op, const(Buffer)[] inputs, Buffer output) 151 { 152 void run(T)() 153 { 154 import std.algorithm : fold, max, sort; 155 156 void process(const(T)[] inbuf, T[] outbuf, size_t highstride, size_t lowstride) 157 { 158 for(size_t o = 0; o < outbuf.length / lowstride; o++) 159 { 160 outbuf[o * lowstride .. (o + 1) * lowstride] = -T.max; 161 162 for(size_t i = 0; i < highstride / lowstride; i++) 163 { 164 for(size_t j = 0; j < lowstride; j++) 165 { 166 outbuf[o * lowstride + j] = max(outbuf[o * lowstride + j], 167 inbuf[o * highstride + i * lowstride + j]); 168 } 169 } 170 } 171 } 172 173 auto axes = op.attributes["axes"].get!(size_t[]); 174 auto shape = op.deps[0].shape.dup; 175 176 auto inbuf = inputs[0].as!T; 177 T[] outbuf; 178 179 foreach(axis; axes) 180 { 181 //auto axis = axes[0]; 182 auto newvol = shape.fold!((a, b) => a * b)(size_t(1)) / shape[axis]; 183 size_t lowstride; 184 185 if(axis == shape.length - 1) 186 { 187 lowstride = 1; 188 } 189 else 190 { 191 lowstride = shape[axis + 1 .. $].fold!((a, b) => a * b)(size_t(1)); 192 } 193 194 size_t highstride = lowstride * shape[axis]; 195 196 outbuf = new T[newvol]; 197 process(inbuf, outbuf, highstride, lowstride); 198 inbuf = outbuf; 199 200 shape[axis] = 1; 201 } 202 203 output.as!T[] = outbuf[]; 204 } 205 206 switch(op.outputType.elementType) 207 { 208 case DataType.float32: 209 run!float(); 210 break; 211 212 case DataType.int32: 213 run!int(); 214 break; 215 216 default: 217 throw new Exception("Not implemented."); 218 } 219 } 220 221 void argminKernel(Operation op, const(Buffer)[] inputs, Buffer output) 222 { 223 void run(T)() 224 { 225 auto inbuf = inputs[0].as!T; 226 auto outbuf = output.as!int; 227 228 size_t axis = op.attributes["axis"].get!size_t; 229 size_t outer = 1; 230 size_t inner; 231 size_t vol = 1; 232 233 for(size_t i = 0; i < op.deps[0].rank; i++) 234 { 235 if(i < axis) 236 { 237 outer *= op.deps[0].shape[i]; 238 } 239 else if(i > axis) 240 { 241 vol *= op.deps[0].shape[i]; 242 } 243 else 244 { 245 inner = op.deps[0].shape[i]; 246 } 247 } 248 249 auto vals = new T[vol]; 250 251 for(size_t o = 0; o < outer; o++) 252 { 253 vals[] = T.max; 254 255 for(size_t i = 0; i < inner; i++) 256 { 257 for(size_t j = 0; j < vol; j++) 258 { 259 if(inbuf[o * vol * inner + i * vol + j] < vals[j]) 260 { 261 vals[j] = inbuf[o * vol * inner + i * vol + j]; 262 outbuf[o * vol + j] = cast(int)i; 263 } 264 } 265 } 266 } 267 } 268 269 switch(op.deps[0].outputType.elementType) 270 { 271 case DataType.float32: 272 run!float(); 273 break; 274 275 case DataType.int32: 276 run!int(); 277 break; 278 279 default: 280 throw new Exception("Not implemented."); 281 } 282 } 283 284 immutable string[] arith = ["add", "sub", "mul", "div"]; 285 immutable string[] comp = ["lt", "lte", "gt", "gte", "eq", "neq"]; 286 immutable string[] binfunc = ["max", "min", "pow"]; 287 immutable string[] unfunc = ["neg", "abs", "sgn", "exp", "log", "sqrt"]; 288 289 string generateRegistrations() 290 { 291 return chain(arith, comp, binfunc, unfunc) 292 .map!(x => "registerCPUKernel(\"" ~ x ~ "\", new CPUKernelDelegate(toDelegate(&" ~ x ~ "Kernel)));") 293 .joiner("\n") 294 .to!string; 295 } 296 297 string generateKernels() 298 { 299 string[string] opsymbol = ["add": "+", "sub": "-", "mul": "*", "div": "/", "lt": "<", "lte": "<=", 300 "gt": ">", "gte": ">=", "eq": "==", "neq": "!=", "neg": "-", 301 "exp": "expCast", "sqrt": "sqrtCast"]; 302 303 string[string] types = ["float": "float32", "int": "int32"]; 304 305 string[] kernelStrings; 306 307 //This is used for generating a kernel for a specific operation and type combination 308 string generateSingleKernel(string op, string dtype, string expr) 309 { 310 return 311 "void " ~ op ~ "Kernel_" ~ dtype ~ "(Operation op, const(Buffer)[] inputs, Buffer output) 312 { 313 auto ins = inputs.map!(x => x.as!" ~ dtype ~ ").array(); 314 auto outs = output.as!" ~ dtype ~ "; 315 316 for(size_t i = 0; i < outs.length; i++) 317 { 318 outs[i] = cast(" ~ dtype ~ ")(" ~ expr ~ "); 319 } 320 } 321 "; 322 } 323 324 //Dispatches the arguments to the correct kernel, as determined by the output type of the operation 325 string generateTypedKernel(string op, string[string] types) 326 { 327 string ret = 328 "void " ~ op ~ "Kernel(Operation op, const(Buffer)[] inputs, Buffer output) 329 { 330 switch(op.outputType.elementType) 331 { 332 "; 333 334 foreach(dtype, vtype; types) 335 { 336 ret ~= "case DataType." ~ vtype ~ ": " ~ op ~ "Kernel_" ~ dtype ~ "(op, inputs, output); break;\n"; 337 } 338 339 ret ~= "default: throw new Exception(\"Unknown data type\"); 340 } 341 } 342 "; 343 344 return ret; 345 } 346 347 //Iterate over each type of (binary) operation and generate the kernels 348 foreach(op; chain(arith, comp, binfunc)) 349 { 350 string sym = opsymbol.get(op, ""); 351 string expr; 352 353 if(sym == "") 354 { 355 expr = op ~ "(ins[0][i], ins[1][i])"; 356 } 357 else 358 { 359 expr = "ins[0][i] " ~ sym ~ "ins[1][i]"; 360 } 361 362 auto mux = generateTypedKernel(op, types); 363 auto kerns = types 364 .keys 365 .map!(x => generateSingleKernel(op, x, expr)) 366 .joiner() 367 .to!string; 368 369 kernelStrings ~= mux; 370 kernelStrings ~= kerns; 371 } 372 373 //Generates kernels for unary operations 374 foreach(op; unfunc) 375 { 376 string sym = opsymbol.get(op, ""); 377 string expr; 378 379 if(sym == "") 380 { 381 expr = op ~ "(ins[0][i])"; 382 } 383 else 384 { 385 expr = sym ~ "(ins[0][i])"; 386 } 387 388 auto mux = generateTypedKernel(op, types); 389 auto kerns = types 390 .keys 391 .map!(x => generateSingleKernel(op, x, expr)) 392 .joiner() 393 .to!string; 394 395 kernelStrings ~= mux; 396 kernelStrings ~= kerns; 397 } 398 399 //Return all the source code we've generated so it can be mixed in 400 return kernelStrings.joiner().to!string; 401 } 402 }