1 /** 2 Contains common maths operations. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.core.ops.math; 7 8 import std.algorithm; 9 import std.conv; 10 import std.functional; 11 import std.range; 12 13 import dopt.core.ops; 14 import dopt.core.types; 15 16 package 17 { 18 void initialize() 19 { 20 void registerPointwiseBinary(string opName) 21 { 22 bool verifier(Operation op) 23 { 24 return op.deps.length == 2 && op.deps[0].outputType == op.deps[1].outputType; 25 } 26 27 TensorType judge(Operation op) 28 { 29 return TensorType(op.deps[0].outputType); 30 } 31 32 registerOperation(opName, OpDef(&verifier, &judge)); 33 } 34 35 void registerPointwiseUnary(string opName) 36 { 37 bool verifier(Operation op) 38 { 39 return true; 40 } 41 42 TensorType judge(Operation op) 43 { 44 return TensorType(op.deps[0].outputType); 45 } 46 47 registerOperation(opName, OpDef(&verifier, &judge)); 48 } 49 50 foreach(opName; chain(arith, comp, binfunc)) 51 { 52 registerPointwiseBinary(opName); 53 } 54 55 foreach(opName; unfunc) 56 { 57 registerPointwiseUnary(opName); 58 } 59 60 registerOperation("matmul", OpDef(toDelegate(&verifyMatmul), toDelegate(&judgeMatmul))); 61 registerOperation("sum", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum))); 62 registerOperation("argmin", OpDef(toDelegate(&verifyArgmin), toDelegate(&judgeArgmin))); 63 64 //maxElement and sum are both reduction operations 65 registerOperation("maxElement", OpDef(toDelegate(&verifySum), toDelegate(&judgeSum))); 66 } 67 } 68 69 private 70 { 71 immutable string[] arith = ["add", "sub", "mul", "div"]; 72 immutable string[] comp = ["lt", "lte", "gt", "gte", "eq", "neq"]; 73 immutable string[] binfunc = ["max", "min", "pow"]; 74 immutable string[] unfunc = ["neg", "abs", "sgn", "exp", "log", "sqrt"]; 75 76 string createAllCtors() 77 { 78 string createOpCtor(string opName, size_t numDeps) 79 { 80 auto params = iota(0, numDeps) 81 .map!(x => "Operation p" ~ x.to!string) 82 .joiner(", ") 83 .to!string(); 84 85 auto args = iota(0, numDeps) 86 .map!(x => "p" ~ x.to!string) 87 .joiner(", ") 88 .to!string; 89 90 return " 91 Operation " ~ opName ~ "(" ~ params ~ ", string mod = __MODULE__, size_t line = __LINE__) 92 { 93 return createOperation(\"" ~ opName ~ "\", [" ~ args ~ "], null, mod, line); 94 } 95 "; 96 } 97 98 string binctors = chain(arith, comp, binfunc) 99 .map!(x => createOpCtor(x, 2)) 100 .joiner("\n") 101 .to!string; 102 103 auto unctors = unfunc 104 .map!(x => createOpCtor(x, 1)) 105 .joiner("\n") 106 .to!string; 107 108 return binctors ~ unctors; 109 } 110 111 bool verifyMatmul(Operation op) 112 { 113 return op.deps.length == 2 114 && op.deps[0].outputType.rank == 2 115 && op.deps[1].outputType.rank == 2 116 && op.deps[0].outputType.elementType == op.deps[1].outputType.elementType 117 && op.deps[0].outputType.shape[1] == op.deps[1].outputType.shape[0]; 118 } 119 120 TensorType judgeMatmul(Operation op) 121 { 122 return TensorType(op.deps[0].outputType.elementType, 123 [op.deps[0].outputType.shape[0], op.deps[1].outputType.shape[1]]); 124 } 125 126 bool verifySum(Operation op) 127 { 128 if(op.deps.length != 1) 129 { 130 return false; 131 } 132 133 if(("axes" in op.attributes) is null || op.attributes["axes"].peek!(size_t[]) is null) 134 { 135 return false; 136 } 137 138 auto axes = op.attributes["axes"].get!(size_t[]); 139 140 return axes.all!(x => x < op.deps[0].rank) && 141 axes.map!(x => size_t(x)).array().sort().uniq().count() == axes.length; 142 } 143 144 TensorType judgeSum(Operation op) 145 { 146 auto t = op.deps[0].outputType; 147 auto axes = op.attributes["axes"].get!(size_t[]); 148 149 auto newShape = t 150 .shape 151 .zip(iota(0, t.shape.length)) 152 .filter!(x => !axes.canFind(x[1])) 153 .map!(x => x[0]) 154 .array(); 155 156 return TensorType(t.elementType, newShape); 157 } 158 159 bool verifyArgmin(Operation op) 160 { 161 return op.deps.length == 1 162 && ("axis" in op.attributes) 163 && (op.attributes["axis"].peek!size_t !is null) 164 && (op.attributes["axis"].get!size_t < op.deps[0].rank); 165 } 166 167 TensorType judgeArgmin(Operation op) 168 { 169 auto shape = op.deps[0].shape.dup; 170 shape[op.attributes["axis"].get!size_t] = 1; 171 172 return TensorType(DataType.int32, shape); 173 } 174 } 175 176 mixin(createAllCtors()); 177 178 /** 179 Computes the matrix multiplication between two rank-2 tensors. 180 181 Params: 182 lhs = The tensor on the left-hand side of the operation. 183 rhs = The tensor on the right-hand side of the operation. 184 185 Returns: 186 The resulting operation. 187 */ 188 Operation matmul(Operation lhs, Operation rhs, string mod = __MODULE__, size_t line = __LINE__) 189 { 190 return createOperation("matmul", [lhs, rhs], null, mod, line); 191 } 192 193 /// 194 unittest 195 { 196 import dopt.core.cpu : evaluate; 197 198 auto a = float32([2, 1], [ 199 1.0f, 200 2.0f 201 ]); 202 203 auto b = float32([1, 2], [ 204 3.0f, 4.0f 205 ]); 206 207 auto c = matmul(a, b); 208 209 assert(c.evaluate().as!float == [ 210 3.0f, 4.0f, 211 6.0f, 8.0f 212 ]); 213 } 214 215 /** 216 Computes a sum reduction along the specified axes. 217 218 Params: 219 op = The input to the reduction. 220 axes = The axes the reduction should be performed along. 221 222 Returns: 223 The resulting operation. 224 */ 225 Operation sum(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__) 226 { 227 import std.variant : Variant; 228 229 if(op.rank == 0) 230 { 231 return op.reshape(op.shape); 232 } 233 234 if(axes.length == 0) 235 { 236 axes = iota(0, op.rank).array(); 237 } 238 239 return createOperation("sum", [op], ["axes": Variant(axes)], mod, line); 240 } 241 242 /// 243 unittest 244 { 245 import dopt.core.cpu : evaluate; 246 247 auto s1 = float32([2], [0.5, 1.5]).sum(); 248 auto s2 = float32([2, 2], [0, 1, 0, 5]).sum(); 249 auto s3 = float32([2, 2], [0, 1, 0, 5]).sum([0]); 250 auto s4 = float32([2, 2], [0, 1, 0, 5]).sum([1]); 251 252 assert(s1.evaluate().as!float == [2.0f]); 253 assert(s2.evaluate().as!float == [6.0f]); 254 assert(s3.evaluate().as!float == [0.0f, 6.0f]); 255 assert(s4.evaluate().as!float == [1.0f, 5.0f]); 256 } 257 258 /** 259 Performs an argmin over the specified dimension. 260 261 Params: 262 input = The operation to perform argmin on. 263 axis = The diension the argmin should be performed over. 264 265 Returns: 266 The new argmin operation. 267 */ 268 Operation argmin(Operation input, size_t axis, string mod = __MODULE__, size_t line = __LINE__) 269 { 270 import std.variant : Variant; 271 272 return createOperation("argmin", [input], ["axis": Variant(axis)], mod, line); 273 } 274 275 unittest 276 { 277 import dopt.core : evaluate; 278 279 auto a = float32([5], [4.0f, 2.0f, 6.0f, 1.0f, 2.0f]).argmin(0); 280 281 auto b = float32([2, 3], [ 282 5.0f, 1.0f, 3.0f, 283 6.0f, 7.0f, 2.0f 284 ]).argmin(1); 285 286 import std.stdio; 287 assert(a.evaluate().as!int == [3]); 288 assert(b.evaluate().as!int == [1, 2]); 289 } 290 291 /** 292 Computes a max reduction along the specified axes. 293 294 Params: 295 op = The input to the reduction. 296 axes = The axes the reduction should be performed along. 297 298 Returns: 299 The resulting operation. 300 */ 301 Operation maxElement(Operation op, size_t[] axes = [], string mod = __MODULE__, size_t line = __LINE__) 302 { 303 import std.variant : Variant; 304 305 if(op.rank == 0) 306 { 307 return op.reshape(op.shape); 308 } 309 310 if(axes.length == 0) 311 { 312 axes = iota(0, op.rank).array(); 313 } 314 315 return createOperation("maxElement", [op], ["axes": Variant(axes)], mod, line); 316 }