amsgrad

Creates a delegate that can be used to perform a step using the AMSGrad update rule.

This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.

amsgrad
(
Operation[] outputs
,
Operation[] wrt
,
Projection[Operation] projs
,
Operation alpha = float32([], [0.001f])
,
Operation beta1 = float32([], [0.9f])
,
Operation beta2 = float32([], [0.999f])
,
Operation eps = float32([], [1e-8])
)

Parameters

outputs
Type: Operation[]

An array of outputs. The first element of this array is the objective function to be minimised.

wrt
Type: Operation[]

An array of Operations that we want the derivative of objective with respect to.

projs
Type: Projection[Operation]

Projection functions that can be applied when updating the values of elements in wrt.

alpha
Type: Operation

The step size.

beta1
Type: Operation

Fading factor for the first moment of the gradient.

beta2
Type: Operation

Fading factor for the second moment of the gradient.

eps
Type: Operation

To prevent division by zero.

Return Value

Type: Updater

A delegate that is used to actually perform the update steps. The optimised values are stored in the value properties of the elements of wrt. The delegate returns the values computed for each element of the outputs array. This can be useful for keeping track of several different performance metrics in a prequential manner.

Examples

1 import std.random : uniform;
2 
3 //Generate some points
4 auto xdata = new float[100];
5 auto ydata = new float[100];
6 
7 foreach(i; 0 .. 100)
8 {
9     xdata[i] = uniform(-10.0f, 10.0f);
10     ydata[i] = 3.0f * xdata[i] + 2.0f;
11 }
12 
13 //Create the model
14 auto x = float32([]);
15 auto m = float32([]);
16 auto c = float32([]);
17 
18 auto yhat = m * x + c;
19 auto y = float32([]);
20 
21 //Create an AMSGrad updater
22 auto updater = amsgrad([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.1f]));
23 
24 //Iterate for a while
25 float loss;
26 
27 for(size_t i = 0; i < 300; i++)
28 {
29     size_t j = i % 100;
30 
31     loss = updater([
32         x: Buffer(xdata[j .. j + 1]),
33         y: Buffer(ydata[j .. j + 1])
34     ])[0].as!float[0];
35 }
36 
37 //Print the loss after 200 iterations. Let the user decide whether it's good enough to be considered a pass.
38 import std.stdio : writeln;
39 writeln(
40     "AMSGrad loss: ", loss, "    ",
41     "m=", m.value.as!float[0], ", ",
42     "c=", c.value.as!float[0], "    ",
43     "(expected m=3, c=2)");

Meta