Ejemplo n.º 1
0
 def test_set_shape_error(self):
     spec = CustomTypeSpec([3, None], tf.int32)
     kt = keras_tensor.KerasTensor(spec)
     with self.assertRaisesRegex(
         ValueError, "Keras requires TypeSpec to have a `with_shape` method"
     ):
         kt.set_shape([3, 3])
Ejemplo n.º 2
0
 def test_missing_shape_error(self):
     spec = CustomTypeSpec(None, tf.int32)
     del spec.shape
     with self.assertRaisesRegex(
         ValueError,
         "KerasTensor only supports TypeSpecs that have a shape field; .*",
     ):
         keras_tensor.KerasTensor(spec)
Ejemplo n.º 3
0
 def test_set_shape(self, spec, new_shape, expected_shape):
     kt = keras_tensor.KerasTensor(spec)
     kt.set_shape(new_shape)
     if expected_shape is None:
         self.assertIsNone(kt.type_spec.shape.rank)
     else:
         self.assertEqual(kt.type_spec.shape.as_list(), expected_shape)
     self.assertTrue(kt.type_spec.is_compatible_with(spec))
Ejemplo n.º 4
0
 def test_missing_dtype_error(self):
     spec = CustomTypeSpec(None, tf.int32)
     del spec.dtype
     kt = keras_tensor.KerasTensor(spec)
     with self.assertRaisesRegex(
             AttributeError,
             "KerasTensor wraps TypeSpec .* which does not have a dtype."):
         kt.dtype  # pylint: disable=pointless-statement
Ejemplo n.º 5
0
 def test_wrong_shape_type_error(self):
     spec = CustomTypeSpec(None, tf.int32)
     spec.shape = "foo"
     with self.assertRaisesRegex(
             TypeError,
             "KerasTensor requires that wrapped TypeSpec's shape is a "
             "TensorShape; .*"):
         keras_tensor.KerasTensor(spec)
Ejemplo n.º 6
0
 def test_wrong_dtype_type_error(self):
     spec = CustomTypeSpec(None, tf.int32)
     spec.dtype = "foo"
     kt = keras_tensor.KerasTensor(spec)
     with self.assertRaisesRegex(
         TypeError,
         "KerasTensor requires that wrapped TypeSpec's dtype is a DType; .*",
     ):
         kt.dtype
Ejemplo n.º 7
0
    def test_repr_and_string(self):
        kt = keras_tensor.KerasTensor(
            type_spec=tf.TensorSpec(shape=(1, 2, 3), dtype=tf.float32))
        expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(1, 2, 3), "
                        "dtype=tf.float32, name=None))")
        expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>"
        self.assertEqual(expected_str, str(kt))
        self.assertEqual(expected_repr, repr(kt))

        kt = keras_tensor.KerasTensor(type_spec=tf.TensorSpec(shape=(2, ),
                                                              dtype=tf.int32),
                                      inferred_value=[2, 3])
        expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(2,), "
                        "dtype=tf.int32, name=None), inferred_value=[2, 3])")
        expected_repr = (
            "<KerasTensor: shape=(2,) dtype=int32 inferred_value=[2, 3]>")
        self.assertEqual(expected_str, str(kt))
        self.assertEqual(expected_repr, repr(kt))

        kt = keras_tensor.KerasTensor(
            type_spec=tf.SparseTensorSpec(shape=(1, 2, 3), dtype=tf.float32))
        expected_str = ("KerasTensor(type_spec=SparseTensorSpec("
                        "TensorShape([1, 2, 3]), tf.float32))")
        expected_repr = ("<KerasTensor: type_spec=SparseTensorSpec("
                         "TensorShape([1, 2, 3]), tf.float32)>")
        self.assertEqual(expected_str, str(kt))
        self.assertEqual(expected_repr, repr(kt))

        inp = layers.Input(shape=(3, 5))
        kt = layers.Dense(10)(inp)
        expected_str = (
            "KerasTensor(type_spec=TensorSpec(shape=(None, 3, 10), "
            "dtype=tf.float32, name=None), name='dense/BiasAdd:0', "
            "description=\"created by layer 'dense'\")")
        expected_repr = (
            "<KerasTensor: shape=(None, 3, 10) dtype=float32 (created "
            "by layer 'dense')>")
        self.assertEqual(expected_str, str(kt))
        self.assertEqual(expected_repr, repr(kt))

        kt = tf.reshape(kt, shape=(3, 5, 2))
        expected_str = (
            "KerasTensor(type_spec=TensorSpec(shape=(3, 5, 2), dtype=tf.float32, "
            "name=None), name='tf.reshape/Reshape:0', description=\"created "
            "by layer 'tf.reshape'\")")
        expected_repr = (
            "<KerasTensor: shape=(3, 5, 2) dtype=float32 (created "
            "by layer 'tf.reshape')>")
        self.assertEqual(expected_str, str(kt))
        self.assertEqual(expected_repr, repr(kt))

        kts = tf.unstack(kt)
        for i in range(3):
            expected_str = (
                "KerasTensor(type_spec=TensorSpec(shape=(5, 2), dtype=tf.float32, "
                "name=None), name='tf.unstack/unstack:%s', description=\"created "
                "by layer 'tf.unstack'\")" % (i, ))
            expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
                             "(created by layer 'tf.unstack')>")
            self.assertEqual(expected_str, str(kts[i]))
            self.assertEqual(expected_repr, repr(kts[i]))
Ejemplo n.º 8
0
 def test_set_shape_equals_expected_shape(self):
     # Tests b/203201161: DenseSpec has both a _shape and a _shape_tuple
     # field, and we need to be sure both get updated.
     kt = keras_tensor.KerasTensor(tf.TensorSpec([8, None], tf.int32))
     kt.set_shape([8, 3])
     self.assertEqual(kt.type_spec, tf.TensorSpec([8, 3], tf.int32))
Ejemplo n.º 9
0
 def test_shape(self, spec, expected_shape):
     kt = keras_tensor.KerasTensor(spec)
     self.assertEqual(kt.shape.as_list(), expected_shape)