def _BandedTriangularSolveGrad(op, grad): """Gradient for BandedTriangularSolve.""" a = op.inputs[0] b = op.inputs[1] num_bands = array_ops.shape(a)[-2] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.banded_triangular_solve(a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type else: grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type if lower_a: grad_a = array_ops.matrix_diag_part(grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT") else: grad_a = array_ops.matrix_diag_part(grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT") # If the static batch shapes are equal, we don't need to unbroadcast. if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and a.shape[:-2] == b.shape[:-2]): return grad_a, grad_b a_shape = array_ops.shape(a) b_shape = array_ops.shape(b) ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) return grad_a, grad_b
def _verifySolve(self, x, y, lower=True, adjoint=False, batch_dims=None, use_placeholder=False, dtypes=(np.float32, np.float64)): for np_type in dtypes: a = x.astype(np_type) b = y.astype(np_type) # Now we need to convert a to a dense triangular matrix. def make_diags(diags, lower=True): n = len(diags[0]) a = np.zeros(n * n, dtype=diags.dtype) if lower: for i, diag in enumerate(diags): a[n * i:n * n:n + 1] = diag[i:] else: diags_flip = np.flip(diags, 0) for i, diag in enumerate(diags_flip): a[i:(n - i) * n:n + 1] = diag[:(n - i)] return a.reshape(n, n) # For numpy.solve we have to explicitly zero out the strictly # upper or lower triangle. if a.size > 0: a_np = make_diags(a, lower=lower) else: a_np = a if adjoint: a_np = np.conj(np.transpose(a_np)) if batch_dims is not None: a = np.tile(a, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) with self.cached_session(): a_tf = a b_tf = b if use_placeholder: a_tf = array_ops.placeholder_with_default(a_tf, shape=None) b_tf = array_ops.placeholder_with_default(b_tf, shape=None) tf_ans = linalg_ops.banded_triangular_solve(a_tf, b_tf, lower=lower, adjoint=adjoint) tf_val = self.evaluate(tf_ans) np_ans = np.linalg.solve(a_np, b) self.assertEqual(np_ans.shape, tf_val.shape) self.assertAllClose(np_ans, tf_val)