1 /**
2     Contains generic utilities for working with $(D Layer) objects.
3     
4     Authors: Henry Gouk
5 */
6 module dopt.nnet.layers;
7 
8 import dopt;
9 
10 public
11 {
12     import dopt.nnet.layers.batchnorm;
13     import dopt.nnet.layers.conv;
14     import dopt.nnet.layers.datasource;
15     import dopt.nnet.layers.dense;
16     import dopt.nnet.layers.dropout;
17     import dopt.nnet.layers.maxpool;
18     import dopt.nnet.layers.relu;
19     import dopt.nnet.layers.softmax;
20 }
21 
22 /**
23     Encapsulates the expressions and parameter information that defines a network layer.
24 */
25 class Layer
26 {
27     public
28     {
29         /**
30             Constructs a new layer.
31 
32             Params:
33                 deps = Other $(D Layer) objects that this layer depends on.
34                 outExpr = The output expression to use at test time.
35                 trainOutExpr = The output expression to use at train time.
36                 params = Any parameters managed by this layer.
37         */
38         this(Layer[] deps, Operation outExpr, Operation trainOutExpr, Parameter[] params)
39         {
40             mDeps = deps.dup;
41             mParams = params.dup;
42             mOutput = outExpr;
43             mTrainOutput = trainOutExpr;
44         }
45 
46         Layer[] deps()
47         {
48             return mDeps.dup;
49         }
50 
51         Parameter[] params()
52         {
53             return mParams.dup;
54         }
55 
56         Operation output()
57         {
58             return mOutput;
59         }
60 
61         Operation trainOutput()
62         {
63             return mTrainOutput;
64         }
65     }
66 
67     private
68     {
69         Layer[] mDeps;
70         Parameter[] mParams;
71         Operation mOutput;
72         Operation mTrainOutput;
73     }
74 }
75 
76 Layer[] topologicalSort(Layer[] ops)
77 {
78     Layer[] sortedOps;
79 
80     void toposort(Layer o)
81     {
82         import std.algorithm : canFind;
83 
84         if(sortedOps.canFind(o))
85         {
86             return;
87         }
88 
89         foreach(d; o.deps)
90         {
91             toposort(d);
92         }
93         
94         sortedOps ~= o;
95     }
96 
97     foreach(o; ops)
98     {
99         toposort(o);
100     }
101 
102     return sortedOps;
103 }