1 module dopt.nnet.models.wrn;
2 
3 import std.math : isNaN;
4 
5 import dopt.core;
6 import dopt.nnet;
7 import dopt.nnet.util;
8 import dopt.nnet.models.maybe;
9 import dopt.online;
10 
11 class WRNOptions
12 {
13     this()
14     {
15         _dropout = false;
16         _maxgainNorm = float.nan;
17         _lipschitzNorm = float.nan;
18         _maxNorm = float.infinity;
19         _spectralDecay = 0.0f;
20         _weightDecay = 0.0001f;
21         _stride = [1, 2, 2];
22     }
23 
24     void verify()
25     {
26         import std.exception : enforce;
27 
28         int regCtr;
29 
30         if(!isNaN(_maxgainNorm))
31         {
32             regCtr++;
33 
34             enforce(_maxgainNorm == 2.0f, "Only a maxgainNorm of 2 is currently supported.");
35         }
36 
37         if(!isNaN(_lipschitzNorm))
38         {
39             regCtr++;
40         }
41 
42         enforce(regCtr <= 1, "VGG models currently only support using one of maxgain and the lipschitz constraint");
43     }
44 
45     mixin(dynamicProperties(
46         "bool", "dropout",
47         "float", "maxgainNorm",
48         "float", "lipschitzNorm",
49         "float", "maxNorm",
50         "float", "spectralDecay",
51         "float", "weightDecay",
52         "size_t[3]", "stride"
53     ));
54 }
55 
56 Layer wideResNet(Operation features, size_t depth, size_t width, WRNOptions opts = new WRNOptions())
57 {
58     size_t n = (depth - 4) / 6;
59 
60     opts.verify();
61 
62     float maxgain = float.infinity;
63 
64     if(opts.maxgainNorm == 2.0f)
65     {
66         maxgain = opts.maxNorm;
67     }
68 
69     float lambda = float.infinity;
70     float lipschitzNorm = float.nan;
71 
72     if(!isNaN(opts.lipschitzNorm))
73     {
74         lipschitzNorm = opts.lipschitzNorm;
75         lambda = opts.maxNorm;
76     }
77 
78     Projection filterProj;
79     Operation lambdaSym = float32Constant(lambda);
80 
81     if(lambda != float.infinity)
82     {
83         filterProj = projConvParams(lambdaSym, features.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm);
84     }
85 
86     auto pred = dataSource(features)
87                .conv2D(16, [3, 3], new Conv2DOptions()
88                     .padding([1, 1])
89                     .useBias(false)
90                     .weightDecay(opts.weightDecay)
91                     .spectralDecay(opts.spectralDecay)
92                     .maxgain(maxgain)
93                     .filterProj(filterProj))
94                .wrnBlock(16 * width, n, opts.stride[0], opts)
95                .wrnBlock(32 * width, n, opts.stride[1], opts)
96                .wrnBlock(64 * width, n, opts.stride[2], opts)
97                .batchNorm(new BatchNormOptions().maxgain(maxgain).lipschitz(lambda))
98                .relu()
99                .meanPool();
100 
101     return pred;
102 }
103 
104 private Layer wrnBlock(Layer inLayer, size_t u, size_t n, size_t s, WRNOptions opts)
105 {
106     float maxgain = float.infinity;
107 
108     if(opts.maxgainNorm == 2.0f)
109     {
110         maxgain = opts.maxNorm;
111     }
112 
113     float lambda = float.infinity;
114     float lipschitzNorm = float.nan;
115 
116     if(!isNaN(opts.lipschitzNorm))
117     {
118         lipschitzNorm = opts.lipschitzNorm;
119         lambda = opts.maxNorm;
120     }
121 
122     Operation lambdaSym = float32Constant(lambda);
123 
124     auto convOpts()
125     {
126         return new Conv2DOptions()
127             .padding([1, 1])
128             .useBias(false)
129             .weightDecay(opts.weightDecay)
130             .spectralDecay(opts.spectralDecay)
131             .maxgain(maxgain);
132     }
133 
134     auto bnOpts()
135     {
136         return new BatchNormOptions()
137             .maxgain(maxgain)
138             .lipschitz(lambda);
139     }
140 
141     Layer res;
142 
143     for(size_t i = 0; i < n; i++)
144     {
145         res = inLayer
146             .batchNorm(bnOpts())
147             .relu();
148         
149         Projection filterProj = null;
150 
151         if(lambda != float.infinity)
152         {
153             filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [s, s], [1, 1], lipschitzNorm);
154         }
155 
156         res = res
157              .conv2D(u, [3, 3], convOpts().stride([s, s]).filterProj(filterProj))
158              .batchNorm(bnOpts())
159              .relu()
160              .maybeDropout(opts.dropout ? 0.3f : 0.0f);
161         
162         if(lambda != float.infinity)
163         {
164             filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm);
165         }
166         
167         res = res
168              .conv2D(u, [3, 3], convOpts().filterProj(filterProj));
169         
170         Layer shortcut = inLayer;
171         
172         if(inLayer.output.shape[1] != res.output.shape[1])
173         {
174             if(lambda != float.infinity)
175             {
176                 filterProj = projConvParams(lambdaSym, inLayer.trainOutput.shape[2 .. $], [s, s], [1, 1],
177                     lipschitzNorm);
178             }
179 
180             shortcut = inLayer.conv2D(u, [1, 1], new Conv2DOptions()
181                                                 .stride([s, s])
182                                                 .useBias(false)
183                                                 .weightDecay(opts.weightDecay)
184                                                 .spectralDecay(opts.spectralDecay)
185                                                 .maxgain(maxgain)
186                                                 .filterProj(filterProj));
187         }
188 
189         res = new Layer(
190             [res, shortcut],
191             res.output + shortcut.output,
192             res.trainOutput + shortcut.trainOutput,
193             []
194         );
195 
196         inLayer = res;
197         s = 1;
198     }
199 
200     return res;
201 }
202 
203 private Layer meanPool(Layer input)
204 {
205     Operation meanPoolImpl(Operation inp)
206     {
207         auto mapVol = inp.shape[2] * inp.shape[3];
208         float scale = 1.0f / mapVol;
209 
210         return inp.reshape([inp.shape[0] * inp.shape[1], inp.shape[2] * inp.shape[3]])
211                   .sum([1])
212                   .reshape([inp.shape[0], inp.shape[1]]) * scale;
213     }
214 
215     auto y = meanPoolImpl(input.output);
216     auto yTr = meanPoolImpl(input.trainOutput);
217 
218     return new Layer([input], y, yTr, []);
219 }