1 module dopt.nnet.data.cifar; 2 3 import std.algorithm; 4 import std.array; 5 import std.file; 6 import std.range; 7 import std.stdio; 8 import std.typecons; 9 10 import dopt.nnet.data; 11 12 auto loadCIFAR10(string path, bool validation = false) 13 { 14 auto batchFiles = [ 15 "data_batch_1.bin", 16 "data_batch_2.bin", 17 "data_batch_3.bin", 18 "data_batch_4.bin", 19 "data_batch_5.bin", 20 "test_batch.bin" 21 ]; 22 23 return loadCIFAR(path, batchFiles, 1, 0, 10, validation); 24 } 25 26 auto loadCIFAR100(string path, bool validation = false) 27 { 28 auto batchFiles = ["train.bin", "test.bin"]; 29 30 return loadCIFAR(path, batchFiles, 2, 1, 100, validation); 31 } 32 33 private 34 { 35 auto loadCIFAR(string path, string[] batchFiles, size_t labelBytes, size_t labelIdx, size_t numLabels, bool valid) 36 { 37 auto batches = batchFiles.map!(x => path ~ "/" ~ x).array(); 38 39 alias T = float; 40 T[][] features; 41 T[][] labels; 42 43 foreach(b; batches) 44 { 45 ubyte[] raw = cast(ubyte[])read(b); 46 47 foreach(tmp; raw.chunks(3 * 32 * 32 + labelBytes)) 48 { 49 auto f = tmp[labelBytes .. $] 50 .map!(x => cast(T)x / 128.0f - 1.0f) 51 .array(); 52 53 auto ls = new T[numLabels]; 54 ls[] = 0; 55 ls[tmp[labelIdx]] = 1.0f; 56 labels ~= ls; 57 features ~= f; 58 } 59 } 60 61 size_t numTrain = valid ? 40_000 : 50_000; 62 63 BatchIterator trainData = new SupervisedBatchIterator( 64 features[0 .. numTrain], 65 labels[0 .. numTrain], 66 [[cast(size_t)3, 32, 32], [numLabels]], 67 true 68 ); 69 70 BatchIterator testData = new SupervisedBatchIterator( 71 features[numTrain .. numTrain + 10_000], 72 labels[numTrain .. numTrain + 10_000], 73 [[cast(size_t)3, 32, 32], [numLabels]], 74 false 75 ); 76 77 import std.typecons; 78 79 return tuple!("train", "test")(trainData, testData); 80 } 81 }