1 /**
2     Provides a useful tools for constructing neural networks.
3 
4     Currently only directed acyclic graphs are supported.
5 
6     Authors: Henry Gouk
7 */
8 module dopt.nnet.networks;
9 
10 import std.algorithm;
11 import std.array;
12 
13 import dopt.core;
14 import dopt.nnet;
15 import dopt.online;
16 
17 /**
18     Encapsulates the details of a network with a directed acyclic graph structure.
19 
20     This class does not provide facilities to actually train the network---that can be accomplished with the 
21     $(D dopt.online) package.
22 */
23 class DAGNetwork
24 {
25     public
26     {
27         /**
28             Construct a DAGNetwork with the given inputs and outputs.
29 
30             Params:
31                 inputs = The inputs to the network. This will usually contain a single $(D Operation) representing a
32                 batch of feature vectors.
33                 outputs = The outputs (i.e., predictions) of the network.
34         */
35         this(Operation[] inputs, Layer[] outputs)
36         {
37             mInputs = inputs.dup;
38             mOutputs = outputs.map!(x => x.output).array();
39             mTrainOutputs = outputs.map!(x => x.trainOutput).array();
40 
41             auto layers = topologicalSort(outputs);
42             auto paramsinfo = layers.map!(x => x.params).joiner().array();
43             mParams = paramsinfo.map!(x => x.symbol).array();
44 
45             foreach(p; paramsinfo)
46             {
47                 if(p.loss !is null)
48                 {
49                     if(mParameterLoss is null)
50                     {
51                         mParameterLoss = p.loss;
52                     }
53                     else
54                     {
55                         mParameterLoss = mParameterLoss + p.loss;
56                     }
57                 }
58 
59                 if(p.projection !is null)
60                 {
61                     mParameterProj[p.symbol] = p.projection;
62                 }
63             }
64 
65             if(mParameterLoss is null)
66             {
67                 //Prevents an annoying-to-debug segfault in user code when there are no param loss terms
68                 mParameterLoss = float32([], [0.0f]);
69             }
70         }
71 
72         /**
73             The inputs provided when the $(D DAGNetwork) was constructed.
74         */
75         Operation[] inputs()
76         {
77             return mInputs.dup;
78         }
79 
80         /**
81             The $(D Operation) objects produced by the output layers provided during construction.
82         */
83         Operation[] outputs()
84         {
85             return mOutputs.dup;
86         }
87 
88         /**
89             Separate $(D Operation) objects produced by the output layers provided during constructions.
90 
91             These should be used when creating the network optimiser.
92         */
93         Operation[] trainOutputs()
94         {
95             return mTrainOutputs.dup;
96         }
97 
98         /**
99             The sum of all the parameter loss terms.
100 
101             This will include all the L2 weight decay terms.
102         */
103         Operation paramLoss()
104         {
105             return mParameterLoss;
106         }
107 
108         /**
109             An associative array of projection operations that should be applied to parameters during optimisation.
110         */
111         Projection[Operation] paramProj()
112         {
113             return mParameterProj;
114         }
115 
116         /**
117             An array of all the $(D Operation) nodes in the graph representing network parameters.
118         */
119         Operation[] params()
120         {
121             return mParams.dup;
122         }
123     }
124 
125     private
126     {
127         Operation[] mInputs;
128         Operation[] mOutputs;
129         Operation[] mTrainOutputs;
130         Operation[] mParams;
131         Operation mParameterLoss;
132         Projection[Operation] mParameterProj;
133     }
134 }