1 module dopt.nnet.data.imagetransformer;
2 
3 import dopt.nnet.data;
4 
5 class ImageTransformer : BatchIterator
6 {
7     public
8     {
9         this(BatchIterator imageDataset, size_t jitterX, size_t jitterY, bool flipX, bool flipY, size_t[] folds = [0])
10         {
11             mDataset = imageDataset;
12             mJitterX = jitterX;
13             mJitterY = jitterY;
14             mFlipX = flipX;
15             mFlipY = flipY;
16             auto shape = this.shape()[0];
17             mPadded = new float[shape[0] * (shape[1] + 2 * jitterY) * (shape[2] + 2 * jitterX)];
18         }
19 
20         size_t[][] shape()
21         {
22             return mDataset.shape();
23         }
24 
25         size_t[] volume()
26         {
27             return mDataset.volume();
28         }
29 
30         size_t length()
31         {
32             return mDataset.length;
33         }
34 
35         bool finished()
36         {
37             return mDataset.finished();
38         }
39 
40         void restart()
41         {
42             mDataset.restart();
43         }
44 
45         void getBatch(float[][] batchData)
46         {
47             import std.algorithm : canFind, reverse;
48             import std.random : uniform;
49             import std.range : chunks, drop, stride, take;
50 
51             mDataset.getBatch(batchData);
52 
53             auto shape = this.shape()[0];
54 
55             size_t pw = shape[2] + 2 * mJitterX;
56             size_t ph = shape[1] + 2 * mJitterY;
57 
58             foreach(img; batchData[0].chunks(volume[0]))
59             {
60                 if(mJitterX != 0 || mJitterY != 0)
61                 {
62                     //Pad the image. The extra content around the border will be filled with reflected image.
63                     for(size_t c = 0; c < shape[0]; c++)
64                     {
65                         for(size_t y = 0; y < shape[1]; y++)
66                         {
67                             for(size_t x = 0; x < shape[2]; x++)
68                             {
69                                 mPadded[c * ph * pw + (y + mJitterY) * pw + x + mJitterX] =
70                                     img[c * shape[1] * shape[2] + y * shape[2] + x];
71                             }
72 
73                             if(mJitterX != 0)
74                             {
75                                 size_t o = c * ph * pw + (y + mJitterY) * pw;
76                                 mPadded[o .. o + mJitterX] = mPadded[o + mJitterX .. o + 2 * mJitterX];
77                                 mPadded[o .. o + mJitterX].reverse();
78 
79                                 o += shape[2];
80                                 mPadded[o + mJitterX .. o + 2 * mJitterX] = mPadded[o .. o + mJitterX];
81                                 mPadded[o + mJitterX .. o + 2 * mJitterX].reverse();
82                             }
83                         }
84 
85                         for(size_t y = 0; y < mJitterY; y++)
86                         {
87                             size_t o = c * pw * ph;
88 
89                             //Pad the top rows
90                             mPadded[o + y * pw .. o + (y + 1) * pw] =
91                                 mPadded[o + (2 * mJitterY - y - 1) * pw .. o + (2 * mJitterY - y) * pw];
92                             
93                             //Pad the bottom rows
94                             mPadded[o + (ph - y - 1) * pw .. o + (ph - y) * pw] =
95                                 mPadded[o + (ph - 2 * mJitterY + y) * pw .. o + (ph - 2 * mJitterY + y + 1) * pw];
96                         }
97                     }
98 
99                     size_t xOff = uniform(0, mJitterX * 2);
100                     size_t yOff = uniform(0, mJitterY * 2);
101 
102                     //Crop the padded image
103                     for(size_t c = 0; c < shape[0]; c++)
104                     {
105                         for(size_t y = 0; y < shape[1]; y++)
106                         {
107                             for(size_t x = 0; x < shape[2]; x++)
108                             {
109                                 img[c * shape[1] * shape[2] + y * shape[2] + x] =
110                                     mPadded[c * ph * pw + (y + yOff) * pw + x + xOff];
111                             }
112                         }
113                     }
114                 }
115 
116                 if(mFlipX && uniform(0.0f, 1.0f) < 0.5f)
117                 {
118                     foreach(row; img.chunks(shape[2]))
119                     {
120                         row.reverse();
121                     }
122                 }
123 
124                 if(mFlipY && uniform(0.0f, 1.0f) < 0.5f)
125                 {
126                     for(size_t c = 0; c < shape[0]; c++)
127                     {
128                         for(size_t x = 0; x < shape[2]; x++)
129                         {
130                             img.drop(c * shape[1] * shape[2] + x)
131                                 .stride(shape[2])
132                                 .take(shape[1])
133                                 .reverse();
134                         }
135                     }
136                 }
137             }
138         }
139     }
140 
141     protected
142     {
143         BatchIterator mDataset;
144         size_t mJitterX;
145         size_t mJitterY;
146         bool mFlipX;
147         bool mFlipY;
148         float[] mPadded;
149     }
150 }