1 #!/usr/bin/env dub
2 /+ dub.sdl:
3 dependency "dopt" path=".."
4 dependency "progress-d" version="~>1.0.0"
5 +/
6 module cifar100;
7 
8 import dopt.core;
9 import dopt.nnet;
10 import dopt.online;
11 import progress;
12 
13 void main(string[] args)
14 {
15 	import std.algorithm : joiner;
16 	import std.array : array;
17 	import std.format : format;
18 	import std.range : zip, chunks;
19 	import std.stdio : stderr, stdout, write, writeln;
20 
21 	if(args.length != 2)
22 	{
23 		stderr.writeln("Usage: cifar100.d <data directory>");
24 		return;
25 	}
26 
27     writeln("Loading data...");
28     auto data = loadCIFAR100(args[1]);
29 	data.train = new ImageTransformer(data.train, 4, 4, true, false);
30 
31     writeln("Constructing network graph...");
32     size_t batchSize = 48;
33     auto features = float32([batchSize, 3, 32, 32]);
34     auto labels = float32([batchSize, 100]);
35 
36     auto preds = wideResNet(features, 16, 4)
37                 .dense(100)
38                 .softmax();
39 
40     auto network = new DAGNetwork([features], [preds]);
41     
42     auto lossSym = crossEntropy(preds.trainOutput, labels) + network.paramLoss;
43 	auto testLossSym = crossEntropy(preds.output, labels) + network.paramLoss;
44 
45     writeln("Creating optimiser...");
46 	auto learningRate = float32(0.1f);
47 	auto momentumRate = float32(0.9);
48 	auto updater = sgd([lossSym, preds.trainOutput], network.params, network.paramProj, learningRate, momentumRate);
49 	auto testPlan = compile([testLossSym, preds.output]);
50 
51 	writeln("Training...");
52 
53 	float[] fs = new float[features.volume];
54 	float[] ls = new float[labels.volume];
55 	size_t bidx;
56 
57 	//Iterate for 160 epochs of training!
58 	foreach(e; 0 .. 200)
59 	{
60 		float trainLoss = 0;
61         float testLoss = 0;
62         float trainAcc = 0;
63         float testAcc = 0;
64         float trainNum = 0;
65         float testNum = 0;
66 
67 		//Decreasing the learning rate after a while often results in better performance.
68 		if(e == 60)
69 		{
70 			learningRate.value.set([0.02f]);
71 		}
72 		else if(e == 120)
73 		{
74 			learningRate.value.set([0.004f]);
75 		}
76         else if(e == 160)
77         {
78             learningRate.value.set([0.0008f]);
79         }
80 
81 		data.train.restart();
82 		data.test.restart();
83 
84 		auto trainProgress = new Progress(data.train.length / batchSize);
85 
86 		while(!data.train.finished())
87 		{
88 			//Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index.
89 			data.train.getBatch([fs, ls]);
90 
91 			//Make an update to the model parameters using the minibatch of training data
92 			auto res = updater([
93 				features: buffer(fs),
94 				labels: buffer(ls)
95 			]);
96 
97 			trainLoss += res[0].get!float[0] * batchSize;
98 			trainAcc += computeAccuracy(ls, res[1].get!float);
99 			trainNum += batchSize;
100 
101 			float loss = trainLoss / trainNum;
102 			float acc = trainAcc / trainNum;
103 
104 			trainProgress.title = format("Epoch: %03d  Loss: %02.4f  Acc: %.4f", e + 1, loss, acc);
105             trainProgress.next();
106 		}
107 
108 		writeln();
109 
110 		auto testProgress = new Progress(data.test.length / batchSize);
111 
112 		while(!data.test.finished)
113 		{
114 			//Get the next batch of testing data
115 			data.test.getBatch([fs, ls]);
116 
117 			//Make some predictions
118 			auto res = testPlan.execute([
119 				features: buffer(fs),
120 				labels: buffer(ls)
121 			]);
122 
123 			testLoss += res[0].get!float[0] * batchSize;
124 			testAcc += computeAccuracy(ls, res[1].get!float);
125 			testNum += batchSize;
126 
127 			float loss = testLoss / testNum;
128 			float acc = testAcc / testNum;
129 
130 			testProgress.title = format("            Loss: %02.4f  Acc: %.4f", loss, acc);
131             testProgress.next();
132 		}
133 
134 		writeln();
135 		writeln();
136 	}
137 }
138 
139 float computeAccuracy(float[] ls, float[] preds)
140 {
141 	import std.algorithm : maxIndex;
142 	import std.range : chunks, zip;
143 
144 	float correct = 0;
145 
146 	foreach(p, t; zip(preds.chunks(100), ls.chunks(100)))
147 	{
148 		if(p.maxIndex == t.maxIndex)
149 		{
150 			correct++;
151 		}
152 	}
153 
154 	return correct;
155 }