1 /**
2     Contains an implementation of ADAM that relies on automatic differentiation
3 
4     Authors: Henry Gouk
5 */
6 module dopt.online.adam;
7 
8 import dopt.core;
9 import dopt.online;
10 
11 /**
12     Creates a delegate that can be used to perform a step using the ADAM update rule.
13     
14     This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be
15     differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.
16 
17     Params:
18         outputs = An array of outputs. The first element of this array is the objective function to be minimised.
19         wrt = An array of Operations that we want the derivative of objective with respect to.
20         projs = Projection functions that can be applied when updating the values of elements in $(D wrt).
21         alpha = The step size.
22         beta1 = Fading factor for the first moment of the gradient.
23         beta2 = Fading factor for the second moment of the gradient.
24         eps = To prevent division by zero.
25 
26     Returns:
27          A delegate that is used to actually perform the update steps. The optimised values are stored in the
28          $(D value) properties of the elements of $(D wrt). The delegate returns the values computed for each element of the
29          $(D outputs) array. This can be useful for keeping track of several different performance metrics in a
30          prequential manner.
31 */
32 Updater adam(Operation[] outputs, Operation[] wrt, Projection[Operation] projs,
33     Operation alpha = float32([], [0.001f]), Operation beta1 = float32([], [0.9f]),
34     Operation beta2 = float32([], [0.999f]), Operation eps = float32([], [1e-8]))
35 {
36     import std.algorithm : map;
37     import std.array : array;
38     import std.range : zip;
39 
40     auto objective = outputs[0];
41 
42     auto grads = grad(objective, wrt);
43     auto means = wrt.map!(x => float32(x.shape)).array();
44     auto vars = wrt.map!(x => float32(x.shape)).array();
45 
46     auto b1 = float32([], [1.0f]);
47     auto b2 = float32([], [1.0f]);
48     auto nb1 = b1 * beta1;
49     auto nb2 = b2 * beta2;
50     auto eta = alpha * sqrt(1.0f - nb2) / (1.0f - nb1);
51 
52     auto newMeans = grads
53                    .zip(means)
54                    .map!(x => beta1 * x[1] + (1.0f - beta1) * x[0])
55                    .array();
56     
57     auto newVars = grads
58                    .zip(vars)
59                    .map!(x => beta2 * x[1] + (1.0f - beta2) * x[0] * x[0])
60                    .array();
61 
62     auto newvals = zip(wrt, newMeans, newVars)
63                   .map!(x => x[0] - eta * (x[1] / (sqrt(x[2]) + eps)))
64                   .array();
65 
66     //Apply projections
67     for(size_t i = 0; i < newvals.length; i++)
68     {
69         if(wrt[i] in projs)
70         {
71             newvals[i] = projs[wrt[i]](newvals[i]);
72         }
73     }
74 
75     auto updatePlan = compile(outputs ~ newvals ~ newMeans ~ newVars ~ [nb1, nb2]);
76 
77     import std.range : chain;
78 
79     auto newbufs = chain(wrt, means, vars, [b1, b2])
80                   .map!(x => x.value)
81                   .array();
82 
83     newbufs = outputs.map!(x => Buffer(new ubyte[x.volume * x.elementType.sizeOf])).array() ~ newbufs;
84 
85     Buffer[] update(Buffer[Operation] args)
86     {
87         updatePlan.execute(args, newbufs);
88 
89         return newbufs[0 .. outputs.length];
90     }
91 
92     return &update;
93 }
94 
95 ///
96 unittest
97 {
98     import std.random : uniform;
99 
100     //Generate some points
101     auto xdata = new float[100];
102     auto ydata = new float[100];
103 
104     foreach(i; 0 .. 100)
105     {
106         xdata[i] = uniform(-10.0f, 10.0f);
107         ydata[i] = 3.0f * xdata[i] + 2.0f;
108     }
109 
110     //Create the model
111     auto x = float32([]);
112     auto m = float32([]);
113     auto c = float32([]);
114 
115     auto yhat = m * x + c;
116     auto y = float32([]);
117 
118     //Create an ADAM updater
119     auto updater = adam([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.1f]));
120 
121     //Iterate for a while
122     float loss;
123 
124     for(size_t i = 0; i < 300; i++)
125     {
126         size_t j = i % 100;
127 
128         loss = updater([
129             x: Buffer(xdata[j .. j + 1]),
130             y: Buffer(ydata[j .. j + 1])
131         ])[0].as!float[0];
132     }
133 
134     //Print the loss after 200 iterations. Let the user decide whether it's good enough to be considered a pass.
135     import std.stdio : writeln;
136     writeln(
137         "Adam loss: ", loss, "    ",
138         "m=", m.value.as!float[0], ", ",
139         "c=", c.value.as!float[0], "    ",
140         "(expected m=3, c=2)");
141 }