1 module dopt.cuda.random;
2 
3 import std.algorithm;
4 import std.conv;
5 import std.functional;
6 import std.math;
7 import std.random;
8 import std.range;
9 
10 import dopt.cuda;
11 import dopt.cuda.nvrtc;
12 import dopt.core.ops;
13 import dopt.core.types;
14 
15 package
16 {
17     extern(C)
18     {
19         struct curandGenerator_st;
20         alias curandGenerator_t = curandGenerator_st *;
21 
22         immutable int CURAND_RNG_PSEUDO_DEFAULT = 100;
23 
24         int function(curandGenerator_t *generator, int rng_type) curandCreateGenerator;
25         int function(curandGenerator_t generator) curandDestroyGenerator;
26         int function(curandGenerator_t generator, ulong seed) curandSetPseudoRandomGeneratorSeed;
27         int function(curandGenerator_t generator, float *outputPtr, size_t num) curandGenerateUniform;
28     }
29 
30     void initialize()
31     {
32         //TODO: make a DerelictCuRAND library
33         import core.sys.posix.dlfcn;
34 		import std..string;
35 		fh = dlopen("libcurand.so".toStringz, RTLD_LAZY);
36 
37         curandCreateGenerator = cast(typeof(curandCreateGenerator))dlsym(fh, "curandCreateGenerator");
38         curandDestroyGenerator = cast(typeof(curandDestroyGenerator))dlsym(fh, "curandDestroyGenerator");
39         curandSetPseudoRandomGeneratorSeed =
40             cast(typeof(curandSetPseudoRandomGeneratorSeed))dlsym(fh, "curandSetPseudoRandomGeneratorSeed");
41         curandGenerateUniform = cast(typeof(curandGenerateUniform))dlsym(fh, "curandGenerateUniform");
42 
43         registerCUDAKernel("uniform", toDelegate(&uniformCtor));
44     }
45 }
46 
47 private
48 {
49     void *fh;
50 
51     CUDAKernel uniformCtor(Operation op)
52     {
53         return new UniformSample(op);
54     }
55 
56     class UniformSample : CUDAKernel
57     {
58         public
59         {
60             this(Operation op)
61             {
62                 mOp = op;
63                 curandCreateGenerator(&mGen, CURAND_RNG_PSEUDO_DEFAULT);
64                 curandSetPseudoRandomGeneratorSeed(mGen, cast(ulong)unpredictableSeed());
65             }
66 
67             ~this()
68             {
69                 curandDestroyGenerator(mGen);
70             }
71 
72             override void execute(const(CUDABuffer)[] inputs, CUDABuffer output)
73             {
74                 curandGenerateUniform(mGen, cast(float *)output.ptr, mOp.volume);
75             }
76         }
77 
78         private
79         {
80             Operation mOp;
81             curandGenerator_t mGen;
82         }
83     }
84 }