1 /**
2     Contains an implementation of batch normalisation.
3     
4     Authors: Henry Gouk
5 */
6 module dopt.nnet.layers.batchnorm;
7 
8 import dopt.core;
9 import dopt.nnet;
10 import dopt.nnet.layers.util;
11 import dopt.online;
12 
13 /**
14     Encapsulates additional options for batchnorm layers.
15 */
16 class BatchNormOptions
17 {
18     this()
19     {
20         _gammaInit = constantInit(1.0f);
21         _betaInit = constantInit(0.0f);
22         _gammaDecay = 0;
23         _momentum = 0.9f;
24     }
25 
26     mixin(dynamicProperties(
27         "ParamInitializer", "gammaInit",
28         "ParamInitializer", "betaInit",
29         "Projection", "gammaProj",
30         "Projection", "betaProj",
31         "float", "gammaDecay",
32         "float", "momentum"
33     ));
34 }
35 
36 ///
37 unittest
38 {
39     //Create a BatchNormOptions object with the default parameters
40     auto opts = new BatchNormOptions()
41                .gammaInit(constantInit(1.0f))
42                .betaInit(constantInit(0.0f))
43                .gammaProj(null)
44                .gammaProj(null)
45                .gammaDecay(0.0f)
46                .momentum(0.9f);
47     
48     //Options can also be read back again later
49     assert(opts.gammaDecay == 0.0f);
50     assert(opts.momentum == 0.9f);
51 }
52 
53 ///
54 Layer batchNorm(Layer input, BatchNormOptions opts = new BatchNormOptions())
55 {
56     /*Appologies to anyone trying to understand how I've implemented BN---this is a bit hacky!
57       What we're doing is packing the running mean/variance estimate provided during the training
58       forward propagation into the same tensor as the normalised layer activations. The batchNormTrain
59       function then seperates these out into 3 different operation nodes. We can then use the projected
60       gradient descent operator to constrain the mean/var model parameters to be equal to these running
61       statistics.
62     */
63 
64     import std.array : array;
65     import std.range : repeat;
66 
67     auto x = input.output;
68     auto xTr = input.trainOutput;
69 
70     auto gamma = float32([1, x.shape[1], 1, 1]);
71     auto beta = float32([x.shape[1]]);
72 
73     opts._gammaInit(gamma);
74     opts._betaInit(beta);
75 
76     auto mean = float32([x.shape[1]]);
77     auto var = float32([x.shape[1]], repeat(1.0f, x.shape[1]).array());
78 
79     auto bnop = xTr.batchNormTrain(gamma, beta, mean, var, opts._momentum);
80     auto yTr = bnop[0];
81     auto meanUpdateSym = bnop[1];
82     auto varUpdateSym = bnop[2];
83 
84     auto y = x.batchNormInference(gamma, beta, mean, var);
85 
86     Operation meanUpdater(Operation ignored)
87     {
88         return meanUpdateSym;
89     }
90 
91     Operation varUpdater(Operation ignored)
92     {
93         return varUpdateSym;
94     }
95 
96     return new Layer([input], y, yTr, [
97         Parameter(gamma, opts._gammaDecay == 0.0f ? null : opts._gammaDecay * sum(gamma * gamma), opts._gammaProj),
98         Parameter(beta, null, opts._betaProj),
99         Parameter(mean, null, &meanUpdater),
100         Parameter(var, null, &varUpdater)
101     ]);
102 }
103 
104 unittest
105 {
106     auto x = float32([3, 2], [1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
107     
108     auto layers = dataSource(x).batchNorm();
109     auto network = new DAGNetwork([x], [layers]);
110 
111     auto trloss = layers.trainOutput.sum();
112 
113     auto updater = adam([trloss], network.params, network.paramProj);
114 
115     for(size_t i = 0; i < 1000; i++)
116     {
117         updater(null);
118     }
119 
120     import std.math : approxEqual;
121     
122     assert(approxEqual(layers.params[2].symbol.value.as!float, [3.0f, 4.0f]));
123 }