1 module dopt.nnet.models.wrn; 2 3 import std.math : isNaN; 4 5 import dopt.core; 6 import dopt.nnet; 7 import dopt.nnet.util; 8 import dopt.nnet.models.maybe; 9 import dopt.online; 10 11 class WRNOptions 12 { 13 this() 14 { 15 _dropout = false; 16 _maxgainNorm = float.nan; 17 _lipschitzNorm = float.nan; 18 _maxNorm = float.infinity; 19 _spectralDecay = 0.0f; 20 _weightDecay = 0.0001f; 21 _stride = [1, 2, 2]; 22 } 23 24 void verify() 25 { 26 import std.exception : enforce; 27 28 int regCtr; 29 30 if(!isNaN(_maxgainNorm)) 31 { 32 regCtr++; 33 34 enforce(_maxgainNorm == 2.0f, "Only a maxgainNorm of 2 is currently supported."); 35 } 36 37 if(!isNaN(_lipschitzNorm)) 38 { 39 regCtr++; 40 } 41 42 enforce(regCtr <= 1, "VGG models currently only support using one of maxgain and the lipschitz constraint"); 43 } 44 45 mixin(dynamicProperties( 46 "bool", "dropout", 47 "float", "maxgainNorm", 48 "float", "lipschitzNorm", 49 "float", "maxNorm", 50 "float", "spectralDecay", 51 "float", "weightDecay", 52 "size_t[3]", "stride" 53 )); 54 } 55 56 Layer wideResNet(Operation features, size_t depth, size_t width, WRNOptions opts = new WRNOptions()) 57 { 58 size_t n = (depth - 4) / 6; 59 60 opts.verify(); 61 62 float maxgain = float.infinity; 63 64 if(opts.maxgainNorm == 2.0f) 65 { 66 maxgain = opts.maxNorm; 67 } 68 69 float lambda = float.infinity; 70 float lipschitzNorm = float.nan; 71 72 if(!isNaN(opts.lipschitzNorm)) 73 { 74 lipschitzNorm = opts.lipschitzNorm; 75 lambda = opts.maxNorm; 76 } 77 78 Projection filterProj; 79 Operation lambdaSym = float32Constant(lambda); 80 81 if(lambda != float.infinity) 82 { 83 filterProj = projConvParams(lambdaSym, features.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm); 84 } 85 86 auto pred = dataSource(features) 87 .conv2D(16, [3, 3], new Conv2DOptions() 88 .padding([1, 1]) 89 .useBias(false) 90 .weightDecay(opts.weightDecay) 91 .maxgain(maxgain) 92 .filterProj(filterProj)) 93 .wrnBlock(16 * width, n, opts.stride[0], opts) 94 .wrnBlock(32 * width, n, opts.stride[1], opts) 95 .wrnBlock(64 * width, n, opts.stride[2], opts) 96 .batchNorm(new BatchNormOptions().maxgain(maxgain)) 97 .relu() 98 .meanPool(); 99 100 return pred; 101 } 102 103 private Layer wrnBlock(Layer inLayer, size_t u, size_t n, size_t s, WRNOptions opts) 104 { 105 float maxgain = float.infinity; 106 107 if(opts.maxgainNorm == 2.0f) 108 { 109 maxgain = opts.maxNorm; 110 } 111 112 float lambda = float.infinity; 113 float lipschitzNorm = float.nan; 114 115 if(!isNaN(opts.lipschitzNorm)) 116 { 117 lipschitzNorm = opts.lipschitzNorm; 118 lambda = opts.maxNorm; 119 } 120 121 Operation lambdaSym = float32Constant(lambda); 122 123 auto convOpts() 124 { 125 return new Conv2DOptions() 126 .padding([1, 1]) 127 .useBias(false) 128 .weightDecay(opts.weightDecay) 129 .maxgain(maxgain); 130 } 131 132 auto bnOpts() 133 { 134 return new BatchNormOptions() 135 .maxgain(maxgain) 136 .lipschitz(lambda); 137 } 138 139 Layer res; 140 141 for(size_t i = 0; i < n; i++) 142 { 143 res = inLayer 144 .batchNorm(bnOpts()) 145 .relu(); 146 147 Projection filterProj = null; 148 149 if(lambda != float.infinity) 150 { 151 filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [s, s], [1, 1], lipschitzNorm); 152 } 153 154 res = res 155 .conv2D(u, [3, 3], convOpts().stride([s, s]).filterProj(filterProj)) 156 .batchNorm(bnOpts()) 157 .relu() 158 .maybeDropout(opts.dropout ? 0.3f : 0.0f); 159 160 if(lambda != float.infinity) 161 { 162 filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm); 163 } 164 165 res = res 166 .conv2D(u, [3, 3], convOpts().filterProj(filterProj)); 167 168 Layer shortcut = inLayer; 169 170 if(inLayer.output.shape[1] != res.output.shape[1]) 171 { 172 if(lambda != float.infinity) 173 { 174 filterProj = projConvParams(lambdaSym, inLayer.trainOutput.shape[2 .. $], [s, s], [1, 1], 175 lipschitzNorm); 176 } 177 178 shortcut = inLayer.conv2D(u, [1, 1], new Conv2DOptions() 179 .stride([s, s]) 180 .useBias(false) 181 .weightDecay(opts.weightDecay) 182 .maxgain(maxgain) 183 .filterProj(filterProj)); 184 } 185 186 res = new Layer( 187 [res, shortcut], 188 res.output + shortcut.output, 189 res.trainOutput + shortcut.trainOutput, 190 [] 191 ); 192 193 inLayer = res; 194 s = 1; 195 } 196 197 return res; 198 } 199 200 private Layer meanPool(Layer input) 201 { 202 Operation meanPoolImpl(Operation inp) 203 { 204 auto mapVol = inp.shape[2] * inp.shape[3]; 205 float scale = 1.0f / mapVol; 206 207 return inp.reshape([inp.shape[0] * inp.shape[1], inp.shape[2] * inp.shape[3]]) 208 .sum([1]) 209 .reshape([inp.shape[0], inp.shape[1]]) * scale; 210 } 211 212 auto y = meanPoolImpl(input.output); 213 auto yTr = meanPoolImpl(input.trainOutput); 214 215 return new Layer([input], y, yTr, []); 216 }