1 /**
2     Contains an implementation of stochastic gradient descent that relies on automatic differentiation
3 
4     Authors: Henry Gouk
5 */
6 module dopt.online.sgd;
7 
8 import dopt.core;
9 import dopt.online;
10 
11 /**
12     Creates a delegate that can be used to perform a step using the stochastic gradient descent update rule.
13     
14     This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be
15     differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.
16 
17     Params:
18         objective = Operation representing the loss function to be minimised.
19         wrt = an array of Operations that we want the derivative of objective with respect to.
20         learningRate = the value used to scale the size of the gradient used in the update rule
21         momentumRate = scaling factor for the previous update
22 
23     Returns:
24          A delegate that is used to actually perform the update steps. The optimised values are stored in the "default"
25          attributes of the elements of wrt.
26 */
27 Updater sgd(Operation[] outputs, Operation[] wrt, Projection[Operation] projs,
28     Operation learningRate = float32([], [0.01f]), Operation momentumRate = float32([], [0.0f]))
29 {
30     import std.algorithm : map;
31     import std.array : array;
32     import std.range : zip;
33 
34     auto objective = outputs[0];
35 
36     auto grads = grad(objective, wrt);
37 
38     auto momentum = grads
39                    .map!(x => float32(x.shape))
40                    .array();
41     
42     auto newMomentum = zip(grads, momentum)
43                       .map!(x => x[1] * momentumRate + learningRate * x[0])
44                       .array();
45 
46     auto newvals = zip(wrt, newMomentum)
47                   .map!(x => x[0] - x[1])
48                   .array();
49 
50     //Apply projections
51     for(size_t i = 0; i < newvals.length; i++)
52     {
53         if(wrt[i] in projs)
54         {
55             newvals[i] = projs[wrt[i]](newvals[i]);
56         }
57     }
58 
59     auto updatePlan = compile(outputs ~ newvals ~ newMomentum);
60 
61     import std.range : chain;
62 
63     auto newbufs = chain(wrt, momentum)
64                   .map!(x => x.value)
65                   .array();
66 
67     newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs;
68 
69     Buffer[] update(Buffer[Operation] args)
70     {
71         updatePlan.execute(args, newbufs);
72 
73         return newbufs[0 .. outputs.length];
74     }
75 
76     return &update;
77 }
78 
79 ///
80 unittest
81 {
82     import std.random : uniform;
83 
84     //Generate some points
85     auto xdata = new float[100];
86     auto ydata = new float[100];
87 
88     foreach(i; 0 .. 100)
89     {
90         xdata[i] = uniform(-10.0f, 10.0f);
91         ydata[i] = 3.0f * xdata[i] + 2.0f;
92     }
93 
94     //Create the model
95     auto x = float32([]);
96     auto m = float32([]);
97     auto c = float32([]);
98 
99     auto yhat = m * x + c;
100     auto y = float32([]);
101 
102     //Create an SGD updater
103     auto updater = sgd([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.001f]), float32([], [0.9f]));
104 
105     //Iterate for a while
106     float loss;
107 
108     for(size_t i = 0; i < 300; i++)
109     {
110         size_t j = i % 100;
111 
112         loss = updater([
113             x: Buffer(xdata[j .. j + 1]),
114             y: Buffer(ydata[j .. j + 1])
115         ])[0].as!float[0];
116     }
117 
118     //Print the loss after 500 iterations. Let the user decide whether it's good enough to be considered a pass.
119     import std.stdio : writeln;
120     writeln(
121         "SGD loss: ", loss, "    ",
122         "m=", m.value.as!float[0], ", ",
123         "c=", c.value.as!float[0], "    ",
124         "(expected m=3, c=2)");
125 }