1 module dopt.core.grads.nnet;
2 
3 import dopt.core.grads;
4 import dopt.core.ops;
5 
6 package
7 {
8     void initialize()
9     {
10         import std.functional : toDelegate;
11         
12         registerGradient("convolution", toDelegate(&convolutionGrad));
13         registerGradient("convolutionFeaturesGrad", toDelegate(&convolutionFeaturesGradGrad));
14         registerGradient("maxpool", toDelegate(&maxpoolGrad));
15         registerGradient("softmax", toDelegate(&softmaxGrad));
16         registerGradient("addBias", toDelegate(&addBiasGrad));
17         registerGradient("batchNormTrain", toDelegate(&batchNormGrad));
18     }
19 }
20 
21 private
22 {
23     Operation[] convolutionGrad(Operation op, Operation parentGrad)
24     {
25         auto padding = op.attributes["padding"].get!(size_t[]);
26         auto stride = op.attributes["stride"].get!(size_t[]);
27 
28         return [
29             convolutionFeaturesGrad(parentGrad, op.deps[1], op.deps[0].shape, padding, stride),
30             convolutionFiltersGrad(parentGrad, op.deps[0], op.deps[1].shape, padding, stride)
31         ];
32     }
33 
34     Operation[] convolutionFeaturesGradGrad(Operation op, Operation parentGrad)
35     {
36         auto padding = op.attributes["padding"].get!(size_t[]);
37         auto stride = op.attributes["stride"].get!(size_t[]);
38 
39         return [
40             convolution(parentGrad, op.deps[1], padding, stride),
41             convolutionFiltersGrad(parentGrad, op.deps[0], op.deps[1].shape, padding, stride)
42         ];
43     }
44 
45     Operation[] maxpoolGrad(Operation op, Operation parentGrad)
46     {
47         return [dopt.core.ops.nnet.maxpoolGrad(parentGrad, op)];
48     }
49 
50     Operation[] softmaxGrad(Operation op, Operation parentGrad)
51     {
52         return [dopt.core.ops.nnet.softmaxGrad(parentGrad, op)];
53     }
54 
55     Operation[] addBiasGrad(Operation op, Operation parentGrad)
56     {
57         return [parentGrad, dopt.core.ops.nnet.addBiasGrad(parentGrad)];
58     }
59 
60     Operation[] batchNormGrad(Operation op, Operation parentGrad)
61     {
62         auto packedGrads = dopt.core.ops.nnet.batchNormGrad(parentGrad, op.deps[0], op.deps[1]);
63         packedGrads = packedGrads.reshape([packedGrads.volume]);
64 
65         auto v0 = op.deps[0].volume;
66         auto v1 = op.deps[1].volume;
67         auto v2 = op.deps[2].volume;
68 
69         return [
70             packedGrads.slice([0], [v0]).reshape(op.deps[0].shape),
71             packedGrads.slice([v0], [v0 + v1]).reshape(op.deps[1].shape),
72             packedGrads.slice([v0 + v1], [v0 + v1 + v2]).reshape(op.deps[2].shape)
73         ];
74     }
75 }