1 /** 2 This package facilitates the construction of various nodes in the operation graph. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.core.ops; 7 8 import std.array; 9 import std.exception; 10 import std.variant; 11 12 import dopt.core.types; 13 14 public 15 { 16 import dopt.core.ops.basic; 17 import dopt.core.ops.math; 18 import dopt.core.ops.nnet; 19 import dopt.core.ops.random; 20 } 21 22 void initialize() 23 { 24 dopt.core.ops.basic.initialize(); 25 dopt.core.ops.math.initialize(); 26 dopt.core.ops.nnet.initialize(); 27 dopt.core.ops.random.initialize(); 28 } 29 30 alias Verifier = bool delegate(Operation); 31 alias Judge = TensorType delegate(Operation); 32 33 /** 34 Contains methods to perform procedures specific to the type of an operation 35 */ 36 struct OpDef 37 { 38 /** 39 A verifier is used to ensure that an Operation object correctly constructed. 40 */ 41 Verifier verifier; 42 43 /** 44 A judge produces a TensorType object that specifies the type of the result of an operation of this type. 45 */ 46 Judge judge; 47 } 48 49 /** 50 A node in the expression graph 51 */ 52 class Operation 53 { 54 public 55 { 56 /** 57 Returns a string identifying the type of this operation. This is the same string used when registering the 58 operation with the registerOperation method. 59 */ 60 @property string opType() 61 { 62 return mOpType; 63 } 64 65 /** 66 Returns a TensorType object that specifies the type of tensor obtained by evaluating this operation. 67 */ 68 @property TensorType outputType() 69 { 70 return mOutputType; 71 } 72 73 /** 74 Returns a list of operands for this operation. 75 */ 76 @property Operation[] deps() 77 { 78 return mDeps; 79 } 80 81 /** 82 Returns an associative array that maps strings to operation specific attributes. 83 */ 84 @property Variant[string] attributes() 85 { 86 return mAttributes; 87 } 88 89 /** 90 Convenience method for pointwise operations. 91 92 Internally, this just calls the appropriate function from dopt.core.ops.math. 93 */ 94 Operation opBinary(string op)(Operation rhs, string mod = __MODULE__, size_t line = __LINE__) 95 { 96 if(rhs.rank == 0 && this.rank != 0) 97 { 98 return this.opBinary!op(rhs.repeat(this.volume, mod, line).reshape(this.shape, mod, line), mod, line); 99 } 100 else if(this.rank == 0 && rhs.rank != 0) 101 { 102 return this.repeat(rhs.volume, mod, line).reshape(rhs.shape, mod, line).opBinary!op(rhs, mod, line); 103 } 104 105 static if(op == "+") 106 { 107 return this.add(rhs, mod, line); 108 } 109 else static if(op == "-") 110 { 111 return this.sub(rhs, mod, line); 112 } 113 else static if(op == "*") 114 { 115 return this.mul(rhs, mod, line); 116 } 117 else static if(op == "/") 118 { 119 return this.div(rhs, mod, line); 120 } 121 else 122 { 123 static assert(0, "Unknown binary operation '" ~ op ~ "'"); 124 } 125 } 126 127 Operation opBinary(string op)(int i, string mod = __MODULE__, size_t line = __LINE__) 128 { 129 auto bc = int32([], [i]); 130 131 return opBinary!op(bc, mod, line); 132 } 133 134 Operation opBinary(string op)(float i, string mod = __MODULE__, size_t line = __LINE__) 135 { 136 auto bc = float32([], [i]); 137 138 return opBinary!op(bc, mod, line); 139 } 140 141 Operation opBinaryRight(string op, T)(T t, string mod = __MODULE__, size_t line = __LINE__) 142 { 143 static if(op == "*" || op == "+") 144 { 145 return opBinary!op(t); 146 } 147 else static if(op == "-" && is(T == float)) 148 { 149 return float32([], [t]) - this; 150 } 151 else static if(op == "/" && is(T == float)) 152 { 153 return float32([], [t]) / this; 154 } 155 else 156 { 157 static assert(0, "Not implemented."); 158 } 159 } 160 161 Operation opUnary(string op)() 162 { 163 static if(op == "-") 164 { 165 return neg(this); 166 } 167 else 168 { 169 static assert("Unknown unary operation '" ~ op ~ "'"); 170 } 171 } 172 173 override string toString() 174 { 175 import std.algorithm : joiner, map; 176 import std.conv : to; 177 178 //If it's a variable, we should have some unique identifier 179 if(opType == "variable") 180 { 181 //This is very ugly. Someone please come up with a better way. 182 return to!string(cast(void *)this); 183 } 184 else 185 { 186 return opType ~ "(" ~ deps.map!(x => x.toString).joiner(", ").to!string ~ ")"; 187 } 188 } 189 190 Buffer value() 191 { 192 return attributes["default"].get!Buffer; 193 } 194 195 auto shape() 196 { 197 return outputType.shape; 198 } 199 200 auto elementType() 201 { 202 return outputType.elementType; 203 } 204 205 auto volume() 206 { 207 return outputType.volume; 208 } 209 210 auto rank() 211 { 212 return outputType.rank; 213 } 214 } 215 216 public 217 { 218 string mOpType; 219 string mModule; 220 size_t mLine; 221 Operation[] mDeps; 222 Variant[string] mAttributes; 223 TensorType mOutputType; 224 225 this(string opType, Operation[] deps, Variant[string] attribs, string mod, size_t line) 226 { 227 import std.conv : to; 228 229 mOpType = opType; 230 mDeps = deps.array; 231 mAttributes = attribs.dup; 232 mModule = mod; 233 mLine = line; 234 235 enforce(mOpDefs[opType].verifier(this), 236 "Operation of type \"" ~ opType ~ "\" failed verification. Instantiated at " ~ mod ~ ":" ~ 237 line.to!string); 238 239 mOutputType = makeJudgement(this); 240 } 241 } 242 } 243 244 /** 245 Registers an operation definition with the given identifier. 246 */ 247 void registerOperation(string name, OpDef def) 248 { 249 enforce((name in mOpDefs) is null, "There is already an operation registered with the name '" ~ name ~ "'"); 250 251 mOpDefs[name] = def; 252 } 253 254 /** 255 Returns a list of identifiers for operations that have been registered so far. 256 */ 257 string[] listAllOperations() 258 { 259 return mOpDefs.keys.dup; 260 } 261 262 /** 263 Creates an operation of the given type, with the given dependencies and attributes. 264 */ 265 Operation createOperation(string opType, Operation[] deps = [], Variant[string] attribs = null, 266 string mod = __MODULE__, size_t line = __LINE__) 267 { 268 import std.conv : to; 269 270 enforce(opType in mOpDefs, 271 "Cannot create operation because there is no operation definition registered with the name '" ~ opType ~ "'"); 272 273 auto op = new Operation(opType, deps, attribs, mod, line); 274 275 return op; 276 } 277 278 Operation[] topologicalSort(Operation[] ops) 279 { 280 Operation[] sortedOps; 281 282 void toposort(Operation o) 283 { 284 import std.algorithm : canFind; 285 286 if(sortedOps.canFind(o)) 287 { 288 return; 289 } 290 291 foreach(d; o.deps) 292 { 293 toposort(d); 294 } 295 296 sortedOps ~= o; 297 } 298 299 foreach(o; ops) 300 { 301 toposort(o); 302 } 303 304 return sortedOps; 305 } 306 307 private 308 { 309 OpDef[string] mOpDefs; 310 311 TensorType makeJudgement(Operation op) 312 { 313 auto def = op.opType in mOpDefs; 314 315 enforce(def !is null, "Cannot make judgement for unknown operation '" ~ op.opType() ~ "'"); 316 317 return def.judge(op); 318 } 319 }