1 /** 2 Contains an implementation of AMSGrad that relies on automatic differentiation 3 4 Authors: Henry Gouk 5 */ 6 module dopt.online.amsgrad; 7 8 import dopt.core; 9 import dopt.online; 10 11 /** 12 Creates a delegate that can be used to perform a step using the AMSGrad update rule. 13 14 This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be 15 differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation. 16 17 Params: 18 outputs = An array of outputs. The first element of this array is the objective function to be minimised. 19 wrt = An array of Operations that we want the derivative of objective with respect to. 20 projs = Projection functions that can be applied when updating the values of elements in $(D wrt). 21 alpha = The step size. 22 beta1 = Fading factor for the first moment of the gradient. 23 beta2 = Fading factor for the second moment of the gradient. 24 eps = To prevent division by zero. 25 26 Returns: 27 A delegate that is used to actually perform the update steps. The optimised values are stored in the 28 $(D value) properties of the elements of $(D wrt). The delegate returns the values computed for each element of the 29 $(D outputs) array. This can be useful for keeping track of several different performance metrics in a 30 prequential manner. 31 */ 32 Updater amsgrad(Operation[] outputs, Operation[] wrt, Projection[Operation] projs, 33 Operation alpha = float32([], [0.001f]), Operation beta1 = float32([], [0.9f]), 34 Operation beta2 = float32([], [0.999f]), Operation eps = float32([], [1e-8])) 35 { 36 import std.algorithm : map; 37 import std.array : array; 38 import std.range : zip; 39 40 auto objective = outputs[0]; 41 42 auto grads = grad(objective, wrt); 43 auto means = wrt.map!(x => float32(x.shape)).array(); 44 auto vars = wrt.map!(x => float32(x.shape)).array(); 45 auto varhats = wrt.map!(x => float32(x.shape)).array(); 46 47 auto b1 = float32([], [1.0f]); 48 auto b2 = float32([], [1.0f]); 49 auto nb1 = b1 * beta1; 50 auto nb2 = b2 * beta2; 51 auto eta = alpha * sqrt(1.0f - nb2) / (1.0f - nb1); 52 53 auto newMeans = grads 54 .zip(means) 55 .map!(x => beta1 * x[1] + (1.0f - beta1) * x[0]) 56 .array(); 57 58 auto newVars = grads 59 .zip(vars) 60 .map!(x => beta2 * x[1] + (1.0f - beta2) * x[0] * x[0]) 61 .array(); 62 63 auto newVarhats = varhats 64 .zip(vars) 65 .map!(x => max(x[0], x[1])) 66 .array(); 67 68 auto newvals = zip(wrt, newMeans, newVars) 69 .map!(x => x[0] - eta * (x[1] / (sqrt(x[2]) + eps))) 70 .array(); 71 72 //Apply projections 73 for(size_t i = 0; i < newvals.length; i++) 74 { 75 if(wrt[i] in projs) 76 { 77 newvals[i] = projs[wrt[i]](newvals[i]); 78 } 79 } 80 81 auto updatePlan = compile(outputs ~ newvals ~ newMeans ~ newVars ~ newVarhats ~ [nb1, nb2]); 82 83 import std.range : chain; 84 85 auto newbufs = chain(wrt, means, vars, varhats, [b1, b2]) 86 .map!(x => x.value) 87 .array(); 88 89 newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs; 90 91 Buffer[] update(Buffer[Operation] args) 92 { 93 updatePlan.execute(args, newbufs); 94 95 return newbufs[0 .. outputs.length]; 96 } 97 98 return &update; 99 } 100 101 /// 102 unittest 103 { 104 import std.random : uniform; 105 106 //Generate some points 107 auto xdata = new float[100]; 108 auto ydata = new float[100]; 109 110 foreach(i; 0 .. 100) 111 { 112 xdata[i] = uniform(-10.0f, 10.0f); 113 ydata[i] = 3.0f * xdata[i] + 2.0f; 114 } 115 116 //Create the model 117 auto x = float32([]); 118 auto m = float32([]); 119 auto c = float32([]); 120 121 auto yhat = m * x + c; 122 auto y = float32([]); 123 124 //Create an AMSGrad updater 125 auto updater = amsgrad([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.1f])); 126 127 //Iterate for a while 128 float loss; 129 130 for(size_t i = 0; i < 300; i++) 131 { 132 size_t j = i % 100; 133 134 loss = updater([ 135 x: Buffer(xdata[j .. j + 1]), 136 y: Buffer(ydata[j .. j + 1]) 137 ])[0].as!float[0]; 138 } 139 140 //Print the loss after 200 iterations. Let the user decide whether it's good enough to be considered a pass. 141 import std.stdio : writeln; 142 writeln( 143 "AMSGrad loss: ", loss, " ", 144 "m=", m.value.as!float[0], ", ", 145 "c=", c.value.as!float[0], " ", 146 "(expected m=3, c=2)"); 147 }