def test_matmul(self):
   self._skip_if_tests_to_skip_contains("matmul")
   for use_placeholder in False, True:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         for adjoint in False, True:
           for adjoint_arg in False, True:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   shape, dtype, use_placeholder=use_placeholder)
               x = self._make_x(operator, adjoint=adjoint)
               # If adjoint_arg, compute A X^H^H = A X.
               if adjoint_arg:
                 op_matmul = operator.matmul(
                     linear_operator_util.matrix_adjoint(x),
                     adjoint=adjoint, adjoint_arg=adjoint_arg)
               else:
                 op_matmul = operator.matmul(x, adjoint=adjoint)
               mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(
                     op_matmul.get_shape(), mat_matmul.get_shape())
               op_matmul_v, mat_matmul_v = sess.run(
                   [op_matmul, mat_matmul], feed_dict=feed_dict)
               self.assertAC(op_matmul_v, mat_matmul_v)
示例#2
0
 def testNonBatchMatrix(self):
     a = [[1, 2, 3j], [4, 5, -6j]]  # Shape (2, 3)
     expected = [[1, 4], [2, 5], [-3j, 6j]]  # Shape (3, 2)
     with self.test_session():
         a_adj = linear_operator_util.matrix_adjoint(a)
         self.assertEqual((3, 2), a_adj.get_shape())
         self.assertAllClose(expected, a_adj.eval())
 def test_solve(self):
   self._skip_if_tests_to_skip_contains("solve")
   for use_placeholder in False, True:
     for shape in self._shapes_to_test:
       for dtype in self._dtypes_to_test:
         for adjoint in False, True:
           for adjoint_arg in False, True:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   shape, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(operator, adjoint=adjoint)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linear_operator_util.matrix_adjoint(rhs),
                     adjoint=adjoint, adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linalg_ops.matrix_solve(mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(
                     op_solve.get_shape(), mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve],
                                                  feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     # Note that adjoint has no effect since this matrix is self-adjoint.
     x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, x)
         x = control_flow_ops.with_dependencies([aps], x)
     return self._possibly_broadcast_batch_shape(x)
示例#5
0
 def _assert_self_adjoint(self):
     dense = self._get_cached_dense_matrix()
     logging.warn(
         "Using (possibly slow) default implementation of assert_self_adjoint."
         "  Requires conversion to a dense matrix.")
     return check_ops.assert_equal(
         dense,
         linear_operator_util.matrix_adjoint(dense),
         message="Matrix was not equal to its adjoint.")
示例#6
0
 def testBatchMatrix(self):
     matrix_0 = [[1j, 2, 3], [4, 5, 6]]
     matrix_0_a = [[-1j, 4], [2, 5], [3, 6]]
     matrix_1 = [[11, 22, 33], [44, 55, 66j]]
     matrix_1_a = [[11, 44], [22, 55], [33, -66j]]
     batch_matrix = [matrix_0, matrix_1]  # Shape (2, 2, 3)
     expected_adj = [matrix_0_a, matrix_1_a]  # Shape (2, 3, 2)
     with self.test_session():
         matrix_adj = linear_operator_util.matrix_adjoint(batch_matrix)
         self.assertEqual((2, 3, 2), matrix_adj.get_shape())
         self.assertAllEqual(expected_adj, matrix_adj.eval())
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
     if adjoint:
         matrix = self._multiplier_matrix_conj
     else:
         matrix = self._multiplier_matrix
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, rhs)
         rhs = control_flow_ops.with_dependencies([aps], rhs)
     return rhs / matrix
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x
     if adjoint:
         matrix = self._multiplier_matrix_conj
     else:
         matrix = self._multiplier_matrix
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, x)
         x = control_flow_ops.with_dependencies([aps], x)
     return x * matrix
示例#9
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     """Default implementation of _solve."""
     if self.is_square is False:
         raise NotImplementedError(
             "Solve is not yet implemented for non-square operators.")
     logging.warn(
         "Using (possibly slow) default implementation of solve."
         "  Requires conversion to a dense matrix and O(N^3) operations.")
     rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
     if self._can_use_cholesky():
         return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
     return linalg_ops.matrix_solve(self._get_cached_dense_matrix(),
                                    rhs,
                                    adjoint=adjoint)
示例#10
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
   return linalg_ops.matrix_triangular_solve(
       self._tril, rhs, lower=True, adjoint=adjoint)
示例#11
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   diag_term = math_ops.conj(self._diag) if adjoint else self._diag
   rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
   inv_diag_mat = array_ops.expand_dims(1. / diag_term, -1)
   return rhs * inv_diag_mat
示例#12
0
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   diag_term = math_ops.conj(self._diag) if adjoint else self._diag
   x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x
   diag_mat = array_ops.expand_dims(diag_term, -1)
   return diag_mat * x