1 /** 2 Contains an implementation of stochastic gradient descent that relies on automatic differentiation 3 4 Authors: Henry Gouk 5 */ 6 module dopt.online.sgd; 7 8 import dopt.core; 9 import dopt.online; 10 11 /** 12 Creates a delegate that can be used to perform a step using the stochastic gradient descent update rule. 13 14 This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be 15 differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation. 16 17 Params: 18 objective = Operation representing the loss function to be minimised. 19 wrt = an array of Operations that we want the derivative of objective with respect to. 20 learningRate = the value used to scale the size of the gradient used in the update rule 21 momentumRate = scaling factor for the previous update 22 nesterov = indicates whether Nesterov's accelerated gradient should be used 23 24 Returns: 25 A delegate that is used to actually perform the update steps. The optimised values are stored in the "default" 26 attributes of the elements of wrt. 27 */ 28 Updater sgd(Operation[] outputs, Operation[] wrt, Projection[Operation] projs, 29 Operation learningRate = float32([], [0.01f]), Operation momentumRate = float32([], [0.0f]), bool nesterov = false) 30 { 31 import std.algorithm : map; 32 import std.array : array; 33 import std.range : zip; 34 35 auto objective = outputs[0]; 36 37 auto grads = grad(objective, wrt); 38 39 auto momentum = grads 40 .map!(x => float32(x.shape)) 41 .array(); 42 43 Operation[] newMomentum; 44 Operation[] newvals; 45 46 if(nesterov) 47 { 48 newMomentum = zip(grads, momentum) 49 .map!(x => x[1] * momentumRate - learningRate * x[0]) 50 .array(); 51 52 newvals = zip(wrt, newMomentum, grads) 53 .map!(x => x[0] + momentumRate * x[1] - learningRate * x[2]) 54 .array(); 55 } 56 else 57 { 58 newMomentum = zip(grads, momentum) 59 .map!(x => x[1] * momentumRate + learningRate * x[0]) 60 .array(); 61 62 newvals = zip(wrt, newMomentum) 63 .map!(x => x[0] - x[1]) 64 .array(); 65 } 66 67 //Apply projections 68 for(size_t i = 0; i < newvals.length; i++) 69 { 70 if(wrt[i] in projs) 71 { 72 newvals[i] = projs[wrt[i]](newvals[i]); 73 } 74 } 75 76 auto updatePlan = compile(outputs ~ newvals ~ newMomentum); 77 78 import std.range : chain; 79 80 auto newbufs = chain(wrt, momentum) 81 .map!(x => x.value) 82 .array(); 83 84 newbufs = outputs.map!(x => allocate(x.volume * x.elementType.sizeOf)).array() ~ newbufs; 85 86 DeviceBuffer[] update(DeviceBuffer[Operation] args) 87 { 88 updatePlan.execute(args, newbufs); 89 90 return newbufs[0 .. outputs.length]; 91 } 92 93 return &update; 94 } 95 96 /// 97 unittest 98 { 99 import std.random : uniform; 100 101 //Generate some points 102 auto xdata = new float[100]; 103 auto ydata = new float[100]; 104 105 foreach(i; 0 .. 100) 106 { 107 xdata[i] = uniform(-10.0f, 10.0f); 108 ydata[i] = 3.0f * xdata[i] + 2.0f; 109 } 110 111 //Create the model 112 auto x = float32([]); 113 auto m = float32([]); 114 auto c = float32([]); 115 116 auto yhat = m * x + c; 117 auto y = float32([]); 118 119 //Create an SGD updater 120 auto updater = sgd([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.001f]), float32([], [0.9f])); 121 122 //Iterate for a while 123 float loss; 124 125 for(size_t i = 0; i < 300; i++) 126 { 127 size_t j = i % 100; 128 129 loss = updater([ 130 x: buffer(xdata[j .. j + 1]), 131 y: buffer(ydata[j .. j + 1]) 132 ])[0].get!float[0]; 133 } 134 135 //Print the loss after 500 iterations. Let the user decide whether it's good enough to be considered a pass. 136 import std.stdio : writeln; 137 writeln( 138 "SGD loss: ", loss, " ", 139 "m=", m.value.get!float[0], ", ", 140 "c=", c.value.get!float[0], " ", 141 "(expected m=3, c=2)"); 142 }