1 /**
2 Contains an implementation of AMSGrad that relies on automatic differentiation
3
4 Authors: Henry Gouk
5 */
6 module dopt.online.amsgrad;
7
8 import dopt.core;
9 import dopt.online;
10
11 /**
12 Creates a delegate that can be used to perform a step using the AMSGrad update rule.
13
14 This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be
15 differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.
16
17 Params:
18 outputs = An array of outputs. The first element of this array is the objective function to be minimised.
19 wrt = An array of Operations that we want the derivative of objective with respect to.
20 projs = Projection functions that can be applied when updating the values of elements in $(D wrt).
21 alpha = The step size.
22 beta1 = Fading factor for the first moment of the gradient.
23 beta2 = Fading factor for the second moment of the gradient.
24 eps = To prevent division by zero.
25
26 Returns:
27 A delegate that is used to actually perform the update steps. The optimised values are stored in the
28 $(D value) properties of the elements of $(D wrt). The delegate returns the values computed for each element of the
29 $(D outputs) array. This can be useful for keeping track of several different performance metrics in a
30 prequential manner.
31 */
32 Updater amsgrad(Operation[] outputs, Operation[] wrt, Projection[Operation] projs,
33 Operation alpha = float32([], [0.001f]), Operation beta1 = float32([], [0.9f]),
34 Operation beta2 = float32([], [0.999f]), Operation eps = float32([], [1e-8]))
35 {
36 import std.algorithm : map;
37 import std.array : array;
38 import std.range : zip;
39
40 auto objective = outputs[0];
41
42 auto grads = grad(objective, wrt);
43 auto means = wrt.map!(x => float32(x.shape)).array();
44 auto vars = wrt.map!(x => float32(x.shape)).array();
45 auto varhats = wrt.map!(x => float32(x.shape)).array();
46
47 auto b1 = float32([], [1.0f]);
48 auto b2 = float32([], [1.0f]);
49 auto nb1 = b1 * beta1;
50 auto nb2 = b2 * beta2;
51 auto eta = alpha * sqrt(1.0f - nb2) / (1.0f - nb1);
52
53 auto newMeans = grads
54 .zip(means)
55 .map!(x => beta1 * x[1] + (1.0f - beta1) * x[0])
56 .array();
57
58 auto newVars = grads
59 .zip(vars)
60 .map!(x => beta2 * x[1] + (1.0f - beta2) * x[0] * x[0])
61 .array();
62
63 auto newVarhats = varhats
64 .zip(vars)
65 .map!(x => max(x[0], x[1]))
66 .array();
67
68 auto newvals = zip(wrt, newMeans, newVars)
69 .map!(x => x[0] - eta * (x[1] / (sqrt(x[2]) + eps)))
70 .array();
71
72 //Apply projections
73 for(size_t i = 0; i < newvals.length; i++)
74 {
75 if(wrt[i] in projs)
76 {
77 newvals[i] = projs[wrt[i]](newvals[i]);
78 }
79 }
80
81 auto updatePlan = compile(outputs ~ newvals ~ newMeans ~ newVars ~ newVarhats ~ [nb1, nb2]);
82
83 import std.range : chain;
84
85 auto newbufs = chain(wrt, means, vars, varhats, [b1, b2])
86 .map!(x => x.value)
87 .array();
88
89 newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs;
90
91 Buffer[] update(Buffer[Operation] args)
92 {
93 updatePlan.execute(args, newbufs);
94
95 return newbufs[0 .. outputs.length];
96 }
97
98 return &update;
99 }
100
101 ///
102 unittest
103 {
104 import std.random : uniform;
105
106 //Generate some points
107 auto xdata = new float[100];
108 auto ydata = new float[100];
109
110 foreach(i; 0 .. 100)
111 {
112 xdata[i] = uniform(-10.0f, 10.0f);
113 ydata[i] = 3.0f * xdata[i] + 2.0f;
114 }
115
116 //Create the model
117 auto x = float32([]);
118 auto m = float32([]);
119 auto c = float32([]);
120
121 auto yhat = m * x + c;
122 auto y = float32([]);
123
124 //Create an AMSGrad updater
125 auto updater = amsgrad([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.1f]));
126
127 //Iterate for a while
128 float loss;
129
130 for(size_t i = 0; i < 300; i++)
131 {
132 size_t j = i % 100;
133
134 loss = updater([
135 x: Buffer(xdata[j .. j + 1]),
136 y: Buffer(ydata[j .. j + 1])
137 ])[0].as!float[0];
138 }
139
140 //Print the loss after 200 iterations. Let the user decide whether it's good enough to be considered a pass.
141 import std.stdio : writeln;
142 writeln(
143 "AMSGrad loss: ", loss, " ",
144 "m=", m.value.as!float[0], ", ",
145 "c=", c.value.as!float[0], " ",
146 "(expected m=3, c=2)");
147 }