def _verifySolve(self, x, y, batch_dims=None):
   for np_type in [np.float32, np.float64, np.complex64, np.complex128]:
     if np_type == np.float32 or np_type == np.complex64:
       tol = 1e-5
     else:
       tol = 1e-12
     for adjoint in False, True:
       if np_type is [np.float32, np.float64]:
         a = x.real().astype(np_type)
         b = y.real().astype(np_type)
       else:
         a = x.astype(np_type)
         b = y.astype(np_type)
         a_np = np.conj(np.transpose(a)) if adjoint else a
       if batch_dims is not None:
         a = np.tile(a, batch_dims + [1, 1])
         a_np = np.tile(a_np, batch_dims + [1, 1])
         b = np.tile(b, batch_dims + [1, 1])
       np_ans = np.linalg.solve(a_np, b)
       for use_placeholder in False, True:
         with self.test_session(use_gpu=True) as sess:
           if use_placeholder:
             a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
             b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
             tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
             out = sess.run(tf_ans, {a_ph: a, b_ph: b})
           else:
             tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
             out = tf_ans.eval()
             self.assertEqual(tf_ans.get_shape(), out.shape)
           self.assertEqual(np_ans.shape, out.shape)
           self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
 def _verifySolve(self, x, y, batch_dims=None):
     for np_type in self.float_types & {np.float32, np.float64}:
         if np_type == np.float32:
             tol = 1e-4
         else:
             tol = 1e-12
         for adjoint in False, True:
             a = x.astype(np_type)
             b = y.astype(np_type)
             a_np = np.transpose(a) if adjoint else a
             if batch_dims is not None:
                 a = np.tile(a, batch_dims + [1, 1])
                 a_np = np.tile(a_np, batch_dims + [1, 1])
                 b = np.tile(b, batch_dims + [1, 1])
             np_ans = np.linalg.solve(a_np, b)
             for use_placeholder in False, True:
                 with self.session() as sess:
                     if use_placeholder:
                         with self.test_scope():
                             a_ph = array_ops.placeholder(
                                 dtypes.as_dtype(np_type))
                             b_ph = array_ops.placeholder(
                                 dtypes.as_dtype(np_type))
                             tf_ans = linalg_ops.matrix_solve(
                                 a_ph, b_ph, adjoint=adjoint)
                         out = sess.run(tf_ans, {a_ph: a, b_ph: b})
                     else:
                         with self.test_scope():
                             tf_ans = linalg_ops.matrix_solve(
                                 a, b, adjoint=adjoint)
                         out = sess.run(tf_ans)
                         self.assertEqual(tf_ans.get_shape(), out.shape)
                     self.assertEqual(np_ans.shape, out.shape)
                     self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
 def testNonSquareMatrix(self):
   # When the solve of a non-square matrix is attempted we should return
   # an error
   with self.test_session():
     with self.assertRaises(ValueError):
       matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
       linalg_ops.matrix_solve(matrix, matrix)
Example #4
0
 def testWrongDimensions(self):
     # The matrix and right-hand sides should have the same number of rows.
     with self.test_session():
         matrix = constant_op.constant([[1., 0.], [0., 1.]])
         rhs = constant_op.constant([[1., 0.]])
         with self.assertRaises(ValueError):
             linalg_ops.matrix_solve(matrix, rhs)
 def testWrongDimensions(self):
   # The matrix and right-hand sides should have the same number of rows.
   with self.test_session():
     matrix = constant_op.constant([[1., 0.], [0., 1.]])
     rhs = constant_op.constant([[1., 0.]])
     with self.assertRaises(ValueError):
       linalg_ops.matrix_solve(matrix, rhs)
 def _verifySolve(self, x, y, batch_dims=None):
   for np_type in [np.float32, np.float64, np.complex64, np.complex128]:
     if np_type == np.float32 or np_type == np.complex64:
       tol = 1e-5
     else:
       tol = 1e-12
     for adjoint in False, True:
       if np_type is [np.float32, np.float64]:
         a = x.real().astype(np_type)
         b = y.real().astype(np_type)
       else:
         a = x.astype(np_type)
         b = y.astype(np_type)
         a_np = np.conj(np.transpose(a)) if adjoint else a
       if batch_dims is not None:
         a = np.tile(a, batch_dims + [1, 1])
         a_np = np.tile(a_np, batch_dims + [1, 1])
         b = np.tile(b, batch_dims + [1, 1])
       np_ans = np.linalg.solve(a_np, b)
       for use_placeholder in False, True:
         with self.test_session(use_gpu=True) as sess:
           if use_placeholder:
             a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
             b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
             tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
             out = sess.run(tf_ans, {a_ph: a, b_ph: b})
           else:
             tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
             out = tf_ans.eval()
             self.assertEqual(tf_ans.get_shape(), out.shape)
           self.assertEqual(np_ans.shape, out.shape)
           self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
Example #7
0
 def testNonSquareMatrix(self):
     # When the solve of a non-square matrix is attempted we should return
     # an error
     with self.test_session():
         with self.assertRaises(ValueError):
             matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
             linalg_ops.matrix_solve(matrix, matrix)
 def testNotInvertible(self):
   # The input should be invertible.
   with self.test_session():
     with self.assertRaisesOpError("Input matrix is not invertible."):
       # All rows of the matrix below add to zero
       matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
                                      [0., -1., 1.]])
       linalg_ops.matrix_solve(matrix, matrix).eval()
Example #9
0
 def testNotInvertible(self):
     # The input should be invertible.
     with self.test_session():
         with self.assertRaisesOpError("Input matrix is not invertible."):
             # All rows of the matrix below add to zero
             matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
                                            [0., -1., 1.]])
             linalg_ops.matrix_solve(matrix, matrix).eval()
 def testConcurrent(self, adjoint):
   with self.session() as sess:
     lhs1 = random_ops.random_normal([3, 3], seed=42)
     lhs2 = random_ops.random_normal([3, 3], seed=42)
     rhs1 = random_ops.random_normal([3, 3], seed=42)
     rhs2 = random_ops.random_normal([3, 3], seed=42)
     with self.test_scope():
       s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint)
       s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint)
     self.assertAllEqual(*sess.run([s1, s2]))
Example #11
0
    def testWrongDimensions(self):
        # The matrix and right-hand sides should have the same number of rows.
        matrix = constant_op.constant([[1., 0.], [0., 1.]])
        rhs = constant_op.constant([[1., 0.]])
        with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
            self.evaluate(linalg_ops.matrix_solve(matrix, rhs))

        # The matrix and right-hand side should have the same batch dimensions
        matrix = np.random.normal(size=(2, 6, 2, 2))
        rhs = np.random.normal(size=(2, 3, 2, 2))
        with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
            self.evaluate(linalg_ops.matrix_solve(matrix, rhs))
Example #12
0
def _EigGrad(op, grad_e, grad_v):
    """Gradient for Eig.

  Based on eq. 4.77 from paper by
  Christoph Boeddeker et al.
  https://arxiv.org/abs/1701.00392
  See also
  "Computation of eigenvalue and eigenvector derivatives
  for a general complex-valued eigensystem" by Nico van der Aa.
  As for now only distinct eigenvalue case is considered.
  """
    e = op.outputs[0]
    compute_v = op.get_attr("compute_v")
    # a = op.inputs[0], which satisfies
    # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
    with ops.control_dependencies([grad_e, grad_v]):
        if compute_v:
            v = op.outputs[1]
            vt = _linalg.adjoint(v)
            # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
            # Notice that because of the term involving f, the gradient becomes
            # infinite (or NaN in practice) when eigenvalues are not unique.
            # Mathematically this should not be surprising, since for (k-fold)
            # degenerate eigenvalues, the corresponding eigenvectors are only defined
            # up to arbitrary rotation in a (k-dimensional) subspace.
            f = array_ops.matrix_set_diag(
                _SafeReciprocal(
                    array_ops.expand_dims(e, -2) -
                    array_ops.expand_dims(e, -1)), array_ops.zeros_like(e))
            f = math_ops.conj(f)
            vgv = math_ops.matmul(vt, grad_v)
            mid = array_ops.matrix_diag(grad_e)
            diag_grad_part = array_ops.matrix_diag(
                array_ops.matrix_diag_part(
                    math_ops.cast(math_ops.real(vgv), vgv.dtype)))
            mid += f * (
                vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
            # vt is formally invertible as long as the original matrix is
            # diagonalizable. However, in practice, vt may
            # be ill-conditioned when matrix original matrix is close to
            # non-diagonalizable one
            grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
        else:
            _, v = linalg_ops.eig(op.inputs[0])
            vt = _linalg.adjoint(v)
            # vt is formally invertible as long as the original matrix is
            # diagonalizable. However, in practice, vt may
            # be ill-conditioned when matrix original matrix is close to
            # non-diagonalizable one
            grad_a = linalg_ops.matrix_solve(
                vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
        return math_ops.cast(grad_a, op.inputs[0].dtype)
 def testConcurrent(self):
   with self.session(use_gpu=True) as sess:
     all_ops = []
     for adjoint_ in False, True:
       lhs1 = random_ops.random_normal([3, 3], seed=42)
       lhs2 = random_ops.random_normal([3, 3], seed=42)
       rhs1 = random_ops.random_normal([3, 3], seed=42)
       rhs2 = random_ops.random_normal([3, 3], seed=42)
       s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
       s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
       all_ops += [s1, s2]
     val = sess.run(all_ops)
     self.assertAllEqual(val[0], val[1])
     self.assertAllEqual(val[2], val[3])
 def testConcurrent(self):
   with self.test_session(use_gpu=True) as sess:
     all_ops = []
     for adjoint_ in False, True:
       lhs1 = random_ops.random_normal([3, 3], seed=42)
       lhs2 = random_ops.random_normal([3, 3], seed=42)
       rhs1 = random_ops.random_normal([3, 3], seed=42)
       rhs2 = random_ops.random_normal([3, 3], seed=42)
       s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
       s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
       all_ops += [s1, s2]
     val = sess.run(all_ops)
     self.assertAllEqual(val[0], val[1])
     self.assertAllEqual(val[2], val[3])
Example #15
0
    def benchmarkMatrixSolveOp(self):
        run_gpu_test = test.is_gpu_available(True)
        for adjoint in False, True:
            for matrix_shape in self.matrix_shapes:
                for num_rhs in 1, 2, matrix_shape[-1]:

                    with ops.Graph().as_default(), \
                        session.Session(config=benchmark.benchmark_config()) as sess, \
                        ops.device("/cpu:0"):
                        matrix, rhs = self._GenerateTestData(
                            matrix_shape, num_rhs)
                        x = linalg_ops.matrix_solve(matrix,
                                                    rhs,
                                                    adjoint=adjoint)
                        self.evaluate(variables.global_variables_initializer())
                        self.run_op_benchmark(
                            sess,
                            control_flow_ops.group(x),
                            min_iters=25,
                            store_memory_usage=False,
                            name=
                            ("matrix_solve_cpu_shape_{matrix_shape}_num_rhs_{num_rhs}_"
                             "adjoint_{adjoint}").format(
                                 matrix_shape=matrix_shape,
                                 num_rhs=num_rhs,
                                 adjoint=adjoint))

                    if run_gpu_test:
                        with ops.Graph().as_default(), \
                            session.Session(config=benchmark.benchmark_config()) as sess, \
                            ops.device("/gpu:0"):
                            matrix, rhs = self._GenerateTestData(
                                matrix_shape, num_rhs)
                            x = linalg_ops.matrix_solve(matrix,
                                                        rhs,
                                                        adjoint=adjoint)
                            self.evaluate(
                                variables.global_variables_initializer())
                            self.run_op_benchmark(
                                sess,
                                control_flow_ops.group(x),
                                min_iters=25,
                                store_memory_usage=False,
                                name=
                                ("matrix_solve_gpu_shape_{matrix_shape}_num_rhs_"
                                 "{num_rhs}_adjoint_{adjoint}").format(
                                     matrix_shape=matrix_shape,
                                     num_rhs=num_rhs,
                                     adjoint=adjoint))
  def test_broadcast_matmul_and_solve(self):
    # These cannot be done in the automated (base test class) tests since they
    # test shapes that tf.matmul cannot handle.
    # In particular, tf.matmul does not broadcast.
    with self.test_session() as sess:
      x = random_ops.random_normal(shape=(2, 2, 3, 4))

      # This LinearOperatorDiag will be broadcast to (2, 2, 3, 3) during solve
      # and matmul with 'x' as the argument.
      diag = random_ops.random_uniform(shape=(2, 1, 3))
      operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
      self.assertAllEqual((2, 1, 3, 3), operator.shape)

      # Create a batch matrix with the broadcast shape of operator.
      diag_broadcast = array_ops.concat((diag, diag), 1)
      mat = array_ops.matrix_diag(diag_broadcast)
      self.assertAllEqual((2, 2, 3, 3), mat.get_shape())  # being pedantic.

      operator_matmul = operator.matmul(x)
      mat_matmul = math_ops.matmul(mat, x)
      self.assertAllEqual(operator_matmul.get_shape(), mat_matmul.get_shape())
      self.assertAllClose(*sess.run([operator_matmul, mat_matmul]))

      operator_solve = operator.solve(x)
      mat_solve = linalg_ops.matrix_solve(mat, x)
      self.assertAllEqual(operator_solve.get_shape(), mat_solve.get_shape())
      self.assertAllClose(*sess.run([operator_solve, mat_solve]))
Example #17
0
def _MatrixSolveGrad(op, grad):
  """Gradients for MatrixSolve."""
  a = op.inputs[0]
  c = op.outputs[0]
  grad_b = linalg_ops.matrix_solve(a, grad, adjoint=True)
  grad_a = -math_ops.matmul(grad_b, c, transpose_b=True)
  return (grad_a, grad_b)
  def _verifySolve(self, x, y, batch_dims=None):
    for adjoint in False, True:
      for np_type in [np.float32, np.float64, np.complex64, np.complex128]:
        if np_type is [np.float32, np.float64]:
          a = x.real().astype(np_type)
          b = y.real().astype(np_type)
        else:
          a = x.astype(np_type)
          b = y.astype(np_type)
        if adjoint:
          a_np = np.conj(np.transpose(a))
        else:
          a_np = a
        if batch_dims is not None:
          a = np.tile(a, batch_dims + [1, 1])
          a_np = np.tile(a_np, batch_dims + [1, 1])
          b = np.tile(b, batch_dims + [1, 1])

        np_ans = np.linalg.solve(a_np, b)
        with self.test_session():
          tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
          out = tf_ans.eval()
          self.assertEqual(tf_ans.get_shape(), out.shape)
          self.assertEqual(np_ans.shape, out.shape)
          self.assertAllClose(np_ans, out)
Example #19
0
    def sqrt_solve(self, x):
        """Computes `solve(self, x)`.

    Doesn't actually do the sqrt! Named as such to agree with API.

    To compute (M + V D V.T), we use the Woodbury matrix identity:
      inv(M + V D V.T) = inv(M) - inv(M) V inv(C) V.T inv(M)
    where,
      C = inv(D) + V.T inv(M) V.
    See: https://en.wikipedia.org/wiki/Woodbury_matrix_identity

    Args:
      x: `Tensor`

    Returns:
      inv_of_self_times_x: `Tensor`
    """
        minv_x = linalg_ops.matrix_triangular_solve(self._m, x)
        vt_minv_x = math_ops.matmul(self._v, minv_x, transpose_a=True)
        cinv_vt_minv_x = linalg_ops.matrix_solve(
            self._woodbury_sandwiched_term(), vt_minv_x)
        v_cinv_vt_minv_x = math_ops.matmul(self._v, cinv_vt_minv_x)
        minv_v_cinv_vt_minv_x = linalg_ops.matrix_triangular_solve(
            self._m, v_cinv_vt_minv_x)
        return minv_x - minv_v_cinv_vt_minv_x
  def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self):
    # Since the second arg has extra dims, and the domain dim of the first arg
    # is larger than the number of linear equations, code will "flip" the extra
    # dims of the first arg to the far right, making extra linear equations
    # (then call the matrix function, then flip back).
    # We have verified that this optimization indeed happens.  How? We stepped
    # through with a debugger.
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 2)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    matrix_ph = array_ops.placeholder(dtypes.float64, shape=[None, None])
    rhs_ph = array_ops.placeholder(dtypes.float64, shape=[None, None, None])

    with self.cached_session():
      result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph,
                                                                rhs_ph)
      self.assertAllEqual(3, result.shape.ndims)
      expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
      self.assertAllClose(
          expected.eval(),
          result.eval(feed_dict={
              matrix_ph: matrix,
              rhs_ph: rhs
          }))
    def test_broadcast_matmul_and_solve(self):
        # These cannot be done in the automated (base test class) tests since they
        # test shapes that tf.matmul cannot handle.
        # In particular, tf.matmul does not broadcast.
        with self.cached_session() as sess:
            x = random_ops.random_normal(shape=(2, 2, 3, 4))

            # This LinearOperatorDiag will be broadcast to (2, 2, 3, 3) during solve
            # and matmul with 'x' as the argument.
            diag = random_ops.random_uniform(shape=(2, 1, 3))
            operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
            self.assertAllEqual((2, 1, 3, 3), operator.shape)

            # Create a batch matrix with the broadcast shape of operator.
            diag_broadcast = array_ops.concat((diag, diag), 1)
            mat = array_ops.matrix_diag(diag_broadcast)
            self.assertAllEqual((2, 2, 3, 3), mat.shape)  # being pedantic.

            operator_matmul = operator.matmul(x)
            mat_matmul = math_ops.matmul(mat, x)
            self.assertAllEqual(operator_matmul.shape, mat_matmul.shape)
            self.assertAllClose(*self.evaluate([operator_matmul, mat_matmul]))

            operator_solve = operator.solve(x)
            mat_solve = linalg_ops.matrix_solve(mat, x)
            self.assertAllEqual(operator_solve.shape, mat_solve.shape)
            self.assertAllClose(*self.evaluate([operator_solve, mat_solve]))
Example #22
0
    def _verifySolve(self, x, y, batch_dims=None):
        for adjoint in False, True:
            for np_type in [
                    np.float32, np.float64, np.complex64, np.complex128
            ]:
                if np_type is [np.float32, np.float64]:
                    a = x.real().astype(np_type)
                    b = y.real().astype(np_type)
                else:
                    a = x.astype(np_type)
                    b = y.astype(np_type)
                if adjoint:
                    a_np = np.conj(np.transpose(a))
                else:
                    a_np = a
                if batch_dims is not None:
                    a = np.tile(a, batch_dims + [1, 1])
                    a_np = np.tile(a_np, batch_dims + [1, 1])
                    b = np.tile(b, batch_dims + [1, 1])

                np_ans = np.linalg.solve(a_np, b)
                with self.test_session():
                    tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
                    out = tf_ans.eval()
                    self.assertEqual(tf_ans.get_shape(), out.shape)
                    self.assertEqual(np_ans.shape, out.shape)
                    self.assertAllClose(np_ans, out)
    def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self):
        # Since the second arg has extra dims, and the domain dim of the first arg
        # is larger than the number of linear equations, code will "flip" the extra
        # dims of the first arg to the far right, making extra linear equations
        # (then call the matrix function, then flip back).
        # We have verified that this optimization indeed happens.  How? We stepped
        # through with a debugger.
        # batch_shape = [2]
        matrix = rng.rand(3, 3)
        rhs = rng.rand(2, 3, 2)
        matrix_broadcast = matrix + np.zeros((2, 1, 1))

        matrix_ph = array_ops.placeholder(dtypes.float64, shape=[None, None])
        rhs_ph = array_ops.placeholder(dtypes.float64,
                                       shape=[None, None, None])

        with self.cached_session():
            result = linear_operator_util.matrix_solve_with_broadcast(
                matrix_ph, rhs_ph)
            self.assertAllEqual(3, result.shape.ndims)
            expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
            self.assertAllClose(
                self.evaluate(expected),
                result.eval(feed_dict={
                    matrix_ph: matrix,
                    rhs_ph: rhs
                }))
 def test_solve(self):
   self._skip_if_tests_to_skip_contains("solve")
   for use_placeholder in False, True:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         for adjoint in False, True:
           for adjoint_arg in False, True:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   shape, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(operator, adjoint=adjoint)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linear_operator_util.matrix_adjoint(rhs),
                     adjoint=adjoint, adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(
                     op_solve.get_shape(), mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve],
                                                  feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
 def test_solve(self):
   self._skip_if_tests_to_skip_contains("solve")
   for use_placeholder in self._use_placeholder_options:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   shape, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(operator, adjoint=adjoint)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linalg.adjoint(rhs),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_solve.get_shape(),
                                     mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run(
                   [op_solve, mat_solve], feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
Example #26
0
  def sqrt_solve(self, x):
    """Computes `solve(self, x)`.

    Doesn't actually do the sqrt! Named as such to agree with API.

    To compute (M + V D V.T), we use the the Woodbury matrix identity:
      inv(M + V D V.T) = inv(M) - inv(M) V inv(C) V.T inv(M)
    where,
      C = inv(D) + V.T inv(M) V.
    See: https://en.wikipedia.org/wiki/Woodbury_matrix_identity

    Args:
      x: `Tensor`

    Returns:
      inv_of_self_times_x: `Tensor`
    """
    minv_x = linalg_ops.matrix_triangular_solve(self._m, x)
    vt_minv_x = math_ops.matmul(self._v, minv_x, transpose_a=True)
    cinv_vt_minv_x = linalg_ops.matrix_solve(
        self._woodbury_sandwiched_term(), vt_minv_x)
    v_cinv_vt_minv_x = math_ops.matmul(self._v, cinv_vt_minv_x)
    minv_v_cinv_vt_minv_x = linalg_ops.matrix_triangular_solve(
        self._m, v_cinv_vt_minv_x)
    return minv_x - minv_v_cinv_vt_minv_x
 def testBatchResultSize(self):
   # 3x3x3 matrices, 3x3x1 right-hand sides.
   matrix = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9.] * 3).reshape(3, 3, 3)
   rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1)
   answer = linalg_ops.matrix_solve(matrix, rhs)
   ls_answer = linalg_ops.matrix_solve_ls(matrix, rhs)
   self.assertEqual(ls_answer.get_shape(), [3, 3, 1])
   self.assertEqual(answer.get_shape(), [3, 3, 1])
Example #28
0
 def testBatchResultSize(self):
   # 3x3x3 matrices, 3x3x1 right-hand sides.
   matrix = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9.] * 3).reshape(3, 3, 3)
   rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1)
   answer = linalg_ops.matrix_solve(matrix, rhs)
   ls_answer = linalg_ops.matrix_solve_ls(matrix, rhs)
   self.assertEqual(ls_answer.get_shape(), [3, 3, 1])
   self.assertEqual(answer.get_shape(), [3, 3, 1])
 def testBatchResultSize(self):
   # 3x3x3 matrices, 3x3x1 right-hand sides.
   matrix = np.array([1., 0., 0., 0., 1., 0., 0., 0., 1.] * 3).reshape(3, 3, 3)  # pylint: disable=too-many-function-args
   rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1)  # pylint: disable=too-many-function-args
   answer = linalg_ops.matrix_solve(matrix, rhs)
   ls_answer = linalg_ops.matrix_solve_ls(matrix, rhs)
   self.assertEqual(ls_answer.get_shape(), [3, 3, 1])
   self.assertEqual(answer.get_shape(), [3, 3, 1])
Example #30
0
def _MatrixSolveGrad(op, grad):
  """Gradients for MatrixSolve."""
  a = op.inputs[0]
  c = op.outputs[0]
  # TODO(rmlarsen): Get rid of explicit transpose after adding
  # adjoint_a attribute to solver.
  grad_b = linalg_ops.matrix_solve(array_ops.transpose(a), grad)
  grad_a = -math_ops.matmul(grad_b, c, transpose_b=True)
  return (grad_a, grad_b)
Example #31
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   if self.is_square is False:
     raise NotImplementedError(
         "Solve is not yet implemented for non-square operators.")
   rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
   if self._can_use_cholesky():
     return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
   return linalg_ops.matrix_solve(
       self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
Example #32
0
def _MatrixSolveGrad(op, grad):
    """Gradients for MatrixSolve."""
    a = op.inputs[0]
    c = op.outputs[0]
    # TODO(rmlarsen): Get rid of explicit transpose after adding
    # adjoint_a attribute to solver.
    grad_b = linalg_ops.matrix_solve(array_ops.transpose(a), grad)
    grad_a = -math_ops.matmul(grad_b, c, transpose_b=True)
    return (grad_a, grad_b)
Example #33
0
 def testConcurrent(self):
     seed = [42, 24]
     matrix_shape = [3, 3]
     all_ops = []
     for adjoint_ in False, True:
         lhs1 = stateless_random_ops.stateless_random_normal(matrix_shape,
                                                             seed=seed)
         lhs2 = stateless_random_ops.stateless_random_normal(matrix_shape,
                                                             seed=seed)
         rhs1 = stateless_random_ops.stateless_random_normal(matrix_shape,
                                                             seed=seed)
         rhs2 = stateless_random_ops.stateless_random_normal(matrix_shape,
                                                             seed=seed)
         s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
         s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
         all_ops += [s1, s2]
     val = self.evaluate(all_ops)
     for i in range(0, len(all_ops), 2):
         self.assertAllEqual(val[i], val[i + 1])
    def test_static_dims_broadcast_matrix_has_extra_dims(self):
        # batch_shape = [2]
        matrix = rng.rand(2, 3, 3)
        rhs = rng.rand(3, 7)
        rhs_broadcast = rhs + np.zeros((2, 1, 1))

        result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
        self.assertAllEqual((2, 3, 7), result.shape)
        expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
        self.assertAllClose(*self.evaluate([expected, result]))
  def test_static_dims_broadcast(self):
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 7)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
      self.assertAllEqual((2, 3, 7), result.get_shape())
      expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
      self.assertAllEqual(expected.eval(), result.eval())
Example #36
0
def _MatrixSolveGrad(op, grad):
    """Gradient for MatrixSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    c = op.outputs[0]
    grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
    else:
        grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
    return (grad_a, grad_b)
Example #37
0
def _MatrixSolveGrad(op, grad):
    """Gradient for MatrixSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    c = op.outputs[0]
    grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
    else:
        grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
    return (grad_a, grad_b)
Example #38
0
def _MatrixSolveGrad(op, grad):
  """Gradient for MatrixSolve."""
  a = op.inputs[0]
  adjoint_a = op.get_attr("adjoint")
  c = op.outputs[0]
  grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
  if adjoint_a:
    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
  else:
    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
  return (grad_a, grad_b)
Example #39
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   """Default implementation of _solve."""
   if self.is_square is False:
     raise NotImplementedError(
         "Solve is not yet implemented for non-square operators.")
   logging.warn(
       "Using (possibly slow) default implementation of solve."
       "  Requires conversion to a dense matrix and O(N^3) operations.")
   rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
   if self._can_use_cholesky():
     return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
   return linalg_ops.matrix_solve(
       self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
  def benchmarkMatrixSolveOp(self):
    run_gpu_test = test.is_gpu_available(True)
    for adjoint in False, True:
      for matrix_shape in self.matrix_shapes:
        for num_rhs in 1, 2, matrix_shape[-1]:

          with ops.Graph().as_default(), \
              session.Session() as sess, \
              ops.device("/cpu:0"):
            matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
            x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
            variables.global_variables_initializer().run()
            self.run_op_benchmark(
                sess,
                control_flow_ops.group(x),
                min_iters=25,
                store_memory_usage=False,
                name=("matrix_solve_cpu_shape_{matrix_shape}_num_rhs_{num_rhs}_"
                      "adjoint_{adjoint}").format(
                          matrix_shape=matrix_shape,
                          num_rhs=num_rhs,
                          adjoint=adjoint))

          if run_gpu_test:
            with ops.Graph().as_default(), \
                session.Session() as sess, \
                ops.device("/gpu:0"):
              matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
              x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
              variables.global_variables_initializer().run()
              self.run_op_benchmark(
                  sess,
                  control_flow_ops.group(x),
                  min_iters=25,
                  store_memory_usage=False,
                  name=("matrix_solve_gpu_shape_{matrix_shape}_num_rhs_"
                        "{num_rhs}_adjoint_{adjoint}").format(
                            matrix_shape=matrix_shape, num_rhs=num_rhs,
                            adjoint=adjoint))
Example #41
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     """Default implementation of _solve."""
     if self.is_square is False:
         raise NotImplementedError(
             "Solve is not yet implemented for non-square operators.")
     logging.warn(
         "Using (possibly slow) default implementation of solve."
         "  Requires conversion to a dense matrix and O(N^3) operations.")
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._can_use_cholesky():
         return linalg_ops.cholesky_solve(
             linalg_ops.cholesky(self.to_dense()), rhs)
     return linalg_ops.matrix_solve(self.to_dense(), rhs, adjoint=adjoint)
 def _verifySolve(self, x, y, adjoint):
   for np_type in self.float_types & {np.float32, np.float64}:
     tol = 1e-4 if np_type == np.float32 else 1e-12
     a = x.astype(np_type)
     b = y.astype(np_type)
     np_ans = np.linalg.solve(np.swapaxes(a, -2, -1) if adjoint else a, b)
     with self.session() as sess:
       with self.test_scope():
         tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
       out = sess.run(tf_ans)
       self.assertEqual(tf_ans.shape, out.shape)
       self.assertEqual(np_ans.shape, out.shape)
       self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
  def testSolve(self):
    with self.test_session():
      for batch_shape in [(), (
          2,
          3,)]:
        for k in [1, 4]:
          operator, mat = self._build_operator_and_mat(batch_shape, k)

          # Work with 5 simultaneous systems.  5 is arbitrary.
          x = self._rng.randn(*(batch_shape + (k, 5)))

          self._compare_results(
              expected=linalg_ops.matrix_solve(mat, x).eval(),
              actual=operator.solve(x))
Example #44
0
  def testSolve(self):
    with self.test_session():
      for batch_shape in [(), (
          2,
          3,)]:
        for k in [1, 4]:
          operator, mat = self._build_operator_and_mat(batch_shape, k)

          # Work with 5 simultaneous systems.  5 is arbitrary.
          x = self._rng.randn(*(batch_shape + (k, 5)))

          self._compare_results(
              expected=linalg_ops.matrix_solve(mat, x).eval(),
              actual=operator.solve(x))
Example #45
0
  def test_dynamic_dims_broadcast_64bit(self):
    # batch_shape = [2, 2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(2, 1, 3, 7)
    matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
    rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))

    matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
    rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)

    result, expected = self.evaluate([
        linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph),
        linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
    ])
    self.assertAllClose(expected, result)
    def test_static_dims_broadcast_rhs_has_extra_dims(self):
        # Since the second arg has extra dims, and the domain dim of the first arg
        # is larger than the number of linear equations, code will "flip" the extra
        # dims of the first arg to the far right, making extra linear equations
        # (then call the matrix function, then flip back).
        # We have verified that this optimization indeed happens.  How? We stepped
        # through with a debugger.
        # batch_shape = [2]
        matrix = rng.rand(3, 3)
        rhs = rng.rand(2, 3, 2)
        matrix_broadcast = matrix + np.zeros((2, 1, 1))

        result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
        self.assertAllEqual((2, 3, 2), result.shape)
        expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
        self.assertAllClose(*self.evaluate([expected, result]))
def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
  """Solve systems of linear equations."""
  with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
    matrix = ops.convert_to_tensor(matrix, name="matrix")
    rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)

    # If either matrix/rhs has extra dims, we can reshape to get rid of them.
    matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
        matrix, rhs, adjoint_a=adjoint)

    # This will broadcast by brute force if we still need to.
    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])

    solution = linalg_ops.matrix_solve(
        matrix, rhs, adjoint=adjoint and still_need_to_transpose)

    return reshape_inv(solution)
 def test_solve(self):
   self._maybe_skip("solve")
   with self.test_session() as sess:
     for use_placeholder in False, True:
       for shape in self._shapes_to_test:
         for dtype in self._dtypes_to_test:
           for adjoint in False, True:
             operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                 shape, dtype, use_placeholder=use_placeholder)
             rhs = self._make_rhs(operator, adjoint=adjoint)
             op_solve = operator.solve(rhs, adjoint=adjoint)
             mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
             if not use_placeholder:
               self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape())
             op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve],
                                                feed_dict=feed_dict)
             self.assertAC(op_solve_v, mat_solve_v)
  def test_static_dims_broadcast_rhs_has_extra_dims(self):
    # Since the second arg has extra dims, and the domain dim of the first arg
    # is larger than the number of linear equations, code will "flip" the extra
    # dims of the first arg to the far right, making extra linear equations
    # (then call the matrix function, then flip back).
    # We have verified that this optimization indeed happens.  How? We stepped
    # through with a debugger.
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 2)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
      self.assertAllEqual((2, 3, 2), result.get_shape())
      expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
      self.assertAllClose(expected.eval(), self.evaluate(result))
  def testSqrtSolve(self):
    # Square roots are not unique, but we should still have
    # S^{-T} S^{-1} x = A^{-1} x.
    # In our case, we should have S = S^T, so then S^{-1} S^{-1} x = A^{-1} x.
    with self.test_session():
      for batch_shape in [(), (
          2,
          3,)]:
        for k in [1, 4]:
          operator, mat = self._build_operator_and_mat(batch_shape, k)

          # Work with 5 simultaneous systems.  5 is arbitrary.
          x = self._rng.randn(*(batch_shape + (k, 5)))

          self._compare_results(
              expected=linalg_ops.matrix_solve(mat, x).eval(),
              actual=operator.sqrt_solve(operator.sqrt_solve(x)))
Example #51
0
  def testMultiplyInverse(self):
    with ops.Graph().as_default(), self.test_session() as sess:
      random_seed.set_random_seed(200)

      # Create a Fisher Block.
      vocab_size = 5
      block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)

      # Add some examples.
      inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
      outputs = array_ops.constant([[0.], [1.], [2.]])
      block.register_additional_minibatch(inputs, outputs)

      # Instantiate factor's variables. Ensure it doesn't fail.
      grads = outputs**2.
      damping = array_ops.constant(0.)
      block.instantiate_factors(((grads,),), damping)
      block._input_factor.instantiate_cov_variables()
      block._output_factor.instantiate_cov_variables()
      block.register_inverse()
      block._input_factor.instantiate_inv_variables()
      block._output_factor.instantiate_inv_variables()

      # Create a sparse update.
      indices = array_ops.constant([1, 3, 4])
      values = array_ops.constant([[1.], [1.], [1.]])
      sparse_vector = ops.IndexedSlices(
          values, indices, dense_shape=[vocab_size, 1])
      dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])

      # Compare Fisher-vector product against explicit result.
      result = block.multiply_inverse(sparse_vector)
      expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
                                                dense_vector)

      sess.run(tf_variables.global_variables_initializer())
      self.assertAlmostEqual(
          sess.run(expected_result[1]), sess.run(result.values[0]))
      self.assertAlmostEqual(
          sess.run(expected_result[3]), sess.run(result.values[1]))
      self.assertAlmostEqual(
          sess.run(expected_result[4]), sess.run(result.values[2]))
  def _solve(self, rhs, adjoint=False):
    if self.base_operator.is_non_singular is False:
      raise ValueError(
          "Solve not implemented unless this is a perturbation of a "
          "non-singular LinearOperator.")
    # The Woodbury formula gives:
    # https://en.wikipedia.org/wiki/Woodbury_matrix_identity
    #   (L + UDV^H)^{-1}
    #   = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1}
    #   = L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
    # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U
    # Note also that, with ^{-H} being the inverse of the adjoint,
    #   (L + UDV^H)^{-H}
    #   = L^{-H} - L^{-H} V C^{-H} U^H L^{-H}
    l = self.base_operator
    if adjoint:
      v = self.u
      u = self.v
    else:
      v = self.v
      u = self.u

    # L^{-1} rhs
    linv_rhs = l.solve(rhs, adjoint=adjoint)
    # V^H L^{-1} rhs
    vh_linv_rhs = math_ops.matmul(v, linv_rhs, adjoint_a=True)
    # C^{-1} V^H L^{-1} rhs
    if self._use_cholesky:
      capinv_vh_linv_rhs = linalg_ops.cholesky_solve(
          self._chol_capacitance, vh_linv_rhs)
    else:
      capinv_vh_linv_rhs = linalg_ops.matrix_solve(
          self._capacitance, vh_linv_rhs, adjoint=adjoint)
    # U C^{-1} V^H M^{-1} rhs
    u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs)
    # L^{-1} U C^{-1} V^H L^{-1} rhs
    linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs, adjoint=adjoint)

    # L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
    return linv_rhs - linv_u_capinv_vh_linv_rhs
  def test_dynamic_dims_broadcast_64bit(self):
    # batch_shape = [2, 2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(2, 1, 3, 7)
    matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
    rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))

    matrix_ph = array_ops.placeholder(dtypes.float64)
    rhs_ph = array_ops.placeholder(dtypes.float64)

    with self.cached_session() as sess:
      result, expected = sess.run(
          [
              linear_operator_util.matrix_solve_with_broadcast(
                  matrix_ph, rhs_ph),
              linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
          ],
          feed_dict={
              matrix_ph: matrix,
              rhs_ph: rhs,
          })
      self.assertAllEqual(expected, result)
 def _solve(self, rhs, adjoint=False):
   if self._is_spd:
     return linalg_ops.cholesky_solve(self._chol, rhs)
   return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = ops.convert_to_tensor(rhs, name="rhs")
   assert not adjoint_arg, "Not implemented for this test class."
   return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint)
Example #56
0
def matrix_exponential(input, name=None):  # pylint: disable=redefined-builtin
  r"""Computes the matrix exponential of one or more square matrices.

  exp(A) = \sum_{n=0}^\infty A^n/n!

  The exponential is computed using a combination of the scaling and squaring
  method and the Pade approximation. Details can be found in:
  Nicholas J. Higham, "The scaling and squaring method for the matrix
  exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.

  The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
  form square matrices. The output is a tensor of the same shape as the input
  containing the exponential for all input submatrices `[..., :, :]`.

  Args:
    input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
      or `complex128` with shape `[..., M, M]`.
    name:  A name to give this `Op` (optional).

  Returns:
    the matrix exponential of the input.

  Raises:
    ValueError: An unsupported type is provided as input.

  @compatibility(scipy)
  Equivalent to scipy.linalg.expm
  @end_compatibility
  """
  with ops.name_scope(name, 'matrix_exponential', [input]):
    matrix = ops.convert_to_tensor(input, name='input')
    if matrix.shape[-2:] == [0, 0]:
      return matrix
    batch_shape = matrix.shape[:-2]
    if not batch_shape.is_fully_defined():
      batch_shape = array_ops.shape(matrix)[:-2]

    # reshaping the batch makes the where statements work better
    matrix = array_ops.reshape(
        matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0))
    l1_norm = math_ops.reduce_max(
        math_ops.reduce_sum(math_ops.abs(matrix),
                            axis=array_ops.size(array_ops.shape(matrix)) - 2),
        axis=-1)
    const = lambda x: constant_op.constant(x, l1_norm.dtype)
    def _nest_where(vals, cases):
      assert len(vals) == len(cases) - 1
      if len(vals) == 1:
        return array_ops.where(
            math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
      else:
        return array_ops.where(
            math_ops.less(l1_norm, const(vals[0])), cases[0],
            _nest_where(vals[1:], cases[1:]))

    if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]:
      maxnorm = const(3.925724783138660)
      squarings = math_ops.maximum(
          math_ops.floor(
              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
      u3, v3 = _matrix_exp_pade3(matrix)
      u5, v5 = _matrix_exp_pade5(matrix)
      u7, v7 = _matrix_exp_pade7(
          matrix / math_ops.pow(
              constant_op.constant(2.0, dtype=matrix.dtype),
              math_ops.cast(squarings, matrix.dtype))[...,
                                                      array_ops.newaxis,
                                                      array_ops.newaxis])
      conds = (4.258730016922831e-001, 1.880152677804762e+000)
      u = _nest_where(conds, (u3, u5, u7))
      v = _nest_where(conds, (v3, v5, v7))
    elif matrix.dtype in [dtypes.float64, dtypes.complex128]:
      maxnorm = const(5.371920351148152)
      squarings = math_ops.maximum(
          math_ops.floor(
              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
      u3, v3 = _matrix_exp_pade3(matrix)
      u5, v5 = _matrix_exp_pade5(matrix)
      u7, v7 = _matrix_exp_pade7(matrix)
      u9, v9 = _matrix_exp_pade9(matrix)
      u13, v13 = _matrix_exp_pade13(
          matrix / math_ops.pow(
              constant_op.constant(2.0, dtype=matrix.dtype),
              math_ops.cast(squarings, matrix.dtype))[...,
                                                      array_ops.newaxis,
                                                      array_ops.newaxis])
      conds = (1.495585217958292e-002,
               2.539398330063230e-001,
               9.504178996162932e-001,
               2.097847961257068e+000)
      u = _nest_where(conds, (u3, u5, u7, u9, u13))
      v = _nest_where(conds, (v3, v5, v7, v9, v13))
    else:
      raise ValueError(
          'tf.linalg.expm does not support matrices of type %s' % matrix.dtype)
    numer = u + v
    denom = -u + v
    result = linalg_ops.matrix_solve(denom, numer)
    max_squarings = math_ops.reduce_max(squarings)

    i = const(0.0)
    c = lambda i, r: math_ops.less(i, max_squarings)
    def b(i, r):
      return i+1, array_ops.where(math_ops.less(i, squarings),
                                  math_ops.matmul(r, r), r)
    _, result = control_flow_ops.while_loop(c, b, [i, result])
    if not matrix.shape.is_fully_defined():
      return array_ops.reshape(
          result,
          array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0))
    return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
   if self._is_spd:
     return linalg_ops.cholesky_solve(self._chol, rhs)
   return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint)