예제 #1
0
def _BandedTriangularSolveGrad(op, grad):
    """Gradient for BandedTriangularSolve."""
    a = op.inputs[0]
    b = op.inputs[1]
    num_bands = array_ops.shape(a)[-2]
    adjoint_a = op.get_attr("adjoint")
    lower_a = op.get_attr("lower")
    c = op.outputs[0]
    grad_b = linalg_ops.banded_triangular_solve(a,
                                                grad,
                                                lower=lower_a,
                                                adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
    else:
        grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
    if lower_a:
        grad_a = array_ops.matrix_diag_part(grad_a,
                                            k=(-(num_bands - 1), 0),
                                            align="LEFT_RIGHT")
    else:
        grad_a = array_ops.matrix_diag_part(grad_a,
                                            k=(0, num_bands - 1),
                                            align="LEFT_RIGHT")
    # If the static batch shapes are equal, we don't need to unbroadcast.
    if (a.shape.is_fully_defined() and b.shape.is_fully_defined()
            and a.shape[:-2] == b.shape[:-2]):
        return grad_a, grad_b
    a_shape = array_ops.shape(a)
    b_shape = array_ops.shape(b)
    ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
    grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
    grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
    return grad_a, grad_b
예제 #2
0
    def _verifySolve(self,
                     x,
                     y,
                     lower=True,
                     adjoint=False,
                     batch_dims=None,
                     use_placeholder=False,
                     dtypes=(np.float32, np.float64)):
        for np_type in dtypes:
            a = x.astype(np_type)
            b = y.astype(np_type)

            # Now we need to convert a to a dense triangular matrix.
            def make_diags(diags, lower=True):
                n = len(diags[0])
                a = np.zeros(n * n, dtype=diags.dtype)
                if lower:
                    for i, diag in enumerate(diags):
                        a[n * i:n * n:n + 1] = diag[i:]
                else:
                    diags_flip = np.flip(diags, 0)
                    for i, diag in enumerate(diags_flip):
                        a[i:(n - i) * n:n + 1] = diag[:(n - i)]
                return a.reshape(n, n)

            # For numpy.solve we have to explicitly zero out the strictly
            # upper or lower triangle.
            if a.size > 0:
                a_np = make_diags(a, lower=lower)
            else:
                a_np = a
            if adjoint:
                a_np = np.conj(np.transpose(a_np))

            if batch_dims is not None:
                a = np.tile(a, batch_dims + [1, 1])
                a_np = np.tile(a_np, batch_dims + [1, 1])
                b = np.tile(b, batch_dims + [1, 1])

            with self.cached_session():
                a_tf = a
                b_tf = b
                if use_placeholder:
                    a_tf = array_ops.placeholder_with_default(a_tf, shape=None)
                    b_tf = array_ops.placeholder_with_default(b_tf, shape=None)
                tf_ans = linalg_ops.banded_triangular_solve(a_tf,
                                                            b_tf,
                                                            lower=lower,
                                                            adjoint=adjoint)
                tf_val = self.evaluate(tf_ans)
                np_ans = np.linalg.solve(a_np, b)
                self.assertEqual(np_ans.shape, tf_val.shape)
                self.assertAllClose(np_ans, tf_val)