1 #!/usr/bin/env dub
2 /+ dub.sdl:
3 dependency "dopt" path=".."
4 +/
5 module mnist;
6 
7 /*
8 	This example trains a small convolutional network on the MNIST dataset of hand-written images. The network used is
9 	very small by today's standards, but MNIST is a very easy dataset so this does not really matter.
10 
11 	The MNIST dataset contains small monochrome images of hand-written digits, and the goal is to determine which digit
12 	each image contains.
13 */
14 void main(string[] args)
15 {
16 	import std.algorithm : joiner, maxIndex;
17 	import std.array : array;
18 	import std.range : zip, chunks;
19 	import std.stdio : stderr, writeln;
20 	
21 	import dopt.core;
22 	import dopt.nnet;
23 	import dopt.online;
24 
25 	if(args.length != 2)
26 	{
27 		stderr.writeln("Usage: mnist.d <data directory>");
28 		return;
29 	}
30 
31 	//Load the minst dataset of hand-written digits. Download the binary files from http://yann.lecun.com/exdb/mnist/
32     auto data = loadMNIST(args[1]);
33 
34 	//Create the variables nodes required to pass data into the operation graph
35 	size_t batchSize = 100;
36     auto features = float32([batchSize, 1, 28, 28]);
37 	auto labels = float32([batchSize, 10]);
38 
39 	//Construct a small convolutional network
40 	auto preds = dataSource(features)
41 				.conv2D(32, [5, 5])
42 				.relu()
43 				.maxPool([2, 2])
44 				.conv2D(32, [5, 5])
45 				.relu()
46 				.maxPool([2, 2])
47 				.dense(10)
48 				.softmax();
49 
50 	//Construct the DAGNetwork object that can be used to collate all the parameters and loss terms
51 	auto network = new DAGNetwork([features], [preds]);
52 
53 	//Create a symbol to represent the training loss function
54 	auto lossSym = crossEntropy(preds.trainOutput, labels) + network.paramLoss;
55 
56 	//Create an optimiser that can use minibatches of labelled data to update the weights of the network
57 	auto lr = float32([], [0.001f]);
58 	auto updater = adam([lossSym], network.params, network.paramProj);
59 
60 	size_t bidx;
61 	float[] fs = new float[features.volume];
62 	float[] ls = new float[labels.volume];
63 
64 	//Iterate for 40 epochs of training
65 	foreach(e; 0 .. 40)
66 	{
67 		float totloss = 0;
68 		float tot = 0;
69 
70 		if(e == 30)
71 		{
72 			lr.value.as!float[0] = 0.0001f;
73 		}
74 
75 		data.shuffle(0);
76 
77 		do
78 		{
79 			//Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index.
80 			bidx = data.getBatch([fs, ls], bidx, 0);
81 
82 			auto loss = updater([
83 				features: Buffer(fs),
84 				labels: Buffer(ls)
85 			]);
86 
87 			totloss += loss[0].as!float[0];
88 			tot++;
89 		}
90 		while(bidx != 0);
91 
92 		//Write out the training loss for this epoch
93 		writeln(e, ": ", totloss / tot);
94 	}
95 
96 	int correct;
97 	int total;
98 
99 	import std.stdio : writeln;
100 
101 	do
102 	{
103 		//Get the next batch of test data (put into [fs, ls]). Update bidx with the next batch index.
104 		bidx = data.getBatch([fs, ls], bidx, 1);
105 
106 		//Make some predictions for this minibatch
107 		auto pred = network.outputs[0].evaluate([
108 			features: Buffer(fs)
109 		]).as!float;
110 
111 		//Determine the accuracy of these predictions using the ground truth data
112 		foreach(p, t; zip(pred.chunks(10), ls.chunks(10)))
113 		{
114 			if(p.maxIndex == t.maxIndex)
115 			{
116 				correct++;
117 			}
118 
119 			total++;
120 		}
121 	}
122 	while(bidx != 0);
123 
124 	//Write out the accuracy of the model on the test set
125 	writeln(correct / cast(float)total);
126 }