1 #!/usr/bin/env dub 2 /+ dub.sdl: 3 dependency "dopt" path=".." 4 dependency "progress-d" version="~>1.0.0" 5 +/ 6 module sins10; 7 8 import dopt.core; 9 import dopt.nnet; 10 import dopt.online; 11 import progress; 12 13 /** 14 This example trains a Wide Residual Network on the SINS-10 dataset. 15 16 The Wide Residual Network family of architectures are described in the paper "Wide Residual Networks" by 17 Sergey Zagoruyko and Nikos Komodakis, published the 2016 British Machine Vision Conference. 18 19 The SINS-10 dataset is designed to enable practical significance testing for deep learning experiments. It can be 20 downloaded from https://www.cs.waikato.ac.nz/~ml/sins10/ 21 */ 22 void main(string[] args) 23 { 24 import std.algorithm : joiner; 25 import std.array : array; 26 import std.conv : to; 27 import std.format : format; 28 import std.range : zip, chunks; 29 import std.stdio : stderr, stdout, write, writeln; 30 31 if(args.length != 3) 32 { 33 stderr.writeln("Usage: ./sins10.d <data directory> <fold index>"); 34 return; 35 } 36 37 writeln("Loading data..."); 38 size_t fold = args[2].to!size_t; 39 auto data = loadSINS10(args[1])[fold]; 40 data.train = new ImageTransformer(data.train, 24, 24, true, false); 41 42 writeln("Constructing network graph..."); 43 size_t batchSize = 50; 44 auto features = float32([batchSize, 3, 96, 96]); 45 auto labels = float32([batchSize, 10]); 46 47 auto opts = new WRNOptions(); 48 opts.stride = [2, 2, 2]; 49 50 auto preds = wideResNet(features, 16, 8, opts) 51 .dense(10) 52 .softmax(); 53 54 auto network = new DAGNetwork([features], [preds]); 55 56 auto lossSym = crossEntropy(preds.trainOutput, labels) + network.paramLoss; 57 auto testLossSym = crossEntropy(preds.output, labels) + network.paramLoss; 58 59 writeln("Creating optimiser..."); 60 auto learningRate = float32(0.1f); 61 auto momentumRate = float32(0.9); 62 auto updater = sgd([lossSym, preds.trainOutput], network.params, network.paramProj, learningRate, momentumRate); 63 auto testPlan = compile([testLossSym, preds.output]); 64 65 writeln("Training..."); 66 67 float[] fs = new float[features.volume]; 68 float[] ls = new float[labels.volume]; 69 size_t bidx; 70 71 //Iterate for 160 epochs of training! 72 foreach(e; 0 .. 90) 73 { 74 float trainLoss = 0; 75 float testLoss = 0; 76 float trainAcc = 0; 77 float testAcc = 0; 78 float trainNum = 0; 79 float testNum = 0; 80 81 //Decreasing the learning rate after a while often results in better performance. 82 if(e == 60) 83 { 84 learningRate.value.set([0.02f]); 85 } 86 else if(e == 80) 87 { 88 learningRate.value.set([0.004f]); 89 } 90 91 data.train.restart(); 92 data.test.restart(); 93 94 auto trainProgress = new Progress(data.train.length / batchSize); 95 96 while(!data.train.finished()) 97 { 98 //Get the next batch of training data (put into [fs, ls]). Update bidx with the next batch index. 99 data.train.getBatch([fs, ls]); 100 101 //Make an update to the model parameters using the minibatch of training data 102 auto res = updater([ 103 features: buffer(fs), 104 labels: buffer(ls) 105 ]); 106 107 trainLoss += res[0].get!float[0] * batchSize; 108 trainAcc += computeAccuracy(ls, res[1].get!float); 109 trainNum += batchSize; 110 111 float loss = trainLoss / trainNum; 112 float acc = trainAcc / trainNum; 113 114 trainProgress.title = format("Epoch: %03d Loss: %02.4f Acc: %.4f", e + 1, loss, acc); 115 trainProgress.next(); 116 } 117 118 writeln(); 119 120 auto testProgress = new Progress(data.test.length / batchSize); 121 122 while(!data.test.finished) 123 { 124 //Get the next batch of testing data 125 data.test.getBatch([fs, ls]); 126 127 //Make some predictions 128 auto res = testPlan.execute([ 129 features: buffer(fs), 130 labels: buffer(ls) 131 ]); 132 133 testLoss += res[0].get!float[0] * batchSize; 134 testAcc += computeAccuracy(ls, res[1].get!float); 135 testNum += batchSize; 136 137 float loss = testLoss / testNum; 138 float acc = testAcc / testNum; 139 140 testProgress.title = format(" Loss: %02.4f Acc: %.4f", loss, acc); 141 testProgress.next(); 142 } 143 144 writeln(); 145 writeln(); 146 } 147 } 148 149 float computeAccuracy(float[] ls, float[] preds) 150 { 151 import std.algorithm : maxIndex; 152 import std.range : chunks, zip; 153 154 float correct = 0; 155 156 foreach(p, t; zip(preds.chunks(10), ls.chunks(10))) 157 { 158 if(p.maxIndex == t.maxIndex) 159 { 160 correct++; 161 } 162 } 163 164 return correct; 165 }