예제 #1
0
 def test_raises_if_rank_is_not_scalar_static(self):
     with self.test_session():
         tensor = constant_op.constant((42, 43), name="my_tensor")
         desired_ranks = (np.array(1, dtype=np.int32),
                          np.array((2, 1), dtype=np.int32))
         with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"):
             check_ops.assert_rank_in(tensor, desired_ranks)
예제 #2
0
 def test_raises_if_rank_is_not_scalar_static(self):
   tensor = constant_op.constant((42, 43), name="my_tensor")
   desired_ranks = (
       np.array(1, dtype=np.int32),
       np.array((2, 1), dtype=np.int32))
   with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"):
     check_ops.assert_rank_in(tensor, desired_ranks)
예제 #3
0
 def test_raises_if_rank_is_not_integer_static(self):
     with self.test_session():
         tensor = constant_op.constant((42, 43), name="my_tensor")
         with self.assertRaisesRegexp(TypeError,
                                      "must be of type <dtype: 'int32'>"):
             check_ops.assert_rank_in(tensor, (
                 1,
                 .5,
             ))
예제 #4
0
 def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self):
   with self.test_session():
     tensor_rank1 = constant_op.constant((42, 43), name="my_tensor")
     with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
       with ops.control_dependencies([
           check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
         array_ops.identity(tensor_rank1).eval()
예제 #5
0
 def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self):
   with self.test_session():
     tensor_rank1 = constant_op.constant([42, 43], name="my_tensor")
     for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
       with ops.control_dependencies([
           check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
         array_ops.identity(tensor_rank1).eval()
예제 #6
0
 def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
   with self.test_session():
     tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
     for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
       with ops.control_dependencies([
           check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
         array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
예제 #7
0
 def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
   with self.test_session():
     tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
     with ops.control_dependencies([
         check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
       with self.assertRaisesOpError("fail.*my_tensor.*rank"):
         array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
예제 #8
0
 def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self):
   tensor_rank0 = constant_op.constant(42, name="my_tensor")
   with self.assertRaisesRegexp(
       ValueError, "fail.*must have rank.*in.*1.*2"):
     with ops.control_dependencies([
         check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
       self.evaluate(array_ops.identity(tensor_rank0))
예제 #9
0
 def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self):
   with self.test_session():
     tensor_rank0 = constant_op.constant(42, name="my_tensor")
     with self.assertRaisesRegexp(
         ValueError, "fail.*my_tensor.*must have rank.*in.*1.*2"):
       with ops.control_dependencies([
           check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
         array_ops.identity(tensor_rank0).eval()
예제 #10
0
 def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
   with self.test_session():
     tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
     with ops.control_dependencies([
         check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
       with self.assertRaisesOpError("my_tensor.*rank"):
         array_ops.identity(tensor_rank1).eval(feed_dict={
             tensor_rank1: (42.0, 43.0)
         })
예제 #11
0
 def test_raises_if_rank_is_not_integer_dynamic(self):
   with self.test_session():
     tensor = constant_op.constant(
         (42, 43), dtype=dtypes.float32, name="my_tensor")
     rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
     with self.assertRaisesRegexp(TypeError,
                                  "must be of type <dtype: 'int32'>"):
       with ops.control_dependencies(
           [check_ops.assert_rank_in(tensor, (1, rank_tensor))]):
         array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
예제 #12
0
 def test_raises_if_rank_is_not_scalar_dynamic(self):
   with self.test_session():
     tensor = constant_op.constant(
         (42, 43), dtype=dtypes.float32, name="my_tensor")
     desired_ranks = (
         array_ops.placeholder(dtypes.int32, name="rank0_tensor"),
         array_ops.placeholder(dtypes.int32, name="rank1_tensor"))
     with self.assertRaisesOpError("Rank must be a scalar"):
       with ops.control_dependencies(
           (check_ops.assert_rank_in(tensor, desired_ranks),)):
         array_ops.identity(tensor).eval(feed_dict={
             desired_ranks[0]: 1,
             desired_ranks[1]: [2, 1],
         })
예제 #13
0
 def test_raises_if_rank_is_not_integer_static(self):
   with self.test_session():
     tensor = constant_op.constant((42, 43), name="my_tensor")
     with self.assertRaisesRegexp(TypeError,
                                  "must be of type <dtype: 'int32'>"):
       check_ops.assert_rank_in(tensor, (1, .5,))
예제 #14
0
 def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self):
   tensor_rank0 = constant_op.constant(42, name="my_tensor")
   for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
     with ops.control_dependencies([
         check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
       self.evaluate(array_ops.identity(tensor_rank0))