1 #!/usr/bin/env dub
2 /+ dub.sdl:
3 dependency "dopt" path=".."
4 dependency "progress-d" version="~>1.0.0"
5 +/
6 module cifar10;
7 
8 import dopt.core;
9 import dopt.nnet;
10 import dopt.online;
11 import progress;
12 
13 /*
14 	This example trains a VGG19-style network on the CIFAR-10 dataset of tiny images.
15 
16 	VGG networks are fairly easy to understand, compared to some of the more recently presented models like GoogLeNet.
17 	See ``Very Deep Convolutional Networks for Large-Scale Image Recognition'' by Simonyan and Zisserman for more
18 	details. This example uses the dopt.nnet.models package to make defining a VGG model very easy.
19 
20 	The CIFAR-10 dataset contains 60,000 32x32 pixel colour images. Each of these images belongs to one of 10 classes.
21 	In the standard setting, 50,000 of these images are used for training a model, and the other 10,000 are used for
22 	evaluating how well the model works.
23 */
24 
25 void main(string[] args)
26 {
27 	import std.algorithm : joiner;
28 	import std.array : array;
29 	import std.format : format;
30 	import std.range : zip, chunks;
31 	import std.stdio : stderr, stdout, write, writeln;
32 
33 	if(args.length != 2)
34 	{
35 		stderr.writeln("Usage: cifar10.d <data directory>");
36 		return;
37 	}
38 
39 	/*
40 		Loads the CIFAR-10 dataset. Download this in the binary format from https://www.cs.toronto.edu/~kriz/cifar.html
41 
42 		This also wraps the Dataset in an ImageTransformer, which will procedurally generate random crops and
43 		horizontal flips of the training images---a popular form of data augmentation for image datasets.
44 	*/
45 	writeln("Loading data...");
46     auto data = new ImageTransformer(loadCIFAR10(args[1]), 4, 4, true, false);
47 
48 	/*
49 	Now we create two variable nodes. ``features'' is used to represent a minibatch of input images, and ``labels''
50 	will be used to represent the label corresponding to each of those images.
51 	*/
52 	writeln("Constructing network graph...");
53 	size_t batchSize = 100;
54     auto features = float32([batchSize, 3, 32, 32]);
55     auto labels = float32([batchSize, 10]);
56 
57 	/*
58 	There are a few predefined models in dopt.nnet.models, such as vgg19. We provide it with the variable we want to
59 	use as the input to this model, tell it what sizes the fully connected layers should be, and then put a softmax
60 	activation function on the end. The softmax function is the standard activation function when one is performing
61 	a classification task. The model is regularised using dropout, batch norm, and maxgain.
62 	*/
63     auto preds = vgg19(features, [512, 512], true, true, 3.0f)
64 				.dense(10, new DenseOptions().maxgain(3.0f))
65 				.softmax();
66     
67 	//The DAGNetwork class takes the inputs and outputs of a network and aggregates parameters in several different.
68     auto network = new DAGNetwork([features], [preds]);
69 
70 	/*
71 	Layer objects have both ``output'' and ``trainOutput'' fields, because operations like dropout perform different
72 	computations at train and test time. Therefore, we construct two different loss symbols: one for optimising, and
73 	one for evaluating.
74 	*/
75     auto lossSym = crossEntropy(preds.trainOutput, labels) + network.paramLoss;
76 	auto testLossSym = crossEntropy(preds.output, labels) + network.paramLoss;
77 
78 	/*
79 	Now we set up an optimiser. Adam is good for proof of concepts, due to the fast convergence, however the
80 	performance of the final model is often slightly worse than that of a model trained with SGD+momentum.
81 	*/
82 	writeln("Creating optimiser...");
83 	auto learningRate = float32([], [0.0001f]);
84 	auto updater = adam([lossSym, preds.trainOutput], network.params, network.paramProj, learningRate);
85 
86 	writeln("Training...");
87 
88 	float[] fs = new float[features.volume];
89 	float[] ls = new float[labels.volume];
90 	size_t bidx;
91 
92 	//Iterate for 120 epochs of training!
93 	foreach(e; 0 .. 120)
94 	{
95 		float trainLoss = 0;
96         float testLoss = 0;
97         float trainAcc = 0;
98         float testAcc = 0;
99         float trainNum = 0;
100         float testNum = 0;
101 
102 		//Decreasing the learning rate after a while often results in better performance.
103 		if(e == 100)
104 		{
105 			learningRate.value.as!float[0] = 0.00001f;
106 		}
107 		else if(e == 120)
108 		{
109 			learningRate.value.as!float[0] = 0.000001f;
110 		}
111 
112 		auto trainProgress = new Progress(data.foldSize(0) / batchSize);
113 
114 		data.shuffle(0);
115 
116 		do
117 		{
118 			//Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index.
119 			bidx = data.getBatch([fs, ls], bidx, 0);
120 
121 			//Make an update to the model parameters using the minibatch of training data
122 			auto res = updater([
123 				features: Buffer(fs),
124 				labels: Buffer(ls)
125 			]);
126 
127 			trainLoss += res[0].as!float[0] * batchSize;
128 			trainAcc += computeAccuracy(ls, res[1].as!float);
129 			trainNum += batchSize;
130 
131 			float loss = trainLoss / trainNum;
132 			float acc = trainAcc / trainNum;
133 
134 			trainProgress.title = format("Epoch: %03d  Loss: %02.4f  Acc: %.4f", e + 1, loss, acc);
135             trainProgress.next();
136 		}
137 		while(bidx != 0);
138 
139 		writeln();
140 
141 		auto testProgress = new Progress(data.foldSize(1) / batchSize);
142 
143 		do
144 		{
145 			//Get the next batch of testing data
146 			bidx = data.getBatch([fs, ls], bidx, 1);
147 
148 			//Make some predictions
149 			auto res = evaluate([testLossSym, preds.output], [
150 				features: Buffer(fs),
151 				labels: Buffer(ls)
152 			]);
153 
154 			testLoss += res[0].as!float[0] * batchSize;
155 			testAcc += computeAccuracy(ls, res[1].as!float);
156 			testNum += batchSize;
157 
158 			float loss = testLoss / testNum;
159 			float acc = testAcc / testNum;
160 
161 			testProgress.title = format("            Loss: %02.4f  Acc: %.4f", loss, acc);
162             testProgress.next();
163 		}
164 		while(bidx != 0);
165 
166 		writeln();
167 		writeln();
168 	}
169 }
170 
171 float computeAccuracy(float[] ls, float[] preds)
172 {
173 	import std.algorithm : maxIndex;
174 	import std.range : chunks, zip;
175 
176 	float correct = 0;
177 
178 	foreach(p, t; zip(preds.chunks(10), ls.chunks(10)))
179 	{
180 		if(p.maxIndex == t.maxIndex)
181 		{
182 			correct++;
183 		}
184 	}
185 
186 	return correct;
187 }