1 /** 2 Contains an implementation of batch normalisation. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.nnet.layers.batchnorm; 7 8 import dopt.core; 9 import dopt.nnet; 10 import dopt.nnet.util; 11 import dopt.online; 12 13 /** 14 Encapsulates additional options for batchnorm layers. 15 */ 16 class BatchNormOptions 17 { 18 this() 19 { 20 _gammaInit = constantInit(1.0f); 21 _betaInit = constantInit(0.0f); 22 _gammaDecay = 0; 23 _momentum = 0.9f; 24 _maxgain = float.infinity; 25 _lipschitz = float.infinity; 26 } 27 28 mixin(dynamicProperties( 29 "ParamInitializer", "gammaInit", 30 "ParamInitializer", "betaInit", 31 "Projection", "gammaProj", 32 "Projection", "betaProj", 33 "float", "maxgain", 34 "float", "gammaDecay", 35 "float", "momentum", 36 "float", "lipschitz" 37 )); 38 } 39 40 /// 41 unittest 42 { 43 //Create a BatchNormOptions object with the default parameters 44 auto opts = new BatchNormOptions() 45 .gammaInit(constantInit(1.0f)) 46 .betaInit(constantInit(0.0f)) 47 .gammaProj(null) 48 .betaProj(null) 49 .gammaDecay(0.0f) 50 .momentum(0.9f); 51 52 //Options can also be read back again later 53 assert(opts.gammaDecay == 0.0f); 54 assert(opts.momentum == 0.9f); 55 } 56 57 /// 58 Layer batchNorm(Layer input, BatchNormOptions opts = new BatchNormOptions()) 59 { 60 /*Appologies to anyone trying to understand how I've implemented BN---this is a bit hacky! 61 What we're doing is packing the running mean/variance estimate provided during the training 62 forward propagation into the same tensor as the normalised layer activations. The batchNormTrain 63 function then seperates these out into 3 different operation nodes. We can then use the projected 64 gradient descent operator to constrain the mean/var model parameters to be equal to these running 65 statistics. 66 */ 67 68 import std.array : array; 69 import std.range : repeat; 70 71 auto x = input.output; 72 auto xTr = input.trainOutput; 73 74 auto gamma = float32([1, x.shape[1], 1, 1]); 75 auto beta = float32([x.shape[1]]); 76 77 opts._gammaInit(gamma); 78 opts._betaInit(beta); 79 80 auto mean = float32([x.shape[1]]); 81 auto var = float32([x.shape[1]], repeat(1.0f, x.shape[1]).array()); 82 83 auto bnop = xTr.batchNormTrain(gamma, beta, mean, var, opts._momentum); 84 auto yTr = bnop[0]; 85 auto meanUpdateSym = bnop[1]; 86 auto varUpdateSym = bnop[2]; 87 88 auto y = x.batchNormInference(gamma, beta, mean, var); 89 90 auto before = xTr; 91 auto zeros = float32Constant([before.shape[1]], repeat(0.0f, before.shape[1]).array()); 92 auto after = before.batchNormInference(gamma, zeros, zeros, var); 93 94 before = before.reshape([before.shape[0], before.volume / before.shape[0]]); 95 after = after.reshape([after.shape[0], after.volume / after.shape[0]]); 96 97 Operation maxGainProj(Operation newGamma) 98 { 99 auto beforeNorms = sum(before * before, [1]) + 1e-8; 100 auto afterNorms = sum(after * after, [1]) + 1e-8; 101 auto mg = maxElement(sqrt(afterNorms / beforeNorms)); 102 103 if(opts._gammaProj is null) 104 { 105 return newGamma * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain)); 106 } 107 else 108 { 109 return opts._gammaProj(newGamma * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain))); 110 } 111 } 112 113 Operation lipschitzProj(Operation newGamma) 114 { 115 auto norm = (newGamma / sqrt(varUpdateSym.reshape(newGamma.shape) + 1e-6)).abs().maxElement(); 116 117 auto g = newGamma * (1.0f / max(float32Constant(1.0f), norm / opts.lipschitz)); 118 119 if(opts._gammaProj is null) 120 { 121 return g; 122 } 123 else 124 { 125 return opts._gammaProj(g); 126 } 127 } 128 129 Projection gammaProj = opts._gammaProj; 130 131 if(opts.maxgain != float.infinity) 132 { 133 gammaProj = &maxGainProj; 134 } 135 else if(opts.lipschitz != float.infinity) 136 { 137 gammaProj = &lipschitzProj; 138 } 139 140 Operation meanUpdater(Operation ignored) 141 { 142 return meanUpdateSym; 143 } 144 145 Operation varUpdater(Operation ignored) 146 { 147 return varUpdateSym; 148 } 149 150 return new Layer([input], y, yTr, [ 151 Parameter(gamma, opts._gammaDecay == 0.0f ? null : opts._gammaDecay * sum(gamma * gamma), gammaProj), 152 Parameter(beta, null, opts._betaProj), 153 Parameter(mean, null, &meanUpdater), 154 Parameter(var, null, &varUpdater) 155 ]); 156 } 157 158 unittest 159 { 160 auto x = float32([3, 2], [1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]); 161 162 auto layers = dataSource(x).batchNorm(); 163 auto network = new DAGNetwork([x], [layers]); 164 165 auto trloss = layers.trainOutput.sum(); 166 167 auto updater = adam([trloss], network.params, network.paramProj); 168 169 for(size_t i = 0; i < 1000; i++) 170 { 171 updater(null); 172 } 173 174 import std.math : approxEqual; 175 176 assert(approxEqual(layers.params[2].symbol.value.get!float, [3.0f, 4.0f])); 177 }