1 /** 2 This package facilitates the construction of various nodes in the operation graph. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.core.ops; 7 8 import std.array; 9 import std.exception; 10 import std.variant; 11 12 import dopt.core.types; 13 14 public 15 { 16 import dopt.core.ops.basic; 17 import dopt.core.ops.math; 18 import dopt.core.ops.nnet; 19 import dopt.core.ops.random; 20 } 21 22 void initialize() 23 { 24 dopt.core.ops.basic.initialize(); 25 dopt.core.ops.math.initialize(); 26 dopt.core.ops.nnet.initialize(); 27 dopt.core.ops.random.initialize(); 28 } 29 30 alias Verifier = bool delegate(Operation); 31 alias Judge = TensorType delegate(Operation); 32 33 /** 34 Contains methods to perform procedures specific to the type of an operation 35 */ 36 struct OpDef 37 { 38 /** 39 A verifier is used to ensure that an Operation object correctly constructed. 40 */ 41 Verifier verifier; 42 43 /** 44 A judge produces a TensorType object that specifies the type of the result of an operation of this type. 45 */ 46 Judge judge; 47 } 48 49 /** 50 A node in the expression graph 51 */ 52 class Operation 53 { 54 public 55 { 56 /** 57 Returns a string identifying the type of this operation. This is the same string used when registering the 58 operation with the registerOperation method. 59 */ 60 @property string opType() 61 { 62 return mOpType; 63 } 64 65 /** 66 Returns a TensorType object that specifies the type of tensor obtained by evaluating this operation. 67 */ 68 @property TensorType outputType() 69 { 70 return mOutputType; 71 } 72 73 /** 74 Returns a list of operands for this operation. 75 */ 76 @property Operation[] deps() 77 { 78 return mDeps; 79 } 80 81 /** 82 Returns an associative array that maps strings to operation specific attributes. 83 */ 84 @property Variant[string] attributes() 85 { 86 return mAttributes; 87 } 88 89 /** 90 Convenience method for pointwise operations. 91 92 Internally, this just calls the appropriate function from dopt.core.ops.math. 93 */ 94 Operation opBinary(string op)(Operation rhs, string mod = __MODULE__, size_t line = __LINE__) 95 { 96 if(rhs.rank == 0 && this.rank != 0) 97 { 98 return this.opBinary!op(rhs.repeat(this.volume, mod, line).reshape(this.shape, mod, line), mod, line); 99 } 100 else if(this.rank == 0 && rhs.rank != 0) 101 { 102 return this.repeat(rhs.volume, mod, line).reshape(rhs.shape, mod, line).opBinary!op(rhs, mod, line); 103 } 104 105 static if(op == "+") 106 { 107 return this.add(rhs, mod, line); 108 } 109 else static if(op == "-") 110 { 111 return this.sub(rhs, mod, line); 112 } 113 else static if(op == "*") 114 { 115 return this.mul(rhs, mod, line); 116 } 117 else static if(op == "/") 118 { 119 return this.div(rhs, mod, line); 120 } 121 else 122 { 123 static assert(0, "Unknown binary operation '" ~ op ~ "'"); 124 } 125 } 126 127 Operation opBinary(string op)(int i, string mod = __MODULE__, size_t line = __LINE__) 128 { 129 auto bc = int32Constant(i); 130 131 return opBinary!op(bc, mod, line); 132 } 133 134 Operation opBinary(string op)(float i, string mod = __MODULE__, size_t line = __LINE__) 135 { 136 auto bc = float32Constant(i); 137 138 return opBinary!op(bc, mod, line); 139 } 140 141 Operation opBinaryRight(string op, T)(T t, string mod = __MODULE__, size_t line = __LINE__) 142 { 143 static if(op == "*" || op == "+") 144 { 145 return opBinary!op(t); 146 } 147 else static if(op == "-" && is(T == float)) 148 { 149 return float32Constant(t) - this; 150 } 151 else static if(op == "/" && is(T == float)) 152 { 153 return float32Constant(t) / this; 154 } 155 else 156 { 157 static assert(0, "Not implemented."); 158 } 159 } 160 161 Operation opUnary(string op)() 162 { 163 static if(op == "-") 164 { 165 return neg(this); 166 } 167 else 168 { 169 static assert("Unknown unary operation '" ~ op ~ "'"); 170 } 171 } 172 173 override string toString() 174 { 175 import std.algorithm : joiner, map; 176 import std.conv : to; 177 178 //If it's a variable, we should have some unique identifier 179 if(opType == "variable") 180 { 181 //This is very ugly. Someone please come up with a better way. 182 return to!string(cast(void *)this); 183 } 184 else 185 { 186 return opType ~ "(" ~ deps.map!(x => x.toString).joiner(", ").to!string ~ ")"; 187 } 188 } 189 190 DeviceBuffer value() 191 { 192 return mBuffer; 193 } 194 195 void setBuffer(DeviceBuffer buf) 196 { 197 mBuffer = buf; 198 } 199 200 auto shape() 201 { 202 return outputType.shape; 203 } 204 205 auto elementType() 206 { 207 return outputType.elementType; 208 } 209 210 auto volume() 211 { 212 return outputType.volume; 213 } 214 215 auto rank() 216 { 217 return outputType.rank; 218 } 219 } 220 221 public 222 { 223 string mOpType; 224 string mModule; 225 size_t mLine; 226 Operation[] mDeps; 227 Variant[string] mAttributes; 228 TensorType mOutputType; 229 DeviceBuffer mBuffer; 230 231 this(string opType, Operation[] deps, Variant[string] attribs, string mod, size_t line) 232 { 233 import std.conv : to; 234 235 mOpType = opType; 236 mDeps = deps.array; 237 mAttributes = attribs.dup; 238 mModule = mod; 239 mLine = line; 240 241 enforce(mOpDefs[opType].verifier(this), 242 "Operation of type \"" ~ opType ~ "\" failed verification. Instantiated at " ~ mod ~ ":" ~ 243 line.to!string); 244 245 mOutputType = makeJudgement(this); 246 } 247 } 248 } 249 250 /** 251 Registers an operation definition with the given identifier. 252 */ 253 void registerOperation(string name, OpDef def) 254 { 255 enforce((name in mOpDefs) is null, "There is already an operation registered with the name '" ~ name ~ "'"); 256 257 mOpDefs[name] = def; 258 } 259 260 /** 261 Returns a list of identifiers for operations that have been registered so far. 262 */ 263 string[] listAllOperations() 264 { 265 return mOpDefs.keys.dup; 266 } 267 268 /** 269 Creates an operation of the given type, with the given dependencies and attributes. 270 */ 271 Operation createOperation(string opType, Operation[] deps = [], Variant[string] attribs = null, 272 string mod = __MODULE__, size_t line = __LINE__) 273 { 274 import std.conv : to; 275 276 enforce(opType in mOpDefs, 277 "Cannot create operation because there is no operation definition registered with the name '" ~ opType ~ "'"); 278 279 auto op = new Operation(opType, deps, attribs, mod, line); 280 281 return op; 282 } 283 284 Operation[] topologicalSort(Operation[] ops) 285 { 286 Operation[] sortedOps; 287 288 void toposort(Operation o) 289 { 290 import std.algorithm : canFind; 291 292 if(sortedOps.canFind(o)) 293 { 294 return; 295 } 296 297 foreach(d; o.deps) 298 { 299 toposort(d); 300 } 301 302 sortedOps ~= o; 303 } 304 305 foreach(o; ops) 306 { 307 toposort(o); 308 } 309 310 return sortedOps; 311 } 312 313 private 314 { 315 OpDef[string] mOpDefs; 316 317 TensorType makeJudgement(Operation op) 318 { 319 auto def = op.opType in mOpDefs; 320 321 enforce(def !is null, "Cannot make judgement for unknown operation '" ~ op.opType() ~ "'"); 322 323 return def.judge(op); 324 } 325 }