1 /** 2 Contains an implementation of stochastic gradient descent that relies on automatic differentiation 3 4 Authors: Henry Gouk 5 */ 6 module dopt.online.sgd; 7 8 import dopt.core; 9 import dopt.online; 10 11 /** 12 Creates a delegate that can be used to perform a step using the stochastic gradient descent update rule. 13 14 This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be 15 differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation. 16 17 Params: 18 objective = Operation representing the loss function to be minimised. 19 wrt = an array of Operations that we want the derivative of objective with respect to. 20 learningRate = the value used to scale the size of the gradient used in the update rule 21 momentumRate = scaling factor for the previous update 22 23 Returns: 24 A delegate that is used to actually perform the update steps. The optimised values are stored in the "default" 25 attributes of the elements of wrt. 26 */ 27 Updater sgd(Operation[] outputs, Operation[] wrt, 28 Operation learningRate = float32([], [0.01f]), Operation momentumRate = float32([], [0.0f])) 29 { 30 import std.algorithm : map; 31 import std.array : array; 32 import std.range : zip; 33 34 auto objective = outputs[0]; 35 36 auto grads = grad(objective, wrt); 37 38 auto momentum = grads 39 .map!(x => float32(x.shape)) 40 .array(); 41 42 auto newMomentum = zip(grads, momentum) 43 .map!(x => x[1] * momentumRate + learningRate * x[0]) 44 .array(); 45 46 auto newvals = zip(wrt, newMomentum) 47 .map!(x => x[0] - x[1]) 48 .array(); 49 50 auto updatePlan = new CUDAPlan(outputs ~ newvals ~ newMomentum); 51 52 import std.range : chain; 53 54 auto newbufs = chain(wrt, momentum) 55 .map!(x => x.value) 56 .array(); 57 58 newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs; 59 60 Buffer[] update(Buffer[Operation] args) 61 { 62 updatePlan.execute(args, newbufs); 63 64 return newbufs[0 .. outputs.length]; 65 } 66 67 return &update; 68 } 69 70 /// 71 unittest 72 { 73 import std.random : uniform; 74 75 //Generate some points 76 auto xdata = new float[100]; 77 auto ydata = new float[100]; 78 79 foreach(i; 0 .. 100) 80 { 81 xdata[i] = uniform(-10.0f, 10.0f); 82 ydata[i] = 3.0f * xdata[i] + 2.0f; 83 } 84 85 //Create the model 86 auto x = float32([]); 87 auto m = float32([]); 88 auto c = float32([]); 89 90 auto yhat = m * x + c; 91 auto y = float32([]); 92 93 //Create an SGD updater 94 auto updater = sgd([(yhat - y) * (yhat - y)], [m, c], float32([], [0.001f]), float32([], [0.9f])); 95 96 //Iterate for a while 97 float loss; 98 99 for(size_t i = 0; i < 300; i++) 100 { 101 size_t j = i % 100; 102 103 loss = updater([ 104 x: Buffer(xdata[j .. j + 1]), 105 y: Buffer(ydata[j .. j + 1]) 106 ])[0].as!float[0]; 107 } 108 109 //Print the loss after 500 iterations. Let the user decide whether it's good enough to be considered a pass. 110 import std.stdio : writeln; 111 writeln( 112 "SGD loss: ", loss, " ", 113 "m=", m.value.as!float[0], ", ", 114 "c=", c.value.as!float[0], " ", 115 "(expected m=3, c=2)"); 116 }