1 module dopt.cuda.nvrtc;
2 
3 import std.algorithm;
4 import std.array;
5 import std..string;
6 
7 import derelict.cuda;
8 import derelict.nvrtc;
9 
10 package
11 {
12     void initialize()
13     {
14         DerelictNVRTC.load();
15     }
16 }
17 
18 private
19 {
20     void nvrtcCheck(nvrtcResult res, string mod = __MODULE__, size_t line = __LINE__)
21     {
22         import std.conv : to;
23         import std.exception : enforce;
24         enforce(res == NVRTC_SUCCESS,
25             mod ~ "(" ~ line.to!string ~ "): Failed to execute NVRTC function. Error " ~ res.to!string);
26     }
27 }
28 
29 class NVRTCKernel
30 {
31     public
32     {
33         /**
34             Constructs an NVRTCKernel with the given code and entry point.
35 
36             Params:
37                 entry = The name of the function inside the CUDA code that should b executed.
38                 code = A string containing the CUDA code to be compiled.
39         */
40         this(string entry, string code)
41         {
42             immutable(char) *entryz = entry.toStringz;
43             immutable(char) *codez = code.toStringz;
44 
45             nvrtcProgram program;
46 
47 			auto options = [//"compute_20",
48 							//"compute_30",
49 							"compute_35"//,
50 							//"compute_50",
51 							//"compute_52",
52 							/*"compute_53"*/].map!(x => ("--gpu-architecture=" ~ x).toStringz()).array();
53 
54 			nvrtcCreateProgram(&program, codez, entryz, 0, null, null).nvrtcCheck;
55 			nvrtcCompileProgram(program, cast(int)options.length, options.ptr).nvrtcCheck;
56 
57 			size_t logSize;
58 			nvrtcGetProgramLogSize(program, &logSize).nvrtcCheck;
59 
60 			if(logSize > 1)
61 			{
62 				auto log = new char[logSize];
63 				nvrtcGetProgramLog(program, log.ptr).nvrtcCheck;
64 
65 				import std.stdio;
66 				stderr.writeln(log[0 .. $ - 1]);
67 			}
68 
69 			size_t ptxSize;
70 			nvrtcGetPTXSize(program, &ptxSize).nvrtcCheck;
71 
72 			auto ptx = new char[ptxSize];
73 			nvrtcGetPTX(program, ptx.ptr).nvrtcCheck;
74 			nvrtcDestroyProgram(&program).nvrtcCheck;
75 
76 			cuModuleLoadDataEx(&mModule, ptx.ptr, 0, null, null);
77 			cuModuleGetFunction(&mKernel, mModule, entryz);
78         }
79 
80         void execute(Args...)(uint numBlocks, uint numThreads, Args args)
81         {
82             execute([numBlocks, 1, 1], [numThreads, 1, 1], args);
83         }
84 
85         void execute(Args...)(uint[3] numBlocks, uint[3] numThreads, Args args)
86         {
87             void*[] argPtrs = new void*[args.length];
88 
89             string genCode()
90             {
91                 import std.conv : to;
92                 string ret;
93 
94                 for(size_t i = 0; i < args.length; i++)
95                 {
96                     ret ~= "argPtrs[" ~ i.to!string ~ "] = cast(void *)&args[" ~ i.to!string ~ "];";
97                 }
98 
99                 return ret;
100             }
101 
102             mixin(genCode());
103             
104             cuLaunchKernel(
105                 mKernel,
106                 numBlocks[0], numBlocks[1], numBlocks[2],
107                 numThreads[0], numThreads[1], numThreads[2],
108                 0, null, argPtrs.ptr, null
109             );
110 
111             cuCtxSynchronize();
112         }
113 
114         void execute(uint[3] numBlocks, uint[3] numThreads, void*[] argPtrs)
115         {
116             cuLaunchKernel(
117                 mKernel,
118                 numBlocks[0], numBlocks[1], numBlocks[2],
119                 numThreads[0], numThreads[1], numThreads[2],
120                 0, null, argPtrs.ptr, null
121             );
122 
123             cuCtxSynchronize();
124         }
125     }
126 
127     private
128     {
129         CUmodule mModule;
130         CUfunction mKernel;
131     }
132 }