1 module dopt.nnet.models.wrn;
2 
3 import std.math : isNaN;
4 
5 import dopt.core;
6 import dopt.nnet;
7 import dopt.nnet.util;
8 import dopt.nnet.models.maybe;
9 import dopt.online;
10 
11 class WRNOptions
12 {
13     this()
14     {
15         _dropout = false;
16         _maxgainNorm = float.nan;
17         _lipschitzNorm = float.nan;
18         _maxNorm = float.infinity;
19         _spectralDecay = 0.0f;
20         _weightDecay = 0.0001f;
21         _stride = [1, 2, 2];
22     }
23 
24     void verify()
25     {
26         import std.exception : enforce;
27 
28         int regCtr;
29 
30         if(!isNaN(_maxgainNorm))
31         {
32             regCtr++;
33 
34             enforce(_maxgainNorm == 2.0f, "Only a maxgainNorm of 2 is currently supported.");
35         }
36 
37         if(!isNaN(_lipschitzNorm))
38         {
39             regCtr++;
40         }
41 
42         enforce(regCtr <= 1, "VGG models currently only support using one of maxgain and the lipschitz constraint");
43     }
44 
45     mixin(dynamicProperties(
46         "bool", "dropout",
47         "float", "maxgainNorm",
48         "float", "lipschitzNorm",
49         "float", "maxNorm",
50         "float", "spectralDecay",
51         "float", "weightDecay",
52         "size_t[3]", "stride"
53     ));
54 }
55 
56 Layer wideResNet(Operation features, size_t depth, size_t width, WRNOptions opts = new WRNOptions())
57 {
58     size_t n = (depth - 4) / 6;
59 
60     opts.verify();
61 
62     float maxgain = float.infinity;
63 
64     if(opts.maxgainNorm == 2.0f)
65     {
66         maxgain = opts.maxNorm;
67     }
68 
69     float lambda = float.infinity;
70     float lipschitzNorm = float.nan;
71 
72     if(!isNaN(opts.lipschitzNorm))
73     {
74         lipschitzNorm = opts.lipschitzNorm;
75         lambda = opts.maxNorm;
76     }
77 
78     Projection filterProj;
79     Operation lambdaSym = float32Constant(lambda);
80 
81     if(lambda != float.infinity)
82     {
83         filterProj = projConvParams(lambdaSym, features.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm);
84     }
85 
86     auto pred = dataSource(features)
87                .conv2D(16, [3, 3], new Conv2DOptions()
88                     .padding([1, 1])
89                     .useBias(false)
90                     .weightDecay(opts.weightDecay)
91                     .maxgain(maxgain)
92                     .filterProj(filterProj))
93                .wrnBlock(16 * width, n, opts.stride[0], opts)
94                .wrnBlock(32 * width, n, opts.stride[1], opts)
95                .wrnBlock(64 * width, n, opts.stride[2], opts)
96                .batchNorm(new BatchNormOptions().maxgain(maxgain))
97                .relu()
98                .meanPool();
99 
100     return pred;
101 }
102 
103 private Layer wrnBlock(Layer inLayer, size_t u, size_t n, size_t s, WRNOptions opts)
104 {
105     float maxgain = float.infinity;
106 
107     if(opts.maxgainNorm == 2.0f)
108     {
109         maxgain = opts.maxNorm;
110     }
111 
112     float lambda = float.infinity;
113     float lipschitzNorm = float.nan;
114 
115     if(!isNaN(opts.lipschitzNorm))
116     {
117         lipschitzNorm = opts.lipschitzNorm;
118         lambda = opts.maxNorm;
119     }
120 
121     Operation lambdaSym = float32Constant(lambda);
122 
123     auto convOpts()
124     {
125         return new Conv2DOptions()
126             .padding([1, 1])
127             .useBias(false)
128             .weightDecay(opts.weightDecay)
129             .maxgain(maxgain);
130     }
131 
132     auto bnOpts()
133     {
134         return new BatchNormOptions()
135             .maxgain(maxgain)
136             .lipschitz(lambda);
137     }
138 
139     Layer res;
140 
141     for(size_t i = 0; i < n; i++)
142     {
143         res = inLayer
144             .batchNorm(bnOpts())
145             .relu();
146         
147         Projection filterProj = null;
148 
149         if(lambda != float.infinity)
150         {
151             filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [s, s], [1, 1], lipschitzNorm);
152         }
153 
154         res = res
155              .conv2D(u, [3, 3], convOpts().stride([s, s]).filterProj(filterProj))
156              .batchNorm(bnOpts())
157              .relu()
158              .maybeDropout(opts.dropout ? 0.3f : 0.0f);
159         
160         if(lambda != float.infinity)
161         {
162             filterProj = projConvParams(lambdaSym, res.trainOutput.shape[2 .. $], [1, 1], [1, 1], lipschitzNorm);
163         }
164         
165         res = res
166              .conv2D(u, [3, 3], convOpts().filterProj(filterProj));
167         
168         Layer shortcut = inLayer;
169         
170         if(inLayer.output.shape[1] != res.output.shape[1])
171         {
172             if(lambda != float.infinity)
173             {
174                 filterProj = projConvParams(lambdaSym, inLayer.trainOutput.shape[2 .. $], [s, s], [1, 1],
175                     lipschitzNorm);
176             }
177 
178             shortcut = inLayer.conv2D(u, [1, 1], new Conv2DOptions()
179                                                 .stride([s, s])
180                                                 .useBias(false)
181                                                 .weightDecay(opts.weightDecay)
182                                                 .maxgain(maxgain)
183                                                 .filterProj(filterProj));
184         }
185 
186         res = new Layer(
187             [res, shortcut],
188             res.output + shortcut.output,
189             res.trainOutput + shortcut.trainOutput,
190             []
191         );
192 
193         inLayer = res;
194         s = 1;
195     }
196 
197     return res;
198 }
199 
200 private Layer meanPool(Layer input)
201 {
202     Operation meanPoolImpl(Operation inp)
203     {
204         auto mapVol = inp.shape[2] * inp.shape[3];
205         float scale = 1.0f / mapVol;
206 
207         return inp.reshape([inp.shape[0] * inp.shape[1], inp.shape[2] * inp.shape[3]])
208                   .sum([1])
209                   .reshape([inp.shape[0], inp.shape[1]]) * scale;
210     }
211 
212     auto y = meanPoolImpl(input.output);
213     auto yTr = meanPoolImpl(input.trainOutput);
214 
215     return new Layer([input], y, yTr, []);
216 }