1 /**
2     Contains an implementation of AMSGrad that relies on automatic differentiation
3 
4     Authors: Henry Gouk
5 */
6 module dopt.online.amsgrad;
7 
8 import dopt.core;
9 import dopt.online;
10 
11 /**
12     Creates a delegate that can be used to perform a step using the AMSGrad update rule.
13     
14     This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be
15     differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.
16 
17     Params:
18         outputs = An array of outputs. The first element of this array is the objective function to be minimised.
19         wrt = An array of Operations that we want the derivative of objective with respect to.
20         projs = Projection functions that can be applied when updating the values of elements in $(D wrt).
21         alpha = The step size.
22         beta1 = Fading factor for the first moment of the gradient.
23         beta2 = Fading factor for the second moment of the gradient.
24         eps = To prevent division by zero.
25 
26     Returns:
27          A delegate that is used to actually perform the update steps. The optimised values are stored in the
28          $(D value) properties of the elements of $(D wrt). The delegate returns the values computed for each element of the
29          $(D outputs) array. This can be useful for keeping track of several different performance metrics in a
30          prequential manner.
31 */
32 Updater amsgrad(Operation[] outputs, Operation[] wrt, Projection[Operation] projs,
33     Operation alpha = float32([], [0.001f]), Operation beta1 = float32([], [0.9f]),
34     Operation beta2 = float32([], [0.999f]), Operation eps = float32([], [1e-8]))
35 {
36     import std.algorithm : map;
37     import std.array : array;
38     import std.range : zip;
39 
40     auto objective = outputs[0];
41 
42     auto grads = grad(objective, wrt);
43     auto means = wrt.map!(x => float32(x.shape)).array();
44     auto vars = wrt.map!(x => float32(x.shape)).array();
45     auto varhats = wrt.map!(x => float32(x.shape)).array();
46 
47     auto b1 = float32([], [1.0f]);
48     auto b2 = float32([], [1.0f]);
49     auto nb1 = b1 * beta1;
50     auto nb2 = b2 * beta2;
51     auto eta = alpha * sqrt(1.0f - nb2) / (1.0f - nb1);
52 
53     auto newMeans = grads
54                    .zip(means)
55                    .map!(x => beta1 * x[1] + (1.0f - beta1) * x[0])
56                    .array();
57     
58     auto newVars = grads
59                    .zip(vars)
60                    .map!(x => beta2 * x[1] + (1.0f - beta2) * x[0] * x[0])
61                    .array();
62 
63     auto newVarhats = varhats
64                      .zip(vars)
65                      .map!(x => max(x[0], x[1]))
66                      .array();
67 
68     auto newvals = zip(wrt, newMeans, newVars)
69                   .map!(x => x[0] - eta * (x[1] / (sqrt(x[2]) + eps)))
70                   .array();
71 
72     //Apply projections
73     for(size_t i = 0; i < newvals.length; i++)
74     {
75         if(wrt[i] in projs)
76         {
77             newvals[i] = projs[wrt[i]](newvals[i]);
78         }
79     }
80 
81     auto updatePlan = compile(outputs ~ newvals ~ newMeans ~ newVars ~ newVarhats ~ [nb1, nb2]);
82 
83     import std.range : chain;
84 
85     auto newbufs = chain(wrt, means, vars, varhats, [b1, b2])
86                   .map!(x => x.value)
87                   .array();
88 
89     newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs;
90 
91     Buffer[] update(Buffer[Operation] args)
92     {
93         updatePlan.execute(args, newbufs);
94 
95         return newbufs[0 .. outputs.length];
96     }
97 
98     return &update;
99 }
100 
101 ///
102 unittest
103 {
104     import std.random : uniform;
105 
106     //Generate some points
107     auto xdata = new float[100];
108     auto ydata = new float[100];
109 
110     foreach(i; 0 .. 100)
111     {
112         xdata[i] = uniform(-10.0f, 10.0f);
113         ydata[i] = 3.0f * xdata[i] + 2.0f;
114     }
115 
116     //Create the model
117     auto x = float32([]);
118     auto m = float32([]);
119     auto c = float32([]);
120 
121     auto yhat = m * x + c;
122     auto y = float32([]);
123 
124     //Create an AMSGrad updater
125     auto updater = amsgrad([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.1f]));
126 
127     //Iterate for a while
128     float loss;
129 
130     for(size_t i = 0; i < 300; i++)
131     {
132         size_t j = i % 100;
133 
134         loss = updater([
135             x: Buffer(xdata[j .. j + 1]),
136             y: Buffer(ydata[j .. j + 1])
137         ])[0].as!float[0];
138     }
139 
140     //Print the loss after 200 iterations. Let the user decide whether it's good enough to be considered a pass.
141     import std.stdio : writeln;
142     writeln(
143         "AMSGrad loss: ", loss, "    ",
144         "m=", m.value.as!float[0], ", ",
145         "c=", c.value.as!float[0], "    ",
146         "(expected m=3, c=2)");
147 }