コード例 #1
0
  def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self):
    # Since the second arg has extra dims, and the domain dim of the first arg
    # is larger than the number of linear equations, code will "flip" the extra
    # dims of the first arg to the far right, making extra linear equations
    # (then call the matrix function, then flip back).
    # We have verified that this optimization indeed happens.  How? We stepped
    # through with a debugger.
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 2)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    matrix_ph = array_ops.placeholder(dtypes.float64, shape=[None, None])
    rhs_ph = array_ops.placeholder(dtypes.float64, shape=[None, None, None])

    with self.cached_session():
      result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph,
                                                                rhs_ph)
      self.assertAllEqual(3, result.shape.ndims)
      expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
      self.assertAllClose(
          expected.eval(),
          result.eval(feed_dict={
              matrix_ph: matrix,
              rhs_ph: rhs
          }))
コード例 #2
0
def _test_solve_base(self, use_placeholder, shapes_info, dtype, adjoint,
                     adjoint_arg, with_batch):
    # If batch dimensions are omitted, but there are
    # no batch dimensions for the linear operator, then
    # skip the test case. This is already checked with
    # with_batch=True.
    if not with_batch and len(shapes_info.shape) <= 2:
        return
    with self.session(graph=ops.Graph()) as sess:
        sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
        operator, mat = self.operator_and_matrix(
            shapes_info, dtype, use_placeholder=use_placeholder)
        rhs = self.make_rhs(operator, adjoint=adjoint, with_batch=with_batch)
        # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
        if adjoint_arg:
            op_solve = operator.solve(linalg.adjoint(rhs),
                                      adjoint=adjoint,
                                      adjoint_arg=adjoint_arg)
        else:
            op_solve = operator.solve(rhs,
                                      adjoint=adjoint,
                                      adjoint_arg=adjoint_arg)
        mat_solve = linear_operator_util.matrix_solve_with_broadcast(
            mat, rhs, adjoint=adjoint)
        if not use_placeholder:
            self.assertAllEqual(op_solve.shape, mat_solve.shape)
        op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
        self.assertAC(op_solve_v, mat_solve_v)
コード例 #3
0
 def _test_solve(self, with_batch):
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   build_info, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(
                   operator, adjoint=adjoint, with_batch=with_batch)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linalg.adjoint(rhs),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linear_operator_util.matrix_solve_with_broadcast(
                   mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_solve.get_shape(),
                                     mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run(
                   [op_solve, mat_solve], feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
コード例 #4
0
    def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self):
        # Since the second arg has extra dims, and the domain dim of the first arg
        # is larger than the number of linear equations, code will "flip" the extra
        # dims of the first arg to the far right, making extra linear equations
        # (then call the matrix function, then flip back).
        # We have verified that this optimization indeed happens.  How? We stepped
        # through with a debugger.
        # batch_shape = [2]
        matrix = rng.rand(3, 3)
        rhs = rng.rand(2, 3, 2)
        matrix_broadcast = matrix + np.zeros((2, 1, 1))

        matrix_ph = array_ops.placeholder(dtypes.float64, shape=[None, None])
        rhs_ph = array_ops.placeholder(dtypes.float64,
                                       shape=[None, None, None])

        with self.cached_session():
            result = linear_operator_util.matrix_solve_with_broadcast(
                matrix_ph, rhs_ph)
            self.assertAllEqual(3, result.shape.ndims)
            expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
            self.assertAllClose(
                self.evaluate(expected),
                result.eval(feed_dict={
                    matrix_ph: matrix,
                    rhs_ph: rhs
                }))
コード例 #5
0
 def _test_solve(self, with_batch):
   for use_placeholder in self._use_placeholder_options:
     for build_info in self._operator_build_infos:
       # If batch dimensions are omitted, but there are
       # no batch dimensions for the linear operator, then
       # skip the test case. This is already checked with
       # with_batch=True.
       if not with_batch and len(build_info.shape) <= 2:
         continue
       for dtype in self._dtypes_to_test:
         for adjoint in self._adjoint_options:
           for adjoint_arg in self._adjoint_arg_options:
             with self.test_session(graph=ops.Graph()) as sess:
               sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
               operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                   build_info, dtype, use_placeholder=use_placeholder)
               rhs = self._make_rhs(
                   operator, adjoint=adjoint, with_batch=with_batch)
               # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
               if adjoint_arg:
                 op_solve = operator.solve(
                     linalg.adjoint(rhs),
                     adjoint=adjoint,
                     adjoint_arg=adjoint_arg)
               else:
                 op_solve = operator.solve(
                     rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
               mat_solve = linear_operator_util.matrix_solve_with_broadcast(
                   mat, rhs, adjoint=adjoint)
               if not use_placeholder:
                 self.assertAllEqual(op_solve.get_shape(),
                                     mat_solve.get_shape())
               op_solve_v, mat_solve_v = sess.run(
                   [op_solve, mat_solve], feed_dict=feed_dict)
               self.assertAC(op_solve_v, mat_solve_v)
コード例 #6
0
 def _test_solve(self, with_batch):
     for use_placeholder in self._use_placeholder_options:
         for build_info in self._operator_build_infos:
             for dtype in self._dtypes_to_test:
                 for adjoint in self._adjoint_options:
                     for adjoint_arg in self._adjoint_arg_options:
                         with self.test_session(graph=ops.Graph()) as sess:
                             sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
                             operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                                 build_info,
                                 dtype,
                                 use_placeholder=use_placeholder)
                             rhs = self._make_rhs(operator,
                                                  adjoint=adjoint,
                                                  with_batch=with_batch)
                             # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
                             if adjoint_arg:
                                 op_solve = operator.solve(
                                     linalg.adjoint(rhs),
                                     adjoint=adjoint,
                                     adjoint_arg=adjoint_arg)
                             else:
                                 op_solve = operator.solve(
                                     rhs,
                                     adjoint=adjoint,
                                     adjoint_arg=adjoint_arg)
                             mat_solve = linear_operator_util.matrix_solve_with_broadcast(
                                 mat, rhs, adjoint=adjoint)
                             if not use_placeholder:
                                 self.assertAllEqual(
                                     op_solve.get_shape(),
                                     mat_solve.get_shape())
                             op_solve_v, mat_solve_v = sess.run(
                                 [op_solve, mat_solve], feed_dict=feed_dict)
                             self.assertAC(op_solve_v, mat_solve_v)
コード例 #7
0
def _test_solve_base(self, use_placeholder, shapes_info, dtype, adjoint,
                     adjoint_arg, blockwise_arg, with_batch):
    # If batch dimensions are omitted, but there are
    # no batch dimensions for the linear operator, then
    # skip the test case. This is already checked with
    # with_batch=True.
    if not with_batch and len(shapes_info.shape) <= 2:
        return
    with self.session(graph=ops.Graph()) as sess:
        sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
        operator, mat = self.operator_and_matrix(
            shapes_info, dtype, use_placeholder=use_placeholder)
        rhs = self.make_rhs(operator, adjoint=adjoint, with_batch=with_batch)
        # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
        if adjoint_arg:
            op_solve = operator.solve(linalg.adjoint(rhs),
                                      adjoint=adjoint,
                                      adjoint_arg=adjoint_arg)
        else:
            op_solve = operator.solve(rhs,
                                      adjoint=adjoint,
                                      adjoint_arg=adjoint_arg)
        mat_solve = linear_operator_util.matrix_solve_with_broadcast(
            mat, rhs, adjoint=adjoint)
        if not use_placeholder:
            self.assertAllEqual(op_solve.shape, mat_solve.shape)

        # If the operator is blockwise, test both blockwise rhs and `Tensor` rhs;
        # else test only `Tensor` rhs. In both cases, evaluate all results in a
        # single `sess.run` call to avoid re-sampling the random rhs in graph mode.
        if blockwise_arg and len(operator.operators) > 1:
            split_rhs = linear_operator_util.split_arg_into_blocks(
                operator._block_domain_dimensions(),  # pylint: disable=protected-access
                operator._block_domain_dimension_tensors,  # pylint: disable=protected-access
                rhs,
                axis=-2)
            if adjoint_arg:
                split_rhs = [linalg.adjoint(y) for y in split_rhs]
            split_solve = operator.solve(split_rhs,
                                         adjoint=adjoint,
                                         adjoint_arg=adjoint_arg)
            self.assertEqual(len(split_solve), len(operator.operators))
            split_solve = linear_operator_util.broadcast_matrix_batch_dims(
                split_solve)
            fused_block_solve = array_ops.concat(split_solve, axis=-2)
            op_solve_v, mat_solve_v, fused_block_solve_v = sess.run(
                [op_solve, mat_solve, fused_block_solve])

            # Check that the operator and matrix give the same solution when the rhs
            # is blockwise.
            self.assertAC(mat_solve_v, fused_block_solve_v)
        else:
            op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])

        # Check that the operator and matrix give the same solution when the rhs is
        # a `Tensor`.
        self.assertAC(op_solve_v, mat_solve_v)
コード例 #8
0
    def test_static_dims_broadcast_matrix_has_extra_dims(self):
        # batch_shape = [2]
        matrix = rng.rand(2, 3, 3)
        rhs = rng.rand(3, 7)
        rhs_broadcast = rhs + np.zeros((2, 1, 1))

        result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
        self.assertAllEqual((2, 3, 7), result.shape)
        expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
        self.assertAllClose(*self.evaluate([expected, result]))
コード例 #9
0
 def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
     """Solve by conversion to a dense matrix."""
     if self.is_square is False:  # pylint: disable=g-bool-id-comparison
         raise NotImplementedError(
             "Solve is not yet implemented for non-square operators.")
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._can_use_cholesky():
         return linalg_ops.cholesky_solve(
             linalg_ops.cholesky(self.to_dense()), rhs)
     return linear_operator_util.matrix_solve_with_broadcast(
         self.to_dense(), rhs, adjoint=adjoint)
コード例 #10
0
  def test_static_dims_broadcast(self):
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 7)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
      self.assertAllEqual((2, 3, 7), result.get_shape())
      expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
      self.assertAllEqual(expected.eval(), result.eval())
コード例 #11
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   """Default implementation of _solve."""
   if self.is_square is False:
     raise NotImplementedError(
         "Solve is not yet implemented for non-square operators.")
   logging.warn(
       "Using (possibly slow) default implementation of solve."
       "  Requires conversion to a dense matrix and O(N^3) operations.")
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   if self._can_use_cholesky():
     return linear_operator_util.cholesky_solve_with_broadcast(
         linalg_ops.cholesky(self.to_dense()), rhs)
   return linear_operator_util.matrix_solve_with_broadcast(
       self.to_dense(), rhs, adjoint=adjoint)
コード例 #12
0
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     """Default implementation of _solve."""
     if self.is_square is False:
         raise NotImplementedError(
             "Solve is not yet implemented for non-square operators.")
     logging.warn(
         "Using (possibly slow) default implementation of solve."
         "  Requires conversion to a dense matrix and O(N^3) operations.")
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._can_use_cholesky():
         return linear_operator_util.cholesky_solve_with_broadcast(
             linalg_ops.cholesky(self.to_dense()), rhs)
     return linear_operator_util.matrix_solve_with_broadcast(
         self.to_dense(), rhs, adjoint=adjoint)
コード例 #13
0
  def test_dynamic_dims_broadcast_64bit(self):
    # batch_shape = [2, 2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(2, 1, 3, 7)
    matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
    rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))

    matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
    rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)

    result, expected = self.evaluate([
        linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph),
        linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
    ])
    self.assertAllClose(expected, result)
コード例 #14
0
    def test_static_dims_broadcast_rhs_has_extra_dims(self):
        # Since the second arg has extra dims, and the domain dim of the first arg
        # is larger than the number of linear equations, code will "flip" the extra
        # dims of the first arg to the far right, making extra linear equations
        # (then call the matrix function, then flip back).
        # We have verified that this optimization indeed happens.  How? We stepped
        # through with a debugger.
        # batch_shape = [2]
        matrix = rng.rand(3, 3)
        rhs = rng.rand(2, 3, 2)
        matrix_broadcast = matrix + np.zeros((2, 1, 1))

        result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
        self.assertAllEqual((2, 3, 2), result.shape)
        expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
        self.assertAllClose(*self.evaluate([expected, result]))
コード例 #15
0
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        if self.base_operator.is_non_singular is False:
            raise ValueError(
                "Solve not implemented unless this is a perturbation of a "
                "non-singular LinearOperator.")
        # The Woodbury formula gives:
        # https://en.wikipedia.org/wiki/Woodbury_matrix_identity
        #   (L + UDV^H)^{-1}
        #   = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1}
        #   = L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
        # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U
        # Note also that, with ^{-H} being the inverse of the adjoint,
        #   (L + UDV^H)^{-H}
        #   = L^{-H} - L^{-H} V C^{-H} U^H L^{-H}
        l = self.base_operator
        if adjoint:
            # If adjoint, U and V have flipped roles in the operator.
            v, u = self._get_uv_as_tensors()
            # Capacitance should still be computed with u=self.u and v=self.v, which
            # after the "flip" on the line above means u=v, v=u. I.e. no need to
            # "flip" in the capacitance call, since the call to
            # matrix_solve_with_broadcast below is done with the `adjoint` argument,
            # and this takes care of things.
            capacitance = self._make_capacitance(u=v, v=u)
        else:
            u, v = self._get_uv_as_tensors()
            capacitance = self._make_capacitance(u=u, v=v)

        # L^{-1} rhs
        linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
        # V^H L^{-1} rhs
        vh_linv_rhs = math_ops.matmul(v, linv_rhs, adjoint_a=True)
        # C^{-1} V^H L^{-1} rhs
        if self._use_cholesky:
            capinv_vh_linv_rhs = linalg_ops.cholesky_solve(
                linalg_ops.cholesky(capacitance), vh_linv_rhs)
        else:
            capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast(
                capacitance, vh_linv_rhs, adjoint=adjoint)
        # U C^{-1} V^H M^{-1} rhs
        u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs)
        # L^{-1} U C^{-1} V^H L^{-1} rhs
        linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs,
                                            adjoint=adjoint)

        # L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
        return linv_rhs - linv_u_capinv_vh_linv_rhs
コード例 #16
0
  def test_static_dims_broadcast_rhs_has_extra_dims(self):
    # Since the second arg has extra dims, and the domain dim of the first arg
    # is larger than the number of linear equations, code will "flip" the extra
    # dims of the first arg to the far right, making extra linear equations
    # (then call the matrix function, then flip back).
    # We have verified that this optimization indeed happens.  How? We stepped
    # through with a debugger.
    # batch_shape = [2]
    matrix = rng.rand(3, 3)
    rhs = rng.rand(2, 3, 2)
    matrix_broadcast = matrix + np.zeros((2, 1, 1))

    with self.cached_session():
      result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
      self.assertAllEqual((2, 3, 2), result.get_shape())
      expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
      self.assertAllClose(expected.eval(), self.evaluate(result))
コード例 #17
0
    def _solve(self, rhs, adjoint=False, adjoint_arg=False):
        if self.base_operator.is_non_singular is False:
            raise ValueError(
                "Solve not implemented unless this is a perturbation of a "
                "non-singular LinearOperator.")
        # The Woodbury formula gives:
        # https://en.wikipedia.org/wiki/Woodbury_matrix_identity
        #   (L + UDV^H)^{-1}
        #   = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1}
        #   = L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
        # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U
        # Note also that, with ^{-H} being the inverse of the adjoint,
        #   (L + UDV^H)^{-H}
        #   = L^{-H} - L^{-H} V C^{-H} U^H L^{-H}
        l = self.base_operator
        if adjoint:
            v = self.u
            u = self.v
        else:
            v = self.v
            u = self.u

        # L^{-1} rhs
        linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
        # V^H L^{-1} rhs
        vh_linv_rhs = linear_operator_util.matmul_with_broadcast(
            v, linv_rhs, adjoint_a=True)
        # C^{-1} V^H L^{-1} rhs
        if self._use_cholesky:
            capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast(
                self._chol_capacitance, vh_linv_rhs)
        else:
            capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast(
                self._capacitance, vh_linv_rhs, adjoint=adjoint)
        # U C^{-1} V^H M^{-1} rhs
        u_capinv_vh_linv_rhs = linear_operator_util.matmul_with_broadcast(
            u, capinv_vh_linv_rhs)
        # L^{-1} U C^{-1} V^H L^{-1} rhs
        linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs,
                                            adjoint=adjoint)

        # L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
        return linv_rhs - linv_u_capinv_vh_linv_rhs
コード例 #18
0
  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    if self.base_operator.is_non_singular is False:
      raise ValueError(
          "Solve not implemented unless this is a perturbation of a "
          "non-singular LinearOperator.")
    # The Woodbury formula gives:
    # https://en.wikipedia.org/wiki/Woodbury_matrix_identity
    #   (L + UDV^H)^{-1}
    #   = L^{-1} - L^{-1} U (D^{-1} + V^H L^{-1} U)^{-1} V^H L^{-1}
    #   = L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
    # where C is the capacitance matrix, C := D^{-1} + V^H L^{-1} U
    # Note also that, with ^{-H} being the inverse of the adjoint,
    #   (L + UDV^H)^{-H}
    #   = L^{-H} - L^{-H} V C^{-H} U^H L^{-H}
    l = self.base_operator
    if adjoint:
      v = self.u
      u = self.v
    else:
      v = self.v
      u = self.u

    # L^{-1} rhs
    linv_rhs = l.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
    # V^H L^{-1} rhs
    vh_linv_rhs = linear_operator_util.matmul_with_broadcast(
        v, linv_rhs, adjoint_a=True)
    # C^{-1} V^H L^{-1} rhs
    if self._use_cholesky:
      capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast(
          self._chol_capacitance, vh_linv_rhs)
    else:
      capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast(
          self._capacitance, vh_linv_rhs, adjoint=adjoint)
    # U C^{-1} V^H M^{-1} rhs
    u_capinv_vh_linv_rhs = linear_operator_util.matmul_with_broadcast(
        u, capinv_vh_linv_rhs)
    # L^{-1} U C^{-1} V^H L^{-1} rhs
    linv_u_capinv_vh_linv_rhs = l.solve(u_capinv_vh_linv_rhs, adjoint=adjoint)

    # L^{-1} - L^{-1} U C^{-1} V^H L^{-1}
    return linv_rhs - linv_u_capinv_vh_linv_rhs
コード例 #19
0
 def _test_solve(self, with_batch):
     for use_placeholder in self._use_placeholder_options:
         for build_info in self._operator_build_infos:
             # If batch dimensions are omitted, but there are
             # no batch dimensions for the linear operator, then
             # skip the test case. This is already checked with
             # with_batch=True.
             if not with_batch and len(build_info.shape) <= 2:
                 continue
             for dtype in self._dtypes_to_test:
                 for adjoint in self._adjoint_options:
                     for adjoint_arg in self._adjoint_arg_options:
                         with self.test_session(graph=ops.Graph()) as sess:
                             sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
                             operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
                                 build_info,
                                 dtype,
                                 use_placeholder=use_placeholder)
                             rhs = self._make_rhs(operator,
                                                  adjoint=adjoint,
                                                  with_batch=with_batch)
                             # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
                             if adjoint_arg:
                                 op_solve = operator.solve(
                                     linalg.adjoint(rhs),
                                     adjoint=adjoint,
                                     adjoint_arg=adjoint_arg)
                             else:
                                 op_solve = operator.solve(
                                     rhs,
                                     adjoint=adjoint,
                                     adjoint_arg=adjoint_arg)
                             mat_solve = linear_operator_util.matrix_solve_with_broadcast(
                                 mat, rhs, adjoint=adjoint)
                             if not use_placeholder:
                                 self.assertAllEqual(
                                     op_solve.get_shape(),
                                     mat_solve.get_shape())
                             op_solve_v, mat_solve_v = sess.run(
                                 [op_solve, mat_solve], feed_dict=feed_dict)
                             self.assertAC(op_solve_v, mat_solve_v)
コード例 #20
0
    def test_dynamic_dims_broadcast_64bit(self):
        # batch_shape = [2, 2]
        matrix = rng.rand(2, 3, 3)
        rhs = rng.rand(2, 1, 3, 7)
        matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
        rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))

        matrix_ph = array_ops.placeholder(dtypes.float64)
        rhs_ph = array_ops.placeholder(dtypes.float64)

        with self.cached_session() as sess:
            result, expected = sess.run([
                linear_operator_util.matrix_solve_with_broadcast(
                    matrix_ph, rhs_ph),
                linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
            ],
                                        feed_dict={
                                            matrix_ph: matrix,
                                            rhs_ph: rhs,
                                        })
            self.assertAllClose(expected, result)
コード例 #21
0
  def test_dynamic_dims_broadcast_64bit(self):
    # batch_shape = [2, 2]
    matrix = rng.rand(2, 3, 3)
    rhs = rng.rand(2, 1, 3, 7)
    matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
    rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))

    matrix_ph = array_ops.placeholder(dtypes.float64)
    rhs_ph = array_ops.placeholder(dtypes.float64)

    with self.cached_session() as sess:
      result, expected = sess.run(
          [
              linear_operator_util.matrix_solve_with_broadcast(
                  matrix_ph, rhs_ph),
              linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
          ],
          feed_dict={
              matrix_ph: matrix,
              rhs_ph: rhs,
          })
      self.assertAllEqual(expected, result)
コード例 #22
0
def _test_solve_base(
    self,
    use_placeholder,
    shapes_info,
    dtype,
    adjoint,
    adjoint_arg,
    with_batch):
  # If batch dimensions are omitted, but there are
  # no batch dimensions for the linear operator, then
  # skip the test case. This is already checked with
  # with_batch=True.
  if not with_batch and len(shapes_info.shape) <= 2:
    return
  with self.session(graph=ops.Graph()) as sess:
    sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
    operator, mat = self.operator_and_matrix(
        shapes_info, dtype, use_placeholder=use_placeholder)
    rhs = self.make_rhs(
        operator, adjoint=adjoint, with_batch=with_batch)
    # If adjoint_arg, solve A X = (rhs^H)^H = rhs.
    if adjoint_arg:
      op_solve = operator.solve(
          linalg.adjoint(rhs),
          adjoint=adjoint,
          adjoint_arg=adjoint_arg)
    else:
      op_solve = operator.solve(
          rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
    mat_solve = linear_operator_util.matrix_solve_with_broadcast(
        mat, rhs, adjoint=adjoint)
    if not use_placeholder:
      self.assertAllEqual(op_solve.get_shape(),
                          mat_solve.get_shape())
    op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
    self.assertAC(op_solve_v, mat_solve_v)