1 module dopt.nnet.data; 2 3 public 4 { 5 import dopt.nnet.data.cifar; 6 import dopt.nnet.data.imagetransformer; 7 import dopt.nnet.data.mnist; 8 import dopt.nnet.data.sins; 9 import dopt.nnet.data.svhn; 10 } 11 12 import std.exception : enforce; 13 14 interface BatchIterator 15 { 16 size_t[][] shape(); 17 size_t[] volume(); 18 size_t length(); 19 void getBatch(float[][] batchData); 20 bool finished(); 21 void restart(); 22 } 23 24 /** 25 A $(D BatchIterator) specialisation for supervised learning tasks. 26 */ 27 class SupervisedBatchIterator : BatchIterator 28 { 29 public 30 { 31 this(float[][] features, float[][] labels, size_t[][] shape, bool shuffle) 32 { 33 import std.algorithm : fold, map; 34 import std.array : array; 35 import std.range : iota; 36 37 enforce(features.length == labels.length, "features.length != labels.length"); 38 enforce(shape.length != 0, "shape.length == 0"); 39 40 mFeatures = features.dup; 41 mLabels = labels.dup; 42 mShape = shape.map!(x => x.dup).array; 43 mShuffle = shuffle; 44 mVolumes = shape 45 .map!(y => y.fold!((a, b) => a * b)) 46 .array(); 47 48 mIndices = iota(0, mFeatures.length).array(); 49 } 50 51 size_t[][] shape() 52 { 53 return mShape; 54 } 55 56 size_t[] volume() 57 { 58 return mVolumes; 59 } 60 61 size_t length() 62 { 63 return mFeatures.length; 64 } 65 66 bool finished() 67 { 68 return mFront >= length; 69 } 70 71 void restart() 72 { 73 import std.random : randomShuffle; 74 75 if(mShuffle) 76 { 77 mIndices.randomShuffle(); 78 } 79 80 mFront = 0; 81 } 82 83 void getBatch(float[][] batchData) 84 { 85 import std.algorithm : map, joiner, copy; 86 import std.range : drop, take; 87 88 //Check the size of the arguments 89 enforce(batchData.length == 2, "SupervisedBatchIterator.getBatch expects two arrays to fill."); 90 enforce(batchData[0].length % volume[0] == 0, "batchData[0].length % volume[0] != 0"); 91 enforce(batchData[1].length % volume[1] == 0, "batchData[1].length % volume[1] != 0"); 92 enforce(batchData[0].length / volume[0] == batchData[1].length / volume[1], 93 "batchData[0].length / volume[0] != batchData[1].length / volume[1]"); 94 95 size_t batchSize = batchData[0].length / volume[0]; 96 97 batchData[0][] = 0; 98 batchData[1][] = 0; 99 100 mIndices.drop(mFront) 101 .take(batchSize) 102 .map!(x => mFeatures[x]) 103 .joiner() 104 .copy(batchData[0]); 105 106 mIndices.drop(mFront) 107 .take(batchSize) 108 .map!(x => mLabels[x]) 109 .joiner() 110 .copy(batchData[1]); 111 112 mFront += batchSize; 113 } 114 } 115 116 protected 117 { 118 float[][] mFeatures; 119 float[][] mLabels; 120 size_t[][] mShape; 121 size_t[] mVolumes; 122 bool mShuffle; 123 size_t mFront; 124 size_t[] mIndices; 125 } 126 }