1 module dopt.core.grads.basic;
2 
3 import dopt.core.grads;
4 import dopt.core.ops;
5 
6 package
7 {
8     void initialize()
9     {
10         import std.functional : toDelegate;
11 
12         registerGradient("transpose", toDelegate(&transposeGrad));
13         registerGradient("slice", toDelegate(&sliceGrad));
14         registerGradient("pad", toDelegate(&padGrad));
15         registerGradient("reshape", toDelegate(&reshapeGrad));
16         registerGradient("repeat", toDelegate(&repeatGrad));
17     }
18 }
19 
20 private
21 {
22     Operation[] transposeGrad(Operation op, Operation parentGrad)
23     {
24         import std.algorithm : countUntil, map;
25         import std.array : array;
26         import std.range : iota;
27 
28         auto order = op
29                     .attributes["order"]
30                     .get!(size_t[]);
31 
32         auto newOrder = iota(0, order.length)
33                        .map!(x => cast(size_t)order.countUntil(x))
34                        .array();
35 
36         return [parentGrad.transpose(newOrder)];
37     }
38 
39     Operation[] sliceGrad(Operation op, Operation parentGrad)
40     {
41         auto before = op.attributes["start"].get!(size_t[]);
42         auto after = op.deps[0].outputType.shape.dup;
43         after[] -= op.attributes["stop"].get!(size_t[])[];
44 
45         return [parentGrad.pad(before, after)];
46     }
47 
48     //Test the sliceGrad function
49     unittest
50     {
51         import dopt.core;
52 
53         auto a = float32([4, 4], [
54             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
55             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f]);
56 
57         auto b = float32([4, 4], [
58             5.0f, 6.0f, 7.0f, 8.0f, 5.0f, 6.0f, 7.0f, 8.0f,
59             5.0f, 6.0f, 7.0f, 8.0f, 5.0f, 6.0f, 7.0f, 8.0f]);
60 
61         auto c = slice(a * b, [1, 1], [2, 2]);
62 
63         import std.algorithm : equal;
64 
65         assert(evaluate(grad(c, [a]))[0].get!float.equal(
66             [0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]));
67     }
68 
69     Operation[] padGrad(Operation op, Operation parentGrad)
70     {
71         auto start = op.attributes["before"].get!(size_t[]);
72         auto stop = op.deps[0].outputType.shape.dup;
73         stop[] += start[];
74 
75         return [parentGrad.slice(start, stop)];
76     }
77 
78     Operation[] reshapeGrad(Operation op, Operation parentGrad)
79     {
80         return [parentGrad.reshape(op.deps[0].outputType.shape)];
81     }
82 
83     Operation[] repeatGrad(Operation op, Operation parentGrad)
84     {
85         import std.array : array;
86         import std.range : iota, roundRobin;
87 
88         auto reps = op.attributes["repetitions"].get!(size_t[]);
89         
90         //Add some new dimensions that explicitly represent the repetitions
91         auto tmpShape = roundRobin(reps, op.deps[0].shape).array();
92         auto tmp = parentGrad.reshape(tmpShape);
93 
94         //Sum over these dimensions
95         return [tmp.sum(iota(0, tmpShape.length, 2).array())];
96     }
97 }