1 /** 2 Contains some utilities for constructing graphs for common loss functions. 3 4 Authors: Henry Gouk 5 */ 6 module dopt.nnet.losses; 7 8 import dopt.core; 9 10 /** 11 Creates a cross entropy loss term suitable for multiclass classification problems. 12 13 It is assumed that the two input operations are rank-2 tensors, where the first dimension is an index into the 14 batch, and the second index is the index into the label probabilities. 15 16 Params: 17 hypothesis = The predictions made by a model. 18 groundTruth = The true values for the labels, as provided by the training dataset. 19 20 Returns: 21 An $(D Operation) representing the mean cross entropy loss. 22 */ 23 Operation crossEntropy(Operation hypothesis, Operation groundTruth) 24 { 25 return sum(groundTruth * log(hypothesis + 1e-6f)) * (-1.0f / hypothesis.shape[0]); 26 } 27 28 /** 29 Creates a squared error loss term suitable for regression (and multi-target regression) problems. 30 31 Params: 32 hypothesis = The predictions made by the model. 33 groundTruth = The true values for the targets, as provided by the training dataset. 34 */ 35 Operation squaredError(Operation hypothesis, Operation groundTruth) 36 { 37 auto diff = hypothesis - groundTruth; 38 39 return sum(diff * diff) * (1.0f / hypothesis.shape[0]); 40 }