1 module dopt.nnet.data.svhn;
2 
3 import std.algorithm;
4 import std.array;
5 import std.file;
6 import std.range;
7 import std.typecons;
8 
9 import dopt.nnet.data;
10 
11 auto loadSVHN(string directory)
12 {
13     auto loadFeatures(string filename)
14     {
15         return (cast(ubyte[])read(directory ~ "/" ~ filename))
16               .map!(x => cast(float)x / 128.0f - 1.0f)
17               .chunks(32 * 32 * 3)
18               .map!(x => x.array())
19               .array();
20     }
21 
22     auto loadLabels(string filename)
23     {
24         auto lbls = (cast(ubyte[])read(directory ~ "/" ~ filename));
25 
26         auto ret = new float[][lbls.length];
27 
28         for(size_t i = 0; i < lbls.length; i++)
29         {
30             ret[i] = new float[10];
31             ret[i][] = 0.0f;
32             ret[i][lbls[i] - 1] = 1.0f;
33         }
34 
35         return ret;
36     }
37 
38     auto trainFeatures = loadFeatures("train_X.bin") ~ loadFeatures("extra_X.bin");
39     auto testFeatures = loadFeatures("test_X.bin");
40     auto trainLabels = loadLabels("train_y.bin") ~ loadLabels("extra_y.bin");
41     auto testLabels = loadLabels("test_y.bin");
42 
43     BatchIterator trainData = new SupervisedBatchIterator(
44         trainFeatures,
45         trainLabels,
46         [[cast(size_t)3, 32, 32], [cast(size_t)10]],
47         true
48     );
49 
50     BatchIterator testData = new SupervisedBatchIterator(
51         testFeatures,
52         testLabels,
53         [[cast(size_t)3, 32, 32], [cast(size_t)10]],
54         false
55     );
56 
57     return tuple!("train", "test")(trainData, testData);
58 }