1 /** 2 Contains an implementation of ADAM that relies on automatic differentiation 3 4 Authors: Henry Gouk 5 */ 6 module dopt.online.adam; 7 8 import dopt.core; 9 import dopt.online; 10 11 /** 12 Creates a delegate that can be used to perform a step using the ADAM update rule. 13 14 This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be 15 differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation. 16 17 Params: 18 outputs = An array of outputs. The first element of this array is the objective function to be minimised. 19 wrt = An array of Operations that we want the derivative of objective with respect to. 20 projs = Projection functions that can be applied when updating the values of elements in $(D wrt). 21 alpha = The step size. 22 beta1 = Fading factor for the first moment of the gradient. 23 beta2 = Fading factor for the second moment of the gradient. 24 eps = To prevent division by zero. 25 26 Returns: 27 A delegate that is used to actually perform the update steps. The optimised values are stored in the 28 $(D value) properties of the elements of $(D wrt). The delegate returns the values computed for each element of the 29 $(D outputs) array. This can be useful for keeping track of several different performance metrics in a 30 prequential manner. 31 */ 32 Updater adam(Operation[] outputs, Operation[] wrt, Projection[Operation] projs, 33 Operation alpha = float32([], [0.001f]), Operation beta1 = float32([], [0.9f]), 34 Operation beta2 = float32([], [0.999f]), Operation eps = float32([], [1e-8])) 35 { 36 import std.algorithm : map; 37 import std.array : array; 38 import std.range : zip; 39 40 auto objective = outputs[0]; 41 42 auto grads = grad(objective, wrt); 43 auto means = wrt.map!(x => float32(x.shape)).array(); 44 auto vars = wrt.map!(x => float32(x.shape)).array(); 45 46 auto b1 = float32([], [1.0f]); 47 auto b2 = float32([], [1.0f]); 48 auto nb1 = b1 * beta1; 49 auto nb2 = b2 * beta2; 50 auto eta = alpha * sqrt(1.0f - nb2) / (1.0f - nb1); 51 52 auto newMeans = grads 53 .zip(means) 54 .map!(x => beta1 * x[1] + (1.0f - beta1) * x[0]) 55 .array(); 56 57 auto newVars = grads 58 .zip(vars) 59 .map!(x => beta2 * x[1] + (1.0f - beta2) * x[0] * x[0]) 60 .array(); 61 62 auto newvals = zip(wrt, newMeans, newVars) 63 .map!(x => x[0] - eta * (x[1] / (sqrt(x[2]) + eps))) 64 .array(); 65 66 //Apply projections 67 for(size_t i = 0; i < newvals.length; i++) 68 { 69 if(wrt[i] in projs) 70 { 71 newvals[i] = projs[wrt[i]](newvals[i]); 72 } 73 } 74 75 auto updatePlan = compile(outputs ~ newvals ~ newMeans ~ newVars ~ [nb1, nb2]); 76 77 import std.range : chain; 78 79 auto newbufs = chain(wrt, means, vars, [b1, b2]) 80 .map!(x => x.value) 81 .array(); 82 83 newbufs = outputs.map!(x => allocate(x.volume * x.elementType.sizeOf)).array() ~ newbufs; 84 85 DeviceBuffer[] update(DeviceBuffer[Operation] args) 86 { 87 updatePlan.execute(args, newbufs); 88 89 return newbufs[0 .. outputs.length]; 90 } 91 92 return &update; 93 } 94 95 /// 96 unittest 97 { 98 import std.random : uniform; 99 100 //Generate some points 101 auto xdata = new float[100]; 102 auto ydata = new float[100]; 103 104 foreach(i; 0 .. 100) 105 { 106 xdata[i] = uniform(-10.0f, 10.0f); 107 ydata[i] = 3.0f * xdata[i] + 2.0f; 108 } 109 110 //Create the model 111 auto x = float32([]); 112 auto m = float32([]); 113 auto c = float32([]); 114 115 auto yhat = m * x + c; 116 auto y = float32([]); 117 118 //Create an ADAM updater 119 auto updater = adam([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.1f])); 120 121 //Iterate for a while 122 float loss; 123 124 for(size_t i = 0; i < 300; i++) 125 { 126 size_t j = i % 100; 127 128 loss = updater([ 129 x: buffer(xdata[j .. j + 1]), 130 y: buffer(ydata[j .. j + 1]) 131 ])[0].get!float[0]; 132 } 133 134 //Print the loss after 200 iterations. Let the user decide whether it's good enough to be considered a pass. 135 import std.stdio : writeln; 136 writeln( 137 "Adam loss: ", loss, " ", 138 "m=", m.value.get!float[0], ", ", 139 "c=", c.value.get!float[0], " ", 140 "(expected m=3, c=2)"); 141 }