1 #!/usr/bin/env dub 2 /+ dub.sdl: 3 dependency "dopt" path=".." 4 +/ 5 module mnistlogit; 6 7 /* 8 This example trains a logistic regression model on the MNIST dataset of hand-written images. 9 10 The MNIST dataset contains small monochrome images of hand-written digits, and the goal is to determine which digit 11 each image contains. 12 */ 13 void main(string[] args) 14 { 15 import std.algorithm : joiner, maxIndex; 16 import std.array : array; 17 import std.range : zip, chunks; 18 import std.stdio : stderr, writeln; 19 20 import dopt.core; 21 import dopt.nnet; 22 import dopt.online; 23 24 if(args.length != 2) 25 { 26 stderr.writeln("Usage: mnist.d <data directory>"); 27 return; 28 } 29 30 //Load the minst dataset of hand-written digits. Download the binary files from http://yann.lecun.com/exdb/mnist/ 31 auto data = loadMNIST(args[1]); 32 33 //Create the variables nodes required to pass data into the operation graph 34 size_t batchSize = 100; 35 auto features = float32([batchSize, 28 * 28]); 36 auto labels = float32([batchSize, 10]); 37 38 //Construct a logistic regression model 39 auto W = float32([28 * 28, 10]); 40 auto b = float32([10]); 41 auto linComb = matmul(features, W) + b.repeat(batchSize); 42 auto numerator = exp(linComb - linComb.maxElement()); 43 auto denominator = numerator.sum([1]).reshape([batchSize, 1]).repeat([1, 10]); 44 auto preds = numerator / denominator; 45 46 //Create a symbol to represent the training loss function 47 auto lossSym = crossEntropy(preds, labels); 48 49 //Create an optimiser that can use minibatches of labelled data to update the parameters of the model 50 auto lr = float32([], [0.001f]); 51 auto updater = adam([lossSym], [W, b], null); 52 53 size_t bidx; 54 float[] fs = new float[features.volume]; 55 float[] ls = new float[labels.volume]; 56 57 //Iterate for 50 epochs of training 58 foreach(e; 0 .. 50) 59 { 60 float totloss = 0; 61 float tot = 0; 62 63 data.shuffle(0); 64 65 do 66 { 67 //Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index. 68 bidx = data.getBatch([fs, ls], bidx, 0); 69 70 auto loss = updater([ 71 features: Buffer(fs), 72 labels: Buffer(ls) 73 ]); 74 75 totloss += loss[0].as!float[0]; 76 tot++; 77 } 78 while(bidx != 0); 79 80 //Write out the training loss for this epoch 81 writeln(e, ": ", totloss / tot); 82 } 83 84 int correct; 85 int total; 86 87 import std.stdio : writeln; 88 89 do 90 { 91 //Get the next batch of test data (put into [fs, ls]). Update bidx with the next batch index. 92 bidx = data.getBatch([fs, ls], bidx, 1); 93 94 //Make some predictions for this minibatch 95 auto pred = evaluate([preds], [ 96 features: Buffer(fs) 97 ])[0].as!float; 98 99 //Determine the accuracy of these predictions using the ground truth data 100 foreach(p, t; zip(pred.chunks(10), ls.chunks(10))) 101 { 102 if(p.maxIndex == t.maxIndex) 103 { 104 correct++; 105 } 106 107 total++; 108 } 109 } 110 while(bidx != 0); 111 112 //Write out the accuracy of the model on the test set 113 writeln(correct / cast(float)total); 114 }