def test_ref_type_shape_args_raises(self):
    with self.assertRaisesRegexp(TypeError, "num_rows.*reference"):
      linalg_lib.LinearOperatorIdentity(num_rows=variables_module.Variable(2))

    with self.assertRaisesRegexp(TypeError, "batch_shape.*reference"):
      linalg_lib.LinearOperatorIdentity(
          num_rows=2, batch_shape=variables_module.Variable([3]))
Example #2
0
    def test_is_x_flags(self):
        # The is_x flags are by default all True.
        operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
        self.assertTrue(operator.is_positive_definite)
        self.assertTrue(operator.is_non_singular)
        self.assertTrue(operator.is_self_adjoint)

        # Any of them False raises because the identity is always self-adjoint etc..
        with self.assertRaisesRegexp(ValueError, "is always non-singular"):
            operator = linalg_lib.LinearOperatorIdentity(
                num_rows=2,
                is_non_singular=None,
            )
Example #3
0
 def test_negative_num_rows_raises_dynamic(self):
     with self.test_session():
         num_rows = array_ops.placeholder(dtypes.int32)
         operator = linalg_lib.LinearOperatorIdentity(
             num_rows, assert_proper_shapes=True)
         with self.assertRaisesOpError("must be non-negative"):
             operator.to_dense().eval(feed_dict={num_rows: -2})
Example #4
0
 def test_negative_batch_shape_raises_dynamic(self):
     with self.cached_session():
         batch_shape = array_ops.placeholder(dtypes.int32)
         operator = linalg_lib.LinearOperatorIdentity(
             num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
         with self.assertRaisesOpError("must be non-negative"):
             operator.to_dense().eval(feed_dict={batch_shape: [-2]})
Example #5
0
 def test_negative_num_rows_raises_dynamic(self):
     with self.cached_session():
         num_rows = array_ops.placeholder_with_default(-2, shape=None)
         with self.assertRaisesError("must be non-negative"):
             operator = linalg_lib.LinearOperatorIdentity(
                 num_rows, assert_proper_shapes=True)
             self.evaluate(operator.to_dense())
Example #6
0
    def test_broadcast_matmul_dynamic_shapes(self):
        # These cannot be done in the automated (base test class) tests since they
        # test shapes that tf.batch_matmul cannot handle.
        # In particular, tf.batch_matmul does not broadcast.
        with self.cached_session() as sess:
            # Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
            # broadcast shape of operator and 'x' is (2, 2, 3, 4)
            x = array_ops.placeholder(dtypes.float32)
            num_rows = array_ops.placeholder(dtypes.int32)
            batch_shape = array_ops.placeholder(dtypes.int32)

            operator = linalg_lib.LinearOperatorIdentity(
                num_rows, batch_shape=batch_shape)
            feed_dict = {
                x: rng.rand(1, 2, 3, 4),
                num_rows: 3,
                batch_shape: (2, 1)
            }

            # Batch matrix of zeros with the broadcast shape of x and operator.
            zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)

            # Expected result of matmul and solve.
            expected = x + zeros

            operator_matmul = operator.matmul(x)
            self.assertAllClose(
                *sess.run([operator_matmul, expected], feed_dict=feed_dict))
Example #7
0
 def test_non_scalar_num_rows_raises_dynamic(self):
     with self.cached_session():
         num_rows = array_ops.placeholder(dtypes.int32)
         operator = linalg_lib.LinearOperatorIdentity(
             num_rows, assert_proper_shapes=True)
         with self.assertRaisesOpError("must be a 0-D Tensor"):
             operator.to_dense().eval(feed_dict={num_rows: [2]})
 def test_non_1d_batch_shape_raises_dynamic(self):
   with self.cached_session():
     batch_shape = array_ops.placeholder_with_default(2, shape=None)
     with self.assertRaisesError("must be a 1-D"):
       operator = linalg_lib.LinearOperatorIdentity(
           num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
       self.evaluate(operator.to_dense())
Example #9
0
 def test_identity_cholesky_type(self):
     operator = linalg_lib.LinearOperatorIdentity(
         num_rows=2,
         is_positive_definite=True,
         is_self_adjoint=True,
     )
     self.assertIsInstance(operator.cholesky(),
                           linalg_lib.LinearOperatorIdentity)
Example #10
0
 def test_float16_matmul(self):
     # float16 cannot be tested by base test class because tf.matrix_solve does
     # not work with float16.
     with self.cached_session():
         operator = linalg_lib.LinearOperatorIdentity(num_rows=2,
                                                      dtype=dtypes.float16)
         x = rng.randn(2, 3).astype(np.float16)
         y = operator.matmul(x)
         self.assertAllClose(x, self.evaluate(y))
  def test_wrong_matrix_dimensions_raises_dynamic(self):
    num_rows = array_ops.placeholder_with_default(2, shape=None)
    x = array_ops.placeholder_with_default(
        rng.rand(3, 3).astype(np.float32), shape=None)

    with self.cached_session():
      with self.assertRaisesError("Dimensions.*not.compatible"):
        operator = linalg_lib.LinearOperatorIdentity(
            num_rows, assert_proper_shapes=True)
        self.evaluate(operator.matmul(x))
Example #12
0
    def test_wrong_matrix_dimensions_raises_dynamic(self):
        num_rows = array_ops.placeholder(dtypes.int32)
        x = array_ops.placeholder(dtypes.float32)

        with self.cached_session():
            operator = linalg_lib.LinearOperatorIdentity(
                num_rows, assert_proper_shapes=True)
            y = operator.matmul(x)
            with self.assertRaisesOpError("Incompatible.*dimensions"):
                y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
    def test_zeros_matmul(self):
        operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
        operator2 = linalg_lib.LinearOperatorZeros(num_rows=2)
        self.assertTrue(
            isinstance(operator1.matmul(operator2),
                       linalg_lib.LinearOperatorZeros))

        self.assertTrue(
            isinstance(operator2.matmul(operator1),
                       linalg_lib.LinearOperatorZeros))
  def _operator_and_matrix(self, build_info, dtype, use_placeholder):
    shape = list(build_info.shape)
    assert shape[-1] == shape[-2]

    batch_shape = shape[:-2]
    num_rows = shape[-1]

    operator = linalg_lib.LinearOperatorIdentity(
        num_rows, batch_shape=batch_shape, dtype=dtype)
    mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)

    return operator, mat
  def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
    # These cannot be done in the automated (base test class) tests since they
    # test shapes that tf.batch_matmul cannot handle.
    # In particular, tf.batch_matmul does not broadcast.
    with self.cached_session():
      x = array_ops.placeholder_with_default(rng.randn(1, 2, 3, 4), shape=None)
      operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)

      operator_matmul = operator.matmul(x)
      expected = x

      self.assertAllClose(*self.evaluate([operator_matmul, expected]))
  def test_identity_solve(self):
    operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
    operator2 = linalg_lib.LinearOperatorScaledIdentity(
        num_rows=2, multiplier=3.)
    self.assertTrue(isinstance(
        operator1.solve(operator1),
        linalg_lib.LinearOperatorIdentity))

    operator_solve = operator1.solve(operator2)
    self.assertTrue(isinstance(
        operator_solve,
        linalg_lib.LinearOperatorScaledIdentity))
    self.assertAllClose(3., self.evaluate(operator_solve.multiplier))
  def test_default_batch_shape_broadcasts_with_everything_static(self):
    # These cannot be done in the automated (base test class) tests since they
    # test shapes that tf.batch_matmul cannot handle.
    # In particular, tf.batch_matmul does not broadcast.
    with self.cached_session() as sess:
      x = random_ops.random_normal(shape=(1, 2, 3, 4))
      operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)

      operator_matmul = operator.matmul(x)
      expected = x

      self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
      self.assertAllClose(*sess.run([operator_matmul, expected]))
Example #18
0
    def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
        # These cannot be done in the automated (base test class) tests since they
        # test shapes that tf.batch_matmul cannot handle.
        # In particular, tf.batch_matmul does not broadcast.
        with self.test_session() as sess:
            x = array_ops.placeholder(dtypes.float32)
            operator = linalg_lib.LinearOperatorIdentity(num_rows=3,
                                                         dtype=x.dtype)

            operator_matmul = operator.matmul(x)
            expected = x

            feed_dict = {x: rng.randn(1, 2, 3, 4)}

            self.assertAllClose(
                *sess.run([operator_matmul, expected], feed_dict=feed_dict))
  def operator_and_matrix(
      self, build_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False):
    # Identity matrix is already Hermitian Positive Definite.
    del ensure_self_adjoint_and_pd

    shape = list(build_info.shape)
    assert shape[-1] == shape[-2]

    batch_shape = shape[:-2]
    num_rows = shape[-1]

    operator = linalg_lib.LinearOperatorIdentity(
        num_rows, batch_shape=batch_shape, dtype=dtype)
    mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)

    return operator, mat
Example #20
0
    def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
        shape = list(shape)
        assert shape[-1] == shape[-2]

        batch_shape = shape[:-2]
        num_rows = shape[-1]

        operator = linalg_lib.LinearOperatorIdentity(num_rows,
                                                     batch_shape=batch_shape,
                                                     dtype=dtype)
        mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)

        # Nothing to feed since LinearOperatorIdentity takes no Tensor args.
        if use_placeholder:
            feed_dict = {}
        else:
            feed_dict = None

        return operator, mat, feed_dict
  def test_broadcast_matmul_static_shapes(self):
    # These cannot be done in the automated (base test class) tests since they
    # test shapes that tf.batch_matmul cannot handle.
    # In particular, tf.batch_matmul does not broadcast.
    with self.cached_session() as sess:
      # Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
      # broadcast shape of operator and 'x' is (2, 2, 3, 4)
      x = random_ops.random_normal(shape=(1, 2, 3, 4))
      operator = linalg_lib.LinearOperatorIdentity(
          num_rows=3, batch_shape=(2, 1), dtype=x.dtype)

      # Batch matrix of zeros with the broadcast shape of x and operator.
      zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)

      # Expected result of matmul and solve.
      expected = x + zeros

      operator_matmul = operator.matmul(x)
      self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
      self.assertAllClose(*sess.run([operator_matmul, expected]))
Example #22
0
    def test_identity_matmul(self):
        operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
        operator2 = linalg_lib.LinearOperatorScaledIdentity(num_rows=2,
                                                            multiplier=3.)
        self.assertIsInstance(operator1.matmul(operator1),
                              linalg_lib.LinearOperatorIdentity)

        self.assertIsInstance(operator1.matmul(operator1),
                              linalg_lib.LinearOperatorIdentity)

        self.assertIsInstance(operator2.matmul(operator2),
                              linalg_lib.LinearOperatorScaledIdentity)

        operator_matmul = operator1.matmul(operator2)
        self.assertIsInstance(operator_matmul,
                              linalg_lib.LinearOperatorScaledIdentity)
        self.assertAllClose(3., self.evaluate(operator_matmul.multiplier))

        operator_matmul = operator2.matmul(operator1)
        self.assertIsInstance(operator_matmul,
                              linalg_lib.LinearOperatorScaledIdentity)
        self.assertAllClose(3., self.evaluate(operator_matmul.multiplier))
Example #23
0
 def test_wrong_matrix_dimensions_raises_static(self):
     operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
     x = rng.randn(3, 3).astype(np.float32)
     with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
         operator.matmul(x)
Example #24
0
 def test_assert_self_adjoint(self):
     with self.cached_session():
         operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
         operator.assert_self_adjoint().run()  # Should not fail
Example #25
0
 def test_assert_positive_definite(self):
     with self.cached_session():
         operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
         operator.assert_positive_definite().run()  # Should not fail
Example #26
0
 def test_identity_inverse_type(self):
     operator = linalg_lib.LinearOperatorIdentity(num_rows=2,
                                                  is_non_singular=True)
     self.assertIsInstance(operator.inverse(),
                           linalg_lib.LinearOperatorIdentity)
Example #27
0
 def test_non_scalar_num_rows_raises_static(self):
     with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"):
         linalg_lib.LinearOperatorIdentity(num_rows=[2])
Example #28
0
 def test_negative_batch_shape_raises_static(self):
     with self.assertRaisesRegexp(ValueError, "must be non-negative"):
         linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2])
Example #29
0
 def test_non_integer_num_rows_raises_static(self):
     with self.assertRaisesRegexp(TypeError, "must be integer"):
         linalg_lib.LinearOperatorIdentity(num_rows=2.)
Example #30
0
 def test_non_1d_batch_shape_raises_static(self):
     with self.assertRaisesRegexp(ValueError, "must be a 1-D"):
         linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=2)