def test_with_dynamic_batch(self, t, expected_shape): if callable(t): t = t() result = tf_utils.get_tensor_spec(t, True) self.assertTrue(result.is_compatible_with(t)) if expected_shape is None: self.assertIsNone(result.shape.rank) else: self.assertEqual(result.shape.as_list(), expected_shape)
def test_with_keras_tensor_with_ragged_spec(self): t = keras.engine.keras_tensor.KerasTensor( tf.RaggedTensorSpec(shape=(None, None, 1))) self.assertIsInstance(tf_utils.get_tensor_spec(t), tf.RaggedTensorSpec)