1 #!/usr/bin/env dub 2 /+ dub.sdl: 3 dependency "dopt" path=".." 4 +/ 5 module mnist; 6 7 /* 8 This example trains a small convolutional network on the MNIST dataset of hand-written images. The network used is 9 very small by today's standards, but MNIST is a very easy dataset so this does not really matter. 10 11 The MNIST dataset contains small monochrome images of hand-written digits, and the goal is to determine which digit 12 each image contains. 13 */ 14 void main(string[] args) 15 { 16 import std.algorithm : joiner, maxIndex; 17 import std.array : array; 18 import std.range : zip, chunks; 19 import std.stdio : stderr, writeln; 20 21 import dopt.core; 22 import dopt.nnet; 23 import dopt.online; 24 25 if(args.length != 2) 26 { 27 stderr.writeln("Usage: mnist.d <data directory>"); 28 return; 29 } 30 31 //Load the minst dataset of hand-written digits. Download the binary files from http://yann.lecun.com/exdb/mnist/ 32 auto data = loadMNIST(args[1]); 33 34 //Create the variables nodes required to pass data into the operation graph 35 size_t batchSize = 100; 36 auto features = float32([batchSize, 1, 28, 28]); 37 auto labels = float32([batchSize, 10]); 38 39 //Construct a small convolutional network 40 auto preds = dataSource(features) 41 .conv2D(32, [5, 5]) 42 .relu() 43 .maxPool([2, 2]) 44 .conv2D(32, [5, 5]) 45 .relu() 46 .maxPool([2, 2]) 47 .dense(10) 48 .softmax(); 49 50 //Construct the DAGNetwork object that can be used to collate all the parameters and loss terms 51 auto network = new DAGNetwork([features], [preds]); 52 53 //Create a symbol to represent the training loss function 54 auto lossSym = crossEntropy(preds.trainOutput, labels) + network.paramLoss; 55 56 //Create an optimiser that can use minibatches of labelled data to update the weights of the network 57 auto lr = float32([], [0.001f]); 58 auto updater = adam([lossSym], network.params, network.paramProj); 59 60 size_t bidx; 61 float[] fs = new float[features.volume]; 62 float[] ls = new float[labels.volume]; 63 64 //Iterate for 40 epochs of training 65 foreach(e; 0 .. 40) 66 { 67 float totloss = 0; 68 float tot = 0; 69 70 if(e == 30) 71 { 72 lr.value.as!float[0] = 0.0001f; 73 } 74 75 data.shuffle(0); 76 77 do 78 { 79 //Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index. 80 bidx = data.getBatch([fs, ls], bidx, 0); 81 82 auto loss = updater([ 83 features: Buffer(fs), 84 labels: Buffer(ls) 85 ]); 86 87 totloss += loss[0].as!float[0]; 88 tot++; 89 } 90 while(bidx != 0); 91 92 //Write out the training loss for this epoch 93 writeln(e, ": ", totloss / tot); 94 } 95 96 int correct; 97 int total; 98 99 import std.stdio : writeln; 100 101 do 102 { 103 //Get the next batch of test data (put into [fs, ls]). Update bidx with the next batch index. 104 bidx = data.getBatch([fs, ls], bidx, 1); 105 106 //Make some predictions for this minibatch 107 auto pred = network.outputs[0].evaluate([ 108 features: Buffer(fs) 109 ]).as!float; 110 111 //Determine the accuracy of these predictions using the ground truth data 112 foreach(p, t; zip(pred.chunks(10), ls.chunks(10))) 113 { 114 if(p.maxIndex == t.maxIndex) 115 { 116 correct++; 117 } 118 119 total++; 120 } 121 } 122 while(bidx != 0); 123 124 //Write out the accuracy of the model on the test set 125 writeln(correct / cast(float)total); 126 }