1 /** 2 Contains common maths operations. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.core.ops.math; 7 8 import std.algorithm; 9 import std.conv; 10 import std.functional; 11 import std.range; 12 13 import dopt.core.ops; 14 import dopt.core.types; 15 16 package 17 { 18 void initialize() 19 { 20 void registerPointwiseBinary(string opName) 21 { 22 bool verifier(Operation op) 23 { 24 return op.deps.length == 2 && op.deps[0].outputType == op.deps[1].outputType; 25 } 26 27 TensorType judge(Operation op) 28 { 29 return TensorType(op.deps[0].outputType); 30 } 31 32 registerOperation(opName, OpDef(&verifier, &judge)); 33 } 34 35 void registerPointwiseUnary(string opName) 36 { 37 bool verifier(Operation op) 38 { 39 return true; 40 } 41 42 TensorType judge(Operation op) 43 { 44 return TensorType(op.deps[0].outputType); 45 } 46 47 registerOperation(opName, OpDef(&verifier, &judge)); 48 } 49 50 static foreach(opName; AliasSeq!(arith, comp, binfunc)) 51 { 52 mixin("registerPointwiseBinary(\"" ~ opName ~ "\");"); 53 } 54 55 static foreach(opName; unfunc) 56 { 57 mixin("registerPointwiseUnary(\"" ~ opName ~ "\");"); 58 } 59 60 registerOperation("matmul", OpDef(toDelegate(&verifyMatmul), toDelegate(&judgeMatmul))); 61 registerOperation("sum", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum))); 62 registerOperation("argmin", OpDef(toDelegate(&verifyArgmin), toDelegate(&judgeArgmin))); 63 64 //maxElement and sum are both reduction operations 65 registerOperation("maxElement", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum))); 66 } 67 } 68 69 private 70 { 71 import std.meta : AliasSeq; 72 enum arith = AliasSeq!("add", "sub", "mul", "div"); 73 enum comp = AliasSeq!("lt", "lte", "gt", "gte", "eq", "neq"); 74 enum binfunc = AliasSeq!("max", "min", "pow"); 75 enum unfunc = AliasSeq!("neg", "abs", "sgn", "exp", "log", "sqrt", "sin", "cos", "tan", "asin", "acos", 76 "atan", "atan2", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh"); 77 78 string createAllCtors() 79 { 80 string createOpCtor(string opName, size_t numDeps) 81 { 82 auto params = iota(0, numDeps) 83 .map!(x => "Operation p" ~ x.to!string) 84 .joiner(", ") 85 .to!string(); 86 87 auto args = iota(0, numDeps) 88 .map!(x => "p" ~ x.to!string) 89 .joiner(", ") 90 .to!string; 91 92 return " 93 Operation " ~ opName ~ "(" ~ params ~ ", string mod = __MODULE__, size_t line = __LINE__) 94 { 95 return createOperation(\"" ~ opName ~ "\", [" ~ args ~ "], null, mod, line); 96 } 97 "; 98 } 99 100 import std.array : appender; 101 102 auto strBuilder = appender!string(); 103 104 /*string binctors = chain(arith, comp, binfunc) 105 .map!(x => createOpCtor(x, 2)) 106 .joiner("\n") 107 .to!string; 108 109 auto unctors = unfunc 110 .map!(x => createOpCtor(x, 1)) 111 .joiner("\n") 112 .to!string; 113 114 return binctors ~ unctors;*/ 115 116 static foreach(f; AliasSeq!(arith, comp, binfunc)) 117 { 118 strBuilder.put(createOpCtor(f, 2)); 119 } 120 121 static foreach(f; unfunc) 122 { 123 strBuilder.put(createOpCtor(f, 1)); 124 } 125 126 return strBuilder.data; 127 } 128 129 bool verifyMatmul(Operation op) 130 { 131 return op.deps.length == 2 132 && op.deps[0].outputType.rank == 2 133 && op.deps[1].outputType.rank == 2 134 && op.deps[0].outputType.elementType == op.deps[1].outputType.elementType 135 && op.deps[0].outputType.shape[1] == op.deps[1].outputType.shape[0]; 136 } 137 138 TensorType judgeMatmul(Operation op) 139 { 140 return TensorType(op.deps[0].outputType.elementType, 141 [op.deps[0].outputType.shape[0], op.deps[1].outputType.shape[1]]); 142 } 143 144 bool verifySum(Operation op) 145 { 146 if(op.deps.length != 1) 147 { 148 return false; 149 } 150 151 if(("axes" in op.attributes) is null || op.attributes["axes"].peek!(size_t[]) is null) 152 { 153 return false; 154 } 155 156 auto axes = op.attributes["axes"].get!(size_t[]); 157 158 return axes.all!(x => x < op.deps[0].rank) && 159 axes.map!(x => size_t(x)).array().sort().uniq().count() == axes.length; 160 } 161 162 TensorType judgeSum(Operation op) 163 { 164 auto t = op.deps[0].outputType; 165 auto axes = op.attributes["axes"].get!(size_t[]); 166 167 auto newShape = t 168 .shape 169 .zip(iota(0, t.shape.length)) 170 .filter!(x => !axes.canFind(x[1])) 171 .map!(x => x[0]) 172 .array(); 173 174 return TensorType(t.elementType, newShape); 175 } 176 177 bool verifyArgmin(Operation op) 178 { 179 return op.deps.length == 1 180 && ("axis" in op.attributes) 181 && (op.attributes["axis"].peek!size_t !is null) 182 && (op.attributes["axis"].get!size_t < op.deps[0].rank); 183 } 184 185 TensorType judgeArgmin(Operation op) 186 { 187 auto shape = op.deps[0].shape.dup; 188 shape[op.attributes["axis"].get!size_t] = 1; 189 190 return TensorType(DataType.int32, shape); 191 } 192 } 193 194 mixin(createAllCtors()); 195 196 /** 197 Computes the matrix multiplication between two rank-2 tensors. 198 199 Params: 200 lhs = The tensor on the left-hand side of the operation. 201 rhs = The tensor on the right-hand side of the operation. 202 203 Returns: 204 The resulting operation. 205 */ 206 Operation matmul(Operation lhs, Operation rhs, string mod = __MODULE__, size_t line = __LINE__) 207 { 208 return createOperation("matmul", [lhs, rhs], null, mod, line); 209 } 210 211 /// 212 unittest 213 { 214 import dopt.core : evaluate; 215 216 auto a = float32([2, 1], [ 217 1.0f, 218 2.0f 219 ]); 220 221 auto b = float32([1, 2], [ 222 3.0f, 4.0f 223 ]); 224 225 auto c = matmul(a, b); 226 227 assert(c.evaluate().as!float == [ 228 3.0f, 4.0f, 229 6.0f, 8.0f 230 ]); 231 } 232 233 /** 234 Computes a sum reduction along the specified axes. 235 236 Params: 237 op = The input to the reduction. 238 axes = The axes the reduction should be performed along. 239 240 Returns: 241 The resulting operation. 242 */ 243 Operation sum(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__) 244 { 245 import std.variant : Variant; 246 247 if(op.rank == 0) 248 { 249 return op.reshape(op.shape); 250 } 251 252 if(axes.length == 0) 253 { 254 axes = iota(0, op.rank).array(); 255 } 256 257 return createOperation("sum", [op], ["axes": Variant(axes)], mod, line); 258 } 259 260 /// 261 unittest 262 { 263 import dopt.core : evaluate; 264 265 auto s1 = float32([2], [0.5, 1.5]).sum(); 266 auto s2 = float32([2, 2], [0, 1, 0, 5]).sum(); 267 auto s3 = float32([2, 2], [0, 1, 0, 5]).sum([0]); 268 auto s4 = float32([2, 2], [0, 1, 0, 5]).sum([1]); 269 270 assert(s1.evaluate().as!float == [2.0f]); 271 assert(s2.evaluate().as!float == [6.0f]); 272 assert(s3.evaluate().as!float == [0.0f, 6.0f]); 273 assert(s4.evaluate().as!float == [1.0f, 5.0f]); 274 } 275 276 /** 277 Performs an argmin over the specified dimension. 278 279 Params: 280 input = The operation to perform argmin on. 281 axis = The diension the argmin should be performed over. 282 283 Returns: 284 The new argmin operation. 285 */ 286 Operation argmin(Operation input, size_t axis, string mod = __MODULE__, size_t line = __LINE__) 287 { 288 import std.variant : Variant; 289 290 return createOperation("argmin", [input], ["axis": Variant(axis)], mod, line); 291 } 292 293 /// 294 unittest 295 { 296 import dopt.core : evaluate; 297 298 auto a = float32([5], [4.0f, 2.0f, 6.0f, 1.0f, 2.0f]).argmin(0); 299 300 auto b = float32([2, 3], [ 301 5.0f, 1.0f, 3.0f, 302 6.0f, 7.0f, 2.0f 303 ]).argmin(1); 304 305 assert(a.evaluate().as!int == [3]); 306 assert(b.evaluate().as!int == [1, 2]); 307 } 308 309 /** 310 Computes a max reduction along the specified axes. 311 312 Params: 313 op = The input to the reduction. 314 axes = The axes the reduction should be performed along. 315 316 Returns: 317 The resulting operation. 318 */ 319 Operation maxElement(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__) 320 { 321 import std.variant : Variant; 322 323 if(op.rank == 0) 324 { 325 return op.reshape(op.shape); 326 } 327 328 if(axes.length == 0) 329 { 330 axes = iota(0, op.rank).array(); 331 } 332 333 return createOperation("maxElement", [op], ["axes": Variant(axes)], mod, line); 334 } 335 336 /// 337 unittest 338 { 339 import dopt.core : evaluate; 340 341 auto a = float32([2, 2],[ 342 1.0f, 4.0f, 343 3.0f, 6.0f 344 ]); 345 346 assert(a.maxElement.evaluate().as!float == [6.0f]); //Max value in the entire tensor 347 assert(a.maxElement([0]).evaluate().as!float == [3.0f, 6.0f]); //Max of each column 348 assert(a.maxElement([1]).evaluate().as!float == [4.0f, 6.0f]); //Max of each row 349 }