1 /** 2 Contains the automatic differentiation framework. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.core.grads; 7 8 import std.exception; 9 10 import dopt.core.grads.basic; 11 import dopt.core.grads.math; 12 import dopt.core.grads.nnet; 13 import dopt.core.ops; 14 import dopt.core.types; 15 16 alias Gradient = Operation[] delegate(Operation op, Operation parentGrad); 17 18 void initialize() 19 { 20 dopt.core.grads.basic.initialize(); 21 dopt.core.grads.math.initialize(); 22 dopt.core.grads.nnet.initialize(); 23 } 24 25 /** 26 Computes the gradient of a scalar-valued operation with respect to several dependencies. 27 28 This function provides an implementation of automatic differentiation can be used for greatly simplifying the 29 process of optimising objective functions. The particular technique used by the function is known as 30 reverse mode automatic differentiation. 31 32 Params: 33 objective = The function being differentiated. 34 wrt = The (indirect) dependencies that $(D objective) is being differentiated with respect to. 35 36 Returns: 37 An array of operations that evaluate to the derivative of $(D objective) to each of the elements of $(D wrt). 38 */ 39 Operation[] grad(Operation objective, Operation[] wrt) 40 { 41 import std.algorithm : canFind, countUntil, map; 42 import std.array : array; 43 import std.conv : to; 44 import std.range : retro, zip; 45 46 enforce(objective.outputType.volume == 1, "The objective must have a volume of one"); 47 enforce(objective.outputType.elementType == DataType.float32, "The objective must have a floating point type"); 48 49 Operation[] ops; 50 51 void traverse(Operation op) 52 { 53 foreach(d; op.deps) 54 { 55 if(!ops.canFind(d)) 56 { 57 traverse(d); 58 } 59 } 60 61 ops ~= op; 62 } 63 64 //Topologically sort the operations 65 traverse(objective); 66 67 Operation[Operation] grads; 68 69 //TODO: when I implement a 'ones' operation, replace this line 70 grads[objective] = float32(objective.outputType.shape, [1.0f]); 71 72 //Iterate through the operations in reverse order (reverse mode autodiff) 73 foreach(op; ops.retro) 74 { 75 //Get the function that will let us compute the gradient of op w.r.t. its deps 76 auto gradFunc = mGradients.get(op.opType, null); 77 auto opGrad = grads.get(op, null); 78 79 if(gradFunc is null || opGrad is null) 80 { 81 //This op, or its parent, is not differentiable, so we will just assume its derivative is zero everywhere 82 continue; 83 } 84 85 //Compute the derivative: d(op)/d(op.deps) 86 auto depGrads = gradFunc(op, opGrad); 87 88 //Add these to grads. If there is already an entry for one of the deps, then it has two parents. 89 //we can just add this grad to the existing grad, because maths. 90 foreach(d, g; zip(op.deps, depGrads)) 91 { 92 auto currentGrad = grads.get(d, null); 93 94 if(currentGrad is null) 95 { 96 grads[d] = g; 97 } 98 else 99 { 100 grads[d] = currentGrad + g; 101 } 102 } 103 } 104 105 auto errIdx = wrt.countUntil!(x => grads.get(x, null) is null); 106 107 enforce(errIdx == -1, "Could not find wrt[" ~ errIdx.to!string ~ "] in the operation graph"); 108 109 return wrt.map!(x => grads[x]).array(); 110 } 111 112 /// 113 unittest 114 { 115 import std.random : uniform; 116 import dopt.core : evaluate; 117 118 auto x = float32(); 119 auto y = x * x; 120 auto gradY = grad(y, [x]); 121 122 auto r = uniform(-100.0f, 100.0f); 123 124 auto gradYwrtX = gradY.evaluate([ 125 x: Buffer([r]) 126 ])[0]; 127 128 assert(gradYwrtX.as!float[0] == r + r); 129 } 130 131 void registerGradient(string opName, Gradient g) 132 { 133 enforce((opName in mGradients) is null, "A gradient is already registered for operation '" ~ opName ~ "'"); 134 135 mGradients[opName] = g; 136 } 137 138 void deregisterGradient(string opName) 139 { 140 mGradients.remove(opName); 141 } 142 143 string[] listAllGradients() 144 { 145 return mGradients.keys.dup; 146 } 147 148 private 149 { 150 Gradient[string] mGradients; 151 }