1 module dopt.cpu.nnet;
2 
3 import dopt.core;
4 import dopt.cpu;
5 
6 package
7 {
8     void initialize()
9     {
10         import std.functional : toDelegate;
11 
12         registerCPUKernel("convolution", new CPUKernelDelegate(toDelegate(&convolution)));
13         registerCPUKernel("maxpool", new CPUKernelDelegate(toDelegate(&maxpool)));
14         registerCPUKernel("softmax", new CPUKernelDelegate(toDelegate(&softmax)));
15     }
16 }
17 
18 private
19 {
20     void convolution(Operation op, const(void[])[] inputs, void[] output)
21     {
22         size_t[] inDims = op.deps[0].shape[2 .. $];
23         size_t[] outDims = op.shape[2 .. $];
24         size_t[] kernDims = op.deps[1].shape[2 .. $];
25         size_t[] padding = op.attributes["padding"].get!(size_t[]);
26         size_t[] stride = op.attributes["stride"].get!(size_t[]);
27         size_t numOutpus = op.deps[1].shape[0];
28         size_t numInputs = op.deps[1].shape[1];
29         size_t batchSize = op.shape[0];
30         size_t inVol = inDims[0] * inDims[1];
31         size_t outVol = outDims[0] * outDims[1];
32         size_t kernVol = kernDims[0] * kernDims[1];
33 
34         void conv2d(const(float)[] inimg, const(float)[] kern, float[] outimg)
35         {
36             size_t outidx = 0;
37 
38             for(size_t y = 0; y < outDims[0]; y++)
39             {
40                 for(size_t x = 0; x < outDims[1]; x++)
41                 {
42                     float outval = 0;
43 
44                     for(size_t j = 0; j < kernDims[0]; j++)
45                     {
46                         size_t jprime = kernDims[0] - j - 1;
47 
48                         for(size_t i = 0; i < kernDims[1]; i++)
49                         {
50                             size_t iprime = kernDims[1] - i - 1;
51                             ptrdiff_t iny = y * stride[0] - padding[0] + j;
52                             ptrdiff_t inx = x * stride[1] - padding[1] + i;
53 
54                             if(0 <= iny && iny < outDims[0] && 0 <= inx && inx < outDims[1])
55                             {
56                                 outval += kern[jprime * kernDims[1] + iprime] * inimg[iny * inDims[1] + inx];
57                             }
58                         }
59                     }
60 
61                     outimg[outidx] += outval;
62                     outidx++;
63                 }
64             }
65         }
66 
67         auto inbuf = cast(const(float[]))inputs[0];
68         auto kernbuf = cast(const(float[]))inputs[1];
69         auto outbuf = cast(float[])output;
70 
71         for(size_t b = 0; b < batchSize; b++)
72         {
73             for(size_t o = 0; o < numOutpus; o++)
74             {
75                 float[] outimg = outbuf[(b * numOutpus + o) * outVol .. (b * numOutpus + o + 1) * outVol];
76                 outimg[] = 0;
77 
78                 for(size_t i = 0; i < numInputs; i++)
79                 {
80                     const(float)[] inimg = inbuf[(b * numInputs + i) * inVol .. (b * numInputs + i + 1) * inVol];
81                     const(float)[] kern = kernbuf[(o * numInputs + i) * kernVol .. (o * numInputs + i + 1) * kernVol];
82 
83                     conv2d(inimg, kern, outimg);
84                 }
85             }
86         }
87     }
88 
89     void maxpool(Operation op, const(void[])[] inputs, void[] output)
90     {
91         size_t[] poolDims = op.attributes["dims"].get!(size_t[]);
92         size_t[] inDims = op.deps[0].shape[2 .. $];
93         size_t[] outDims = op.shape[2 .. $];
94         size_t numMaps = op.shape[0] * op.shape[1];
95         size_t inVol = inDims[0] * inDims[1];
96         size_t outVol = outDims[0] * outDims[1];
97 
98         void pool(const(float)[] inimg, float[] outimg)
99         {
100             for(size_t y = 0; y < outDims[0]; y++)
101             {
102                 for(size_t x = 0; x < outDims[1]; x++)
103                 {
104                     float maxval = -float.max;
105 
106                     for(size_t j = 0; j < poolDims[0]; j++)
107                     {
108                         for(size_t i = 0; i < poolDims[1]; i++)
109                         {
110                             import std.algorithm : max;
111                             maxval = max(maxval, inimg[(y * poolDims[0] + j) * inDims[1] + x * poolDims[1] + i]);
112                         }
113                     }
114 
115                     outimg[y * outDims[1] + x] = maxval;
116                 }
117             }
118         }
119 
120         float[] outbuf = cast(float[])output;
121         const(float)[] inbuf = cast(const(float)[])inputs[0];
122 
123         for(size_t i = 0; i < numMaps; i++)
124         {
125             pool(inbuf[i * inVol .. (i + 1) * inVol], outbuf[i * outVol .. (i + 1) * outVol]);
126         }
127     }
128 
129     void softmax(Operation op, const(void[])[] inputs, void[] output)
130     {
131         const(float)[] inbuf = cast(const(float[]))inputs[0];
132         float[] outbuf = cast(float[])output;
133 
134         size_t elvol = op.volume / (op.shape[0] * op.shape[1]);
135 
136         for(size_t b = 0; b < op.shape[0]; b++)
137         {
138             for(size_t i = 0; i < elvol; i++)
139             {
140                 float m = -float.max;
141 
142                 for(size_t o = 0; o < op.shape[1]; o++)
143                 {
144                     import std.algorithm : max;
145 
146                     m = max(m, inbuf[b * op.shape[1] * elvol + o * elvol + i]);
147                 }
148 
149                 float s = 0;
150 
151                 for(size_t o = 0; o < op.shape[1]; o++)
152                 {
153                     import std.math : exp;
154 
155                     float pot = exp(inbuf[b * op.shape[1] * elvol + o * elvol + i] - m);
156                     s += pot;
157                     outbuf[b * op.shape[1] * elvol + o * elvol + i] = pot;
158                 }
159 
160                 for(size_t o = 0; o < op.shape[1]; o++)
161                 {
162                     outbuf[b * op.shape[1] * elvol + o * elvol + i] /= s;
163                 }
164             }
165         }
166     }
167 }