Пример #1
0
  def sqrt_solve(self, x):
    """Computes `solve(self, x)`.

    Doesn't actually do the sqrt! Named as such to agree with API.

    To compute (M + V D V.T), we use the the Woodbury matrix identity:
      inv(M + V D V.T) = inv(M) - inv(M) V inv(C) V.T inv(M)
    where,
      C = inv(D) + V.T inv(M) V.
    See: https://en.wikipedia.org/wiki/Woodbury_matrix_identity

    Args:
      x: `Tensor`

    Returns:
      inv_of_self_times_x: `Tensor`
    """
    minv_x = linalg_ops.matrix_triangular_solve(self._m, x)
    vt_minv_x = math_ops.matmul(self._v, minv_x, transpose_a=True)
    cinv_vt_minv_x = linalg_ops.matrix_solve(
        self._woodbury_sandwiched_term(), vt_minv_x)
    v_cinv_vt_minv_x = math_ops.matmul(self._v, cinv_vt_minv_x)
    minv_v_cinv_vt_minv_x = linalg_ops.matrix_triangular_solve(
        self._m, v_cinv_vt_minv_x)
    return minv_x - minv_v_cinv_vt_minv_x
 def testWrongDimensions(self):
   randn = np.random.RandomState(0).randn
   for dtype in self.float_types:
     lhs = constant_op.constant(randn(3, 3), dtype=dtype)
     rhs = constant_op.constant(randn(4, 3), dtype=dtype)
     with self.assertRaises(ValueError):
       linalg_ops.matrix_triangular_solve(lhs, rhs)
     with self.assertRaises(ValueError):
       linalg_ops.matrix_triangular_solve(lhs, rhs)
 def testNonSquareCoefficientMatrix(self):
   rng = np.random.RandomState(0)
   for dtype in self.float_types:
     a = rng.randn(3, 4).astype(dtype)
     b = rng.randn(4, 4).astype(dtype)
     with self.assertRaises(ValueError):
       linalg_ops.matrix_triangular_solve(a, b)
     with self.assertRaises(ValueError):
       linalg_ops.matrix_triangular_solve(a, b)
  def _verifySolve(self,
                   x,
                   y,
                   lower=True,
                   adjoint=False,
                   batch_dims=None,
                   use_gpu=False):
    for np_type in [np.float32, np.float64]:
      a = x.astype(np_type)
      b = y.astype(np_type)
      # For numpy.solve we have to explicitly zero out the strictly
      # upper or lower triangle.
      if lower and a.size > 0:
        a_np = np.tril(a)
      elif a.size > 0:
        a_np = np.triu(a)
      else:
        a_np = a
      if adjoint:
        a_np = np.conj(np.transpose(a_np))

      if batch_dims is not None:
        a = np.tile(a, batch_dims + [1, 1])
        a_np = np.tile(a_np, batch_dims + [1, 1])
        b = np.tile(b, batch_dims + [1, 1])

      with self.test_session(use_gpu=use_gpu):
        tf_ans = linalg_ops.matrix_triangular_solve(
            a, b, lower=lower, adjoint=adjoint)
        out = tf_ans.eval()
        np_ans = np.linalg.solve(a_np, b)
        self.assertEqual(np_ans.shape, tf_ans.get_shape())
        self.assertEqual(np_ans.shape, out.shape)
        self.assertAllClose(np_ans, out)
Пример #5
0
def matrix_triangular_solve_with_broadcast(matrix,
                                           rhs,
                                           lower=True,
                                           adjoint=False,
                                           name=None):
  """Solves triangular systems of linear equations with by backsubstitution.

  Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims
  of `matrix` and `rhs` (by replicating) if they are determined statically to be
  different, or if static shapes are not fully defined.  Thus, this may result
  in an inefficient replication of data.

  Args:
    matrix: A Tensor. Must be one of the following types:
      `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
    rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
      Shape is `[..., M, K]`.
    lower: An optional `bool`. Defaults to `True`. Indicates whether the
      innermost matrices in `matrix` are lower or upper triangular.
    adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
      with matrix or its (block-wise) adjoint.
    name: A name for the operation (optional).

  Returns:
    `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
  """
  with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
    return linalg_ops.matrix_triangular_solve(
        matrix,
        rhs,
        lower=lower,
        adjoint=adjoint)
  def _verifySolve(self,
                   x,
                   y,
                   lower=True,
                   adjoint=False,
                   batch_dims=None,
                   use_placeholder=False,
                   dtypes=(np.float32, np.float64)):
    for np_type in dtypes:
      a = x.astype(np_type)
      b = y.astype(np_type)
      # For numpy.solve we have to explicitly zero out the strictly
      # upper or lower triangle.
      if lower and a.size > 0:
        a_np = np.tril(a)
      elif a.size > 0:
        a_np = np.triu(a)
      else:
        a_np = a
      if adjoint:
        a_np = np.conj(np.transpose(a_np))

      if batch_dims is not None:
        a = np.tile(a, batch_dims + [1, 1])
        a_np = np.tile(a_np, batch_dims + [1, 1])
        b = np.tile(b, batch_dims + [1, 1])

      with self.test_session(use_gpu=True) as sess:
        if use_placeholder:
          a_tf = array_ops.placeholder(a.dtype)
          b_tf = array_ops.placeholder(b.dtype)
          tf_ans = linalg_ops.matrix_triangular_solve(
              a_tf, b_tf, lower=lower, adjoint=adjoint)
          tf_val = sess.run(tf_ans, feed_dict={a_tf: a, b_tf: b})
          np_ans = np.linalg.solve(a_np, b)
        else:
          a_tf = constant_op.constant(a)
          b_tf = constant_op.constant(b)
          tf_ans = linalg_ops.matrix_triangular_solve(
              a_tf, b_tf, lower=lower, adjoint=adjoint)
          tf_val = tf_ans.eval()
          np_ans = np.linalg.solve(a_np, b)
          self.assertEqual(np_ans.shape, tf_ans.get_shape())
        self.assertEqual(np_ans.shape, tf_val.shape)
        self.assertAllClose(np_ans, tf_val)
Пример #7
0
def TriAngInvCompositeGrad(l, grad):
  num_rows = array_ops.shape(l)[-1]
  batch_shape = array_ops.shape(l)[:-2]
  l_inverse = linalg_ops.matrix_triangular_solve(l,
                                                 linalg_ops.eye(
                                                     num_rows,
                                                     batch_shape=batch_shape,
                                                     dtype=l.dtype))
  return _GradWithInverseL(l, l_inverse, grad)
Пример #8
0
def TriAngSolveCompositeGrad(l, grad):
  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}

  # Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle
  middle = math_ops.matmul(l, grad, adjoint_a=True)
  middle = array_ops.matrix_set_diag(middle,
                                     0.5 * array_ops.matrix_diag_part(middle))
  middle = array_ops.matrix_band_part(middle, -1, 0)

  # Compute l^{-H} @ middle = z
  l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True)

  # We need to compute z @ l^{-1}. With matrix_triangular_solve we
  # actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H}
  # we can ommit the conjugate transpose here.
  z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle))
  grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True)
  grad_a += linalg.adjoint(grad_a)
  return grad_a * 0.5
Пример #9
0
def mvn_tril_log_prob(loc, scale_tril, x):
  """Computes the MVN log pdf under tril scale. Doesn't handle batches."""
  x0 = x - loc
  z = linalg_ops.matrix_triangular_solve(
      scale_tril, x0[..., array_ops.newaxis])[..., 0]
  log_det_cov = 2. * math_ops.reduce_sum(math_ops.log(
      array_ops.matrix_diag_part(scale_tril)), axis=-1)
  d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype)
  return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1)
                 + d * np.log(2. * np.pi) + log_det_cov)
  def test_static_dims_broadcast(self):
    # batch_shape = [2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(3, 7)
    rhs_broadcast = rhs + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_triangular_solve_with_broadcast(
          matrix, rhs)
      self.assertAllEqual((2, 3, 7), result.get_shape())
      expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
      self.assertAllEqual(expected.eval(), result.eval())
 def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
   clean_a = np.tril(a) if lower else np.triu(a)
   with self.test_session() as sess:
     placeholder_a = MakePlaceholder(a)
     placeholder_ca = MakePlaceholder(clean_a)
     placeholder_b = MakePlaceholder(b)
     with self.test_scope():
       x = linalg_ops.matrix_triangular_solve(
           placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
     verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint)
     self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
                                     placeholder_b, a, clean_a, b,
                                     verification, atol)
def matrix_triangular_solve_with_broadcast(matrix,
                                           rhs,
                                           lower=True,
                                           adjoint=False,
                                           name=None):
  """Solves triangular systems of linear equations with by backsubstitution.

  Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims
  of `matrix` and `rhs` (by replicating) if they are determined statically to be
  different, or if static shapes are not fully defined.  Thus, this may result
  in an inefficient replication of data.

  Args:
    matrix: A Tensor. Must be one of the following types:
      `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
    rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
      Shape is `[..., M, K]`.
    lower: An optional `bool`. Defaults to `True`. Indicates whether the
      innermost matrices in `matrix` are lower or upper triangular.
    adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
      with matrix or its (block-wise) adjoint.
    name: A name for the operation (optional).

  Returns:
    `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
  """
  with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
    matrix = ops.convert_to_tensor(matrix, name="matrix")
    rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)

    # If either matrix/rhs has extra dims, we can reshape to get rid of them.
    matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
        matrix, rhs, adjoint_a=adjoint)

    # lower indicates whether the matrix is lower triangular. If we have
    # manually taken adjoint inside _reshape_for_efficiency, it is now upper tri
    if not still_need_to_transpose and adjoint:
      lower = not lower

    # This will broadcast by brute force if we still need to.
    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])

    solution = linalg_ops.matrix_triangular_solve(
        matrix,
        rhs,
        lower=lower,
        adjoint=adjoint and still_need_to_transpose)

    return reshape_inv(solution)
Пример #13
0
  def _woodbury_sandwiched_term(self):
    """Computes the sandwiched term in the Woodbury identity.

    Computes the "`C`" in the the identity:
       inv(M + V D V.T) = inv(M) - inv(M) V inv(C) V.T inv(M)
    where,
       C = inv(D) + V.T inv(M) V.

    See: https://en.wikipedia.org/wiki/Woodbury_matrix_identity

    Returns:
      woodbury_sandwich_term: A `Tensor` to be used like `C`, above.
    """
    minv_v = linalg_ops.matrix_triangular_solve(self._m, self._v)
    vt_minv_v = math_ops.matmul(self._v, minv_v, adjoint_a=True)
    return self._d_inv.add_to_tensor(vt_minv_v)
Пример #14
0
def _MatrixTriangularSolveGrad(op, grad):
    """Gradient for MatrixTriangularSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    lower_a = op.get_attr("lower")
    c = op.outputs[0]
    grad_b = linalg_ops.matrix_triangular_solve(a, grad, lower=lower_a, adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
    else:
        grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
    if lower_a:
        grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
    else:
        grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
    return (grad_a, grad_b)
  def test_static_dims_broadcast_rhs_has_extra_dims(self):
    # Since the second arg has extra dims, and the domain dim of the first arg
    # is larger than the number of linear equations, code will "flip" the extra
    # dims of the first arg to the far right, making extra linear equations
    # (then call the matrix function, then flip back).
    # We have verified that this optimization indeed happens.  How? We stepped
    # through with a debugger.
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 2)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_triangular_solve_with_broadcast(
          matrix, rhs)
      self.assertAllEqual((2, 3, 2), result.get_shape())
      expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
      self.assertAllClose(expected.eval(), result.eval())
Пример #16
0
    def test_static_dims_broadcast_rhs_has_extra_dims(self):
        # Since the second arg has extra dims, and the domain dim of the first arg
        # is larger than the number of linear equations, code will "flip" the extra
        # dims of the first arg to the far right, making extra linear equations
        # (then call the matrix function, then flip back).
        # We have verified that this optimization indeed happens.  How? We stepped
        # through with a debugger.
        # batch_shape = [2]
        matrix = rng.rand(3, 3)
        rhs = rng.rand(2, 3, 2)
        matrix_broadcast = matrix + np.zeros((2, 1, 1))

        with self.cached_session():
            result = linear_operator_util.matrix_triangular_solve_with_broadcast(
                matrix, rhs)
            self.assertAllEqual((2, 3, 2), result.get_shape())
            expected = linalg_ops.matrix_triangular_solve(
                matrix_broadcast, rhs)
            self.assertAllClose(expected.eval(), self.evaluate(result))
Пример #17
0
def _MatrixTriangularSolveGrad(op, grad):
    """Gradient for MatrixTriangularSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    lower_a = op.get_attr("lower")
    c = op.outputs[0]
    grad_b = linalg_ops.matrix_triangular_solve(a,
                                                grad,
                                                lower=lower_a,
                                                adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
    else:
        grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
    if lower_a:
        grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
    else:
        grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
    return (grad_a, grad_b)
    def test_dynamic_dims_broadcast_64bit(self):
        # batch_shape = [2]
        matrix = rng.rand(2, 3, 3)
        rhs = rng.rand(3, 7)
        rhs_broadcast = rhs + np.zeros((2, 1, 1))

        matrix_ph = array_ops.placeholder(dtypes.float64)
        rhs_ph = array_ops.placeholder(dtypes.float64)

        with self.cached_session() as sess:
            result, expected = sess.run([
                linear_operator_util.matrix_triangular_solve_with_broadcast(
                    matrix_ph, rhs_ph),
                linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
            ],
                                        feed_dict={
                                            matrix_ph: matrix,
                                            rhs_ph: rhs,
                                        })
            self.assertAllClose(expected, result)
Пример #19
0
def _CholeskyGrad(op, grad):
  """Gradient for Cholesky."""

  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
  l = op.outputs[0]
  num_rows = array_ops.shape(l)[-1]
  batch_shape = array_ops.shape(l)[:-2]
  l_inverse = linalg_ops.matrix_triangular_solve(
      l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))

  middle = math_ops.matmul(l, grad, adjoint_a=True)
  middle = array_ops.matrix_set_diag(middle,
                                     0.5 * array_ops.matrix_diag_part(middle))
  middle = array_ops.matrix_band_part(middle, -1, 0)

  grad_a = math_ops.matmul(
      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)

  grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
  return grad_a * 0.5
Пример #20
0
def _CholeskyGrad(op, grad):
    """Gradient for Cholesky."""

    # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
    l = op.outputs[0]
    num_rows = array_ops.shape(l)[-1]
    batch_shape = array_ops.shape(l)[:-2]
    l_inverse = linalg_ops.matrix_triangular_solve(
        l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))

    middle = math_ops.matmul(l, grad, adjoint_a=True)
    middle = array_ops.matrix_set_diag(
        middle, 0.5 * array_ops.matrix_diag_part(middle))
    middle = array_ops.matrix_band_part(middle, -1, 0)

    grad_a = math_ops.matmul(
        math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)

    grad_a += _linalg.adjoint(grad_a)
    return grad_a * 0.5
Пример #21
0
  def _define_full_covariance_probs(self, shard_id, shard):
    """Defines the full covariance probabilities per example in a class.

    Updates a matrix with dimension num_examples X num_classes.

    Args:
      shard_id: id of the current shard.
      shard: current data shard, 1 X num_examples X dimensions.
    """
    diff = shard - self._means
    cholesky = linalg_ops.cholesky(self._covs + self._min_var)
    log_det_covs = 2.0 * math_ops.reduce_sum(
        math_ops.log(array_ops.matrix_diag_part(cholesky)), 1)
    x_mu_cov = math_ops.square(
        linalg_ops.matrix_triangular_solve(
            cholesky, array_ops.transpose(
                diff, perm=[0, 2, 1]), lower=True))
    diag_m = array_ops.transpose(math_ops.reduce_sum(x_mu_cov, 1))
    self._probs[shard_id] = -0.5 * (diag_m + math_ops.to_float(self._dimensions)
                                    * math_ops.log(2 * np.pi) + log_det_covs)
  def test_dynamic_dims_broadcast_64bit(self):
    # batch_shape = [2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(3, 7)
    rhs_broadcast = rhs + np.zeros((2, 1, 1))

    matrix_ph = array_ops.placeholder(dtypes.float64)
    rhs_ph = array_ops.placeholder(dtypes.float64)

    with self.cached_session() as sess:
      result, expected = sess.run(
          [
              linear_operator_util.matrix_triangular_solve_with_broadcast(
                  matrix_ph, rhs_ph),
              linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
          ],
          feed_dict={
              matrix_ph: matrix,
              rhs_ph: rhs,
          })
      self.assertAllEqual(expected, result)
  def testNotInvertible(self):
    # The input should be invertible.
    # The matrix is singular because it has a zero on the diagonal.
    singular_matrix = np.array(
        [[[1., 0., 0.],
          [-1., 0., 0.],
          [0., -1., 1.]],
         [[1., 0., 0.],
          [-1., 1., 0.],
          [0., -1., 0.]],
         [[1., 0., 0.],
          [-1., 1., 0.],
          [0., -1., 1.]]])
    rhs = np.array([[3.], [5.], [1.]])

    expected = np.array([
        [[3.], [np.inf], [np.inf]],
        [[3.], [8.], [np.inf]],
        [[3.], [8.], [9.]]])

    with self.cached_session(use_gpu=False):
      ans = linalg_ops.matrix_triangular_solve(singular_matrix, rhs)
      self.assertAllClose(self.evaluate(ans), expected)
 def _forward(self, x):
   with ops.control_dependencies(self._assertions(x)):
     shape = array_ops.shape(x)
     return linalg_ops.matrix_triangular_solve(
         x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True)
 def _batch_sqrt_solve(self, rhs):
   return linalg_ops.matrix_triangular_solve(self._chol, rhs, lower=True)
Пример #26
0
 def _inverse(self, y):
   x = y - self.loc
   x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
   x = linalg_ops.matrix_triangular_solve(self.scale, x)
   x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
   return x
Пример #27
0
 def _batch_sqrt_solve(self, rhs):
   return linalg_ops.matrix_triangular_solve(self._chol, rhs, lower=True)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   return linalg_ops.matrix_triangular_solve(
       self._tril, rhs, lower=True, adjoint=adjoint)
Пример #29
0
 def _inverse(self, x):
   x -= self.loc
   x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
   x = linalg_ops.matrix_triangular_solve(self.scale, x)
   x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
   return x
Пример #30
0
 def _TriangularSolve(x, r):
   """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
   return _linalg.adjoint(
       linalg_ops.matrix_triangular_solve(
           r, _linalg.adjoint(x), lower=False, adjoint=False))
Пример #31
0
 def _solve(self, rhs, adjoint=False):
   return linalg_ops.matrix_triangular_solve(
       self._tril, rhs, lower=True, adjoint=adjoint)
Пример #32
0
 def _forward(self, x):
   with ops.control_dependencies(self._assertions(x)):
     shape = array_ops.shape(x)
     return linalg_ops.matrix_triangular_solve(
         x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True)
Пример #33
0
 def loop_fn(i):
   a = array_ops.gather(x, i) if stack_a else x
   b = array_ops.gather(y, i) if stack_b else y
   return linalg_ops.matrix_triangular_solve(a, b,
                                             lower=lower,
                                             adjoint=adjoint)
Пример #34
0
 def _TriangularSolve(x, r):
   """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
   return _linalg.adjoint(
       linalg_ops.matrix_triangular_solve(
           r, _linalg.adjoint(x), lower=False, adjoint=False))
Пример #35
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     return linalg_ops.matrix_triangular_solve(self._tril,
                                               rhs,
                                               lower=True,
                                               adjoint=adjoint)
Пример #36
0
 def loop_fn(i):
     a = array_ops.gather(x, i) if stack_a else x
     b = array_ops.gather(y, i) if stack_b else y
     return linalg_ops.matrix_triangular_solve(
         a, b, lower=lower, adjoint=adjoint)
Пример #37
0
def TriAngInvCompositeGrad(l, grad):
    num_rows = array_ops.shape(l)[-1]
    batch_shape = array_ops.shape(l)[:-2]
    l_inverse = linalg_ops.matrix_triangular_solve(
        l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
    return _GradWithInverseL(l, l_inverse, grad)