Ejemplo n.º 1
0
    def L_op(self, inputs, outputs, gradients):
        """
        Cholesky decomposition reverse-mode gradient update.

        Symbolic expression for reverse-mode Cholesky gradient taken from [#]_

        References
        ----------
        .. [#] I. Murray, "Differentiation of the Cholesky decomposition",
           http://arxiv.org/abs/1602.07527

        """

        dz = gradients[0]
        chol_x = outputs[0]

        # Replace the cholesky decomposition with 1 if there are nans
        # or solve_upper_triangular will throw a ValueError.
        if self.on_error == "nan":
            ok = ~atm.any(atm.isnan(chol_x))
            chol_x = at.switch(ok, chol_x, 1)
            dz = at.switch(ok, dz, 1)

        # deal with upper triangular by converting to lower triangular
        if not self.lower:
            chol_x = chol_x.T
            dz = dz.T

        def tril_and_halve_diagonal(mtx):
            """Extracts lower triangle of square matrix and halves diagonal."""
            return at.tril(mtx) - at.diag(at.diagonal(mtx) / 2.0)

        def conjugate_solve_triangular(outer, inner):
            """Computes L^{-T} P L^{-1} for lower-triangular L."""
            return solve_upper_triangular(
                outer.T, solve_upper_triangular(outer.T, inner.T).T
            )

        s = conjugate_solve_triangular(
            chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz))
        )

        if self.lower:
            grad = at.tril(s + s.T) - at.diag(at.diagonal(s))
        else:
            grad = at.triu(s + s.T) - at.diag(at.diagonal(s))

        if self.on_error == "nan":
            return [at.switch(ok, grad, np.nan)]
        else:
            return [grad]
Ejemplo n.º 2
0
    def L_op(self, inputs, outputs, gradients):
        # Modified from aesara/tensor/slinalg.py
        # No handling for on_error = 'nan'
        dz = gradients[0]
        chol_x = outputs[0]

        # this is for nan mode
        #
        # ok = ~tm.any(tm.isnan(chol_x))
        # chol_x = aet.switch(ok, chol_x, 1)
        # dz = aet.switch(ok, dz, 1)

        # deal with upper triangular by converting to lower triangular
        if not self.lower:
            chol_x = chol_x.T
            dz = dz.T

        def tril_and_halve_diagonal(mtx):
            """Extracts lower triangle of square matrix and halves diagonal."""
            return aet.tril(mtx) - aet.diag(aet.diagonal(mtx) / 2.0)

        def conjugate_solve_triangular(outer, inner):
            """Computes L^{-T} P L^{-1} for lower-triangular L."""
            return gpu_solve_upper_triangular(
                outer.T,
                gpu_solve_upper_triangular(outer.T, inner.T).T)

        s = conjugate_solve_triangular(
            chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz)))

        if self.lower:
            grad = aet.tril(s + s.T) - aet.diag(aet.diagonal(s))
        else:
            grad = aet.triu(s + s.T) - aet.diag(aet.diagonal(s))

        return [grad]
Ejemplo n.º 3
0
    def grad(self, inp, cost_grad):
        """
        Notes
        -----
        The gradient is currently implemented for matrices only.

        """
        a, val = inp
        grad = cost_grad[0]
        if a.dtype.startswith("complex"):
            return [None, None]
        elif a.ndim > 2:
            raise NotImplementedError("%s: gradient is currently implemented"
                                      " for matrices only" %
                                      self.__class__.__name__)
        wr_a = fill_diagonal(grad, 0)  # valid for any number of dimensions
        # diag is only valid for matrices
        wr_val = aet.diag(grad).sum()
        return [wr_a, wr_val]
Ejemplo n.º 4
0
 def tril_and_halve_diagonal(mtx):
     """Extracts lower triangle of square matrix and halves diagonal."""
     return aet.tril(mtx) - aet.diag(aet.diagonal(mtx) / 2.0)
Ejemplo n.º 5
0
def test_jax_basic():
    rng = np.random.default_rng(28494)

    x = matrix("x")
    y = matrix("y")
    b = vector("b")

    # `ScalarOp`
    z = cosh(x**2 + y / 3.0)

    # `[Inc]Subtensor`
    out = aet_subtensor.set_subtensor(z[0], -10.0)
    out = aet_subtensor.inc_subtensor(out[0, 1], 2.0)
    out = out[:5, :3]

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        np.tile(np.arange(10), (10, 1)).astype(config.floatX),
        np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
    ]
    (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals)

    # Confirm that the `Subtensor` slice operations are correct
    assert jax_res.shape == (5, 3)

    # Confirm that the `IncSubtensor` operations are correct
    assert jax_res[0, 0] == -10.0
    assert jax_res[0, 1] == -8.0

    out = clip(x, y, 5)
    out_fg = FunctionGraph([x, y], [out])
    compare_jax_and_py(out_fg, test_input_vals)

    out = aet.diagonal(x, 0)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_slinalg.cholesky(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    # not sure why this isn't working yet with lower=False
    out = aet_slinalg.Cholesky(lower=False)(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    out = aet_slinalg.solve(x, b)
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )

    out = aet.diag(b)
    out_fg = FunctionGraph([b], [out])
    compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])

    out = aet_nlinalg.det(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_nlinalg.matrix_inverse(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )