1 /**
2 Contains an implementation of stochastic gradient descent that relies on automatic differentiation
3
4 Authors: Henry Gouk
5 */
6 module dopt.online.sgd;
7
8 import dopt.core;
9 import dopt.online;
10
11 /**
12 Creates a delegate that can be used to perform a step using the stochastic gradient descent update rule.
13
14 This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be
15 differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.
16
17 Params:
18 objective = Operation representing the loss function to be minimised.
19 wrt = an array of Operations that we want the derivative of objective with respect to.
20 learningRate = the value used to scale the size of the gradient used in the update rule
21 momentumRate = scaling factor for the previous update
22
23 Returns:
24 A delegate that is used to actually perform the update steps. The optimised values are stored in the "default"
25 attributes of the elements of wrt.
26 */
27 Updater sgd(Operation[] outputs, Operation[] wrt, Projection[Operation] projs,
28 Operation learningRate = float32([], [0.01f]), Operation momentumRate = float32([], [0.0f]))
29 {
30 import std.algorithm : map;
31 import std.array : array;
32 import std.range : zip;
33
34 auto objective = outputs[0];
35
36 auto grads = grad(objective, wrt);
37
38 auto momentum = grads
39 .map!(x => float32(x.shape))
40 .array();
41
42 auto newMomentum = zip(grads, momentum)
43 .map!(x => x[1] * momentumRate + learningRate * x[0])
44 .array();
45
46 auto newvals = zip(wrt, newMomentum)
47 .map!(x => x[0] - x[1])
48 .array();
49
50 //Apply projections
51 for(size_t i = 0; i < newvals.length; i++)
52 {
53 if(wrt[i] in projs)
54 {
55 newvals[i] = projs[wrt[i]](newvals[i]);
56 }
57 }
58
59 auto updatePlan = compile(outputs ~ newvals ~ newMomentum);
60
61 import std.range : chain;
62
63 auto newbufs = chain(wrt, momentum)
64 .map!(x => x.value)
65 .array();
66
67 newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs;
68
69 Buffer[] update(Buffer[Operation] args)
70 {
71 updatePlan.execute(args, newbufs);
72
73 return newbufs[0 .. outputs.length];
74 }
75
76 return &update;
77 }
78
79 ///
80 unittest
81 {
82 import std.random : uniform;
83
84 //Generate some points
85 auto xdata = new float[100];
86 auto ydata = new float[100];
87
88 foreach(i; 0 .. 100)
89 {
90 xdata[i] = uniform(-10.0f, 10.0f);
91 ydata[i] = 3.0f * xdata[i] + 2.0f;
92 }
93
94 //Create the model
95 auto x = float32([]);
96 auto m = float32([]);
97 auto c = float32([]);
98
99 auto yhat = m * x + c;
100 auto y = float32([]);
101
102 //Create an SGD updater
103 auto updater = sgd([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.001f]), float32([], [0.9f]));
104
105 //Iterate for a while
106 float loss;
107
108 for(size_t i = 0; i < 300; i++)
109 {
110 size_t j = i % 100;
111
112 loss = updater([
113 x: Buffer(xdata[j .. j + 1]),
114 y: Buffer(ydata[j .. j + 1])
115 ])[0].as!float[0];
116 }
117
118 //Print the loss after 500 iterations. Let the user decide whether it's good enough to be considered a pass.
119 import std.stdio : writeln;
120 writeln(
121 "SGD loss: ", loss, " ",
122 "m=", m.value.as!float[0], ", ",
123 "c=", c.value.as!float[0], " ",
124 "(expected m=3, c=2)");
125 }