1 /**
2     Contains functions for creating variable nodes and subsequently manipulating their shapes.
3 
4     Authors: Henry Gouk
5 */
6 module dopt.core.ops.basic;
7 
8 import dopt.core : allocate;
9 import dopt.core.ops;
10 import dopt.core.types;
11 
12 import std.algorithm;
13 import std.array;
14 import std.exception;
15 import std.functional;
16 import std.range;
17 import std.variant;
18 
19 package
20 {
21     void initialize()
22     {
23         registerOperation("slice", OpDef(toDelegate(&verifySlice), toDelegate(&judgeSlice)));
24         registerOperation("pad", OpDef(toDelegate(&verifyPad), toDelegate(&judgePad)));
25         registerOperation("reshape", OpDef(toDelegate(&verifyReshape), toDelegate(&judgeReshape)));
26         registerOperation("transpose", OpDef(toDelegate(&verifyTranspose), toDelegate(&judgeTranspose)));
27         registerOperation("repeat", OpDef(toDelegate(&verifyRepeat), toDelegate(&judgeRepeat)));
28         registerOperation("variable", OpDef(toDelegate(&verifyVariable), toDelegate(&judgeVariable)));
29         registerOperation("constant", OpDef(toDelegate(&verifyVariable), toDelegate(&judgeVariable)));
30     }
31 }
32 
33 private
34 {
35     bool verifySlice(Operation op)
36     {
37         if(("start" in op.attributes) is null || ("stop" in op.attributes) is null)
38         {
39             return false;
40         }
41 
42         auto startVar = op.attributes["start"];
43         auto stopVar = op.attributes["stop"];
44 
45         if(startVar.peek!(size_t[]) is null || stopVar.peek!(size_t[]) is null)
46         {
47             return false;
48         }
49 
50         auto start = startVar.get!(size_t[]);
51         auto stop = stopVar.get!(size_t[]);
52 
53         return op.deps.length == 1
54             && start.length == stop.length
55             && start.length == op.deps[0].outputType.rank
56             && zip(start, op.deps[0].outputType.shape).all!(x => x[0] < x[1])
57             && zip(stop, op.deps[0].outputType.shape).all!(x => x[0] <= x[1])
58             && zip(start, stop).all!(x => x[0] < x[1]);
59     }
60 
61     TensorType judgeSlice(Operation op)
62     {
63         auto start = op
64                     .attributes["start"]
65                     .get!(size_t[]);
66 
67         auto stop = op
68                    .attributes["stop"]
69                    .get!(size_t[]);
70 
71         auto shape = zip(start, stop)
72                     .map!(x => x[1] - x[0])
73                     .array();
74 
75         return TensorType(op.deps[0].outputType.elementType, shape);
76     }
77 
78     bool verifyPad(Operation op)
79     {
80         if(("before" in op.attributes) is null || ("after" in op.attributes) is null)
81         {
82             return false;
83         }
84 
85         auto beforeVar = op.attributes["before"];
86         auto afterVar = op.attributes["after"];
87 
88         if(beforeVar.peek!(size_t[]) is null || afterVar.peek!(size_t[]) is null)
89         {
90             return false;
91         }
92 
93         auto before = beforeVar.get!(size_t[]);
94         auto after = afterVar.get!(size_t[]);
95 
96         return op.deps.length == 1
97             && before.length == after.length
98             && before.length == op.deps[0].outputType.rank;
99     }
100 
101     TensorType judgePad(Operation op)
102     {
103         auto before = op
104                      .attributes["before"]
105                      .get!(size_t[]);
106 
107         auto after = op
108                     .attributes["after"]
109                     .get!(size_t[]);
110 
111         auto shape = zip(before, after, op.deps[0].outputType.shape)
112                     .map!(x => x[0] + x[1] + x[2])
113                     .array();
114 
115         return TensorType(op.deps[0].outputType.elementType, shape);
116     }
117 
118     bool verifyReshape(Operation op)
119     {
120         auto newShape = "shape" in op.attributes;
121 
122         return op.deps.length == 1
123             && newShape !is null
124             && newShape.peek!(size_t[]) !is null
125             && newShape.get!(size_t[]).fold!((a, b) => a * b)(cast(size_t)1) == op.deps[0].outputType.volume;
126     }
127 
128     TensorType judgeReshape(Operation op)
129     {
130         return TensorType(op.deps[0].outputType.elementType, op.attributes["shape"].get!(size_t[]));
131     }
132 
133     bool verifyTranspose(Operation op)
134     {
135         auto newOrder = "order" in op.attributes;
136 
137         return op.deps.length == 1
138             && newOrder !is null
139             && newOrder.peek!(size_t[]) !is null
140             && newOrder.get!(size_t[]).dup.sort().equal(iota(0, op.deps[0].outputType.rank));
141     }
142 
143     TensorType judgeTranspose(Operation op)
144     {
145         auto order = op
146                     .attributes["order"]
147                     .get!(size_t[]);
148 
149         auto newShape = order
150                        .map!(x => op.deps[0].outputType.shape[x])
151                        .array();
152 
153         return TensorType(op.deps[0].outputType.elementType, newShape);
154     }
155 
156     bool verifyRepeat(Operation op)
157     {
158         if(("repetitions" in op.attributes) is null)
159         {
160             return false;
161         }
162 
163         auto reps = op.attributes["repetitions"].get!(size_t[]);
164 
165         return op.deps.length == 1
166             && reps.length == op.deps[0].rank
167             && reps.all!(x => x > 0);
168     }
169 
170     TensorType judgeRepeat(Operation op)
171     {
172         auto reps = op.attributes["repetitions"].get!(size_t[]);
173         auto shape = op.deps[0].shape.dup;
174         shape[] *= reps[];
175 
176         return TensorType(op.deps[0].elementType, shape);
177     }
178 
179     bool verifyVariable(Operation op)
180     {
181         return op.deps.length == 0
182             && ("type" in op.attributes) !is null
183             && op.attributes["type"].peek!TensorType !is null;
184     }
185 
186     TensorType judgeVariable(Operation op)
187     {
188         return op.attributes["type"].get!TensorType;
189     }
190 }
191 
192 public
193 {
194     /**
195         Slices the result of an operation.
196 
197         Params:
198             input = The operation that should be sliced.
199             start = The starting indices for each dimension.
200             stop = The stopping indices for each dimension.
201 
202         Returns:
203             The new $(D Operation).
204     */
205     Operation slice(Operation input, size_t[] start, size_t[] stop,
206         string mod = __MODULE__, size_t line = __LINE__)
207     {
208         return createOperation("slice", [input], ["start": Variant(start), "stop": Variant(stop)], mod, line);
209     }
210 
211     ///
212     unittest
213     {
214         import dopt.core : evaluate;
215 
216         auto s1 = int32([3, 3], [
217             1, 2, 3,
218             4, 5, 6,
219             7, 8, 9
220         ]).slice([1, 1], [3, 3]);
221 
222         assert(s1.evaluate().get!int == [
223             5, 6,
224             8, 9
225         ]);
226     }
227 
228     /**
229         Pads the result of an operation with zeros in each dimension.
230 
231         Params:
232             input = The operation that should be padded.
233             before = The amount of padding that should be prepended for each dimension.
234             after = The amount of padding that should be appended for each dimension.
235 
236         Returns:
237             The new $(D Operation).
238     */
239     Operation pad(Operation input, size_t[] before, size_t[] after,
240         string mod = __MODULE__, size_t line = __LINE__)
241     {
242         return createOperation("pad", [input], ["before": Variant(before), "after": Variant(after)], mod, line);
243     }
244 
245     ///
246     unittest
247     {
248         import dopt.core : evaluate;
249 
250         auto p1 = int32([1, 1], [3]).pad([2, 1], [3, 3]);
251 
252         assert(p1.evaluate().get!int == [
253             0, 0, 0, 0, 0,
254             0, 0, 0, 0, 0,
255             0, 3, 0, 0, 0,
256             0, 0, 0, 0, 0,
257             0, 0, 0, 0, 0,
258             0, 0, 0, 0, 0
259         ]);
260     }
261 
262     /**
263         Allows one to cast an operation to a different shape with the same volume.
264 
265         Params:
266             input = The operation to be reshaped.
267             shape = The new shape.
268 
269         Returns:
270             The new $(D Operation).
271     */
272     Operation reshape(Operation input, size_t[] shape, string mod = __MODULE__, size_t line = __LINE__)
273     {
274         return createOperation("reshape", [input], ["shape": Variant(shape)], mod, line);
275     }
276 
277     ///
278     unittest
279     {
280         import dopt.core : evaluate;
281 
282         auto r1 = float32([2, 2], [1.0f, 2.0f, 3.0f, 4.0f]).reshape([1, 4]);
283 
284         assert(r1.shape == [1, 4]);
285         assert(r1.evaluate().get!float == [1.0f, 2.0f, 3.0f, 4.0f]);
286     }
287 
288     /**
289         Reorders the dimensions of output of an operation.
290 
291         Params:
292             input = The operation that should have its dimensions reordered.
293             order = Determines how the dimensions are permuted.
294 
295         Notes:
296             Currently only implemented for rank 2 tensors.
297 
298         Returns:
299             The new $(D Operation).
300     */
301     Operation transpose(Operation input, size_t[] order, string mod = __MODULE__, size_t line = __LINE__)
302     {
303         return createOperation("transpose", [input], ["order": Variant(order)], mod, line);
304     }
305 
306     ///
307     unittest
308     {
309         import dopt.core : evaluate;
310 
311         auto t1 = float32([2, 2], [1.0f, 2.0f, 3.0f, 4.0f]).transpose([1, 0]);
312 
313         assert(t1.evaluate().get!float == [1.0f, 3.0f, 2.0f, 4.0f]);
314     }
315 
316     /**
317         Repeats the output of an operation along each axis the given number of times.
318 
319         Params:
320             input = The operation to have its output repeated.
321             repetitions = The number of repetitions to perform along each axis.
322 
323         Return:
324             The new $(D Operation).
325     */
326     Operation repeat(Operation input, size_t[] repetitions, string mod = __MODULE__,
327         size_t line = __LINE__)
328     {
329         enforce(repetitions.length == input.rank,
330             "The length of repetitions must be the same as the rank of the input.");
331         
332         return createOperation("repeat", [input], ["repetitions": Variant(repetitions)], mod, line);
333     }
334 
335     ///
336     unittest
337     {
338         import dopt.core : evaluate;
339         
340         auto r1 = float32([1, 1], [3.0f]).repeat([2, 3]);
341         auto r2 = float32([2, 2], [1.0f, 2.0f, 3.0f, 4.0f]).repeat([3, 2]);
342 
343         assert(r1.evaluate().get!float == [
344             3.0f, 3.0f, 3.0f,
345             3.0f, 3.0f, 3.0f
346         ]);
347 
348         assert(r2.evaluate().get!float == [
349             1.0f, 2.0f, 1.0f, 2.0f,
350             3.0f, 4.0f, 3.0f, 4.0f,
351             1.0f, 2.0f, 1.0f, 2.0f,
352             3.0f, 4.0f, 3.0f, 4.0f,
353             1.0f, 2.0f, 1.0f, 2.0f,
354             3.0f, 4.0f, 3.0f, 4.0f
355         ]);
356     }
357 
358     /**
359         Repeats the output of an operation the given number of times.
360 
361         A new dimension is added, allowing one to index each of these repetitions.
362 
363         Params:
364             input = The operation to have its output repeated.
365             repetitions = The number of repetitions to perform.
366         
367         Return:
368             The new $(D Operation).
369     */
370     Operation repeat(Operation input, size_t repetitions, string mod = __MODULE__, size_t line = __LINE__)
371     {
372         auto vec = input.reshape([1, input.volume]);
373 
374         import std.range : drepeat = repeat;
375         import std.array : array;
376 
377         auto pattern = float32Constant([repetitions, 1], drepeat(1.0f, repetitions).array());
378         auto r = pattern.matmul(vec);
379         
380         return r.reshape([repetitions] ~ input.shape, mod, line);
381     }
382 
383     ///
384     unittest
385     {
386         import dopt.core : evaluate;
387 
388         auto r1 = float32([2], [1.0f, 2.0f]).repeat(3);
389 
390         assert(r1.evaluate().get!float == [
391             1.0f, 2.0f,
392             1.0f, 2.0f,
393             1.0f, 2.0f
394         ]);
395     }
396 
397     /**
398         Creates a variable with the given type.
399 
400         If no default value is provided, then the variable will have a default value of all zeros. The default value is
401         stored in the attributes["default"] field of the returned operation.
402 
403         Params:
404             type = The type of the variable
405             defaultVal = The default value of the variable. The array should store the elements in row major order.
406 
407         Returns:
408             The newly created variable
409     */
410     Operation variable(TensorType type, void[] defaultVal = null, string mod = __MODULE__, size_t line = __LINE__)
411     {
412         auto bufSize = type.volume * sizeOf(type.elementType);
413 
414         if(defaultVal is null)
415         {
416             defaultVal = new ubyte[bufSize];
417         }
418         else
419         {
420             enforce(defaultVal.length == bufSize, "The length of defaultVal does not match type.volume.");
421         }
422 
423         auto op = createOperation("variable", [], ["type": Variant(type)], mod, line);
424         auto buf = allocate(bufSize);
425         buf.set(defaultVal);
426         op.setBuffer(buf);
427 
428         return op;
429     }
430 
431     /**
432         Creates a variable with the given shape and float32 elements.
433 
434         If no default value is provided, then the variable will have a default value of all zeros. The default value is
435         stored in the attributes["default"] field of the returned operation.
436 
437         Params:
438             size = The shape of the variable
439             defaultVal = The default value of the variable. The array should store the elements in row major order.
440 
441         Returns:
442             The newly created variable
443     */
444     Operation float32(size_t[] size = [], float[] defaultVal = null, string mod = __MODULE__, size_t line = __LINE__)
445     {
446         return variable(TensorType(DataType.float32, size), defaultVal, mod, line);
447     }
448 
449     ///
450     Operation float32(float defaultVal, string mod = __MODULE__, size_t line = __LINE__)
451     {
452         return float32([], [defaultVal], mod, line);
453     }
454 
455     /**
456         Creates a variable with the given shape and int32 elements.
457 
458         If no default value is provided, then the variable will have a default value of all zeros. The default value is
459         stored in the attributes["default"] field of the returned operation.
460 
461         Params:
462             size = The shape of the variable
463             defaultVal = The default value of the variable. The array should store the elements in row major order.
464 
465         Returns:
466             The newly created variable
467     */
468     Operation int32(size_t[] size = [], int[] defaultVal = null, string mod = __MODULE__, size_t line = __LINE__)
469     {
470         return variable(TensorType(DataType.int32, size), defaultVal, mod, line);
471     }
472 
473     ///
474     Operation int32(int defaultVal, string mod = __MODULE__, size_t line = __LINE__)
475     {
476         return int32([], [defaultVal], mod, line);
477     }
478 
479     /**
480         Creates a constant with the given type.
481 
482         Params:
483             type = The type of the constant
484             val = The value of the constant. The array should store the elements in row major order.
485 
486         Returns:
487             The newly created constant
488     */
489     Operation constant(TensorType type, void[] val, string mod = __MODULE__, size_t line = __LINE__)
490     {
491         auto bufSize = type.volume * sizeOf(type.elementType);
492 
493         if(val is null)
494         {
495             val = new ubyte[bufSize];
496         }
497         else
498         {
499             enforce(val.length == bufSize, "The length of val does not match type.volume.");
500         }
501 
502         auto op = createOperation("constant", [], ["type": Variant(type)], mod, line);
503         auto buf = allocate(bufSize);
504         buf.set(val);
505         op.setBuffer(buf);
506         
507         return op;
508     }
509 
510     /**
511         Creates a constant with the given shape and float32 values.
512 
513         Params:
514             size = The shape of the constant
515             val = The value of the constant. The array should store the elements in row major order.
516 
517         Returns:
518             The newly created constant
519     */
520     Operation float32Constant(size_t[] size, float[] val, string mod = __MODULE__, size_t line = __LINE__)
521     {
522         return constant(TensorType(DataType.float32, size), val, mod, line);
523     }
524 
525     ///
526     Operation float32Constant(float val, string mod = __MODULE__, size_t line = __LINE__)
527     {
528         return float32Constant([], [val], mod, line);
529     }
530 
531     /**
532         Creates a constant with the given shape and int32 values.
533 
534         Params:
535             size = The shape of the constant
536             val = The value of the constant. The array should store the elements in row major order.
537 
538         Returns:
539             The newly created constant
540     */
541     Operation int32Constant(size_t[] size, int[] val, string mod = __MODULE__, size_t line = __LINE__)
542     {
543         return constant(TensorType(DataType.int32, size), val, mod, line);
544     }
545 
546     ///
547     Operation int32Constant(int val, string mod = __MODULE__, size_t line = __LINE__)
548     {
549         return int32Constant(val, mod, line);
550     }
551 }