1 /**
2     Contains an implementation of stochastic gradient descent that relies on automatic differentiation
3 
4     Authors: Henry Gouk
5 */
6 module dopt.online.sgd;
7 
8 import dopt.core;
9 import dopt.online;
10 
11 /**
12     Creates a delegate that can be used to perform a step using the stochastic gradient descent update rule.
13     
14     This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be
15     differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.
16 
17     Params:
18         objective = Operation representing the loss function to be minimised.
19         wrt = an array of Operations that we want the derivative of objective with respect to.
20         learningRate = the value used to scale the size of the gradient used in the update rule
21         momentumRate = scaling factor for the previous update
22         nesterov = indicates whether Nesterov's accelerated gradient should be used
23 
24     Returns:
25          A delegate that is used to actually perform the update steps. The optimised values are stored in the "default"
26          attributes of the elements of wrt.
27 */
28 Updater sgd(Operation[] outputs, Operation[] wrt, Projection[Operation] projs,
29     Operation learningRate = float32([], [0.01f]), Operation momentumRate = float32([], [0.0f]), bool nesterov = false)
30 {
31     import std.algorithm : map;
32     import std.array : array;
33     import std.range : zip;
34 
35     auto objective = outputs[0];
36 
37     auto grads = grad(objective, wrt);
38 
39     auto momentum = grads
40                    .map!(x => float32(x.shape))
41                    .array();
42     
43     Operation[] newMomentum;
44     Operation[] newvals;
45     
46     if(nesterov)
47     {
48         newMomentum = zip(grads, momentum)
49                       .map!(x => x[1] * momentumRate - learningRate * x[0])
50                       .array();
51 
52         newvals = zip(wrt, newMomentum, grads)
53                   .map!(x => x[0] + momentumRate * x[1] - learningRate * x[2])
54                   .array();
55     }
56     else
57     {
58         newMomentum = zip(grads, momentum)
59                       .map!(x => x[1] * momentumRate + learningRate * x[0])
60                       .array();
61 
62         newvals = zip(wrt, newMomentum)
63                   .map!(x => x[0] - x[1])
64                   .array();
65     }
66 
67     //Apply projections
68     for(size_t i = 0; i < newvals.length; i++)
69     {
70         if(wrt[i] in projs)
71         {
72             newvals[i] = projs[wrt[i]](newvals[i]);
73         }
74     }
75 
76     auto updatePlan = compile(outputs ~ newvals ~ newMomentum);
77 
78     import std.range : chain;
79 
80     auto newbufs = chain(wrt, momentum)
81                   .map!(x => x.value)
82                   .array();
83 
84     newbufs = outputs.map!(x => allocate(x.volume * x.elementType.sizeOf)).array() ~ newbufs;
85 
86     DeviceBuffer[] update(DeviceBuffer[Operation] args)
87     {
88         updatePlan.execute(args, newbufs);
89 
90         return newbufs[0 .. outputs.length];
91     }
92 
93     return &update;
94 }
95 
96 ///
97 unittest
98 {
99     import std.random : uniform;
100 
101     //Generate some points
102     auto xdata = new float[100];
103     auto ydata = new float[100];
104 
105     foreach(i; 0 .. 100)
106     {
107         xdata[i] = uniform(-10.0f, 10.0f);
108         ydata[i] = 3.0f * xdata[i] + 2.0f;
109     }
110 
111     //Create the model
112     auto x = float32([]);
113     auto m = float32([]);
114     auto c = float32([]);
115 
116     auto yhat = m * x + c;
117     auto y = float32([]);
118 
119     //Create an SGD updater
120     auto updater = sgd([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.001f]), float32([], [0.9f]));
121 
122     //Iterate for a while
123     float loss;
124 
125     for(size_t i = 0; i < 300; i++)
126     {
127         size_t j = i % 100;
128 
129         loss = updater([
130             x: buffer(xdata[j .. j + 1]),
131             y: buffer(ydata[j .. j + 1])
132         ])[0].get!float[0];
133     }
134 
135     //Print the loss after 500 iterations. Let the user decide whether it's good enough to be considered a pass.
136     import std.stdio : writeln;
137     writeln(
138         "SGD loss: ", loss, "    ",
139         "m=", m.value.get!float[0], ", ",
140         "c=", c.value.get!float[0], "    ",
141         "(expected m=3, c=2)");
142 }