Exemplo n.º 1
0
    def L_op(self, inputs, outputs, output_gradients):
        r"""
        Reverse-mode gradient updates for matrix solve operation c = A \\\ b.

        Symbolic expression for updates taken from [#]_.

        References
        ----------
        .. [#] M. B. Giles, "An extended collection of matrix derivative results
          for forward and reverse mode automatic differentiation",
          http://eprints.maths.ox.ac.uk/1079/

        """
        A, b = inputs
        c = outputs[0]
        c_bar = output_gradients[0]
        trans_map = {
            "lower_triangular": "upper_triangular",
            "upper_triangular": "lower_triangular",
        }
        trans_solve_op = Solve(
            # update A_structure and lower to account for a transpose operation
            A_structure=trans_map.get(self.A_structure, self.A_structure),
            lower=not self.lower,
        )
        b_bar = trans_solve_op(A.T, c_bar)
        # force outer product if vector second input
        A_bar = -tm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
        if self.A_structure == "lower_triangular":
            A_bar = aet.tril(A_bar)
        elif self.A_structure == "upper_triangular":
            A_bar = aet.triu(A_bar)
        return [A_bar, b_bar]
Exemplo n.º 2
0
    def L_op(self, inputs, outputs, output_gradients):
        res = super().L_op(inputs, outputs, output_gradients)

        if self.lower:
            res[0] = at.tril(res[0])
        else:
            res[0] = at.triu(res[0])

        return res
Exemplo n.º 3
0
    def L_op(self, inputs, outputs, gradients):
        """
        Cholesky decomposition reverse-mode gradient update.

        Symbolic expression for reverse-mode Cholesky gradient taken from [#]_

        References
        ----------
        .. [#] I. Murray, "Differentiation of the Cholesky decomposition",
           http://arxiv.org/abs/1602.07527

        """

        dz = gradients[0]
        chol_x = outputs[0]

        # Replace the cholesky decomposition with 1 if there are nans
        # or solve_upper_triangular will throw a ValueError.
        if self.on_error == "nan":
            ok = ~atm.any(atm.isnan(chol_x))
            chol_x = at.switch(ok, chol_x, 1)
            dz = at.switch(ok, dz, 1)

        # deal with upper triangular by converting to lower triangular
        if not self.lower:
            chol_x = chol_x.T
            dz = dz.T

        def tril_and_halve_diagonal(mtx):
            """Extracts lower triangle of square matrix and halves diagonal."""
            return at.tril(mtx) - at.diag(at.diagonal(mtx) / 2.0)

        def conjugate_solve_triangular(outer, inner):
            """Computes L^{-T} P L^{-1} for lower-triangular L."""
            return solve_upper_triangular(
                outer.T, solve_upper_triangular(outer.T, inner.T).T
            )

        s = conjugate_solve_triangular(
            chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz))
        )

        if self.lower:
            grad = at.tril(s + s.T) - at.diag(at.diagonal(s))
        else:
            grad = at.triu(s + s.T) - at.diag(at.diagonal(s))

        if self.on_error == "nan":
            return [at.switch(ok, grad, np.nan)]
        else:
            return [grad]
Exemplo n.º 4
0
    def L_op(self, inputs, outputs, output_gradients):
        # Modified from aesara/tensor/slinalg.py
        A, b = inputs
        c = outputs[0]
        c_bar = output_gradients[0]

        trans_solve_op = GpuCublasTriangularSolve(not self.lower)
        b_bar = trans_solve_op(A.T, c_bar)

        A_bar = -tm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)

        if self.lower:
            A_bar = aet.tril(A_bar)
        else:
            A_bar = aet.triu(A_bar)
        return [A_bar, b_bar]
Exemplo n.º 5
0
    def L_op(self, inputs, outputs, gradients):
        # Modified from aesara/tensor/slinalg.py
        # No handling for on_error = 'nan'
        dz = gradients[0]
        chol_x = outputs[0]

        # this is for nan mode
        #
        # ok = ~tm.any(tm.isnan(chol_x))
        # chol_x = aet.switch(ok, chol_x, 1)
        # dz = aet.switch(ok, dz, 1)

        # deal with upper triangular by converting to lower triangular
        if not self.lower:
            chol_x = chol_x.T
            dz = dz.T

        def tril_and_halve_diagonal(mtx):
            """Extracts lower triangle of square matrix and halves diagonal."""
            return aet.tril(mtx) - aet.diag(aet.diagonal(mtx) / 2.0)

        def conjugate_solve_triangular(outer, inner):
            """Computes L^{-T} P L^{-1} for lower-triangular L."""
            return gpu_solve_upper_triangular(
                outer.T,
                gpu_solve_upper_triangular(outer.T, inner.T).T)

        s = conjugate_solve_triangular(
            chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz)))

        if self.lower:
            grad = aet.tril(s + s.T) - aet.diag(aet.diagonal(s))
        else:
            grad = aet.triu(s + s.T) - aet.diag(aet.diagonal(s))

        return [grad]