1 /**
2     Contains an implementation of batch normalisation.
3     
4     Authors: Henry Gouk
5 */
6 module dopt.nnet.layers.batchnorm;
7 
8 import dopt.core;
9 import dopt.nnet;
10 import dopt.nnet.util;
11 import dopt.online;
12 
13 /**
14     Encapsulates additional options for batchnorm layers.
15 */
16 class BatchNormOptions
17 {
18     this()
19     {
20         _gammaInit = constantInit(1.0f);
21         _betaInit = constantInit(0.0f);
22         _gammaDecay = 0;
23         _momentum = 0.9f;
24         _maxgain = float.infinity;
25         _lipschitz = float.infinity;
26     }
27 
28     mixin(dynamicProperties(
29         "ParamInitializer", "gammaInit",
30         "ParamInitializer", "betaInit",
31         "Projection", "gammaProj",
32         "Projection", "betaProj",
33         "float", "maxgain",
34         "float", "gammaDecay",
35         "float", "momentum",
36         "float", "lipschitz"
37     ));
38 }
39 
40 ///
41 unittest
42 {
43     //Create a BatchNormOptions object with the default parameters
44     auto opts = new BatchNormOptions()
45                .gammaInit(constantInit(1.0f))
46                .betaInit(constantInit(0.0f))
47                .gammaProj(null)
48                .betaProj(null)
49                .gammaDecay(0.0f)
50                .momentum(0.9f);
51     
52     //Options can also be read back again later
53     assert(opts.gammaDecay == 0.0f);
54     assert(opts.momentum == 0.9f);
55 }
56 
57 ///
58 Layer batchNorm(Layer input, BatchNormOptions opts = new BatchNormOptions())
59 {
60     /*Appologies to anyone trying to understand how I've implemented BN---this is a bit hacky!
61       What we're doing is packing the running mean/variance estimate provided during the training
62       forward propagation into the same tensor as the normalised layer activations. The batchNormTrain
63       function then seperates these out into 3 different operation nodes. We can then use the projected
64       gradient descent operator to constrain the mean/var model parameters to be equal to these running
65       statistics.
66     */
67 
68     import std.array : array;
69     import std.range : repeat;
70 
71     auto x = input.output;
72     auto xTr = input.trainOutput;
73 
74     auto gamma = float32([1, x.shape[1], 1, 1]);
75     auto beta = float32([x.shape[1]]);
76 
77     opts._gammaInit(gamma);
78     opts._betaInit(beta);
79 
80     auto mean = float32([x.shape[1]]);
81     auto var = float32([x.shape[1]], repeat(1.0f, x.shape[1]).array());
82 
83     auto bnop = xTr.batchNormTrain(gamma, beta, mean, var, opts._momentum);
84     auto yTr = bnop[0];
85     auto meanUpdateSym = bnop[1];
86     auto varUpdateSym = bnop[2];
87 
88     auto y = x.batchNormInference(gamma, beta, mean, var);
89 
90     auto before = xTr;
91     auto zeros = float32Constant([before.shape[1]], repeat(0.0f, before.shape[1]).array());
92     auto after = before.batchNormInference(gamma, zeros, zeros, var);
93 
94     before = before.reshape([before.shape[0], before.volume / before.shape[0]]);
95     after = after.reshape([after.shape[0], after.volume / after.shape[0]]);
96 
97     Operation maxGainProj(Operation newGamma)
98     {
99         auto beforeNorms = sum(before * before, [1]) + 1e-8;
100         auto afterNorms = sum(after * after, [1]) + 1e-8;
101         auto mg = maxElement(sqrt(afterNorms / beforeNorms));
102 
103         if(opts._gammaProj is null)
104         {
105             return newGamma * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain));
106         }
107         else
108         {
109             return opts._gammaProj(newGamma * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain)));
110         }
111     }
112 
113     Operation lipschitzProj(Operation newGamma)
114     {
115         auto norm = (newGamma / sqrt(varUpdateSym + 1e-6)).abs().maxElement();
116 
117         return newGamma * (1.0f / max(float32Constant(1.0f), norm / opts.lipschitz));
118     }
119 
120     Projection gammaProj = opts._gammaProj;
121 
122     if(opts.maxgain != float.infinity)
123     {
124         gammaProj = &maxGainProj;
125     }
126 
127     Operation meanUpdater(Operation ignored)
128     {
129         return meanUpdateSym;
130     }
131 
132     Operation varUpdater(Operation ignored)
133     {
134         return varUpdateSym;
135     }
136 
137     return new Layer([input], y, yTr, [
138         Parameter(gamma, opts._gammaDecay == 0.0f ? null : opts._gammaDecay * sum(gamma * gamma), gammaProj),
139         Parameter(beta, null, opts._betaProj),
140         Parameter(mean, null, &meanUpdater),
141         Parameter(var, null, &varUpdater)
142     ]);
143 }
144 
145 unittest
146 {
147     auto x = float32([3, 2], [1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
148     
149     auto layers = dataSource(x).batchNorm();
150     auto network = new DAGNetwork([x], [layers]);
151 
152     auto trloss = layers.trainOutput.sum();
153 
154     auto updater = adam([trloss], network.params, network.paramProj);
155 
156     for(size_t i = 0; i < 1000; i++)
157     {
158         updater(null);
159     }
160 
161     import std.math : approxEqual;
162     
163     assert(approxEqual(layers.params[2].symbol.value.get!float, [3.0f, 4.0f]));
164 }