1 /**
2     Contains an implementation of dense (i.e., fully connected) layers.
3     Authors: Henry Gouk
4 */
5 module dopt.nnet.layers.dense;
6 
7 import dopt.core;
8 import dopt.nnet;
9 import dopt.nnet.util;
10 import dopt.online;
11 
12 /**
13     Encapsulates additional options for dense layers.
14 */
15 class DenseOptions
16 {
17     this()
18     {
19         _weightInit = heGaussianInit();
20         _biasInit = constantInit(0.0f);
21         _useBias = true;
22         _weightDecay = 0;
23         _maxgain = float.infinity;
24         _spectralDecay = 0.0f;
25     }
26 
27     mixin(dynamicProperties(
28         "ParamInitializer", "weightInit",
29         "ParamInitializer", "biasInit",
30         "Projection", "weightProj",
31         "Projection", "biasProj",
32         "float", "maxgain",
33         "float", "weightDecay",
34         "float", "spectralDecay",
35         "bool", "useBias"
36     ));
37 }
38 
39 ///
40 unittest
41 {
42     //Create a DenseOptions object with the default parameters
43     auto opts = new DenseOptions()
44                .weightInit(heGaussianInit())
45                .biasInit(constantInit(0.0f))
46                .weightProj(null)
47                .biasProj(null)
48                .weightDecay(0.0f)
49                .useBias(true);
50     
51     //Options can also be read back again later
52     assert(opts.weightDecay == 0.0f);
53     assert(opts.useBias == true);
54 }
55 
56 /**
57     Creates a fully connected (AKA, dense) layer.
58 
59     Params:
60         input = The previous layer in the network.
61         numOutputs = The number of units in this layer.
62         opts = Additional options with sensible default values.
63     
64     Returns:
65         The new layer.
66 */
67 Layer dense(Layer input, size_t numOutputs, DenseOptions opts = new DenseOptions())
68 {
69     Operation safeAdd(Operation op1, Operation op2)
70     {
71         if(op1 is null && op2 is null)
72         {
73             return null;
74         }
75         else if(op1 is null)
76         {
77             return op2;
78         }
79         else if(op2 is null)
80         {
81             return op1;
82         }
83         else
84         {
85             return op1 + op2;
86         }
87     }
88 
89     auto x = input.output;
90     auto xTr = input.trainOutput;
91 
92     x = x.reshape([x.shape[0], x.volume / x.shape[0]]);
93     xTr = xTr.reshape([xTr.shape[0], xTr.volume / xTr.shape[0]]);
94 
95     auto weights = float32([numOutputs, x.shape[1]]);
96     opts._weightInit(weights);
97 
98     Operation weightLoss;
99     weightLoss = safeAdd(weightLoss, (opts.weightDecay == 0.0f) ? null : (opts.weightDecay * sum(weights * weights)));
100     weightLoss = safeAdd(
101         weightLoss,
102         (opts.spectralDecay == 0.0f) ? null : (opts.spectralDecay * spectralNorm(weights))
103     );
104 
105     auto weightProj = opts._weightProj;
106 
107     auto y = matmul(x, weights.transpose([1, 0]));
108     auto yTr = matmul(xTr, weights.transpose([1, 0]));
109 
110     auto before = xTr.reshape([xTr.shape[0], xTr.volume / xTr.shape[0]]);
111     auto after = yTr.reshape([yTr.shape[0], yTr.volume / yTr.shape[0]]);
112 
113     Operation maxGainProj(Operation newWeights)
114     {
115         auto beforeNorms = sum(before * before, [1]) + 1e-8;
116         auto afterNorms = sum(after * after, [1]) + 1e-8;
117         auto mg = maxElement(sqrt(afterNorms / beforeNorms));
118 
119         if(opts.weightProj is null)
120         {
121             return newWeights * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain));
122         }
123         else
124         {
125             return opts._weightProj(newWeights * (1.0f / max(float32Constant([], [1.0f]), mg / opts.maxgain)));
126         }
127     }
128 
129     if(opts.maxgain != float.infinity)
130     {
131         weightProj = &maxGainProj;
132     }
133 
134     Parameter[] params = [
135         Parameter(weights, weightLoss, weightProj)
136     ];
137 
138     if(opts.useBias)
139     {
140         auto bias = float32([numOutputs]);
141         opts._biasInit(bias);
142 
143         y = y + bias.repeat(y.shape[0]);
144         yTr = yTr + bias.repeat(yTr.shape[0]);
145 
146         params ~= Parameter(bias, null, opts.biasProj);
147     }
148 
149     return new Layer([input], y, yTr, params);
150 }
151 
152 private Operation spectralNorm(Operation weights, size_t numIts = 1)
153 {
154     auto x = uniformSample([weights.shape[0], 1]) * 2.0f - 1.0f;
155 
156     for(int i = 0; i < numIts; i++)
157     {
158         x = matmul(weights.transpose([1, 0]), matmul(weights, x));
159     }
160 
161     auto v = x / sqrt(sum(x * x));
162     auto y = matmul(weights, v);
163 
164     return sum(y * y);
165 }