1 /** 2 Contains an implementation of batch normalisation. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.nnet.layers.batchnorm; 7 8 import dopt.core; 9 import dopt.nnet; 10 import dopt.nnet.layers.util; 11 import dopt.online; 12 13 /** 14 Encapsulates additional options for batchnorm layers. 15 */ 16 class BatchNormOptions 17 { 18 this() 19 { 20 _gammaInit = constantInit(1.0f); 21 _betaInit = constantInit(0.0f); 22 _gammaDecay = 0; 23 _momentum = 0.9f; 24 } 25 26 mixin(dynamicProperties( 27 "ParamInitializer", "gammaInit", 28 "ParamInitializer", "betaInit", 29 "Projection", "gammaProj", 30 "Projection", "betaProj", 31 "float", "gammaDecay", 32 "float", "momentum" 33 )); 34 } 35 36 /// 37 unittest 38 { 39 //Create a BatchNormOptions object with the default parameters 40 auto opts = new BatchNormOptions() 41 .gammaInit(constantInit(1.0f)) 42 .betaInit(constantInit(0.0f)) 43 .gammaProj(null) 44 .gammaProj(null) 45 .gammaDecay(0.0f) 46 .momentum(0.9f); 47 48 //Options can also be read back again later 49 assert(opts.gammaDecay == 0.0f); 50 assert(opts.momentum == 0.9f); 51 } 52 53 /// 54 Layer batchNorm(Layer input, BatchNormOptions opts = new BatchNormOptions()) 55 { 56 /*Appologies to anyone trying to understand how I've implemented BN---this is a bit hacky! 57 What we're doing is packing the running mean/variance estimate provided during the training 58 forward propagation into the same tensor as the normalised layer activations. The batchNormTrain 59 function then seperates these out into 3 different operation nodes. We can then use the projected 60 gradient descent operator to constrain the mean/var model parameters to be equal to these running 61 statistics. 62 */ 63 64 import std.array : array; 65 import std.range : repeat; 66 67 auto x = input.output; 68 auto xTr = input.trainOutput; 69 70 auto gamma = float32([1, x.shape[1], 1, 1]); 71 auto beta = float32([x.shape[1]]); 72 73 opts._gammaInit(gamma); 74 opts._betaInit(beta); 75 76 auto mean = float32([x.shape[1]]); 77 auto var = float32([x.shape[1]], repeat(1.0f, x.shape[1]).array()); 78 79 auto bnop = xTr.batchNormTrain(gamma, beta, mean, var, opts._momentum); 80 auto yTr = bnop[0]; 81 auto meanUpdateSym = bnop[1]; 82 auto varUpdateSym = bnop[2]; 83 84 auto y = x.batchNormInference(gamma, beta, mean, var); 85 86 Operation meanUpdater(Operation ignored) 87 { 88 return meanUpdateSym; 89 } 90 91 Operation varUpdater(Operation ignored) 92 { 93 return varUpdateSym; 94 } 95 96 return new Layer([input], y, yTr, [ 97 Parameter(gamma, opts._gammaDecay == 0.0f ? null : opts._gammaDecay * sum(gamma * gamma), opts._gammaProj), 98 Parameter(beta, null, opts._betaProj), 99 Parameter(mean, null, &meanUpdater), 100 Parameter(var, null, &varUpdater) 101 ]); 102 } 103 104 unittest 105 { 106 auto x = float32([3, 2], [1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]); 107 108 auto layers = dataSource(x).batchNorm(); 109 auto network = new DAGNetwork([x], [layers]); 110 111 auto trloss = layers.trainOutput.sum(); 112 113 auto updater = adam([trloss], network.params, network.paramProj); 114 115 for(size_t i = 0; i < 1000; i++) 116 { 117 updater(null); 118 } 119 120 import std.math : approxEqual; 121 122 assert(approxEqual(layers.params[2].symbol.value.as!float, [3.0f, 4.0f])); 123 }