1 module dopt.nnet.data.mnist; 2 3 import std.algorithm; 4 import std.array; 5 import std.file; 6 import std.range; 7 import std.typecons; 8 9 import dopt.nnet.data; 10 11 auto loadMNIST(string path) 12 { 13 T[][] loadFeatures(T)(string filename) 14 { 15 const size_t numFeatures = 28 * 28; 16 17 //Load the data from disk 18 ubyte[] raw = cast(ubyte[])read(filename); 19 20 //Skip over the header 21 raw = raw[16 .. $]; 22 23 //Get the number of instances in this file 24 size_t numInstances = raw.length / numFeatures; 25 26 //Allocate space to store the references to each instance 27 T[][] result = new T[][numInstances]; 28 29 //Convert the ubytes to floats 30 T[] features = raw.map!(x => cast(T)x).array(); 31 32 //Iterate over each instance and set the references to the correct slice 33 for(size_t i = 0; i < numInstances; i++) 34 { 35 result[i] = features[i * numFeatures .. (i + 1) * numFeatures]; 36 result[i][] /= 128.0f; 37 result[i][] -= 1.0f; 38 } 39 40 return result; 41 } 42 43 T[][] loadLabels(T)(string filename) 44 { 45 const size_t numLabels = 10; 46 47 //Load the data from disk 48 ubyte[] raw = cast(ubyte[])read(filename); 49 50 //Skip over the header 51 raw = raw[8 .. $]; 52 53 //Get the number of instances in this file 54 size_t numInstances = raw.length; 55 56 //Allocate space to store the references to each instance 57 T[][] result = new T[][numInstances]; 58 T[] labels = new T[numInstances * numLabels]; 59 labels[] = 0.0; 60 61 //Create the one-hot encoding array and set up references to the appropriate slices for each instance 62 for(size_t i = 0; i < numInstances; i++) 63 { 64 result[i] = labels[i * numLabels .. (i + 1) * numLabels]; 65 result[i][raw[i]] = 1.0; 66 } 67 68 return result; 69 } 70 71 auto trainFeatures = loadFeatures!float(path ~ "/train-images-idx3-ubyte"); 72 auto trainLabels = loadLabels!float(path ~ "/train-labels-idx1-ubyte"); 73 auto testFeatures = loadFeatures!float(path ~ "/t10k-images-idx3-ubyte"); 74 auto testLabels = loadLabels!float(path ~ "/t10k-labels-idx1-ubyte"); 75 76 BatchIterator trainData = new SupervisedBatchIterator(trainFeatures, trainLabels, [[1, 28, 28], [10]], true); 77 BatchIterator testData = new SupervisedBatchIterator(testFeatures, testLabels, [[1, 28, 28], [10]], false); 78 79 import std.typecons; 80 81 return tuple!("train", "test")(trainData, testData); 82 }