1 module dopt.nnet.data.cifar;
2 
3 import std.algorithm;
4 import std.array;
5 import std.file;
6 import std.range;
7 import std.stdio;
8 import std.typecons;
9 
10 import dopt.nnet.data;
11 
12 auto loadCIFAR10(string path, bool validation = false)
13 {
14 	auto batchFiles = [
15 		"data_batch_1.bin",
16 		"data_batch_2.bin",
17 		"data_batch_3.bin",
18 		"data_batch_4.bin",
19 		"data_batch_5.bin",
20 		"test_batch.bin"
21 	];
22 
23 	return loadCIFAR(path, batchFiles, 1, 0, 10, validation);
24 }
25 
26 auto loadCIFAR100(string path, bool validation = false)
27 {
28 	auto batchFiles = ["train.bin", "test.bin"];
29 
30 	return loadCIFAR(path, batchFiles, 2, 1, 100, validation);
31 }
32 
33 private
34 {
35 	auto loadCIFAR(string path, string[] batchFiles, size_t labelBytes, size_t labelIdx, size_t numLabels, bool valid)
36 	{
37 		auto batches = batchFiles.map!(x => path ~ "/" ~ x).array();
38 
39 		alias T = float;
40 		T[][] features;
41 		T[][] labels;
42 
43 		foreach(b; batches)
44 		{
45 			ubyte[] raw = cast(ubyte[])read(b);
46 
47 			foreach(tmp; raw.chunks(3 * 32 * 32 + labelBytes))
48 			{
49 				auto f = tmp[labelBytes .. $]
50 					.map!(x => cast(T)x / 128.0f - 1.0f)
51 					.array();
52 
53 				auto ls = new T[numLabels];
54 				ls[] = 0;
55 				ls[tmp[labelIdx]] = 1.0f;
56 				labels ~= ls;
57 				features ~= f;
58 			}
59 		}
60 
61 		size_t numTrain = valid ? 40_000 : 50_000;
62 
63 		BatchIterator trainData = new SupervisedBatchIterator(
64 			features[0 .. numTrain],
65 			labels[0 .. numTrain],
66 			[[cast(size_t)3, 32, 32], [numLabels]],
67 			true
68 		);
69 
70 		BatchIterator testData = new SupervisedBatchIterator(
71 			features[numTrain .. numTrain + 10_000],
72 			labels[numTrain .. numTrain + 10_000],
73 			[[cast(size_t)3, 32, 32], [numLabels]],
74 			false
75 		);
76 
77 		import std.typecons;
78     
79     	return tuple!("train", "test")(trainData, testData);
80 	}
81 }