1 module dopt.core.grads.basic; 2 3 import dopt.core.grads; 4 import dopt.core.ops; 5 6 package 7 { 8 void initialize() 9 { 10 import std.functional : toDelegate; 11 12 registerGradient("transpose", toDelegate(&transposeGrad)); 13 registerGradient("slice", toDelegate(&sliceGrad)); 14 registerGradient("pad", toDelegate(&padGrad)); 15 registerGradient("reshape", toDelegate(&reshapeGrad)); 16 registerGradient("repeat", toDelegate(&repeatGrad)); 17 } 18 } 19 20 private 21 { 22 Operation[] transposeGrad(Operation op, Operation parentGrad) 23 { 24 import std.algorithm : countUntil, map; 25 import std.array : array; 26 import std.range : iota; 27 28 auto order = op 29 .attributes["order"] 30 .get!(size_t[]); 31 32 auto newOrder = iota(0, order.length) 33 .map!(x => cast(size_t)order.countUntil(x)) 34 .array(); 35 36 return [parentGrad.transpose(newOrder)]; 37 } 38 39 Operation[] sliceGrad(Operation op, Operation parentGrad) 40 { 41 auto before = op.attributes["start"].get!(size_t[]); 42 auto after = op.deps[0].outputType.shape.dup; 43 after[] -= op.attributes["stop"].get!(size_t[])[]; 44 45 return [parentGrad.pad(before, after)]; 46 } 47 48 //Test the sliceGrad function 49 unittest 50 { 51 import dopt.core; 52 53 auto a = float32([4, 4], [ 54 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 55 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f]); 56 57 auto b = float32([4, 4], [ 58 5.0f, 6.0f, 7.0f, 8.0f, 5.0f, 6.0f, 7.0f, 8.0f, 59 5.0f, 6.0f, 7.0f, 8.0f, 5.0f, 6.0f, 7.0f, 8.0f]); 60 61 auto c = slice(a * b, [1, 1], [2, 2]); 62 63 import std.algorithm : equal; 64 65 assert(evaluate(grad(c, [a]))[0].get!float.equal( 66 [0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])); 67 } 68 69 Operation[] padGrad(Operation op, Operation parentGrad) 70 { 71 auto start = op.attributes["before"].get!(size_t[]); 72 auto stop = op.deps[0].outputType.shape.dup; 73 stop[] += start[]; 74 75 return [parentGrad.slice(start, stop)]; 76 } 77 78 Operation[] reshapeGrad(Operation op, Operation parentGrad) 79 { 80 return [parentGrad.reshape(op.deps[0].outputType.shape)]; 81 } 82 83 Operation[] repeatGrad(Operation op, Operation parentGrad) 84 { 85 import std.array : array; 86 import std.range : iota, roundRobin; 87 88 auto reps = op.attributes["repetitions"].get!(size_t[]); 89 90 //Add some new dimensions that explicitly represent the repetitions 91 auto tmpShape = roundRobin(reps, op.deps[0].shape).array(); 92 auto tmp = parentGrad.reshape(tmpShape); 93 94 //Sum over these dimensions 95 return [tmp.sum(iota(0, tmpShape.length, 2).array())]; 96 } 97 }