1 /**
2     Contains the automatic differentiation framework.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.grads;
7 
8 import std.exception;
9 
10 import dopt.core.grads.basic;
11 import dopt.core.grads.math;
12 import dopt.core.grads.nnet;
13 import dopt.core.ops;
14 import dopt.core.types;
15 
16 alias Gradient = Operation[] delegate(Operation op, Operation parentGrad);
17 
18 void initialize()
19 {
20     dopt.core.grads.basic.initialize();
21     dopt.core.grads.math.initialize();
22     dopt.core.grads.nnet.initialize();
23 }
24 
25 /**
26     Computes the gradient of a scalar-valued operation with respect to several dependencies.
27 
28     This function provides an implementation of automatic differentiation can be used for greatly simplifying the
29     process of optimising objective functions. The particular technique used by the function is known as
30     reverse mode automatic differentiation.
31 
32     Params:
33         objective = The function being differentiated.
34         wrt = The (indirect) dependencies that $(D objective) is being differentiated with respect to.
35 
36     Returns:
37         An array of operations that evaluate to the derivative of $(D objective) to each of the elements of $(D wrt).
38 */
39 Operation[] grad(Operation objective, Operation[] wrt)
40 {
41     import std.algorithm : canFind, countUntil, map;
42     import std.array : array;
43     import std.conv : to;
44     import std.range : retro, zip;
45 
46     enforce(objective.outputType.volume == 1, "The objective must have a volume of one");
47     enforce(objective.outputType.elementType == DataType.float32, "The objective must have a floating point type");
48 
49     Operation[] ops;
50 
51     void traverse(Operation op)
52     {
53         foreach(d; op.deps)
54         {
55             if(!ops.canFind(d))
56             {
57                 traverse(d);
58             }
59         }
60 
61         ops ~= op;
62     }
63 
64     //Topologically sort the operations
65     traverse(objective);
66 
67     Operation[Operation] grads;
68 
69     //TODO: when I implement a 'ones' operation, replace this line
70     grads[objective] = float32(objective.outputType.shape, [1.0f]);
71 
72     //Iterate through the operations in reverse order (reverse mode autodiff)
73     foreach(op; ops.retro)
74     {
75         //Get the function that will let us compute the gradient of op w.r.t. its deps
76         auto gradFunc = mGradients.get(op.opType, null);
77         auto opGrad = grads.get(op, null);
78         
79         if(gradFunc is null || opGrad is null)
80         {
81             //This op, or its parent, is not differentiable, so we will just assume its derivative is zero everywhere
82             continue;
83         }
84 
85         //Compute the derivative: d(op)/d(op.deps)
86         auto depGrads = gradFunc(op, opGrad);
87 
88         //Add these to grads. If there is already an entry for one of the deps, then it has two parents.
89         //we can just add this grad to the existing grad, because maths.
90         foreach(d, g; zip(op.deps, depGrads))
91         {
92             auto currentGrad = grads.get(d, null);
93 
94             if(currentGrad is null)
95             {
96                 grads[d] = g;
97             }
98             else
99             {
100                 grads[d] = currentGrad + g;
101             }
102         }
103     }
104 
105     auto errIdx = wrt.countUntil!(x => grads.get(x, null) is null);
106 
107     enforce(errIdx == -1, "Could not find wrt[" ~ errIdx.to!string ~ "] in the operation graph");
108 
109     return wrt.map!(x => grads[x]).array();
110 }
111 
112 ///
113 unittest
114 {
115     import std.random : uniform;
116     import dopt.core : evaluate;
117 
118     auto x = float32();
119     auto y = x * x;
120     auto gradY = grad(y, [x]);
121 
122     auto r = uniform(-100.0f, 100.0f);
123 
124     auto gradYwrtX = gradY.evaluate([
125         x: Buffer([r])
126     ])[0];
127 
128     assert(gradYwrtX.as!float[0] == r + r);
129 }
130 
131 void registerGradient(string opName, Gradient g)
132 {
133     enforce((opName in mGradients) is null, "A gradient is already registered for operation '" ~ opName ~ "'");
134 
135     mGradients[opName] = g;
136 }
137 
138 void deregisterGradient(string opName)
139 {
140     mGradients.remove(opName);
141 }
142 
143 string[] listAllGradients()
144 {
145     return mGradients.keys.dup;
146 }
147 
148 private
149 {
150     Gradient[string] mGradients;
151 }