1 /**
2     Contains some utilities for constructing graphs for common loss functions.
3     
4     Authors: Henry Gouk
5 */
6 module dopt.nnet.losses;
7 
8 import dopt.core;
9 
10 /**
11     Creates a cross entropy loss term suitable for multiclass classification problems.
12 
13     It is assumed that the two input operations are rank-2 tensors, where the first dimension is an index into the
14     batch, and the second index is the index into the label probabilities.
15 
16     Params:
17         hypothesis = The predictions made by a model.
18         groundTruth = The true values for the labels, as provided by the training dataset.
19     
20     Returns:
21         An $(D Operation) representing the mean cross entropy loss.
22 */
23 Operation crossEntropy(Operation hypothesis, Operation groundTruth)
24 {
25     return sum(groundTruth * log(hypothesis + 1e-6f)) * (-1.0f / hypothesis.shape[0]);
26 }
27 
28 /**
29     Creates a squared error loss term suitable for regression (and multi-target regression) problems.
30 
31     Params:
32         hypothesis = The predictions made by the model.
33         groundTruth = The true values for the targets, as provided by the training dataset.
34 */
35 Operation squaredError(Operation hypothesis, Operation groundTruth)
36 {
37     auto diff = hypothesis - groundTruth;
38 
39     return sum(diff * diff) * (1.0f / hypothesis.shape[0]);
40 }