1 module dopt.nnet.data.mnist;
2 
3 import std.algorithm;
4 import std.array;
5 import std.file;
6 import std.range;
7 import std.typecons;
8 
9 import dopt.nnet.data;
10 
11 auto loadMNIST(string path)
12 {
13     T[][] loadFeatures(T)(string filename)
14     {
15         const size_t numFeatures = 28 * 28;
16 
17         //Load the data from disk
18         ubyte[] raw = cast(ubyte[])read(filename);
19 
20         //Skip over the header
21         raw = raw[16 .. $];
22 
23         //Get the number of instances in this file
24         size_t numInstances = raw.length / numFeatures;
25 
26         //Allocate space to store the references to each instance
27         T[][] result = new T[][numInstances];
28 
29         //Convert the ubytes to floats
30         T[] features = raw.map!(x => cast(T)x).array();
31 
32         //Iterate over each instance and set the references to the correct slice
33         for(size_t i = 0; i < numInstances; i++)
34         {
35             result[i] = features[i * numFeatures .. (i + 1) * numFeatures];
36             result[i][] /= 128.0f;
37             result[i][] -= 1.0f;
38         }
39 
40         return result;
41     }
42 
43     T[][] loadLabels(T)(string filename)
44     {
45         const size_t numLabels = 10;
46 
47         //Load the data from disk
48         ubyte[] raw = cast(ubyte[])read(filename);
49 
50         //Skip over the header
51         raw = raw[8 .. $];
52 
53         //Get the number of instances in this file
54         size_t numInstances = raw.length;
55 
56         //Allocate space to store the references to each instance
57         T[][] result = new T[][numInstances];
58         T[] labels = new T[numInstances * numLabels];
59         labels[] = 0.0;
60 
61         //Create the one-hot encoding array and set up references to the appropriate slices for each instance
62         for(size_t i = 0; i < numInstances; i++)
63         {
64             result[i] = labels[i * numLabels .. (i + 1) * numLabels];
65             result[i][raw[i]] = 1.0;
66         }
67 
68         return result;
69     }
70 
71 	auto trainFeatures = loadFeatures!float(path ~ "/train-images-idx3-ubyte");
72 	auto trainLabels = loadLabels!float(path ~ "/train-labels-idx1-ubyte");
73 	auto testFeatures = loadFeatures!float(path ~ "/t10k-images-idx3-ubyte");
74 	auto testLabels = loadLabels!float(path ~ "/t10k-labels-idx1-ubyte");
75 
76 	BatchIterator trainData = new SupervisedBatchIterator(trainFeatures, trainLabels, [[1, 28, 28], [10]], true);
77     BatchIterator testData = new SupervisedBatchIterator(testFeatures, testLabels, [[1, 28, 28], [10]], false);
78 
79     import std.typecons;
80 
81     return tuple!("train", "test")(trainData, testData);
82 }