1 /**
2     This module enables operation graphs to be evaluated using CPU kernels.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.cpu;
7 
8 import std.exception;
9 
10 import dopt.core;
11 
12 shared static this()
13 {
14     import dopt.cpu.basic;
15     import dopt.cpu.math;
16     import dopt.cpu.nnet;
17     import dopt.cpu.random;
18 
19     dopt.cpu.basic.initialize();
20     dopt.cpu.math.initialize();
21     dopt.cpu.nnet.initialize();
22     dopt.cpu.random.initialize();
23 
24     import std.functional : toDelegate;
25     defaultEvaluator = toDelegate(&evaluateCPU);
26     defaultCompiler = (Operation[] ops) { return new CPUPlan(ops); };
27 }
28 
29 /**
30     Common interface for all CPU kernels.
31 */
32 interface CPUKernel
33 {
34     void execute(Operation op, const(Buffer)[] inputs, Buffer output);
35 }
36 
37 /**
38     Convenience class that allows one to wrap a delegate and implement CPUKernel.
39 */
40 class CPUKernelDelegate : CPUKernel
41 {
42     public
43     {
44         this(void delegate(Operation, const(Buffer)[], Buffer) kern)
45         {
46             mKernel = kern;
47         }
48 
49         void execute(Operation op, const(Buffer)[] inputs, Buffer output)
50         {
51             mKernel(op, inputs, output);
52         }
53     }
54 
55     private
56     {
57         void delegate(Operation op, const(Buffer)[], Buffer) mKernel;
58     }
59 }
60 
61 /**
62     Registers a kernel for the specified operation.
63 
64     Params:
65         opName = The name of the operation.
66         kernel = A kernel that can execute operations of the type specified by opName.
67 
68     Throws:
69         If there is already a kernel registered for the operation.
70 */
71 void registerCPUKernel(string opName, CPUKernel kernel)
72 {
73     enforce((opName in mKernels) is null, "A CPUKernel is already registered for the operation '" ~ opName ~ "'");
74 
75     mKernels[opName] = kernel;
76 }
77 
78 /**
79     Deregisters the kernel associated with the specified operation.
80 
81     Params:
82         opName = The name of the operation that should have its kernel deregistered.
83 */
84 void deregisterCPUKernel(string opName)
85 {
86     mKernels.remove(opName);
87 }
88 
89 /**
90     Provides a list of operations for which a CPUKernel has been registered.
91 
92     Returns:
93         An array of operation names.
94 */
95 string[] listAllCPUOperations()
96 {
97     return mKernels.keys.dup ~ ["constant", "variable", "reshape"];
98 }
99 
100 class CPUPlan : Plan
101 {
102     public
103     {
104         this(Operation[] outputs)
105         {
106             super(outputs);
107         }
108     }
109 
110     protected
111     {
112         override void executeImpl(Buffer[Operation] args, Buffer[] rets)
113         {
114             auto tmpRets = evaluateCPU(mOutputs, args);
115 
116             import std.range : zip;
117 
118             foreach(t, r; zip(tmpRets, rets))
119             {
120                 r.as!ubyte[] = t.as!ubyte[];
121             }
122         }
123     }
124 }
125 
126 /**
127     Evaluates an several nodes from the operation graph using the CPU.
128 
129     If the elements of $(D ops) have common dependencies, then each dependency is evaluated only once. For this
130     reason it is recommended that this overload is used when multiple nodes should be evaluated.
131 
132     Params:
133         ops = The nodes of the operation graph that values should be computed for.
134         args = A set of variable assignments.
135 
136     Returns:
137         An array of $(D Buffer) objects, each containing the value of the corresponding element in $(D ops).
138 */
139 Buffer[] evaluateCPU(Operation[] ops, Buffer[Operation] args = null)
140 {
141     import std.algorithm : canFind, filter;
142     import std.array : array;
143 
144     //Toposort the operations by dependency
145     Operation[] sortedOps = topologicalSort(ops)
146                                   .filter!(x => !canFind(args.keys, x))
147                                   .array();
148 
149     //Count the number of references to each operation
150     int[Operation] refCounts;
151 
152     foreach(o; ops)
153     {
154         refCounts[o]++;
155     }
156 
157     foreach(o; sortedOps)
158     {
159         foreach(d; o.deps)
160         {
161             refCounts[d]++;
162         }
163     }
164 
165     //Start executing the operations
166     Buffer[Operation] results = args.dup;
167 
168     foreach(o; sortedOps)
169     {
170         import std.conv : to;
171         import std.stdio : stdout, write, writeln;
172 
173         //Check for some easy optimizations
174         if(o.opType == "variable" && !("variable" in mKernels))
175         {
176             results[o] = cast(Buffer)o.attributes["default"].get!Buffer;
177             continue;
178         }
179         else if(o.opType == "constant" && !("constant" in mKernels))
180         {
181             results[o] = cast(Buffer)o.attributes["default"].get!Buffer;
182             continue;
183         }
184         else if(o.opType == "reshape" && !("reshape" in mKernels))
185         {
186             results[o] = results[o.deps[0]];
187             continue;
188         }
189 
190         //Allocate a buffer for the output of this operation
191         auto output = Buffer(new ubyte[o.outputType.volume * o.outputType.elementType.sizeOf()]);
192         results[o] = output;
193 
194         //Get the input buffers
195         Buffer[] inputs;
196 
197         foreach(d; o.deps)
198         {
199             inputs ~= results[d];
200             refCounts[d]--;
201         }
202 
203         //Execute the operation
204         auto kern = mKernels.get(o.opType, null);
205 
206         if(kern is null)
207         {
208             throw new Exception("No CPU kernel registered for operation " ~ o.opType);
209         }
210 
211         kern.execute(o, inputs, output);
212 
213         foreach(d; o.deps)
214         {
215             //Remove the pointer to this buffer if we don't need it anymore
216             //This will allow the GC to collect it at some point, if required
217             if(refCounts[d] == 0)
218             {
219                 results[d] = Buffer([]);
220             }
221         }
222     }
223 
224     Buffer[] returnVals = new Buffer[ops.length];
225 
226     foreach(i, o; ops)
227     {
228         returnVals[i] = results[o];
229     }
230 
231     return returnVals;
232 }
233 
234 private
235 {
236     CPUKernel[string] mKernels;
237 }