1 module dopt.cuda.nvrtc;
2 
3 import std.algorithm;
4 import std.array;
5 import std..string;
6 
7 import derelict.cuda;
8 import derelict.nvrtc;
9 
10 package
11 {
12     void initialize()
13     {
14         DerelictNVRTC.load();
15     }
16 }
17 
18 private
19 {
20     void nvrtcCheck(nvrtcResult res, string mod = __MODULE__, size_t line = __LINE__)
21     {
22         import std.conv : to;
23         import std.exception : enforce;
24         enforce(res == NVRTC_SUCCESS, mod ~ "(" ~ line.to!string ~ "): Failed to execute NVRTC function.");
25     }
26 }
27 
28 class NVRTCKernel
29 {
30     public
31     {
32         /**
33             Constructs an NVRTCKernel with the given code and entry point.
34 
35             Params:
36                 entry = The name of the function inside the CUDA code that should b executed.
37                 code = A string containing the CUDA code to be compiled.
38         */
39         this(string entry, string code)
40         {
41             immutable(char) *entryz = entry.toStringz;
42             immutable(char) *codez = code.toStringz;
43 
44             nvrtcProgram program;
45 
46 			auto options = [//"compute_20",
47 							//"compute_30",
48 							"compute_35"//,
49 							//"compute_50",
50 							//"compute_52",
51 							/*"compute_53"*/].map!(x => ("--gpu-architecture=" ~ x).toStringz()).array();
52 
53 			nvrtcCreateProgram(&program, codez, entryz, 0, null, null).nvrtcCheck;
54 			nvrtcCompileProgram(program, cast(int)options.length, options.ptr).nvrtcCheck;
55 
56 			size_t logSize;
57 			nvrtcGetProgramLogSize(program, &logSize).nvrtcCheck;
58 
59 			if(logSize > 1)
60 			{
61 				auto log = new char[logSize];
62 				nvrtcGetProgramLog(program, log.ptr).nvrtcCheck;
63 
64 				import std.stdio;
65 				stderr.writeln(log[0 .. $ - 1]);
66 			}
67 
68 			size_t ptxSize;
69 			nvrtcGetPTXSize(program, &ptxSize).nvrtcCheck;
70 
71 			auto ptx = new char[ptxSize];
72 			nvrtcGetPTX(program, ptx.ptr).nvrtcCheck;
73 			nvrtcDestroyProgram(&program).nvrtcCheck;
74 
75 			cuModuleLoadDataEx(&mModule, ptx.ptr, 0, null, null);
76 			cuModuleGetFunction(&mKernel, mModule, entryz);
77         }
78 
79         void execute(Args...)(uint numBlocks, uint numThreads, Args args)
80         {
81             execute([numBlocks, 1, 1], [numThreads, 1, 1], args);
82         }
83 
84         void execute(Args...)(uint[3] numBlocks, uint[3] numThreads, Args args)
85         {
86             void*[] argPtrs = new void*[args.length];
87 
88             string genCode()
89             {
90                 import std.conv : to;
91                 string ret;
92 
93                 for(size_t i = 0; i < args.length; i++)
94                 {
95                     ret ~= "argPtrs[" ~ i.to!string ~ "] = cast(void *)&args[" ~ i.to!string ~ "];";
96                 }
97 
98                 return ret;
99             }
100 
101             mixin(genCode());
102             
103             cuLaunchKernel(
104                 mKernel,
105                 numBlocks[0], numBlocks[1], numBlocks[2],
106                 numThreads[0], numThreads[1], numThreads[2],
107                 0, null, argPtrs.ptr, null
108             );
109 
110             cuCtxSynchronize();
111         }
112 
113         void execute(uint[3] numBlocks, uint[3] numThreads, void*[] argPtrs)
114         {
115             cuLaunchKernel(
116                 mKernel,
117                 numBlocks[0], numBlocks[1], numBlocks[2],
118                 numThreads[0], numThreads[1], numThreads[2],
119                 0, null, argPtrs.ptr, null
120             );
121 
122             cuCtxSynchronize();
123         }
124     }
125 
126     private
127     {
128         CUmodule mModule;
129         CUfunction mKernel;
130     }
131 }