1 module dopt.nnet.data.cifar;
2 
3 import std.algorithm;
4 import std.array;
5 import std.file;
6 import std.range;
7 import std.stdio;
8 import std.typecons;
9 
10 import dopt.nnet.data;
11 
12 auto loadCIFAR10(string path)
13 {
14 	auto batchFiles = [
15 		"data_batch_1.bin",
16 		"data_batch_2.bin",
17 		"data_batch_3.bin",
18 		"data_batch_4.bin",
19 		"data_batch_5.bin",
20 		"test_batch.bin"
21 	];
22 
23 	return loadCIFAR(path, batchFiles, 1, 0, 10);
24 }
25 
26 auto loadCIFAR100(string path)
27 {
28 	auto batchFiles = ["train.bin", "test.bin"];
29 
30 	return loadCIFAR(path, batchFiles, 2, 1, 100);
31 }
32 
33 private
34 {
35 	auto loadCIFAR(string path, string[] batchFiles, size_t labelBytes, size_t labelIdx, size_t numLabels)
36 	{
37 		auto batches = batchFiles.map!(x => path ~ "/" ~ x).array();
38 
39 		alias T = float;
40 		T[][] features;
41 		T[][] labels;
42 
43 		foreach(b; batches)
44 		{
45 			ubyte[] raw = cast(ubyte[])read(b);
46 
47 			foreach(tmp; raw.chunks(3 * 32 * 32 + labelBytes))
48 			{
49 				auto f = tmp[labelBytes .. $]
50 					.map!(x => cast(T)x / 128.0f - 1.0f)
51 					.array();
52 
53 				auto ls = new T[numLabels];
54 				ls[] = 0;
55 				ls[tmp[labelIdx]] = 1.0f;
56 				labels ~= ls;
57 				features ~= f;
58 			}
59 		}
60 
61 		BatchIterator trainData = new SupervisedBatchIterator(
62 			features[0 .. 50_000],
63 			labels[0 .. 50_000],
64 			[[cast(size_t)3, 32, 32], [numLabels]],
65 			true
66 		);
67 
68 		BatchIterator testData = new SupervisedBatchIterator(
69 			features[50_000 .. $],
70 			labels[50_000 .. $],
71 			[[cast(size_t)3, 32, 32], [numLabels]],
72 			false
73 		);
74 
75 		import std.typecons;
76     
77     	return tuple!("train", "test")(trainData, testData);
78 	}
79 }