1 module dopt.core.grads.math; 2 3 import dopt.core.grads; 4 import dopt.core.ops; 5 6 package 7 { 8 void initialize() 9 { 10 import std.functional; 11 12 string createRegisterGradientCalls() 13 { 14 auto ops = ["add", "sub", "mul", "div", "pow", "min", "max", 15 "neg", "abs", "exp", "log", "sqrt"]; 16 17 string ret; 18 19 foreach(o; ops) 20 { 21 ret ~= "registerGradient(\"" ~ o ~ "\", toDelegate(&" ~ o ~ "Grad));\n"; 22 } 23 24 return ret; 25 } 26 27 mixin(createRegisterGradientCalls()); 28 29 registerGradient("matmul", toDelegate(&matmulGrad)); 30 registerGradient("sum", toDelegate(&sumGrad)); 31 } 32 } 33 34 private 35 { 36 Operation[] matmulGrad(Operation op, Operation parentGrad) 37 { 38 return [matmul(parentGrad, transpose(op.deps[1], [1, 0])), matmul(transpose(op.deps[0], [1, 0]), parentGrad)]; 39 } 40 41 Operation[] sumGrad(Operation op, Operation parentGrad) 42 { 43 if(op.volume == 1) 44 { 45 return [parentGrad.repeat(op.deps[0].volume).reshape(op.deps[0].shape)]; 46 } 47 else 48 { 49 auto axes = op.attributes["axes"].get!(size_t[]); 50 auto tmpShape = op.deps[0].shape.dup; 51 auto reps = new size_t[tmpShape.length]; 52 reps[] = 1; 53 54 foreach(a; axes) 55 { 56 reps[a] = tmpShape[a]; 57 tmpShape[a] = 1; 58 } 59 60 auto tmp = parentGrad.reshape(tmpShape); 61 62 return [tmp.repeat(reps)]; 63 } 64 } 65 66 Operation[] addGrad(Operation op, Operation parentGrad) 67 { 68 return [parentGrad, parentGrad]; 69 } 70 71 Operation[] subGrad(Operation op, Operation parentGrad) 72 { 73 return [parentGrad, neg(parentGrad)]; 74 } 75 76 Operation[] mulGrad(Operation op, Operation parentGrad) 77 { 78 return [parentGrad * op.deps[1], parentGrad * op.deps[0]]; 79 } 80 81 Operation[] divGrad(Operation op, Operation parentGrad) 82 { 83 return [ 84 parentGrad / op.deps[1], 85 neg(parentGrad * op.deps[0]) / (op.deps[1] * op.deps[1]) 86 ]; 87 } 88 89 Operation[] powGrad(Operation op, Operation parentGrad) 90 { 91 return [ 92 parentGrad * op.deps[1] * pow(op.deps[0], op.deps[1] - 1), 93 parentGrad * op.deps[1] * log(op.deps[0]) 94 ]; 95 } 96 97 Operation[] minGrad(Operation op, Operation parentGrad) 98 { 99 return [ 100 op.deps[0].eq(op) * parentGrad, 101 op.deps[1].eq(op) * parentGrad 102 ]; 103 } 104 105 Operation[] maxGrad(Operation op, Operation parentGrad) 106 { 107 return [ 108 op.deps[0].eq(op) * parentGrad, 109 op.deps[1].eq(op) * parentGrad 110 ]; 111 } 112 113 Operation[] negGrad(Operation op, Operation parentGrad) 114 { 115 return [neg(parentGrad)]; 116 } 117 118 Operation[] absGrad(Operation op, Operation parentGrad) 119 { 120 return [parentGrad * sgn(op.deps[0])]; 121 } 122 123 Operation[] expGrad(Operation op, Operation parentGrad) 124 { 125 return [parentGrad * op]; 126 } 127 128 Operation[] logGrad(Operation op, Operation parentGrad) 129 { 130 return [parentGrad / op.deps[0]]; 131 } 132 133 Operation[] sqrtGrad(Operation op, Operation parentGrad) 134 { 135 return [parentGrad / op]; 136 } 137 }