1 #!/usr/bin/env dub
2 /+ dub.sdl:
3 dependency "dopt" path=".."
4 +/
5 module mnistlogit;
6 
7 /*
8 	This example trains a logistic regression model on the MNIST dataset of hand-written images.
9 
10 	The MNIST dataset contains small monochrome images of hand-written digits, and the goal is to determine which digit
11 	each image contains.
12 */
13 void main(string[] args)
14 {
15 	import std.algorithm : joiner, maxIndex;
16 	import std.array : array;
17 	import std.range : zip, chunks;
18 	import std.stdio : stderr, writeln;
19 	
20 	import dopt.core;
21 	import dopt.nnet;
22 	import dopt.online;
23 
24 	if(args.length != 2)
25 	{
26 		stderr.writeln("Usage: mnist.d <data directory>");
27 		return;
28 	}
29 
30 	//Load the minst dataset of hand-written digits. Download the binary files from http://yann.lecun.com/exdb/mnist/
31     auto data = loadMNIST(args[1]);
32 
33 	//Create the variables nodes required to pass data into the operation graph
34 	size_t batchSize = 100;
35     auto features = float32([batchSize, 28 * 28]);
36 	auto labels = float32([batchSize, 10]);
37 
38 	//Construct a logistic regression model
39     auto W = float32([28 * 28, 10]);
40     auto b = float32([10]);
41     auto linComb = matmul(features, W) + b.repeat(batchSize);
42 	auto numerator = exp(linComb - linComb.maxElement());
43     auto denominator = numerator.sum([1]).reshape([batchSize, 1]).repeat([1, 10]);
44     auto preds = numerator / denominator;
45 
46 	//Create a symbol to represent the training loss function
47 	auto lossSym = crossEntropy(preds, labels);
48 
49 	//Create an optimiser that can use minibatches of labelled data to update the parameters of the model
50 	auto lr = float32([], [0.001f]);
51 	auto updater = adam([lossSym], [W, b], null);
52 
53 	size_t bidx;
54 	float[] fs = new float[features.volume];
55 	float[] ls = new float[labels.volume];
56 
57 	//Iterate for 50 epochs of training
58 	foreach(e; 0 .. 50)
59 	{
60 		float totloss = 0;
61 		float tot = 0;
62 
63 		data.shuffle(0);
64 
65 		do
66 		{
67 			//Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index.
68 			bidx = data.getBatch([fs, ls], bidx, 0);
69 
70 			auto loss = updater([
71 				features: Buffer(fs),
72 				labels: Buffer(ls)
73 			]);
74 
75 			totloss += loss[0].as!float[0];
76 			tot++;
77 		}
78 		while(bidx != 0);
79 
80 		//Write out the training loss for this epoch
81 		writeln(e, ": ", totloss / tot);
82 	}
83 
84 	int correct;
85 	int total;
86 
87 	import std.stdio : writeln;
88 
89 	do
90 	{
91 		//Get the next batch of test data (put into [fs, ls]). Update bidx with the next batch index.
92 		bidx = data.getBatch([fs, ls], bidx, 1);
93 
94 		//Make some predictions for this minibatch
95 		auto pred = evaluate([preds], [
96 			features: Buffer(fs)
97 		])[0].as!float;
98 
99 		//Determine the accuracy of these predictions using the ground truth data
100 		foreach(p, t; zip(pred.chunks(10), ls.chunks(10)))
101 		{
102 			if(p.maxIndex == t.maxIndex)
103 			{
104 				correct++;
105 			}
106 
107 			total++;
108 		}
109 	}
110 	while(bidx != 0);
111 
112 	//Write out the accuracy of the model on the test set
113 	writeln(correct / cast(float)total);
114 }