Example #1
0
 def test_with_dynamic_batch(self, t, expected_shape):
   if callable(t):
     t = t()
   result = tf_utils.get_tensor_spec(t, True)
   self.assertTrue(result.is_compatible_with(t))
   if expected_shape is None:
     self.assertIsNone(result.shape.rank)
   else:
     self.assertEqual(result.shape.as_list(), expected_shape)
Example #2
0
 def test_with_keras_tensor_with_ragged_spec(self):
     t = keras.engine.keras_tensor.KerasTensor(
         tf.RaggedTensorSpec(shape=(None, None, 1)))
     self.assertIsInstance(tf_utils.get_tensor_spec(t), tf.RaggedTensorSpec)