def test_set_shape_error(self): spec = CustomTypeSpec([3, None], tf.int32) kt = keras_tensor.KerasTensor(spec) with self.assertRaisesRegex( ValueError, "Keras requires TypeSpec to have a `with_shape` method" ): kt.set_shape([3, 3])
def test_missing_shape_error(self): spec = CustomTypeSpec(None, tf.int32) del spec.shape with self.assertRaisesRegex( ValueError, "KerasTensor only supports TypeSpecs that have a shape field; .*", ): keras_tensor.KerasTensor(spec)
def test_set_shape(self, spec, new_shape, expected_shape): kt = keras_tensor.KerasTensor(spec) kt.set_shape(new_shape) if expected_shape is None: self.assertIsNone(kt.type_spec.shape.rank) else: self.assertEqual(kt.type_spec.shape.as_list(), expected_shape) self.assertTrue(kt.type_spec.is_compatible_with(spec))
def test_missing_dtype_error(self): spec = CustomTypeSpec(None, tf.int32) del spec.dtype kt = keras_tensor.KerasTensor(spec) with self.assertRaisesRegex( AttributeError, "KerasTensor wraps TypeSpec .* which does not have a dtype."): kt.dtype # pylint: disable=pointless-statement
def test_wrong_shape_type_error(self): spec = CustomTypeSpec(None, tf.int32) spec.shape = "foo" with self.assertRaisesRegex( TypeError, "KerasTensor requires that wrapped TypeSpec's shape is a " "TensorShape; .*"): keras_tensor.KerasTensor(spec)
def test_wrong_dtype_type_error(self): spec = CustomTypeSpec(None, tf.int32) spec.dtype = "foo" kt = keras_tensor.KerasTensor(spec) with self.assertRaisesRegex( TypeError, "KerasTensor requires that wrapped TypeSpec's dtype is a DType; .*", ): kt.dtype
def test_repr_and_string(self): kt = keras_tensor.KerasTensor( type_spec=tf.TensorSpec(shape=(1, 2, 3), dtype=tf.float32)) expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(1, 2, 3), " "dtype=tf.float32, name=None))") expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>" self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kt = keras_tensor.KerasTensor(type_spec=tf.TensorSpec(shape=(2, ), dtype=tf.int32), inferred_value=[2, 3]) expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(2,), " "dtype=tf.int32, name=None), inferred_value=[2, 3])") expected_repr = ( "<KerasTensor: shape=(2,) dtype=int32 inferred_value=[2, 3]>") self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kt = keras_tensor.KerasTensor( type_spec=tf.SparseTensorSpec(shape=(1, 2, 3), dtype=tf.float32)) expected_str = ("KerasTensor(type_spec=SparseTensorSpec(" "TensorShape([1, 2, 3]), tf.float32))") expected_repr = ("<KerasTensor: type_spec=SparseTensorSpec(" "TensorShape([1, 2, 3]), tf.float32)>") self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) inp = layers.Input(shape=(3, 5)) kt = layers.Dense(10)(inp) expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(None, 3, 10), " "dtype=tf.float32, name=None), name='dense/BiasAdd:0', " "description=\"created by layer 'dense'\")") expected_repr = ( "<KerasTensor: shape=(None, 3, 10) dtype=float32 (created " "by layer 'dense')>") self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kt = tf.reshape(kt, shape=(3, 5, 2)) expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(3, 5, 2), dtype=tf.float32, " "name=None), name='tf.reshape/Reshape:0', description=\"created " "by layer 'tf.reshape'\")") expected_repr = ( "<KerasTensor: shape=(3, 5, 2) dtype=float32 (created " "by layer 'tf.reshape')>") self.assertEqual(expected_str, str(kt)) self.assertEqual(expected_repr, repr(kt)) kts = tf.unstack(kt) for i in range(3): expected_str = ( "KerasTensor(type_spec=TensorSpec(shape=(5, 2), dtype=tf.float32, " "name=None), name='tf.unstack/unstack:%s', description=\"created " "by layer 'tf.unstack'\")" % (i, )) expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 " "(created by layer 'tf.unstack')>") self.assertEqual(expected_str, str(kts[i])) self.assertEqual(expected_repr, repr(kts[i]))
def test_set_shape_equals_expected_shape(self): # Tests b/203201161: DenseSpec has both a _shape and a _shape_tuple # field, and we need to be sure both get updated. kt = keras_tensor.KerasTensor(tf.TensorSpec([8, None], tf.int32)) kt.set_shape([8, 3]) self.assertEqual(kt.type_spec, tf.TensorSpec([8, 3], tf.int32))
def test_shape(self, spec, expected_shape): kt = keras_tensor.KerasTensor(spec) self.assertEqual(kt.shape.as_list(), expected_shape)