def type_spec_with_shape(spec, shape): """Returns a copy of TypeSpec `spec` with its shape set to `shape`.""" if isinstance(spec, tf.TensorSpec): # pylint: disable=protected-access # TODO(b/203201161) Figure out why mutation is needed here, and remove it. # (TensorSpec objects should be immutable; and we should not be modifying # private fields.) shape = tf.TensorShape(shape) spec._shape = shape return spec elif isinstance(spec, tf.RaggedTensorSpec): return tf.RaggedTensorSpec( shape, spec.dtype, spec.ragged_rank, spec.row_splits_dtype, spec.flat_values_spec, ) elif isinstance(spec, tf.SparseTensorSpec): return tf.SparseTensorSpec(shape, spec.dtype) elif hasattr(spec, "with_shape"): # TODO(edloper): Consider adding .with_shape method to TensorSpec, # RaggedTensorSpec, and SparseTensorSpec. return spec.with_shape(shape) else: # TODO(edloper): Consider moving this check to the KerasTensor constructor. raise ValueError( "Keras requires TypeSpec to have a `with_shape` method " "that returns a copy of `self` with an updated shape.")
def common_spec(x, y): common_shape = get_common_shape(x.shape, y.shape) if isinstance(x, tf.SparseTensorSpec): return tf.SparseTensorSpec(common_shape, x.dtype) elif isinstance(x, tf.RaggedTensorSpec): return tf.RaggedTensorSpec(common_shape, x.dtype) return tf.TensorSpec(common_shape, x.dtype, x.name)
class TestGetTensorSpec(parameterized.TestCase): @parameterized.parameters([ (lambda: tf.constant([[1, 2]]), [1, 2]), (tf.TensorSpec([8, 3], tf.int32), [8, 3]), (tf.TensorSpec([8], tf.int32), [8]), (tf.TensorSpec([], tf.int32), []), (tf.TensorSpec(None, tf.int32), None), (tf.RaggedTensorSpec([8, 3], tf.int32), [8, 3]), (tf.SparseTensorSpec([8, 3], tf.int32), [8, 3]), ]) def test_without_dynamic_batch(self, t, expected_shape): if callable(t): t = t() result = tf_utils.get_tensor_spec(t) self.assertTrue(result.is_compatible_with(t)) if expected_shape is None: self.assertIsNone(result.shape.rank) else: self.assertEqual(result.shape.as_list(), expected_shape) @parameterized.parameters([ (lambda: tf.constant([[1, 2]]), [None, 2]), (tf.TensorSpec([8, 3], tf.int32), [None, 3]), (tf.TensorSpec([8], tf.int32), [None]), (tf.TensorSpec([], tf.int32), []), (tf.TensorSpec(None, tf.int32), None), (tf.RaggedTensorSpec([8, 3], tf.int32), [None, 3]), (tf.SparseTensorSpec([8, 3], tf.int32), [None, 3]), ]) def test_with_dynamic_batch(self, t, expected_shape): if callable(t): t = t() result = tf_utils.get_tensor_spec(t, True) self.assertTrue(result.is_compatible_with(t)) if expected_shape is None: self.assertIsNone(result.shape.rank) else: self.assertEqual(result.shape.as_list(), expected_shape) def test_with_keras_tensor_with_ragged_spec(self): t = keras.engine.keras_tensor.KerasTensor( tf.RaggedTensorSpec(shape=(None, None, 1))) self.assertIsInstance(tf_utils.get_tensor_spec(t), tf.RaggedTensorSpec)
class KerasTensorTest(test_combinations.TestCase): def test_repr_and_string(self): kt = keras_tensor.KerasTensor( type_spec=tf.TensorSpec(shape=(1, 2, 3), dtype=tf.float32) ) expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(1, 2, 3), " "dtype=tf.float32, name=None))" ) expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>" self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kt = keras_tensor.KerasTensor( type_spec=tf.TensorSpec(shape=(2,), dtype=tf.int32), inferred_value=[2, 3], ) expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(2,), " "dtype=tf.int32, name=None), inferred_value=[2, 3])" ) expected_repr = ( "<KerasTensor: shape=(2,) dtype=int32 inferred_value=[2, 3]>" ) self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kt = keras_tensor.KerasTensor( type_spec=tf.SparseTensorSpec(shape=(1, 2, 3), dtype=tf.float32) ) expected_str = ( "KerasTensor(type_spec=SparseTensorSpec(" "TensorShape([1, 2, 3]), tf.float32))" ) expected_repr = ( "<KerasTensor: type_spec=SparseTensorSpec(" "TensorShape([1, 2, 3]), tf.float32)>" ) self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) inp = layers.Input(shape=(3, 5)) kt = layers.Dense(10)(inp) expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(None, 3, 10), " "dtype=tf.float32, name=None), name='dense/BiasAdd:0', " "description=\"created by layer 'dense'\")" ) expected_repr = ( "<KerasTensor: shape=(None, 3, 10) dtype=float32 (created " "by layer 'dense')>" ) self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kt = tf.reshape(kt, shape=(3, 5, 2)) expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(3, 5, 2), " "dtype=tf.float32, name=None), name='tf.reshape/Reshape:0', " "description=\"created by layer 'tf.reshape'\")" ) expected_repr = ( "<KerasTensor: shape=(3, 5, 2) dtype=float32 (created " "by layer 'tf.reshape')>" ) self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kts = tf.unstack(kt) for i in range(3): expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(5, 2), " "dtype=tf.float32, name=None), name='tf.unstack/unstack:%s', " "description=\"created by layer 'tf.unstack'\")" % (i,) ) expected_repr = ( "<KerasTensor: shape=(5, 2) dtype=float32 " "(created by layer 'tf.unstack')>" ) self.assertEqual(expected_str, str(kts[i])) self.assertEqual(expected_repr, repr(kts[i])) @parameterized.parameters( {"property_name": "values"}, {"property_name": "indices"}, {"property_name": "dense_shape"}, ) def test_sparse_instance_property(self, property_name): inp = layers.Input(shape=[3], sparse=True) out = getattr(inp, property_name) model = training.Model(inp, out) x = tf.SparseTensor( [[0, 0], [0, 1], [1, 1], [1, 2]], [1, 2, 3, 4], [2, 3] ) expected_property = getattr(x, property_name) self.assertAllEqual(model(x), expected_property) # Test that it works with serialization and deserialization as well model_config = model.get_config() model2 = training.Model.from_config(model_config) self.assertAllEqual(model2(x), expected_property) @parameterized.parameters( [ (tf.TensorSpec([2, 3], tf.int32), [2, 3]), (tf.RaggedTensorSpec([2, None]), [2, None]), (tf.SparseTensorSpec([8]), [8]), (CustomTypeSpec([3, 8], tf.int32), [3, 8]), ] ) def test_shape(self, spec, expected_shape): kt = keras_tensor.KerasTensor(spec) self.assertEqual(kt.shape.as_list(), expected_shape) @parameterized.parameters( [ (tf.TensorSpec([8, 3], tf.int32), [8, 3], [8, 3]), (tf.TensorSpec([None, 3], tf.int32), [8, 3], [8, 3]), (tf.TensorSpec([8, 3], tf.int32), [None, 3], [8, 3]), (tf.TensorSpec(None, tf.int32), [8, 3], [8, 3]), (tf.TensorSpec(None, tf.int32), [8, None], [8, None]), (tf.TensorSpec(None, tf.int32), None, None), (tf.RaggedTensorSpec([2, None, None]), [2, None, 5], [2, None, 5]), (tf.SparseTensorSpec([8]), [8], [8]), (CustomTypeSpec2([3, None], tf.int32), [3, 8], [3, 8]), ] ) def test_set_shape(self, spec, new_shape, expected_shape): kt = keras_tensor.KerasTensor(spec) kt.set_shape(new_shape) if expected_shape is None: self.assertIsNone(kt.type_spec.shape.rank) else: self.assertEqual(kt.type_spec.shape.as_list(), expected_shape) self.assertTrue(kt.type_spec.is_compatible_with(spec)) def test_set_shape_error(self): spec = CustomTypeSpec([3, None], tf.int32) kt = keras_tensor.KerasTensor(spec) with self.assertRaisesRegex( ValueError, "Keras requires TypeSpec to have a `with_shape` method" ): kt.set_shape([3, 3]) def test_set_shape_equals_expected_shape(self): # Tests b/203201161: DenseSpec has both a _shape and a _shape_tuple # field, and we need to be sure both get updated. kt = keras_tensor.KerasTensor(tf.TensorSpec([8, None], tf.int32)) kt.set_shape([8, 3]) self.assertEqual(kt.type_spec, tf.TensorSpec([8, 3], tf.int32)) def test_type_spec_with_shape_equals_expected_shape(self): # Tests b/203201161: DenseSpec has both a _shape and a _shape_tuple # field, and we need to be sure both get updated. spec1 = tf.TensorSpec([8, None], tf.int32) spec2 = keras_tensor.type_spec_with_shape(spec1, [8, 3]) expected = tf.TensorSpec([8, 3], tf.int32) self.assertEqual(spec2, expected) def test_missing_shape_error(self): spec = CustomTypeSpec(None, tf.int32) del spec.shape with self.assertRaisesRegex( ValueError, "KerasTensor only supports TypeSpecs that have a shape field; .*", ): keras_tensor.KerasTensor(spec) def test_wrong_shape_type_error(self): spec = CustomTypeSpec(None, tf.int32) spec.shape = "foo" with self.assertRaisesRegex( TypeError, "KerasTensor requires that wrapped TypeSpec's shape is a " "TensorShape; .*", ): keras_tensor.KerasTensor(spec) def test_missing_dtype_error(self): spec = CustomTypeSpec(None, tf.int32) del spec.dtype kt = keras_tensor.KerasTensor(spec) with self.assertRaisesRegex( AttributeError, "KerasTensor wraps TypeSpec .* which does not have a dtype.", ): kt.dtype def test_wrong_dtype_type_error(self): spec = CustomTypeSpec(None, tf.int32) spec.dtype = "foo" kt = keras_tensor.KerasTensor(spec) with self.assertRaisesRegex( TypeError, "KerasTensor requires that wrapped TypeSpec's dtype is a DType; .*", ): kt.dtype def test_from_tensor_mask_tensor_is_none(self): tensor = tf.constant([1.0]) kt = keras_tensor.keras_tensor_from_tensor(tensor) self.assertIsNone(getattr(kt, "_keras_mask", None)) def test_from_tensor_mask_tensor_is_not_none(self): tensor = tf.constant([1.0]) tensor._keras_mask = tf.constant([1.0]) kt = keras_tensor.keras_tensor_from_tensor(tensor) self.assertIsInstance(kt._keras_mask, keras_tensor.KerasTensor)
def test_with_keras_tensor_with_ragged_spec(self): t = keras.engine.keras_tensor.KerasTensor( tf.RaggedTensorSpec(shape=(None, None, 1))) self.assertIsInstance(tf_utils.get_tensor_spec(t), tf.RaggedTensorSpec)