def test_incompatible_dimensions_raise(self):
     with self.cached_session():
         x = ops.convert_to_tensor(rng.rand(2, 4, 4))
         operator = DomainDimensionStubOperator(3)
         with self.assertRaisesOpError("Incompatible matrix dimensions"):
             linear_operator_util.assert_compatible_matrix_dimensions(
                 operator, x).run()  # pyformat: disable
 def test_incompatible_dimensions_raise(self):
   with self.cached_session():
     x = ops.convert_to_tensor(rng.rand(2, 4, 4))
     operator = DomainDimensionStubOperator(3)
     with self.assertRaisesOpError("Incompatible matrix dimensions"):
       linear_operator_util.assert_compatible_matrix_dimensions(
           operator, x).run()  # pyformat: disable
 def test_compatible_dimensions_do_not_raise(self):
     with self.cached_session():
         x = ops.convert_to_tensor(rng.rand(2, 3, 4))
         operator = DomainDimensionStubOperator(3)
         # Should not raise
         linear_operator_util.assert_compatible_matrix_dimensions(
             operator, x).run()  # pyformat: disable
 def test_compatible_dimensions_do_not_raise(self):
   with self.cached_session():
     x = ops.convert_to_tensor(rng.rand(2, 3, 4))
     operator = DomainDimensionStubOperator(3)
     # Should not raise
     linear_operator_util.assert_compatible_matrix_dimensions(
         operator, x).run()  # pyformat: disable
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, rhs)
         rhs = control_flow_ops.with_dependencies([aps], rhs)
     return rhs / self._make_multiplier_matrix(conjugate=adjoint)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     x = linalg.adjoint(x) if adjoint_arg else x
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, x)
         x = control_flow_ops.with_dependencies([aps], x)
     return x * self._make_multiplier_matrix(conjugate=adjoint)
Ejemplo n.º 7
0
  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    if self._assert_proper_shapes:
      x = linalg.adjoint(x) if adjoint_arg else x
      aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
      x = control_flow_ops.with_dependencies([aps], x)
    if self.is_square:
      # Note that adjoint has no effect since this matrix is self-adjoint.
      if adjoint_arg:
        output_shape = array_ops.concat([
            array_ops.shape(x)[:-2],
            [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0)
      else:
        output_shape = array_ops.shape(x)

      return self._possibly_broadcast_batch_shape(
          array_ops.zeros(shape=output_shape, dtype=x.dtype))

    x_shape = array_ops.shape(x)
    n = self._num_columns if adjoint else self._num_rows
    m = x_shape[-2] if adjoint_arg else x_shape[-1]

    output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0)

    zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
    return self._possibly_broadcast_batch_shape(zeros)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   # Note that adjoint has no effect since this matrix is self-adjoint.
   x = linalg.adjoint(x) if adjoint_arg else x
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = control_flow_ops.with_dependencies([aps], x)
   return self._possibly_broadcast_batch_shape(x)
Ejemplo n.º 9
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        if self._assert_proper_shapes:
            x = linalg.adjoint(x) if adjoint_arg else x
            aps = linear_operator_util.assert_compatible_matrix_dimensions(
                self, x)
            x = control_flow_ops.with_dependencies([aps], x)
        if self.is_square:
            # Note that adjoint has no effect since this matrix is self-adjoint.
            if adjoint_arg:
                output_shape = array_ops.concat([
                    array_ops.shape(x)[:-2],
                    [array_ops.shape(x)[-1],
                     array_ops.shape(x)[-2]]
                ],
                                                axis=0)
            else:
                output_shape = array_ops.shape(x)

            return self._possibly_broadcast_batch_shape(
                array_ops.zeros(shape=output_shape, dtype=x.dtype))

        x_shape = array_ops.shape(x)
        n = self._num_columns if adjoint else self._num_rows
        m = x_shape[-2] if adjoint_arg else x_shape[-1]

        output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0)

        zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
        return self._possibly_broadcast_batch_shape(zeros)
 def test_compatible_dimensions_do_not_raise(self):
     x = ops.convert_to_tensor(rng.rand(2, 3, 4))
     operator = DomainDimensionStubOperator(3)
     # Should not raise
     self.evaluate(
         linear_operator_util.assert_compatible_matrix_dimensions(
             operator, x))
Ejemplo n.º 11
0
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   # Note that adjoint has no effect since this matrix is self-adjoint.
   x = linalg.adjoint(x) if adjoint_arg else x
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = control_flow_ops.with_dependencies([aps], x)
   return self._possibly_broadcast_batch_shape(x)
Ejemplo n.º 12
0
 def test_incompatible_dimensions_raise(self):
   x = ops.convert_to_tensor(rng.rand(2, 4, 4))
   operator = DomainDimensionStubOperator(3)
   # pylint: disable=g-error-prone-assert-raises
   with self.assertRaisesOpError("Dimensions are not compatible"):
     self.evaluate(
         linear_operator_util.assert_compatible_matrix_dimensions(operator, x))
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
   rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
   if adjoint:
     matrix = self._multiplier_matrix_conj
   else:
     matrix = self._multiplier_matrix
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs)
     rhs = control_flow_ops.with_dependencies([aps], rhs)
   return rhs / matrix
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
   x = linalg.adjoint(x) if adjoint_arg else x
   if adjoint:
     matrix = self._multiplier_matrix_conj
   else:
     matrix = self._multiplier_matrix
   if self._assert_proper_shapes:
     aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
     x = control_flow_ops.with_dependencies([aps], x)
   return x * matrix