1 module dopt.cuda.nnet.cudnn7;
2 
3 import std.algorithm;
4 import std.array;
5 import std.functional;
6 
7 import dopt.cuda;
8 import dopt.core.ops;
9 
10 import derelict.cuda;
11 import derelict.cudnn7;
12 
13 package
14 {
15     void initializeCuDNN7()
16     {
17         DerelictCuDNN7.load();
18         
19         registerCUDAKernel("convolution", toDelegate(&cudaKernelCtr!ConvolutionForward));
20         registerCUDAKernel("convolutionFeaturesGrad", toDelegate(&cudaKernelCtr!ConvolutionFeaturesGrad));
21         registerCUDAKernel("convolutionFiltersGrad", toDelegate(&cudaKernelCtr!ConvolutionFiltersGrad));
22         registerCUDAKernel("maxpool", toDelegate(&cudaKernelCtr!MaxpoolForward));
23         registerCUDAKernel("maxpoolGrad", toDelegate(&cudaKernelCtr!MaxpoolGrad));
24         registerCUDAKernel("softmax", toDelegate(&cudaKernelCtr!Softmax));
25         registerCUDAKernel("softmaxGrad", toDelegate(&cudaKernelCtr!SoftmaxGrad));
26         registerCUDAKernel("relu", toDelegate(&cudaKernelCtr!ReLU));
27         registerCUDAKernel("reluGrad", toDelegate(&cudaKernelCtr!ReLUGrad));
28         registerCUDAKernel("addBias", toDelegate(&cudaKernelCtr!AddBias));
29         registerCUDAKernel("addBiasGrad", toDelegate(&cudaKernelCtr!AddBiasGrad));
30         registerCUDAKernel("batchNormTrain", toDelegate(&cudaKernelCtr!BatchNormTrain));
31         registerCUDAKernel("batchNormGrad", toDelegate(&cudaKernelCtr!BatchNormGrad));
32         registerCUDAKernel("batchNormInference", toDelegate(&cudaKernelCtr!BatchNormInference));
33 
34         cudnnCreate(&handle);
35     }
36 }
37 
38 private
39 {
40     cudnnHandle_t handle;
41 
42     void cudnnCheck(cudnnStatus_t status, string mod = __MODULE__, size_t line = __LINE__)
43     {
44         import std.conv : to;
45         import std.exception : enforce;
46         enforce(status == CUDNN_STATUS_SUCCESS, mod ~ "(" ~ line.to!string ~ "): Failed to execute cuDNN function." ~
47             " Error code: " ~ cudnnGetErrorString(status).to!string);
48     }
49 
50     CUDAKernel cudaKernelCtr(K)(Operation op)
51     {
52         return new K(op);
53     }
54 
55     class ConvolutionBase : CUDAKernel
56     {
57         this(Operation op, int[] inShape, int[] filterShape, int[] outShape)
58         {
59             mOp = op;
60 
61             int padH = 0;
62             int padW = 0;
63             int strideY = 1;
64             int strideX = 1;
65 
66             auto padding = op.attributes["padding"].get!(size_t[]);
67             padH = cast(int)padding[0];
68             padW = cast(int)padding[1];
69 
70             auto stride = op.attributes["stride"].get!(size_t[]);
71             strideY = cast(int)stride[0];
72             strideX = cast(int)stride[1];
73 
74             auto dilation = [1LU, 1LU];
75             int dilY = cast(int)dilation[0];
76             int dilX = cast(int)dilation[1];
77 
78             cudnnCreateTensorDescriptor(&xDesc).cudnnCheck();
79 			cudnnCreateFilterDescriptor(&wDesc).cudnnCheck();
80 			cudnnCreateConvolutionDescriptor(&convDesc).cudnnCheck();
81 			cudnnCreateTensorDescriptor(&yDesc).cudnnCheck();
82 
83             cudnnSetTensor4dDescriptor(xDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, inShape[0], inShape[1], inShape[2],
84                 inShape[3]).cudnnCheck();
85             cudnnSetFilter4dDescriptor(wDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, filterShape[0], filterShape[1],
86                 filterShape[2], filterShape[3]).cudnnCheck();
87             cudnnSetConvolution2dDescriptor(convDesc, padH, padW, strideY, strideX, dilY, dilX, CUDNN_CONVOLUTION,
88                 CUDNN_DATA_FLOAT).cudnnCheck();
89             cudnnSetTensor4dDescriptor(yDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, outShape[0], outShape[1],
90                 outShape[2], outShape[3]).cudnnCheck();
91         }
92 
93         ~this()
94         {
95             cudnnDestroyFilterDescriptor(wDesc).cudnnCheck();
96 			cudnnDestroyTensorDescriptor(yDesc).cudnnCheck();
97 			cudnnDestroyConvolutionDescriptor(convDesc).cudnnCheck();
98 			cudnnDestroyTensorDescriptor(xDesc).cudnnCheck();
99         }
100 
101         abstract void execute(const(CUDABuffer)[] inputs, CUDABuffer output);
102 
103         Operation mOp;
104         cudnnTensorDescriptor_t xDesc;
105 		cudnnFilterDescriptor_t wDesc;
106 		cudnnTensorDescriptor_t bDesc;
107 		cudnnConvolutionDescriptor_t convDesc;
108 		cudnnTensorDescriptor_t yDesc;
109     }
110 
111     private static CUDABuffer mWorkspace;
112 
113     class ConvolutionForward : ConvolutionBase
114     {
115 		private cudnnConvolutionFwdAlgo_t mAlgo;
116 
117         this(Operation op)
118         {
119             auto inShape = op.deps[0].outputType.shape.map!(x => cast(int)x).array();
120             auto filterShape = op.deps[1].outputType.shape.map!(x => cast(int)x).array();
121             auto outShape = op.outputType.shape.map!(x => cast(int)x).array();
122 
123             super(op, inShape, filterShape, outShape);
124 
125 			cudnnConvolutionFwdAlgoPerf_t[9] algoPerfs;
126 			int numAlgos;
127 			cudnnFindConvolutionForwardAlgorithm(handle, xDesc, wDesc, convDesc, yDesc, cast(int)algoPerfs.length, &numAlgos,
128                 algoPerfs.ptr).cudnnCheck();
129 
130 			if(mWorkspace !is null && algoPerfs[0].memory > mWorkspace.numBytes)
131 			{
132 				CUDABuffer.destroy(mWorkspace);
133 				mWorkspace = null;
134 			}
135 
136 			if(mWorkspace is null)
137 			{
138 				mWorkspace = CUDABuffer.create(algoPerfs[0].memory);
139 			}
140 
141 			mAlgo = algoPerfs[0].algo;
142         }
143 
144         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
145         {
146             auto x = cast(void *)inputs[0].ptr;
147             auto w = cast(void *)inputs[1].ptr;
148             auto y = cast(void *)output.ptr;
149             float alpha = 1;
150             float beta = 0;
151 
152 			auto ws = cast(void *)(mWorkspace.ptr);
153 			auto wss = mWorkspace.numBytes;
154             cudnnConvolutionForward(handle, &alpha, xDesc, x, wDesc, w, convDesc, mAlgo, ws, wss, &beta, yDesc, y)
155             .cudnnCheck();
156 
157             cuCtxSynchronize();
158         }
159     }
160 
161     class ConvolutionFeaturesGrad : ConvolutionBase
162     {
163 		private cudnnConvolutionBwdDataAlgo_t mAlgo;
164 
165         this(Operation op)
166         {
167             auto inShape = op.shape.map!(x => cast(int)x).array();
168             auto filterShape = op.deps[1].shape.map!(x => cast(int)x).array();
169             auto outShape = op.deps[0].shape.map!(x => cast(int)x).array();
170 
171             super(op, inShape, filterShape, outShape);
172 
173 			cudnnConvolutionBwdDataAlgoPerf_t[9] algoPerfs;
174 			int numAlgos;
175 			cudnnFindConvolutionBackwardDataAlgorithm(handle, wDesc, yDesc, convDesc, xDesc, cast(int)algoPerfs.length, &numAlgos, algoPerfs.ptr);
176 
177 			if(mWorkspace !is null && algoPerfs[0].memory > mWorkspace.numBytes)
178 			{
179 				CUDABuffer.destroy(mWorkspace);
180 				mWorkspace = null;
181 			}
182 
183 			if(mWorkspace is null)
184 			{
185 				mWorkspace = CUDABuffer.create(algoPerfs[0].memory);
186 			}
187 
188 			mAlgo = algoPerfs[0].algo;
189         }
190 
191         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
192         {
193             auto w = cast(void *)inputs[1].ptr;
194             auto dy = cast(void *)inputs[0].ptr;
195             auto dx = cast(void *)output.ptr;
196             float alpha = 1;
197             float beta = 0;
198 
199             cudnnConvolutionBackwardData(handle, &alpha, wDesc, w, yDesc, dy, convDesc, mAlgo, cast(void *)mWorkspace.ptr, mWorkspace.numBytes, &beta, xDesc, dx)
200             .cudnnCheck();
201 
202             cuCtxSynchronize();
203         }
204     }
205 
206     class ConvolutionFiltersGrad : ConvolutionBase
207     {
208 		private cudnnConvolutionBwdFilterAlgo_t mAlgo;
209 
210         this(Operation op)
211         {
212             auto inShape = op.deps[1].outputType.shape.map!(x => cast(int)x).array();
213             auto filterShape = op.outputType.shape.map!(x => cast(int)x).array();
214             auto outShape = op.deps[0].outputType.shape.map!(x => cast(int)x).array();
215 
216             super(op, inShape, filterShape, outShape);
217 
218 			cudnnConvolutionBwdFilterAlgoPerf_t[9] algoPerfs;
219 			int numAlgos;
220 			cudnnFindConvolutionBackwardFilterAlgorithm(handle, xDesc, yDesc, convDesc, wDesc, cast(int)algoPerfs.length, &numAlgos, algoPerfs.ptr);
221 
222 			if(mWorkspace !is null && algoPerfs[0].memory > mWorkspace.numBytes)
223 			{
224 				CUDABuffer.destroy(mWorkspace);
225 				mWorkspace = null;
226 			}
227 
228 			if(mWorkspace is null)
229 			{
230 				mWorkspace = CUDABuffer.create(algoPerfs[0].memory);
231 			}
232 
233 			mAlgo = algoPerfs[0].algo;
234         }
235 
236         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
237         {
238             auto x = cast(void *)inputs[1].ptr;
239             auto dy = cast(void *)inputs[0].ptr;
240             auto dw = cast(void *)output.ptr;
241             float alpha = 1;
242             float beta = 0;
243 
244             cudnnConvolutionBackwardFilter(handle, &alpha, xDesc, x, yDesc, dy, convDesc, mAlgo, cast(void *)mWorkspace.ptr, mWorkspace.numBytes, &beta, wDesc,
245                 dw).cudnnCheck();
246 
247             cuCtxSynchronize();
248         }
249     }
250 
251     class MaxpoolBase : CUDAKernel
252     {
253         this(Operation op, int[] inShape, int[]outShape)
254         {
255             auto dims = op.attributes["dims"].get!(size_t[]);
256             auto poolShape = dims.map!(x => cast(int)x).array();
257             auto poolStride = poolShape.dup;
258 
259             cudnnCreatePoolingDescriptor(&poolingDesc).cudnnCheck();
260 			cudnnSetPooling2dDescriptor(poolingDesc, CUDNN_POOLING_MAX, 1, cast(int)poolShape[0],
261                 cast(int)poolShape[1], 0, 0, cast(int)poolStride[0], cast(int)poolStride[1]).cudnnCheck();
262 
263 			cudnnCreateTensorDescriptor(&xDesc).cudnnCheck();
264 			cudnnCreateTensorDescriptor(&yDesc).cudnnCheck();
265 			cudnnSetTensor4dDescriptor(xDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, inShape[0], inShape[1], inShape[2],
266                 inShape[3]).cudnnCheck();
267 			cudnnSetTensor4dDescriptor(yDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, outShape[0], outShape[1],
268                 outShape[2], outShape[3]).cudnnCheck();
269 
270         }
271 
272         ~this()
273 		{
274 			cudnnDestroyPoolingDescriptor(poolingDesc).cudnnCheck();
275 			cudnnDestroyTensorDescriptor(xDesc).cudnnCheck();
276 			cudnnDestroyTensorDescriptor(yDesc).cudnnCheck();
277 		}
278 
279         abstract void execute(const(CUDABuffer)[] inputs, CUDABuffer output);
280 
281         cudnnPoolingDescriptor_t poolingDesc;
282 		cudnnTensorDescriptor_t xDesc;
283 		cudnnTensorDescriptor_t yDesc;
284     }
285 
286     class MaxpoolForward : MaxpoolBase
287     {
288         this(Operation op)
289         {
290             auto inShape = op.deps[0].outputType.shape.map!(x => cast(int)x).array();
291 			auto outShape = op.outputType.shape.map!(x => cast(int)x).array();
292 
293             super(op, inShape, outShape);
294         }
295 
296         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
297         {
298             auto x = cast(void *)inputs[0].ptr;
299 			auto y = cast(void *)output.ptr;
300 			float alpha = 1;
301 			float beta = 0;
302 
303 			cudnnPoolingForward(handle, poolingDesc, &alpha, xDesc, x, &beta, yDesc, y).cudnnCheck();
304 
305             cuCtxSynchronize();
306         }
307     }
308 
309     class MaxpoolGrad : MaxpoolBase
310     {
311         this(Operation op)
312         {
313             auto inShape = op.deps[2].outputType.shape.map!(x => cast(int)x).array();
314 			auto outShape = op.deps[1].outputType.shape.map!(x => cast(int)x).array();
315 
316             super(op, inShape, outShape);
317         }
318 
319         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
320         {
321             auto dx = cast(void *)output.ptr;
322 			auto dy = cast(void *)inputs[0].ptr;
323 			auto y = cast(void *)inputs[1].ptr;
324 			auto x = cast(void *)inputs[2].ptr;
325 			float alpha = 1;
326 			float beta = 0;
327 
328 			cudnnPoolingBackward(handle, poolingDesc, &alpha, yDesc, y, yDesc, dy, xDesc, x, &beta, xDesc, dx)
329             .cudnnCheck();
330 
331             cuCtxSynchronize();
332         }
333     }
334 
335     class Softmax : CUDAKernel
336     {
337         this(Operation op)
338         {
339             auto shape = op.shape.map!(x => cast(int)x).array();
340             auto vol = 1;
341             
342             for(size_t i = 2; i < shape.length; i++)
343             {
344                 vol *= shape[i];
345             }
346 
347 			cudnnCreateTensorDescriptor(&desc).cudnnCheck();
348 			cudnnSetTensor4dDescriptor(desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, shape[0], shape[1], vol, 1)
349             .cudnnCheck();
350         }
351 
352         ~this()
353         {
354             cudnnDestroyTensorDescriptor(desc).cudnnCheck();
355         }
356 
357         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
358         {
359             float alpha = 1.0;
360 			float beta = 0.0;
361 			auto x = cast(void *)inputs[0].ptr;
362 			auto y = cast(void *)output.ptr;
363 
364 			cudnnSoftmaxForward(handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, desc, x, &beta,
365                 desc, y).cudnnCheck();
366             
367             cuCtxSynchronize();
368         }
369 
370         cudnnTensorDescriptor_t desc;
371     }
372 
373     class SoftmaxGrad : CUDAKernel
374     {
375         this(Operation op)
376         {
377             auto shape = op.shape.map!(x => cast(int)x).array();
378 
379 			cudnnCreateTensorDescriptor(&desc).cudnnCheck();
380 			cudnnSetTensor4dDescriptor(desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, shape[0], shape[1],
381                 reduce!"a * b"(1, shape[2 .. $]), 1).cudnnCheck();
382         }
383 
384         ~this()
385         {
386             cudnnDestroyTensorDescriptor(desc).cudnnCheck();
387         }
388 
389         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
390         {
391             float alpha = 1.0;
392 			float beta = 0.0;
393 			auto dy = cast(void *)inputs[0].ptr;
394 			auto y = cast(void *)inputs[1].ptr;
395 			auto dx = cast(void *)output.ptr;
396 
397 			cudnnSoftmaxBackward(handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, desc, y, desc, dy,
398                 &beta, desc, dx).cudnnCheck();
399 
400             cuCtxSynchronize();
401         }
402 
403         cudnnTensorDescriptor_t desc;
404     }
405 
406     class ReLU : CUDAKernel
407     {
408         this(Operation op)
409         {
410             auto shape = op.shape.map!(x => cast(int)x).array();
411 
412 			cudnnCreateTensorDescriptor(&desc).cudnnCheck();
413 			cudnnSetTensor4dDescriptor(desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, shape[0], shape[1],
414                 reduce!"a * b"(1, shape[2 .. $]), 1).cudnnCheck();
415             
416             cudnnCreateActivationDescriptor(&actDesc).cudnnCheck();
417             cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0).cudnnCheck();
418         }
419 
420         ~this()
421         {
422             cudnnDestroyTensorDescriptor(desc).cudnnCheck();
423             cudnnDestroyActivationDescriptor(actDesc).cudnnCheck();
424         }
425 
426         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
427         {
428             float alpha = 1.0;
429 			float beta = 0.0;
430             auto x = cast(void *)inputs[0].ptr;
431             auto y = cast(void *)output.ptr;
432 
433 			cudnnActivationForward(handle, actDesc, &alpha, desc, x, &beta, desc, y).cudnnCheck();
434 
435             cuCtxSynchronize();
436         }
437 
438         cudnnTensorDescriptor_t desc;
439         cudnnActivationDescriptor_t actDesc;
440     }
441 
442     class ReLUGrad : CUDAKernel
443     {
444         this(Operation op)
445         {
446             auto shape = op.shape.map!(x => cast(int)x).array();
447 
448 			cudnnCreateTensorDescriptor(&desc).cudnnCheck();
449 			cudnnSetTensor4dDescriptor(desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, shape[0], shape[1],
450                 reduce!"a * b"(1, shape[2 .. $]), 1).cudnnCheck();
451             
452             cudnnCreateActivationDescriptor(&actDesc).cudnnCheck();
453             cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0).cudnnCheck();
454         }
455 
456         ~this()
457         {
458             cudnnDestroyTensorDescriptor(desc).cudnnCheck();
459             cudnnDestroyActivationDescriptor(actDesc).cudnnCheck();
460         }
461 
462         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
463         {
464             float alpha = 1.0;
465 			float beta = 0.0;
466             auto dy = cast(void *)inputs[0].ptr;
467             auto y = cast(void *)inputs[1].ptr;
468             auto x = cast(void *)inputs[2].ptr;
469             auto dx = cast(void *)output.ptr;
470 
471 			cudnnActivationBackward(handle, actDesc, &alpha, desc, y, desc, dy, desc, x, &beta, desc, dx).cudnnCheck();
472 
473             cuCtxSynchronize();
474         }
475 
476         cudnnTensorDescriptor_t desc;
477         cudnnActivationDescriptor_t actDesc;
478     }
479 
480     class AddBias : CUDAKernel
481     {
482         this(Operation op)
483         {
484             auto shape = op.shape.map!(x => cast(int)x).array();
485 
486 			cudnnCreateTensorDescriptor(&cDesc).cudnnCheck();
487 			cudnnSetTensor4dDescriptor(cDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, shape[0], shape[1],
488                 reduce!"a * b"(1, shape[2 .. $]), 1).cudnnCheck();
489             
490             cudnnCreateTensorDescriptor(&aDesc).cudnnCheck();
491             cudnnSetTensor4dDescriptor(aDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, shape[1], 1, 1).cudnnCheck();
492         }
493 
494         ~this()
495         {
496             cudnnDestroyTensorDescriptor(cDesc).cudnnCheck();
497             cudnnDestroyTensorDescriptor(aDesc).cudnnCheck();
498         }
499 
500         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
501         {
502             cuMemcpy(output.ptr, inputs[0].ptr, output.numBytes);
503 
504             float alpha = 1;
505             float beta = 1;
506 
507             cudnnAddTensor(handle, &alpha, aDesc, cast(void *)inputs[1].ptr, &beta, cDesc, cast(void *)output.ptr);
508         }
509 
510         cudnnTensorDescriptor_t cDesc;
511         cudnnTensorDescriptor_t aDesc;
512     }
513 
514     class AddBiasGrad : CUDAKernel
515     {
516         this(Operation op)
517         {
518             auto shape = op.deps[0].shape.map!(x => cast(int)x).array();
519 
520 			cudnnCreateTensorDescriptor(&dyDesc).cudnnCheck();
521 			cudnnSetTensor4dDescriptor(dyDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, shape[0], shape[1],
522                 reduce!"a * b"(1, shape[2 .. $]), 1).cudnnCheck();
523             
524             cudnnCreateTensorDescriptor(&dbDesc).cudnnCheck();
525             cudnnSetTensor4dDescriptor(dbDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, shape[1], 1, 1).cudnnCheck();
526         }
527 
528         ~this()
529         {
530             cudnnDestroyTensorDescriptor(dyDesc).cudnnCheck();
531             cudnnDestroyTensorDescriptor(dbDesc).cudnnCheck();
532         }
533 
534         override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
535         {
536             float alpha = 1.0f;
537             float beta = 1.0f;
538 
539             cudnnConvolutionBackwardBias(handle, &alpha, dyDesc, cast(void *)inputs[0].ptr, &beta, dbDesc,
540                 cast(void *)output.ptr);
541         }
542 
543         cudnnTensorDescriptor_t dyDesc;
544         cudnnTensorDescriptor_t dbDesc;
545     }
546 
547     abstract class BatchNormBase : CUDAKernel
548     {
549         this(Operation op)
550         {
551             if(op.rank == 2)
552             {
553                 mode = CUDNN_BATCHNORM_PER_ACTIVATION;
554             }
555             else
556             {
557                 mode = CUDNN_BATCHNORM_SPATIAL;
558             }
559 
560             import std.range;
561 
562             auto shape = op.deps[0].shape
563                         .chain(repeat(1))
564                         .map!(x => cast(int)x)
565                         .take(4)
566                         .array();
567 
568             cudnnCreateTensorDescriptor(&xDesc).cudnnCheck();
569             cudnnCreateTensorDescriptor(&bnDesc).cudnnCheck();
570 
571             cudnnSetTensor4dDescriptor(xDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, shape[0], shape[1], shape[2],
572                 shape[3]).cudnnCheck();
573             cudnnDeriveBNTensorDescriptor(bnDesc, xDesc, mode).cudnnCheck();
574         }
575 
576         ~this()
577         {
578             cudnnDestroyTensorDescriptor(xDesc).cudnnCheck();
579             cudnnDestroyTensorDescriptor(bnDesc).cudnnCheck();
580         }
581 
582         cudnnBatchNormMode_t mode;
583         cudnnTensorDescriptor_t xDesc;
584         cudnnTensorDescriptor_t bnDesc;
585     }
586 
587     class BatchNormTrain : BatchNormBase
588     {
589         this(Operation op)
590         {
591             super(op);
592             mMomentum = 1.0 - op.attributes["momentum"].get!double;
593         }
594 
595         void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
596         {
597             float alpha = 1.0f;
598             float beta = 0.0f;
599 
600             //We're going to pack the running mean/variance after the BN forward prop. Let the higher level
601             //API slice them out into different nodes.
602             auto mean = output.ptr + inputs[0].numBytes;
603             auto var = mean + (output.numBytes - inputs[0].numBytes) / 2;
604 
605             cuMemcpy(mean, inputs[3].ptr, inputs[3].numBytes);
606             cuMemcpy(var, inputs[4].ptr, inputs[4].numBytes);
607 
608             cudnnBatchNormalizationForwardTraining(handle, mode, &alpha, &beta, xDesc,
609                 cast(void *)inputs[0].ptr, xDesc, cast(void *)output.ptr, bnDesc, cast(void *)inputs[1].ptr,
610                 cast(void *)inputs[2].ptr, mMomentum, cast(void *)mean, cast(void *)var, 1e-5f, null, null).cudnnCheck();
611         }
612 
613         double mMomentum;
614     }
615 
616     class BatchNormGrad : BatchNormBase
617     {
618         this(Operation op)
619         {
620             super(op);
621         }
622 
623         void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
624         {
625             float alpha = 1.0f;
626             float beta = 0.0f;
627 
628             void *dx = cast(void *)(output.ptr);
629             void *dscale = cast(void *)(output.ptr + inputs[1].numBytes);
630             void *dbias = cast(void *)(output.ptr + inputs[1].numBytes + inputs[2].numBytes);
631 
632             cudnnBatchNormalizationBackward(handle, mode, &alpha, &beta, &alpha, &beta, xDesc,
633                 cast(void *)inputs[1].ptr, xDesc, cast(void *)inputs[0].ptr, xDesc, dx, bnDesc,
634                 cast(void *)inputs[2].ptr, dscale, dbias, 1e-5f, null, null);
635         }
636     }
637 
638     class BatchNormInference : BatchNormBase
639     {
640         this(Operation op)
641         {
642             super(op);
643         }
644 
645         void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
646         {
647             float alpha = 1.0f;
648             float beta = 0.0f;
649 
650             cudnnBatchNormalizationForwardInference(handle, mode, &alpha, &beta, xDesc, cast(void *)inputs[0].ptr,
651                 xDesc, cast(void *)output.ptr, bnDesc, cast(void *)inputs[1].ptr, cast(void *)inputs[2].ptr,
652                 cast(void *)inputs[3].ptr, cast(void *)inputs[4].ptr, 1e-5).cudnnCheck();
653         }
654     }
655 }