1 module dopt.nnet.models.wrn; 2 3 import std.math : isNaN; 4 5 import dopt.core; 6 import dopt.nnet; 7 import dopt.nnet.util; 8 import dopt.nnet.models.maybe; 9 import dopt.online; 10 11 class WRNOptions 12 { 13 this() 14 { 15 _dropout = false; 16 _maxgainNorm = float.nan; 17 _lipschitzNorm = float.nan; 18 _maxNorm = float.infinity; 19 _spectralDecay = 0.0f; 20 _weightDecay = 0.0001f; 21 _stride = [1, 2, 2]; 22 } 23 24 void verify() 25 { 26 import std.exception : enforce; 27 28 int regCtr; 29 30 if(!isNaN(_maxgainNorm)) 31 { 32 regCtr++; 33 34 enforce(_maxgainNorm == 2.0f, "Only a maxgainNorm of 2 is currently supported."); 35 } 36 37 if(!isNaN(_lipschitzNorm)) 38 { 39 regCtr++; 40 } 41 42 enforce(regCtr <= 1, "VGG models currently only support using one of maxgain and the lipschitz constraint"); 43 } 44 45 mixin(dynamicProperties( 46 "bool", "dropout", 47 "float", "maxgainNorm", 48 "float", "lipschitzNorm", 49 "float", "maxNorm", 50 "float", "spectralDecay", 51 "float", "weightDecay", 52 "size_t[3]", "stride" 53 )); 54 } 55 56 Layer wideResNet(Operation features, size_t depth, size_t width, WRNOptions opts = new WRNOptions()) 57 { 58 size_t n = (depth - 4) / 6; 59 60 opts.verify(); 61 62 float maxgain = float.infinity; 63 64 if(opts.maxgainNorm == 2.0f) 65 { 66 maxgain = opts.maxNorm; 67 } 68 69 float lambda = float.infinity; 70 float lipschitzNorm = float.nan; 71 72 if(!isNaN(opts.lipschitzNorm)) 73 { 74 lipschitzNorm = opts.lipschitzNorm; 75 lambda = opts.maxNorm; 76 } 77 78 Projection filterProj; 79 Operation lambdaSym = float32Constant(lambda); 80 81 if(lambda != float.infinity) 82 { 83 filterProj = projConvParams(lambdaSym, features.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm); 84 } 85 86 auto pred = dataSource(features) 87 .conv2D(16, [3, 3], new Conv2DOptions() 88 .padding([1, 1]) 89 .useBias(false) 90 .weightDecay(opts.weightDecay) 91 .spectralDecay(opts.spectralDecay) 92 .maxgain(maxgain) 93 .filterProj(filterProj)) 94 .wrnBlock(16 * width, n, opts.stride[0], opts) 95 .wrnBlock(32 * width, n, opts.stride[1], opts) 96 .wrnBlock(64 * width, n, opts.stride[2], opts) 97 .batchNorm(new BatchNormOptions().maxgain(maxgain).lipschitz(lambda)) 98 .relu() 99 .meanPool(); 100 101 return pred; 102 } 103 104 private Layer wrnBlock(Layer inLayer, size_t u, size_t n, size_t s, WRNOptions opts) 105 { 106 float maxgain = float.infinity; 107 108 if(opts.maxgainNorm == 2.0f) 109 { 110 maxgain = opts.maxNorm; 111 } 112 113 float lambda = float.infinity; 114 float lipschitzNorm = float.nan; 115 116 if(!isNaN(opts.lipschitzNorm)) 117 { 118 lipschitzNorm = opts.lipschitzNorm; 119 lambda = opts.maxNorm; 120 } 121 122 Operation lambdaSym = float32Constant(lambda); 123 124 auto convOpts() 125 { 126 return new Conv2DOptions() 127 .padding([1, 1]) 128 .useBias(false) 129 .weightDecay(opts.weightDecay) 130 .spectralDecay(opts.spectralDecay) 131 .maxgain(maxgain); 132 } 133 134 auto bnOpts() 135 { 136 return new BatchNormOptions() 137 .maxgain(maxgain) 138 .lipschitz(lambda); 139 } 140 141 Layer res; 142 143 for(size_t i = 0; i < n; i++) 144 { 145 res = inLayer 146 .batchNorm(bnOpts()) 147 .relu(); 148 149 Projection filterProj = null; 150 151 if(lambda != float.infinity) 152 { 153 filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [s, s], [1, 1], lipschitzNorm); 154 } 155 156 res = res 157 .conv2D(u, [3, 3], convOpts().stride([s, s]).filterProj(filterProj)) 158 .batchNorm(bnOpts()) 159 .relu() 160 .maybeDropout(opts.dropout ? 0.3f : 0.0f); 161 162 if(lambda != float.infinity) 163 { 164 filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm); 165 } 166 167 res = res 168 .conv2D(u, [3, 3], convOpts().filterProj(filterProj)); 169 170 Layer shortcut = inLayer; 171 172 if(inLayer.output.shape[1] != res.output.shape[1]) 173 { 174 if(lambda != float.infinity) 175 { 176 filterProj = projConvParams(lambdaSym, inLayer.trainOutput.shape[2 .. $], [s, s], [1, 1], 177 lipschitzNorm); 178 } 179 180 shortcut = inLayer.conv2D(u, [1, 1], new Conv2DOptions() 181 .stride([s, s]) 182 .useBias(false) 183 .weightDecay(opts.weightDecay) 184 .spectralDecay(opts.spectralDecay) 185 .maxgain(maxgain) 186 .filterProj(filterProj)); 187 } 188 189 res = new Layer( 190 [res, shortcut], 191 res.output + shortcut.output, 192 res.trainOutput + shortcut.trainOutput, 193 [] 194 ); 195 196 inLayer = res; 197 s = 1; 198 } 199 200 return res; 201 } 202 203 private Layer meanPool(Layer input) 204 { 205 Operation meanPoolImpl(Operation inp) 206 { 207 auto mapVol = inp.shape[2] * inp.shape[3]; 208 float scale = 1.0f / mapVol; 209 210 return inp.reshape([inp.shape[0] * inp.shape[1], inp.shape[2] * inp.shape[3]]) 211 .sum([1]) 212 .reshape([inp.shape[0], inp.shape[1]]) * scale; 213 } 214 215 auto y = meanPoolImpl(input.output); 216 auto yTr = meanPoolImpl(input.trainOutput); 217 218 return new Layer([input], y, yTr, []); 219 }