1 #!/usr/bin/env dub
2 /+ dub.sdl:
3 dependency "dopt" path=".."
4 dependency "progress-d" version="~>1.0.0"
5 +/
6 module cifar10;
7 
8 import dopt.core;
9 import dopt.nnet;
10 import dopt.online;
11 import progress;
12 
13 /*
14 	This example trains a VGG19-style network on the CIFAR-10 dataset of tiny images.
15 
16 	VGG networks are fairly easy to understand, compared to some of the more recently presented models like GoogLeNet.
17 	See ``Very Deep Convolutional Networks for Large-Scale Image Recognition'' by Simonyan and Zisserman for more
18 	details. This example uses the dopt.nnet.models package to make defining a VGG model very easy.
19 
20 	The CIFAR-10 dataset contains 60,000 32x32 pixel colour images. Each of these images belongs to one of 10 classes.
21 	In the standard setting, 50,000 of these images are used for training a model, and the other 10,000 are used for
22 	evaluating how well the model works.
23 */
24 
25 void main(string[] args)
26 {
27 	import std.algorithm : joiner;
28 	import std.array : array;
29 	import std.format : format;
30 	import std.range : zip, chunks;
31 	import std.stdio : stderr, stdout, write, writeln;
32 
33 	if(args.length != 2)
34 	{
35 		stderr.writeln("Usage: cifar10.d <data directory>");
36 		return;
37 	}
38 
39 	/*
40 		Loads the CIFAR-10 dataset. Download this in the binary format from https://www.cs.toronto.edu/~kriz/cifar.html
41 
42 		This also wraps the Dataset in an ImageTransformer, which will procedurally generate random crops and
43 		horizontal flips of the training images---a popular form of data augmentation for image datasets.
44 	*/
45 	writeln("Loading data...");
46     auto data = loadCIFAR10(args[1]);
47 	data.train = new ImageTransformer(data.train, 4, 4, true, false);
48 
49 	/*
50 	Now we create two variable nodes. ``features'' is used to represent a minibatch of input images, and ``labels''
51 	will be used to represent the label corresponding to each of those images.
52 	*/
53 	writeln("Constructing network graph...");
54 	size_t batchSize = 100;
55     auto features = float32([batchSize, 3, 32, 32]);
56     auto labels = float32([batchSize, 10]);
57 
58 	/*
59 	There are a few predefined models in dopt.nnet.models, such as vgg19. We provide it with the variable we want to
60 	use as the input to this model, tell it what sizes the fully connected layers should be, and then put a softmax
61 	activation function on the end. The softmax function is the standard activation function when one is performing
62 	a classification task.
63 	*/
64     auto preds = vgg19(features, [512, 512])
65 				.dense(10)
66 				.softmax();
67     
68 	//The DAGNetwork class takes the inputs and outputs of a network and aggregates parameters in several different.
69     auto network = new DAGNetwork([features], [preds]);
70 
71 	/*
72 	Layer objects have both ``output'' and ``trainOutput'' fields, because operations like dropout perform different
73 	computations at train and test time. Therefore, we construct two different loss symbols: one for optimising, and
74 	one for evaluating.
75 	*/
76     auto lossSym = crossEntropy(preds.trainOutput, labels) + network.paramLoss;
77 	auto testLossSym = crossEntropy(preds.output, labels) + network.paramLoss;
78 
79 	/*
80 	Now we set up an optimiser. Adam is good for proof of concepts, due to the fast convergence, however the
81 	performance of the final model is often slightly worse than that of a model trained with SGD+momentum.
82 	*/
83 	writeln("Creating optimiser...");
84 	auto learningRate = float32([], [0.0001f]);
85 	auto updater = amsgrad([lossSym, preds.trainOutput], network.params, network.paramProj, learningRate);
86 	auto testPlan = compile([testLossSym, preds.output]);
87 
88 	writeln("Training...");
89 
90 	float[] fs = new float[features.volume];
91 	float[] ls = new float[labels.volume];
92 
93 	//Iterate for 140 epochs of training!
94 	foreach(e; 0 .. 140)
95 	{
96 		float trainLoss = 0;
97         float testLoss = 0;
98         float trainAcc = 0;
99         float testAcc = 0;
100         float trainNum = 0;
101         float testNum = 0;
102 
103 		//Decreasing the learning rate after a while often results in better performance.
104 		if(e == 100)
105 		{
106 			learningRate.value.set([0.00001f]);
107 		}
108 		else if(e == 120)
109 		{
110 			learningRate.value.set([0.000001f]);
111 		}
112 
113 		data.train.restart();
114 		data.test.restart();
115 
116 		auto trainProgress = new Progress(data.train.length / batchSize);
117 
118 		while(!data.train.finished())
119 		{
120 			data.train.getBatch([fs, ls]);
121 
122 			//Make an update to the model parameters using the minibatch of training data
123 			auto res = updater([
124 				features: buffer(fs),
125 				labels: buffer(ls)
126 			]);
127 
128 			trainLoss += res[0].get!float[0] * batchSize;
129 			trainAcc += computeAccuracy(ls, res[1].get!float);
130 			trainNum += batchSize;
131 
132 			float loss = trainLoss / trainNum;
133 			float acc = trainAcc / trainNum;
134 
135 			trainProgress.title = format("Epoch: %03d  Loss: %02.4f  Acc: %.4f", e + 1, loss, acc);
136             trainProgress.next();
137 		}
138 
139 		writeln();
140 
141 		auto testProgress = new Progress(data.test.length / batchSize);
142 
143 		while(!data.test.finished())
144 		{
145 			data.test.getBatch([fs, ls]);
146 
147 			//Make some predictions
148 			auto res = testPlan.execute([
149 				features: buffer(fs),
150 				labels: buffer(ls)
151 			]);
152 
153 			testLoss += res[0].get!float[0] * batchSize;
154 			testAcc += computeAccuracy(ls, res[1].get!float);
155 			testNum += batchSize;
156 
157 			float loss = testLoss / testNum;
158 			float acc = testAcc / testNum;
159 
160 			testProgress.title = format("            Loss: %02.4f  Acc: %.4f", loss, acc);
161             testProgress.next();
162 		}
163 
164 		writeln();
165 		writeln();
166 	}
167 }
168 
169 float computeAccuracy(float[] ls, float[] preds)
170 {
171 	import std.algorithm : maxIndex;
172 	import std.range : chunks, zip;
173 
174 	float correct = 0;
175 
176 	foreach(p, t; zip(preds.chunks(10), ls.chunks(10)))
177 	{
178 		if(p.maxIndex == t.maxIndex)
179 		{
180 			correct++;
181 		}
182 	}
183 
184 	return correct;
185 }