1 #!/usr/bin/env dub
2 /+ dub.sdl:
3 dependency "dopt" path=".."
4 dependency "progress-d" version="~>1.0.0"
5 +/
6 module sins10;
7 
8 import dopt.core;
9 import dopt.nnet;
10 import dopt.online;
11 import progress;
12 
13 /**
14 	This example trains a Wide Residual Network on the SINS-10 dataset.
15 
16 	The Wide Residual Network family of architectures are described in the paper "Wide Residual Networks" by
17 	Sergey Zagoruyko and Nikos Komodakis, published the 2016 British Machine Vision Conference.
18 
19 	The SINS-10 dataset is designed to enable practical significance testing for deep learning experiments. It can be
20 	downloaded from https://www.cs.waikato.ac.nz/~ml/sins10/
21 */
22 void main(string[] args)
23 {
24     import std.algorithm : joiner;
25 	import std.array : array;
26     import std.conv : to;
27 	import std.format : format;
28 	import std.range : zip, chunks;
29 	import std.stdio : stderr, stdout, write, writeln;
30 
31 	if(args.length != 3)
32 	{
33 		stderr.writeln("Usage: ./sins10.d <data directory> <fold index>");
34 		return;
35 	}
36 
37     writeln("Loading data...");
38     size_t fold = args[2].to!size_t;
39     auto data = loadSINS10(args[1])[fold];
40 	data.train = new ImageTransformer(data.train, 24, 24, true, false);
41 
42     writeln("Constructing network graph...");
43 	size_t batchSize = 50;
44     auto features = float32([batchSize, 3, 96, 96]);
45     auto labels = float32([batchSize, 10]);
46 
47 	auto opts = new WRNOptions();
48 	opts.stride = [2, 2, 2];
49 	
50     auto preds = wideResNet(features, 16, 8, opts)
51                 .dense(10)
52                 .softmax();
53 
54     auto network = new DAGNetwork([features], [preds]);
55 
56     auto lossSym = crossEntropy(preds.trainOutput, labels) + network.paramLoss;
57 	auto testLossSym = crossEntropy(preds.output, labels) + network.paramLoss;
58 
59     writeln("Creating optimiser...");
60 	auto learningRate = float32(0.1f);
61 	auto momentumRate = float32(0.9);
62 	auto updater = sgd([lossSym, preds.trainOutput], network.params, network.paramProj, learningRate, momentumRate);
63 	auto testPlan = compile([testLossSym, preds.output]);
64 
65 	writeln("Training...");
66 
67 	float[] fs = new float[features.volume];
68 	float[] ls = new float[labels.volume];
69 	size_t bidx;
70 
71 	//Iterate for 160 epochs of training!
72 	foreach(e; 0 .. 90)
73 	{
74 		float trainLoss = 0;
75         float testLoss = 0;
76         float trainAcc = 0;
77         float testAcc = 0;
78         float trainNum = 0;
79         float testNum = 0;
80 
81 		//Decreasing the learning rate after a while often results in better performance.
82 		if(e == 60)
83 		{
84 			learningRate.value.set([0.02f]);
85 		}
86 		else if(e == 80)
87 		{
88 			learningRate.value.set([0.004f]);
89 		}
90 
91 		data.train.restart();
92 		data.test.restart();
93 
94 		auto trainProgress = new Progress(data.train.length / batchSize);
95 
96 		while(!data.train.finished())
97 		{
98 			//Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index.
99 			data.train.getBatch([fs, ls]);
100 
101 			//Make an update to the model parameters using the minibatch of training data
102 			auto res = updater([
103 				features: buffer(fs),
104 				labels: buffer(ls)
105 			]);
106 
107 			trainLoss += res[0].get!float[0] * batchSize;
108 			trainAcc += computeAccuracy(ls, res[1].get!float);
109 			trainNum += batchSize;
110 
111 			float loss = trainLoss / trainNum;
112 			float acc = trainAcc / trainNum;
113 
114 			trainProgress.title = format("Epoch: %03d  Loss: %02.4f  Acc: %.4f", e + 1, loss, acc);
115             trainProgress.next();
116 		}
117 
118 		writeln();
119 
120 		auto testProgress = new Progress(data.test.length / batchSize);
121 
122 		while(!data.test.finished)
123 		{
124 			//Get the next batch of testing data
125 			data.test.getBatch([fs, ls]);
126 
127 			//Make some predictions
128 			auto res = testPlan.execute([
129 				features: buffer(fs),
130 				labels: buffer(ls)
131 			]);
132 
133 			testLoss += res[0].get!float[0] * batchSize;
134 			testAcc += computeAccuracy(ls, res[1].get!float);
135 			testNum += batchSize;
136 
137 			float loss = testLoss / testNum;
138 			float acc = testAcc / testNum;
139 
140 			testProgress.title = format("            Loss: %02.4f  Acc: %.4f", loss, acc);
141             testProgress.next();
142 		}
143 
144 		writeln();
145 		writeln();
146 	}
147 }
148 
149 float computeAccuracy(float[] ls, float[] preds)
150 {
151 	import std.algorithm : maxIndex;
152 	import std.range : chunks, zip;
153 
154 	float correct = 0;
155 
156 	foreach(p, t; zip(preds.chunks(10), ls.chunks(10)))
157 	{
158 		if(p.maxIndex == t.maxIndex)
159 		{
160 			correct++;
161 		}
162 	}
163 
164 	return correct;
165 }