1 module dopt.core.grads.math;
2 
3 import dopt.core.grads;
4 import dopt.core.ops;
5 
6 package
7 {
8     void initialize()
9     {
10         import std.functional;
11         
12         string createRegisterGradientCalls()
13         {
14             auto ops = ["add", "sub", "mul", "div", "pow", "min", "max",
15                         "neg", "abs", "exp", "log", "sqrt"];
16 
17             string ret;
18 
19             foreach(o; ops)
20             {
21                 ret ~= "registerGradient(\"" ~ o ~ "\", toDelegate(&" ~ o ~ "Grad));\n";
22             }
23 
24             return ret;
25         }
26 
27         mixin(createRegisterGradientCalls());
28 
29         registerGradient("matmul", toDelegate(&matmulGrad));
30         registerGradient("sum", toDelegate(&sumGrad));
31     }
32 }
33 
34 private
35 {
36     Operation[] matmulGrad(Operation op, Operation parentGrad)
37     {
38         return [matmul(parentGrad, transpose(op.deps[1], [1, 0])), matmul(transpose(op.deps[0], [1, 0]), parentGrad)];
39     }
40 
41     Operation[] sumGrad(Operation op, Operation parentGrad)
42     {
43         if(op.volume == 1)
44         {
45             return [parentGrad.repeat(op.deps[0].volume).reshape(op.deps[0].shape)];
46         }
47         else
48         {
49             auto axes = op.attributes["axes"].get!(size_t[]);
50             auto tmpShape = op.deps[0].shape.dup;
51             auto reps = new size_t[tmpShape.length];
52             reps[] = 1;
53             
54             foreach(a; axes)
55             {
56                 reps[a] = tmpShape[a];
57                 tmpShape[a] = 1;
58             }
59 
60             auto tmp = parentGrad.reshape(tmpShape);
61 
62             return [tmp.repeat(reps)];
63         }
64     }
65 
66     Operation[] addGrad(Operation op, Operation parentGrad)
67     {
68         return [parentGrad, parentGrad];
69     }
70 
71     Operation[] subGrad(Operation op, Operation parentGrad)
72     {
73         return [parentGrad, neg(parentGrad)];
74     }
75 
76     Operation[] mulGrad(Operation op, Operation parentGrad)
77     {
78         return [parentGrad * op.deps[1], parentGrad * op.deps[0]];
79     }
80 
81     Operation[] divGrad(Operation op, Operation parentGrad)
82     {
83         return [
84             parentGrad / op.deps[1],
85             neg(parentGrad * op.deps[0]) / (op.deps[1] * op.deps[1])
86         ];
87     }
88 
89     Operation[] powGrad(Operation op, Operation parentGrad)
90     {
91         return [
92             parentGrad * op.deps[1] * pow(op.deps[0], op.deps[1] - 1),
93             parentGrad * op.deps[1] * log(op.deps[0])
94         ];
95     }
96 
97     Operation[] minGrad(Operation op, Operation parentGrad)
98     {
99         return [
100             op.deps[0].eq(op) * parentGrad,
101             op.deps[1].eq(op) * parentGrad
102         ];
103     }
104 
105     Operation[] maxGrad(Operation op, Operation parentGrad)
106     {
107         return [
108             op.deps[0].eq(op) * parentGrad,
109             op.deps[1].eq(op) * parentGrad
110         ];
111     }
112 
113     Operation[] negGrad(Operation op, Operation parentGrad)
114     {
115         return [neg(parentGrad)];
116     }
117 
118     Operation[] absGrad(Operation op, Operation parentGrad)
119     {
120         return [parentGrad * sgn(op.deps[0])];
121     }
122 
123     Operation[] expGrad(Operation op, Operation parentGrad)
124     {
125         return [parentGrad * op];
126     }
127 
128     Operation[] logGrad(Operation op, Operation parentGrad)
129     {
130         return [parentGrad / op.deps[0]];
131     }
132 
133     Operation[] sqrtGrad(Operation op, Operation parentGrad)
134     {
135         return [parentGrad / op];
136     }
137 }