1 /** 2 Contains an implementation of batch normalisation. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.nnet.layers.batchnorm; 7 8 import dopt.core; 9 import dopt.nnet; 10 import dopt.nnet.util; 11 import dopt.online; 12 13 /** 14 Encapsulates additional options for batchnorm layers. 15 */ 16 class BatchNormOptions 17 { 18 this() 19 { 20 _gammaInit = constantInit(1.0f); 21 _betaInit = constantInit(0.0f); 22 _gammaDecay = 0; 23 _momentum = 0.9f; 24 _maxgain = float.infinity; 25 _lipschitz = float.infinity; 26 } 27 28 mixin(dynamicProperties( 29 "ParamInitializer", "gammaInit", 30 "ParamInitializer", "betaInit", 31 "Projection", "gammaProj", 32 "Projection", "betaProj", 33 "float", "maxgain", 34 "float", "gammaDecay", 35 "float", "momentum", 36 "float", "lipschitz" 37 )); 38 } 39 40 /// 41 unittest 42 { 43 //Create a BatchNormOptions object with the default parameters 44 auto opts = new BatchNormOptions() 45 .gammaInit(constantInit(1.0f)) 46 .betaInit(constantInit(0.0f)) 47 .gammaProj(null) 48 .betaProj(null) 49 .gammaDecay(0.0f) 50 .momentum(0.9f); 51 52 //Options can also be read back again later 53 assert(opts.gammaDecay == 0.0f); 54 assert(opts.momentum == 0.9f); 55 } 56 57 /// 58 Layer batchNorm(Layer input, BatchNormOptions opts = new BatchNormOptions()) 59 { 60 /*Appologies to anyone trying to understand how I've implemented BN---this is a bit hacky! 61 What we're doing is packing the running mean/variance estimate provided during the training 62 forward propagation into the same tensor as the normalised layer activations. The batchNormTrain 63 function then seperates these out into 3 different operation nodes. We can then use the projected 64 gradient descent operator to constrain the mean/var model parameters to be equal to these running 65 statistics. 66 */ 67 68 import std.array : array; 69 import std.range : repeat; 70 71 auto x = input.output; 72 auto xTr = input.trainOutput; 73 74 auto gamma = float32([1, x.shape[1], 1, 1]); 75 auto beta = float32([x.shape[1]]); 76 77 opts._gammaInit(gamma); 78 opts._betaInit(beta); 79 80 auto mean = float32([x.shape[1]]); 81 auto var = float32([x.shape[1]], repeat(1.0f, x.shape[1]).array()); 82 83 auto bnop = xTr.batchNormTrain(gamma, beta, mean, var, opts._momentum); 84 auto yTr = bnop[0]; 85 auto meanUpdateSym = bnop[1]; 86 auto varUpdateSym = bnop[2]; 87 88 auto y = x.batchNormInference(gamma, beta, mean, var); 89 90 auto before = xTr; 91 auto zeros = float32Constant([before.shape[1]], repeat(0.0f, before.shape[1]).array()); 92 auto after = before.batchNormInference(gamma, zeros, zeros, var); 93 94 before = before.reshape([before.shape[0], before.volume / before.shape[0]]); 95 after = after.reshape([after.shape[0], after.volume / after.shape[0]]); 96 97 Operation maxGainProj(Operation newGamma) 98 { 99 auto beforeNorms = sum(before * before, [1]) + 1e-8; 100 auto afterNorms = sum(after * after, [1]) + 1e-8; 101 auto mg = maxElement(sqrt(afterNorms / beforeNorms)); 102 103 if(opts._gammaProj is null) 104 { 105 return newGamma * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain)); 106 } 107 else 108 { 109 return opts._gammaProj(newGamma * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain))); 110 } 111 } 112 113 Operation lipschitzProj(Operation newGamma) 114 { 115 auto norm = (newGamma / sqrt(varUpdateSym + 1e-6)).abs().maxElement(); 116 117 return newGamma * (1.0f / max(float32Constant(1.0f), norm / opts.lipschitz)); 118 } 119 120 Projection gammaProj = opts._gammaProj; 121 122 if(opts.maxgain != float.infinity) 123 { 124 gammaProj = &maxGainProj; 125 } 126 127 Operation meanUpdater(Operation ignored) 128 { 129 return meanUpdateSym; 130 } 131 132 Operation varUpdater(Operation ignored) 133 { 134 return varUpdateSym; 135 } 136 137 return new Layer([input], y, yTr, [ 138 Parameter(gamma, opts._gammaDecay == 0.0f ? null : opts._gammaDecay * sum(gamma * gamma), gammaProj), 139 Parameter(beta, null, opts._betaProj), 140 Parameter(mean, null, &meanUpdater), 141 Parameter(var, null, &varUpdater) 142 ]); 143 } 144 145 unittest 146 { 147 auto x = float32([3, 2], [1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]); 148 149 auto layers = dataSource(x).batchNorm(); 150 auto network = new DAGNetwork([x], [layers]); 151 152 auto trloss = layers.trainOutput.sum(); 153 154 auto updater = adam([trloss], network.params, network.paramProj); 155 156 for(size_t i = 0; i < 1000; i++) 157 { 158 updater(null); 159 } 160 161 import std.math : approxEqual; 162 163 assert(approxEqual(layers.params[2].symbol.value.get!float, [3.0f, 4.0f])); 164 }