1 module dopt.nnet.data.svhn;
2 
3 import std.algorithm;
4 import std.array;
5 import std.file;
6 import std.range;
7 import std.typecons;
8 
9 import dopt.nnet.data;
10 
11 auto loadSVHN(string directory, bool validation = false)
12 {
13     auto loadFeatures(string filename)
14     {
15         return (cast(ubyte[])read(directory ~ "/" ~ filename))
16               .map!(x => cast(float)x / 128.0f - 1.0f)
17               .chunks(32 * 32 * 3)
18               .map!(x => x.array())
19               .array();
20     }
21 
22     auto loadLabels(string filename)
23     {
24         auto lbls = (cast(ubyte[])read(directory ~ "/" ~ filename));
25 
26         auto ret = new float[][lbls.length];
27 
28         for(size_t i = 0; i < lbls.length; i++)
29         {
30             ret[i] = new float[10];
31             ret[i][] = 0.0f;
32             ret[i][lbls[i] - 1] = 1.0f;
33         }
34 
35         return ret;
36     }
37 
38     auto trainFeatures = loadFeatures("train_X.bin") ~ loadFeatures("extra_X.bin");
39     auto testFeatures = loadFeatures("test_X.bin");
40     auto trainLabels = loadLabels("train_y.bin") ~ loadLabels("extra_y.bin");
41     auto testLabels = loadLabels("test_y.bin");
42 
43     if(validation)
44     {
45         testFeatures = trainFeatures[0 .. 10_000];
46         testLabels = trainLabels[0 .. 10_000];
47         trainFeatures = trainFeatures[10_000 .. $];
48         trainLabels = trainLabels[10_000 .. $];
49     }
50 
51     BatchIterator trainData = new SupervisedBatchIterator(
52         trainFeatures,
53         trainLabels,
54         [[cast(size_t)3, 32, 32], [cast(size_t)10]],
55         true
56     );
57 
58     BatchIterator testData = new SupervisedBatchIterator(
59         testFeatures,
60         testLabels,
61         [[cast(size_t)3, 32, 32], [cast(size_t)10]],
62         false
63     );
64 
65     return tuple!("train", "test")(trainData, testData);
66 }