1 /** 2 Contains an implementation of stochastic gradient descent that relies on automatic differentiation 3 4 Authors: Henry Gouk 5 */ 6 module dopt.online.sgd; 7 8 import dopt.core; 9 import dopt.online; 10 11 /** 12 Creates a delegate that can be used to perform a step using the stochastic gradient descent update rule. 13 14 This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be 15 differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation. 16 17 Params: 18 objective = Operation representing the loss function to be minimised. 19 wrt = an array of Operations that we want the derivative of objective with respect to. 20 learningRate = the value used to scale the size of the gradient used in the update rule 21 momentumRate = scaling factor for the previous update 22 23 Returns: 24 A delegate that is used to actually perform the update steps. The optimised values are stored in the "default" 25 attributes of the elements of wrt. 26 */ 27 Updater sgd(Operation[] outputs, Operation[] wrt, Projection[Operation] projs, 28 Operation learningRate = float32([], [0.01f]), Operation momentumRate = float32([], [0.0f])) 29 { 30 import std.algorithm : map; 31 import std.array : array; 32 import std.range : zip; 33 34 auto objective = outputs[0]; 35 36 auto grads = grad(objective, wrt); 37 38 auto momentum = grads 39 .map!(x => float32(x.shape)) 40 .array(); 41 42 auto newMomentum = zip(grads, momentum) 43 .map!(x => x[1] * momentumRate + learningRate * x[0]) 44 .array(); 45 46 auto newvals = zip(wrt, newMomentum) 47 .map!(x => x[0] - x[1]) 48 .array(); 49 50 //Apply projections 51 for(size_t i = 0; i < newvals.length; i++) 52 { 53 if(wrt[i] in projs) 54 { 55 newvals[i] = projs[wrt[i]](newvals[i]); 56 } 57 } 58 59 auto updatePlan = compile(outputs ~ newvals ~ newMomentum); 60 61 import std.range : chain; 62 63 auto newbufs = chain(wrt, momentum) 64 .map!(x => x.value) 65 .array(); 66 67 newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs; 68 69 Buffer[] update(Buffer[Operation] args) 70 { 71 updatePlan.execute(args, newbufs); 72 73 return newbufs[0 .. outputs.length]; 74 } 75 76 return &update; 77 } 78 79 /// 80 unittest 81 { 82 import std.random : uniform; 83 84 //Generate some points 85 auto xdata = new float[100]; 86 auto ydata = new float[100]; 87 88 foreach(i; 0 .. 100) 89 { 90 xdata[i] = uniform(-10.0f, 10.0f); 91 ydata[i] = 3.0f * xdata[i] + 2.0f; 92 } 93 94 //Create the model 95 auto x = float32([]); 96 auto m = float32([]); 97 auto c = float32([]); 98 99 auto yhat = m * x + c; 100 auto y = float32([]); 101 102 //Create an SGD updater 103 auto updater = sgd([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.001f]), float32([], [0.9f])); 104 105 //Iterate for a while 106 float loss; 107 108 for(size_t i = 0; i < 300; i++) 109 { 110 size_t j = i % 100; 111 112 loss = updater([ 113 x: Buffer(xdata[j .. j + 1]), 114 y: Buffer(ydata[j .. j + 1]) 115 ])[0].as!float[0]; 116 } 117 118 //Print the loss after 500 iterations. Let the user decide whether it's good enough to be considered a pass. 119 import std.stdio : writeln; 120 writeln( 121 "SGD loss: ", loss, " ", 122 "m=", m.value.as!float[0], ", ", 123 "c=", c.value.as!float[0], " ", 124 "(expected m=3, c=2)"); 125 }