Example #1
0
def assert_compatible_matrix_dimensions(operator, x):
  """Assert that an argument to solve/matmul has proper domain dimension.

  If `tensor_shape.TensorShape(operator.shape)[-2:] = [M, N]`, and `tensor_shape.TensorShape(x.shape)[-2:] = [Q, R]`, then
  `operator.matmul(x)` is defined only if `N = Q`.  This `Op` returns an
  `Assert` that "fires" if this is not the case.  Static checks are already
  done by the base class `LinearOperator`.

  Args:
    operator:  `LinearOperator`.
    x:  `Tensor`.

  Returns:
    `Assert` `Op`.
  """
  # Static checks are done in the base class.  Only tensor asserts here.
  assert_same_dd = check_ops.assert_equal(
      array_ops.shape(x)[-2],
      operator.domain_dimension_tensor(),
      # This error message made to look similar to error raised by static check
      # in the base class.
      message=("Dimensions are not compatible.  "
               "shape[-2] of argument to be the same as this operator"))

  return assert_same_dd
Example #2
0
 def _assert_self_adjoint(self):
     dense = self.to_dense()
     logging.warn(
         "Using (possibly slow) default implementation of assert_self_adjoint."
         "  Requires conversion to a dense matrix.")
     return check_ops.assert_equal(
         dense,
         linalg.adjoint(dense),
         message="Matrix was not equal to its adjoint.")
Example #3
0
def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
  """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts.

  Args:
    x:  Numeric `Tensor`, real, integer, or complex.
    message:  A string message to prepend to failure message.
    name:  A name to give this `Op`.

  Returns:
    An `Op` that asserts `x` has no entries with modulus zero.
  """
  with ops.name_scope(name, values=[x]):
    x = ops.convert_to_tensor(x, name="x")
    dtype = x.dtype

    if dtype.is_floating:
      return control_flow_ops.no_op()

    zero = ops.convert_to_tensor(0, dtype=dtypes.real_dtype(dtype))
    return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
 def _assert_self_adjoint(self):
     return check_ops.assert_equal(
         self.row,
         self.col,
         message=("row and col are not the same, and "
                  "so this operator is not self-adjoint."))
 def _assert_self_adjoint(self):
     imag_multiplier = math_ops.imag(self.multiplier)
     return check_ops.assert_equal(
         array_ops.zeros_like(imag_multiplier),
         imag_multiplier,
         message="LinearOperator was not self-adjoint")