1 /** 2 This package contains the framework for constructing and executing operation graphs. 3 4 $(UL 5 $(LI $(D dopt.core.ops) provides functions for constructing nodes in the operation graph.) 6 $(LI $(D dopt.core.grads) provides functions for computing the derivatives of operations.) 7 $(LI $(D dopt.core.cpu) contains a backend that executes operation graphs using the CPU.) 8 $(LI $(D dopt.core.cuda) contains a backend that executes operation graphs using a CUDA enabled GPU.) 9 ) 10 11 Authors: Henry Gouk 12 */ 13 module dopt.core; 14 15 public 16 { 17 import dopt.core.grads; 18 import dopt.core.ops; 19 import dopt.core.types; 20 } 21 22 alias Evaluator = DeviceBuffer[] delegate(Operation[] ops, DeviceBuffer[Operation] args); 23 alias Compiler = Plan delegate(Operation[] ops); 24 alias Allocator = DeviceBuffer delegate(size_t numBytes); 25 26 private __gshared Evaluator mDefaultEvaluator; 27 private __gshared Compiler mDefaultCompiler; 28 private __gshared Allocator mDefaultVarAllocator; 29 private __gshared Allocator mDefaultArgAllocator; 30 31 Evaluator defaultEvaluator() 32 { 33 return mDefaultEvaluator; 34 } 35 36 void defaultEvaluator(Evaluator de) 37 { 38 mDefaultEvaluator = de; 39 } 40 41 Compiler defaultCompiler() 42 { 43 return mDefaultCompiler; 44 } 45 46 void defaultCompiler(Compiler de) 47 { 48 mDefaultCompiler = de; 49 } 50 51 Allocator defaultVarAllocator() 52 { 53 return mDefaultVarAllocator; 54 } 55 56 void defaultVarAllocator(Allocator da) 57 { 58 mDefaultVarAllocator = da; 59 } 60 61 Allocator defaultArgAllocator() 62 { 63 return mDefaultArgAllocator; 64 } 65 66 void defaultArgAllocator(Allocator da) 67 { 68 mDefaultArgAllocator = da; 69 } 70 71 shared static this() 72 { 73 import std.functional : toDelegate; 74 75 dopt.core.ops.initialize(); 76 dopt.core.grads.initialize(); 77 } 78 79 /** 80 Evaluates a several nodes from the operation graph. 81 82 Params: 83 ops = The nodes of the operation graph that values should be computed for. 84 args = A set of variable assignments. 85 86 Returns: 87 An array of $(D DeviceBuffer) objects, each containing the value of the corresponding element in $(D ops). 88 */ 89 DeviceBuffer[] evaluate(Operation[] ops, DeviceBuffer[Operation] args = null) 90 { 91 return mDefaultEvaluator(ops, args); 92 } 93 94 /** 95 Evaluates an operation graph with a single root node. 96 97 This overload is here for convenience. Internally, the multi-output version of evaluate is called. 98 99 Params: 100 op = The root node of the operation graph. 101 args = A set of variable assignments. 102 103 Returns: 104 A $(D Buffer) containing the result of the computation. 105 */ 106 DeviceBuffer evaluate(Operation op, DeviceBuffer[Operation] args = null) 107 { 108 return evaluate([op], args)[0]; 109 } 110 111 /** 112 Compile an Operation graph into a reusable execution plan. 113 114 This can be useful in the case where the function might need to be evaluated multiple times, as it will avoid 115 repeating initialisation and optimisation procedures. 116 117 Params: 118 outputs = The output nodes of the Operation graph. 119 120 Returns: 121 A $(D Plan) that can be executed. 122 */ 123 Plan compile(Operation[] outputs) 124 { 125 return mDefaultCompiler(outputs); 126 } 127 128 DeviceBuffer allocate(size_t numBytes) 129 { 130 return mDefaultVarAllocator(numBytes); 131 } 132 133 DeviceBuffer buffer(void[] vals) 134 { 135 auto buf = mDefaultArgAllocator(vals.length); 136 buf.set(vals); 137 138 return buf; 139 } 140 141 class Plan 142 { 143 public 144 { 145 this(Operation[] outputs) 146 { 147 import std.array : array; 148 149 mOutputs = outputs.array(); 150 } 151 152 /** 153 Executes the plan. 154 155 Params: 156 args = A set of variable assignments. 157 */ 158 DeviceBuffer[] execute(DeviceBuffer[Operation] args = null) 159 { 160 auto rets = new DeviceBuffer[mOutputs.length]; 161 162 foreach(i, o; mOutputs) 163 { 164 rets[i] = allocate(o.outputType.volume * o.outputType.elementType.sizeOf()); 165 } 166 167 execute(args, rets); 168 169 return rets; 170 } 171 172 /// 173 void execute(DeviceBuffer[Operation] args, DeviceBuffer[] rets) 174 { 175 executeImpl(args, rets); 176 } 177 } 178 179 protected 180 { 181 Operation[] mOutputs; 182 183 abstract void executeImpl(DeviceBuffer[Operation] args, DeviceBuffer[] rets); 184 } 185 }