1 module dopt.nnet.data.cifar; 2 3 import std.algorithm; 4 import std.array; 5 import std.file; 6 import std.range; 7 import std.stdio; 8 import std.typecons; 9 10 import dopt.nnet.data; 11 12 auto loadCIFAR10(string path) 13 { 14 auto batchFiles = [ 15 "data_batch_1.bin", 16 "data_batch_2.bin", 17 "data_batch_3.bin", 18 "data_batch_4.bin", 19 "data_batch_5.bin", 20 "test_batch.bin" 21 ]; 22 23 return loadCIFAR(path, batchFiles, 1, 0, 10); 24 } 25 26 auto loadCIFAR100(string path) 27 { 28 auto batchFiles = ["train.bin", "test.bin"]; 29 30 return loadCIFAR(path, batchFiles, 2, 1, 100); 31 } 32 33 private 34 { 35 auto loadCIFAR(string path, string[] batchFiles, size_t labelBytes, size_t labelIdx, size_t numLabels) 36 { 37 auto batches = batchFiles.map!(x => path ~ "/" ~ x).array(); 38 39 alias T = float; 40 T[][] features; 41 T[][] labels; 42 43 foreach(b; batches) 44 { 45 ubyte[] raw = cast(ubyte[])read(b); 46 47 foreach(tmp; raw.chunks(3 * 32 * 32 + labelBytes)) 48 { 49 auto f = tmp[labelBytes .. $] 50 .map!(x => cast(T)x / 128.0f - 1.0f) 51 .array(); 52 53 auto ls = new T[numLabels]; 54 ls[] = 0; 55 ls[tmp[labelIdx]] = 1.0f; 56 labels ~= ls; 57 features ~= f; 58 } 59 } 60 61 BatchIterator trainData = new SupervisedBatchIterator( 62 features[0 .. 50_000], 63 labels[0 .. 50_000], 64 [[cast(size_t)3, 32, 32], [numLabels]], 65 true 66 ); 67 68 BatchIterator testData = new SupervisedBatchIterator( 69 features[50_000 .. $], 70 labels[50_000 .. $], 71 [[cast(size_t)3, 32, 32], [numLabels]], 72 false 73 ); 74 75 import std.typecons; 76 77 return tuple!("train", "test")(trainData, testData); 78 } 79 }