Example #1
0
 def grad(self, inputs, g_outputs):
     x, i0, i1, amt = inputs
     gy = g_outputs[0]
     return [
         gy,
         DisconnectedType()(),
         DisconnectedType()(),
         diagonal_subtensor(gy, i0, i1),
     ]
Example #2
0
    def grad(self, inputs, gout):
        (x, repeats) = inputs
        (gz, ) = gout
        if repeats.ndim == 0:
            if self.axis is None:
                axis = x.ndim
            else:
                if self.axis >= 0:
                    axis = self.axis + 1
                else:
                    axis = self.axis + x.ndim + 1

            shape = [x.shape[k] for k in range(x.ndim)]
            shape.insert(axis, repeats)

            return [
                gz.reshape(shape, x.ndim + 1).sum(axis=axis),
                DisconnectedType()()
            ]
        elif repeats.ndim == 1:
            # For this implementation, we would need to specify the length
            # of repeats in order to split gz in the right way to sum
            # the good part.
            raise NotImplementedError()
        else:
            raise ValueError()
Example #3
0
 def grad(self, inputs, output_grads):
     (gout, ) = output_grads
     s = inputs[1]
     # Divide the last dimension of the output gradients by 2, they are
     # double-counted by the real-IFFT due to symmetry, except the first
     # and last elements (for even transforms) which are unique.
     idx = ([slice(None)] * (gout.ndim - 2) +
            [slice(1, (s[-1] // 2) + (s[-1] % 2))] + [slice(None)])
     gout = set_subtensor(gout[idx], gout[idx] * 0.5)
     return [irfft_op(gout, s), DisconnectedType()()]
Example #4
0
 def grad(self, inputs, output_grads):
     (gout, ) = output_grads
     s = inputs[1]
     gf = rfft_op(gout, s)
     # Multiply the last dimension of the gradient by 2, they represent
     # both positive and negative frequencies, except the first
     # and last elements (for even transforms) which are unique.
     idx = ([slice(None)] * (gf.ndim - 2) +
            [slice(1, (s[-1] // 2) + (s[-1] % 2))] + [slice(None)])
     gf = set_subtensor(gf[idx], gf[idx] * 2)
     return [gf, DisconnectedType()()]
Example #5
0
def test_disconnected_cost_grad():
    # Tests that if we say the cost is disconnected via the
    # known_grads mechanism, it is treated as such by the rest of the
    # system.
    # This is so that Ops that are built around minigraphs like OpFromGraph
    # and scan can implement Op.grad by passing ograds to known_grads

    x = iscalar()
    y = iscalar()
    cost = x + y
    assert cost.dtype in discrete_dtypes
    try:
        grad(
            cost,
            [x, y],
            known_grads={cost: DisconnectedType()()},
            disconnected_inputs="raise",
        )
    except DisconnectedInputError:
        return
    raise AssertionError("A disconnected gradient has been ignored.")
 def grad(self, inputs, grads):
     return [DisconnectedType()() for i in inputs]
Example #7
0
 def grad(self, input, output_gradients):
     return output_gradients + [DisconnectedType()()] * (len(input) - 1)
Example #8
0
 def grad(self, inputs, g_outputs):
     z = at.zeros_like(inputs[0])
     gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0])
     return [gx, DisconnectedType()(), DisconnectedType()()]
Example #9
0
 def grad(self, inp, grads):
     x, shp = inp
     (g_out, ) = grads
     return [reshape(g_out, shape(x), ndim=x.ndim), DisconnectedType()()]
Example #10
0
 def grad(self, inputs, output_gradients):
     return [DisconnectedType()()] + output_gradients
Example #11
0
    def test_grad_override(self, cls_ofg):
        x, y = vectors("xy")

        def go(inps, gs):
            x, y = inps
            (g, ) = gs
            return [g * y * 2, g * x * 1.5]

        dedz = vector("dedz")
        op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))

        op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
        op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)

        # single override case (function or OfG instance)
        xx, yy = vector("xx"), vector("yy")
        for op in [op_mul, op_mul2]:
            zz = tt_sum(op(xx, yy))
            dx, dy = grad(zz, [xx, yy])
            fn = function([xx, yy], [dx, dy])
            xv = np.random.rand(16).astype(config.floatX)
            yv = np.random.rand(16).astype(config.floatX)
            dxv, dyv = fn(xv, yv)
            assert np.allclose(yv * 2, dxv)
            assert np.allclose(xv * 1.5, dyv)

        # list override case
        def go1(inps, gs):
            x, w, b = inps
            g = gs[0]
            return g * w * 2

        def go2(inps, gs):
            x, w, b = inps
            g = gs[0]
            return g * x * 1.5

        w, b = vectors("wb")
        # we make the 3rd gradient default (no override)
        op_linear = cls_ofg([x, w, b], [x * w + b],
                            grad_overrides=[go1, go2, "default"])
        xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
        zz = tt_sum(op_linear(xx, ww, bb))
        dx, dw, db = grad(zz, [xx, ww, bb])
        fn = function([xx, ww, bb], [dx, dw, db])
        xv = np.random.rand(16).astype(config.floatX)
        wv = np.random.rand(16).astype(config.floatX)
        bv = np.random.rand(16).astype(config.floatX)
        dxv, dwv, dbv = fn(xv, wv, bv)
        assert np.allclose(wv * 2, dxv)
        assert np.allclose(xv * 1.5, dwv)
        assert np.allclose(np.ones(16, dtype=config.floatX), dbv)

        # NullType and DisconnectedType
        op_linear2 = cls_ofg(
            [x, w, b],
            [x * w + b],
            grad_overrides=[go1, NullType()(),
                            DisconnectedType()()],
        )
        zz2 = tt_sum(op_linear2(xx, ww, bb))
        dx2, dw2, db2 = grad(
            zz2,
            [xx, ww, bb],
            return_disconnected="Disconnected",
            disconnected_inputs="ignore",
            null_gradients="return",
        )
        assert isinstance(dx2.type, TensorType)
        assert dx2.ndim == 1
        assert isinstance(dw2.type, NullType)
        assert isinstance(db2.type, DisconnectedType)