sgd

Creates a delegate that can be used to perform a step using the stochastic gradient descent update rule.

This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.

sgd
(
Operation[] outputs
,
Operation[] wrt
,
Projection[Operation] projs
,
Operation learningRate = float32([], [0.01f])
,
Operation momentumRate = float32([], [0.0f])
)

Parameters

wrt
Type: Operation[]

an array of Operations that we want the derivative of objective with respect to.

learningRate
Type: Operation

the value used to scale the size of the gradient used in the update rule

momentumRate
Type: Operation

scaling factor for the previous update

Return Value

Type: Updater

A delegate that is used to actually perform the update steps. The optimised values are stored in the "default" attributes of the elements of wrt.

Examples

1 import std.random : uniform;
2 
3 //Generate some points
4 auto xdata = new float[100];
5 auto ydata = new float[100];
6 
7 foreach(i; 0 .. 100)
8 {
9     xdata[i] = uniform(-10.0f, 10.0f);
10     ydata[i] = 3.0f * xdata[i] + 2.0f;
11 }
12 
13 //Create the model
14 auto x = float32([]);
15 auto m = float32([]);
16 auto c = float32([]);
17 
18 auto yhat = m * x + c;
19 auto y = float32([]);
20 
21 //Create an SGD updater
22 auto updater = sgd([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.001f]), float32([], [0.9f]));
23 
24 //Iterate for a while
25 float loss;
26 
27 for(size_t i = 0; i < 300; i++)
28 {
29     size_t j = i % 100;
30 
31     loss = updater([
32         x: Buffer(xdata[j .. j + 1]),
33         y: Buffer(ydata[j .. j + 1])
34     ])[0].as!float[0];
35 }
36 
37 //Print the loss after 500 iterations. Let the user decide whether it's good enough to be considered a pass.
38 import std.stdio : writeln;
39 writeln(
40     "SGD loss: ", loss, "    ",
41     "m=", m.value.as!float[0], ", ",
42     "c=", c.value.as!float[0], "    ",
43     "(expected m=3, c=2)");

Meta