def test_tuple_deriv(self): """Test tuples work via derivatives""" A = tile.Value.from_ndims(2) B = tile.Value.from_ndims(2) out_dims = (A.shape.dims[0], B.shape.dims[1]) out_shape = tile.Shape(tile.common_dtype(A.shape.dtype, B.shape.dtype), out_dims) out = tile.Operation( """ function (A[I, K], B[K, J]) -> (O) { T = tuple(A, B); C = element(T, 0); D = element(T, 1); O[i, j : I, J] = +(C[i, k] * D[k, j]); } """, [('A', A), ('B', B)], [('O', out_shape)]).outputs['O'] tot = op.summation(out, [0, 1]) dA = op.gradients(tot, [A])[0] func = tile.compose(self._ctx, self._dev, inputs=[('A', A), ('B', B)], outputs=[('DA', dA)]) invoker = plaidml.Invoker(self._ctx, func) invoker.set_input('A', self.make_inited_tensor((3, 3))) invoker.set_input('B', self.make_inited_tensor((3, 3))) output = self.make_output_tensor(invoker.get_output_shape('DA')) invoker.set_output('DA', output) invoker.invoke()
def reduce_log_sum_exp(unused_ctx, data, axes=None, keepdims=1): if axes is None: axes = range(data.shape.ndims) if not isinstance(axes, (list, tuple)): axes = tuple(axes) return (op.log(op.summation(op.exp(data), axes=axes, keepdims=keepdims)), )
def reduce_sum_square(unused_ctx, data, axes=None, keepdims=1): if axes is None: axes = range(data.shape.ndims) if not isinstance(axes, (list, tuple)): axes = tuple(axes) return (op.summation(data * data, axes=axes, keepdims=keepdims), )