1 module dopt.nnet.data;
2 
3 public
4 {
5     import dopt.nnet.data.cifar;
6     import dopt.nnet.data.imagetransformer;
7     import dopt.nnet.data.mnist;
8     import dopt.nnet.data.sins;
9     import dopt.nnet.data.svhn;
10 }
11 
12 import std.exception : enforce;
13 
14 interface BatchIterator
15 {
16     size_t[][] shape();
17     size_t[] volume();
18     size_t length();
19     void getBatch(float[][] batchData);
20     bool finished();
21     void restart();
22 }
23 
24 /**
25     A $(D BatchIterator) specialisation for supervised learning tasks.
26 */
27 class SupervisedBatchIterator : BatchIterator
28 {
29     public
30     {
31         this(float[][] features, float[][] labels, size_t[][] shape, bool shuffle)
32         {
33             import std.algorithm : fold, map;
34             import std.array : array;
35             import std.range : iota;
36 
37             enforce(features.length == labels.length, "features.length != labels.length");
38             enforce(shape.length != 0, "shape.length == 0");
39 
40             mFeatures = features.dup;
41             mLabels = labels.dup;
42             mShape = shape.map!(x => x.dup).array;
43             mShuffle = shuffle;
44             mVolumes = shape
45                       .map!(y => y.fold!((a, b) => a * b))
46                       .array();
47 
48             mIndices = iota(0, mFeatures.length).array();
49         }
50 
51         size_t[][] shape()
52         {
53             return mShape;
54         }
55 
56         size_t[] volume()
57         {
58             return mVolumes;
59         }
60 
61         size_t length()
62         {
63             return mFeatures.length;
64         }
65 
66         bool finished()
67         {
68             return mFront >= length;
69         }
70 
71         void restart()
72         {
73             import std.random : randomShuffle;
74 
75             if(mShuffle)
76             {
77                 mIndices.randomShuffle();
78             }
79 
80             mFront = 0;
81         }
82 
83         void getBatch(float[][] batchData)
84         {
85             import std.algorithm : map, joiner, copy;
86             import std.range : drop, take;
87 
88             //Check the size of the arguments
89             enforce(batchData.length == 2, "SupervisedBatchIterator.getBatch expects two arrays to fill.");
90             enforce(batchData[0].length % volume[0] == 0, "batchData[0].length % volume[0] != 0");
91             enforce(batchData[1].length % volume[1] == 0, "batchData[1].length % volume[1] != 0");
92             enforce(batchData[0].length / volume[0] == batchData[1].length / volume[1],
93                 "batchData[0].length / volume[0] != batchData[1].length / volume[1]");
94             
95             size_t batchSize = batchData[0].length / volume[0];
96 
97             batchData[0][] = 0;
98             batchData[1][] = 0;
99 
100             mIndices.drop(mFront)
101                     .take(batchSize)
102                     .map!(x => mFeatures[x])
103                     .joiner()
104                     .copy(batchData[0]);
105             
106             mIndices.drop(mFront)
107                     .take(batchSize)
108                     .map!(x => mLabels[x])
109                     .joiner()
110                     .copy(batchData[1]);
111 
112             mFront += batchSize;
113         }
114     }
115 
116     protected
117     {
118         float[][] mFeatures;
119         float[][] mLabels;
120         size_t[][] mShape;
121         size_t[] mVolumes;
122         bool mShuffle;
123         size_t mFront;
124         size_t[] mIndices;
125     }
126 }