1 module dopt.nnet.data.imagetransformer; 2 3 import dopt.nnet.data; 4 5 class ImageTransformer : BatchIterator 6 { 7 public 8 { 9 this(BatchIterator imageDataset, size_t jitterX, size_t jitterY, bool flipX, bool flipY, size_t[] folds = [0]) 10 { 11 mDataset = imageDataset; 12 mJitterX = jitterX; 13 mJitterY = jitterY; 14 mFlipX = flipX; 15 mFlipY = flipY; 16 auto shape = this.shape()[0]; 17 mPadded = new float[shape[0] * (shape[1] + 2 * jitterY) * (shape[2] + 2 * jitterX)]; 18 } 19 20 size_t[][] shape() 21 { 22 return mDataset.shape(); 23 } 24 25 size_t[] volume() 26 { 27 return mDataset.volume(); 28 } 29 30 size_t length() 31 { 32 return mDataset.length; 33 } 34 35 bool finished() 36 { 37 return mDataset.finished(); 38 } 39 40 void restart() 41 { 42 mDataset.restart(); 43 } 44 45 void getBatch(float[][] batchData) 46 { 47 import std.algorithm : canFind, reverse; 48 import std.random : uniform; 49 import std.range : chunks, drop, stride, take; 50 51 mDataset.getBatch(batchData); 52 53 auto shape = this.shape()[0]; 54 55 size_t pw = shape[2] + 2 * mJitterX; 56 size_t ph = shape[1] + 2 * mJitterY; 57 58 foreach(img; batchData[0].chunks(volume[0])) 59 { 60 if(mJitterX != 0 || mJitterY != 0) 61 { 62 //Pad the image. The extra content around the border will be filled with reflected image. 63 for(size_t c = 0; c < shape[0]; c++) 64 { 65 for(size_t y = 0; y < shape[1]; y++) 66 { 67 for(size_t x = 0; x < shape[2]; x++) 68 { 69 mPadded[c * ph * pw + (y + mJitterY) * pw + x + mJitterX] = 70 img[c * shape[1] * shape[2] + y * shape[2] + x]; 71 } 72 73 if(mJitterX != 0) 74 { 75 size_t o = c * ph * pw + (y + mJitterY) * pw; 76 mPadded[o .. o + mJitterX] = mPadded[o + mJitterX .. o + 2 * mJitterX]; 77 mPadded[o .. o + mJitterX].reverse(); 78 79 o += shape[2]; 80 mPadded[o + mJitterX .. o + 2 * mJitterX] = mPadded[o .. o + mJitterX]; 81 mPadded[o + mJitterX .. o + 2 * mJitterX].reverse(); 82 } 83 } 84 85 for(size_t y = 0; y < mJitterY; y++) 86 { 87 size_t o = c * pw * ph; 88 89 //Pad the top rows 90 mPadded[o + y * pw .. o + (y + 1) * pw] = 91 mPadded[o + (2 * mJitterY - y - 1) * pw .. o + (2 * mJitterY - y) * pw]; 92 93 //Pad the bottom rows 94 mPadded[o + (ph - y - 1) * pw .. o + (ph - y) * pw] = 95 mPadded[o + (ph - 2 * mJitterY + y) * pw .. o + (ph - 2 * mJitterY + y + 1) * pw]; 96 } 97 } 98 99 size_t xOff = uniform(0, mJitterX * 2); 100 size_t yOff = uniform(0, mJitterY * 2); 101 102 //Crop the padded image 103 for(size_t c = 0; c < shape[0]; c++) 104 { 105 for(size_t y = 0; y < shape[1]; y++) 106 { 107 for(size_t x = 0; x < shape[2]; x++) 108 { 109 img[c * shape[1] * shape[2] + y * shape[2] + x] = 110 mPadded[c * ph * pw + (y + yOff) * pw + x + xOff]; 111 } 112 } 113 } 114 } 115 116 if(mFlipX && uniform(0.0f, 1.0f) < 0.5f) 117 { 118 foreach(row; img.chunks(shape[2])) 119 { 120 row.reverse(); 121 } 122 } 123 124 if(mFlipY && uniform(0.0f, 1.0f) < 0.5f) 125 { 126 for(size_t c = 0; c < shape[0]; c++) 127 { 128 for(size_t x = 0; x < shape[2]; x++) 129 { 130 img.drop(c * shape[1] * shape[2] + x) 131 .stride(shape[2]) 132 .take(shape[1]) 133 .reverse(); 134 } 135 } 136 } 137 } 138 } 139 } 140 141 protected 142 { 143 BatchIterator mDataset; 144 size_t mJitterX; 145 size_t mJitterY; 146 bool mFlipX; 147 bool mFlipY; 148 float[] mPadded; 149 } 150 }