1 /**
2     Contains generic utilities for working with $(D Layer) objects.
3     
4     Authors: Henry Gouk
5 */
6 module dopt.nnet.layers;
7 
8 import dopt.core;
9 import dopt.nnet;
10 import dopt.nnet.layers.util;
11 import dopt.online;
12 
13 public
14 {
15     import dopt.nnet.layers.batchnorm;
16     import dopt.nnet.layers.conv;
17     import dopt.nnet.layers.datasource;
18     import dopt.nnet.layers.dense;
19     import dopt.nnet.layers.dropout;
20     import dopt.nnet.layers.maxpool;
21     import dopt.nnet.layers.relu;
22     import dopt.nnet.layers.softmax;
23 }
24 
25 /**
26     Encapsulates the expressions and parameter information that defines a network layer.
27 */
28 class Layer
29 {
30     public
31     {
32         /**
33             Constructs a new layer.
34 
35             Params:
36                 deps = Other $(D Layer) objects that this layer depends on.
37                 outExpr = The output expression to use at test time.
38                 trainOutExpr = The output expression to use at train time.
39                 params = Any parameters managed by this layer.
40         */
41         this(Layer[] deps, Operation outExpr, Operation trainOutExpr, Parameter[] params)
42         {
43             mDeps = deps.dup;
44             mParams = params.dup;
45             mOutput = outExpr;
46             mTrainOutput = trainOutExpr;
47         }
48 
49         Layer[] deps()
50         {
51             return mDeps.dup;
52         }
53 
54         Parameter[] params()
55         {
56             return mParams;
57         }
58 
59         Operation output()
60         {
61             return mOutput;
62         }
63 
64         Operation trainOutput()
65         {
66             return mTrainOutput;
67         }
68     }
69 
70     private
71     {
72         Layer[] mDeps;
73         Parameter[] mParams;
74         Operation mOutput;
75         Operation mTrainOutput;
76     }
77 }
78 
79 Layer[] topologicalSort(Layer[] ops)
80 {
81     Layer[] sortedOps;
82 
83     void toposort(Layer o)
84     {
85         import std.algorithm : canFind;
86 
87         if(sortedOps.canFind(o))
88         {
89             return;
90         }
91 
92         foreach(d; o.deps)
93         {
94             toposort(d);
95         }
96         
97         sortedOps ~= o;
98     }
99 
100     foreach(o; ops)
101     {
102         toposort(o);
103     }
104 
105     return sortedOps;
106 }