1 /** 2 Contains an implementation of the regularisation techniques presented in Gouk et al. (2018). 3 4 Gouk, H., Frank, E., Pfahringer, B., & Cree, M. (2018). Regularisation of Neural Networks by Enforcing Lipschitz 5 Continuity. arXiv preprint arXiv:1804.04368. 6 7 Authors: Henry Gouk 8 */ 9 module dopt.nnet.lipschitz; 10 11 import std.exception; 12 13 import dopt.core; 14 import dopt.online; 15 16 /** 17 Returns a projection function that can be used to constrain a matrix norm. 18 19 The operator norm induced by the vector p-norm is used. 20 21 Params: 22 maxnorm = A scalar value indicating the maximum allowable operator norm. 23 p = The vector p-norm that will induce the operator norm. 24 25 Returns: 26 A projection function that can be used with the online optimisation algorithms. 27 */ 28 Projection projMatrix(Operation maxnorm, float p = 2) 29 { 30 Operation proj(Operation param) 31 { 32 auto norm = matrixNorm(param, p); 33 34 return maxNorm(param, norm, maxnorm); 35 } 36 37 return &proj; 38 } 39 40 /** 41 Computes the induced operator norm corresponding to the vector p-norm. 42 */ 43 Operation matrixNorm(Operation param, float p, size_t n = 2) 44 { 45 import std.exception : enforce; 46 47 enforce(param.rank == 2, "This function only operates on matrices"); 48 49 if(p == 1.0f) 50 { 51 /*if(param.rank != 2) 52 { 53 param = param.reshape([param.shape[0], param.volume / param.shape[0]]); 54 }*/ 55 56 /* 57 The matrix norm induced by the L1 vector norm is max Sum_i abs(a_ij), which is the maximum absolute column 58 sum. We are doing a row sum here because the weight matrices in dense and convolutional layers are 59 transposed before being multiplied with the input features. 60 */ 61 62 return param.abs().sum([1]).maxElement(); 63 } 64 else if(p == 2.0f) 65 { 66 auto x = uniformSample([param.shape[0], 1]) * 2.0f - 1.0f; 67 auto weightsT = param.transpose([1, 0]); 68 auto wwT = param.matmul(weightsT); 69 70 for(int i = 0; i < n; i++) 71 { 72 x = matmul(wwT, x); 73 } 74 75 auto v = x / sqrt(sum(x * x)); 76 auto y = matmul(weightsT, v); 77 78 return sqrt(sum(y * y)); 79 } 80 else if(p == float.infinity) 81 { 82 /* 83 The matrix norm induced by the L_infty vector norm is max Sum_j abs(a_ij), which is the maximum absolute 84 row sum. We are doing a column sum here because the weight matrices in dense and convolutional layers are 85 transposed before being multiplied with the input features. 86 */ 87 88 return param.abs().sum([0]).maxElement(); 89 } 90 else 91 { 92 import std.conv : to; 93 94 throw new Exception("Cannot compute matrix norm for p=" ~ p.to!string); 95 } 96 } 97 98 Projection projConvParams(Operation maxnorm, size_t[] inShape, size_t[] stride, size_t[] padding, float p = 2.0f) 99 { 100 Operation proj(Operation param) 101 { 102 auto norm = convParamsNorm(param, inShape, stride, padding, p); 103 104 return maxNorm(param, norm, maxnorm); 105 } 106 107 return &proj; 108 } 109 110 Operation convParamsNorm(Operation param, size_t[] inShape, size_t[] stride, size_t[] padding, float p = 2.0f, 111 size_t n = 2) 112 { 113 if(p == 2.0f) 114 { 115 auto x = uniformSample([cast(size_t)1, param.shape[1]] ~ inShape) * 2.0f - 1.0f; 116 117 for(int i = 0; i < n; i++) 118 { 119 x = x 120 .convolution(param, padding, stride) 121 .convolutionTranspose(param, padding, stride); 122 } 123 124 auto v = x / sqrt(sum(x * x)); 125 auto y = convolution(v, param, padding, stride); 126 127 return sqrt(sum(y * y)); 128 } 129 else if(p == 1.0f || p == float.infinity) 130 { 131 //Turns out this is equivalent, but only for $p \in \{1, infty\}$ 132 if(param.rank != 2) 133 { 134 param = param.reshape([param.shape[0], param.volume / param.shape[0]]); 135 } 136 137 return matrixNorm(param, p); 138 } 139 else 140 { 141 import std.conv : to; 142 143 throw new Exception("Cannot compute convolution params norm for p=" ~ p.to!string); 144 } 145 } 146 147 unittest 148 { 149 auto k = float32([1, 1, 3, 3], [ 150 1, 2, 3, 4, 5, 6, 7, 8, 9 151 ]); 152 153 auto norm = convParamsNorm(k, [200, 200], [1, 1], [1, 1], 2.0f); 154 155 import std.stdio; 156 writeln(norm.evaluate().get!float[0]); 157 } 158 159 /** 160 Performs a projection of param such that the new norm will be less than or equal to maxval. 161 */ 162 Operation maxNorm(Operation param, Operation norm, Operation maxval) 163 { 164 return param * (1.0f / max(float32([], [1.0f]), norm / maxval)); 165 }