1 /**
2     This module enables operation graphs to be evaluated using CPU kernels.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.cpu;
7 
8 import std.exception;
9 
10 import dopt.core;
11 
12 void initialize()
13 {
14     import dopt.core.cpu.basic;
15     import dopt.core.cpu.math;
16 
17     dopt.core.cpu.basic.initialize();
18     dopt.core.cpu.math.initialize();
19 }
20 
21 /**
22     Common interface for all CPU kernels.
23 */
24 interface CPUKernel
25 {
26     void execute(Operation op, const(Buffer)[] inputs, Buffer output);
27 }
28 
29 /**
30     Convenience class that allows one to wrap a delegate and implement CPUKernel.
31 */
32 class CPUKernelDelegate : CPUKernel
33 {
34     public
35     {
36         this(void delegate(Operation, const(Buffer)[], Buffer) kern)
37         {
38             mKernel = kern;
39         }
40 
41         void execute(Operation op, const(Buffer)[] inputs, Buffer output)
42         {
43             mKernel(op, inputs, output);
44         }
45     }
46 
47     private
48     {
49         void delegate(Operation op, const(Buffer)[], Buffer) mKernel;
50     }
51 }
52 
53 /**
54     Registers a kernel for the specified operation.
55 
56     Params:
57         opName = The name of the operation.
58         kernel = A kernel that can execute operations of the type specified by opName.
59 
60     Throws:
61         If there is already a kernel registered for the operation.
62 */
63 void registerCPUKernel(string opName, CPUKernel kernel)
64 {
65     enforce((opName in mKernels) is null, "A CPUKernel is already registered for the operation '" ~ opName ~ "'");
66 
67     mKernels[opName] = kernel;
68 }
69 
70 /**
71     Deregisters the kernel associated with the specified operation.
72 
73     Params:
74         opName = The name of the operation that should have its kernel deregistered.
75 */
76 void deregisterCPUKernel(string opName)
77 {
78     mKernels.remove(opName);
79 }
80 
81 /**
82     Provides a list of operations for which a CPUKernel has been registered.
83 
84     Returns:
85         An array of operation names.
86 */
87 string[] listAllCPUOperations()
88 {
89     return mKernels.keys.dup ~ ["constant", "variable", "reshape"];
90 }
91 
92 /**
93     Evaluates an several nodes from the operation graph using the CPU.
94 
95     If the elements of $(D ops) have common dependencies, then each dependency is evaluated only once. For this
96     reason it is recommended that this overload is used when multiple nodes should be evaluated.
97 
98     Params:
99         ops = The nodes of the operation graph that values should be computed for.
100         args = A set of variable assignments.
101 
102     Returns:
103         An array of $(D Buffer) objects, each containing the value of the corresponding element in $(D ops).
104 */
105 Buffer[] evaluateCPU(Operation[] ops, Buffer[Operation] args = null)
106 {
107     import std.algorithm : canFind, filter;
108     import std.array : array;
109 
110     //Toposort the operations by dependency
111     Operation[] sortedOps = topologicalSort(ops)
112                                   .filter!(x => !canFind(args.keys, x))
113                                   .array();
114 
115     //Count the number of references to each operation
116     int[Operation] refCounts;
117 
118     foreach(o; sortedOps)
119     {
120         foreach(d; o.deps)
121         {
122             refCounts[d]++;
123         }
124     }
125 
126     //Start executing the operations
127     Buffer[Operation] results = args.dup;
128 
129     foreach(o; sortedOps)
130     {
131         import std.conv : to;
132         import std.stdio : stdout, write, writeln;
133 
134         //Check for some easy optimizations
135         if(o.opType == "variable" && !("variable" in mKernels))
136         {
137             results[o] = cast(Buffer)o.attributes["default"].get!Buffer;
138             continue;
139         }
140         else if(o.opType == "constant" && !("constant" in mKernels))
141         {
142             results[o] = cast(Buffer)o.attributes["default"].get!Buffer;
143             continue;
144         }
145         else if(o.opType == "reshape" && !("reshape" in mKernels))
146         {
147             results[o] = results[o.deps[0]];
148             continue;
149         }
150 
151         //Allocate a buffer for the output of this operation
152         auto output = Buffer(new ubyte[o.outputType.volume * o.outputType.elementType.sizeOf()]);
153         results[o] = output;
154 
155         //Get the input buffers
156         Buffer[] inputs;
157 
158         foreach(d; o.deps)
159         {
160             inputs ~= results[d];
161             refCounts[d]--;
162         }
163 
164         //Execute the operation
165         auto kern = mKernels.get(o.opType, null);
166 
167         if(kern is null)
168         {
169             throw new Exception("No CPU kernel registered for operation " ~ o.opType);
170         }
171 
172         kern.execute(o, inputs, output);
173 
174         foreach(d; o.deps)
175         {
176             //Remove the pointer to this buffer if we don't need it anymore
177             //This will allow the GC to collect it at some point, if required
178             if(refCounts[d] == 0)
179             {
180                 results[d] = Buffer([]);
181             }
182         }
183     }
184 
185     Buffer[] returnVals = new Buffer[ops.length];
186 
187     foreach(i, o; ops)
188     {
189         returnVals[i] = results[o];
190     }
191 
192     return returnVals;
193 }
194 
195 private
196 {
197     CPUKernel[string] mKernels;
198 }