示例#1
0
    def test_serialize_deserialize2(self):
        td = self.td
        td = td.map(lambda x, y: x)
        batch = next(iter(td))

        td = td.map(serialize_to_example)
        td = td.map(make_dataset_deserialize_fn(td))

        batch_d = next(iter(td))
        self.assertAllEqual(batch, batch_d)
示例#2
0
    def test_serialize_deserialize0(self):
        td = self.td
        batch = next(iter(td))

        td = td.map(serialize_to_example)
        td = td.map(make_dataset_deserialize_fn(td))

        batch_d = next(iter(td))
        self.assertAllEqual(batch[0], batch_d[0])
        self.assertAllEqual(batch[1], batch_d[1])
示例#3
0
    def test_serialize_deserialize1(self):
        td = self.td
        td = td.map(lambda x, y: (x, tf.cast(x, tf.float32), y))
        batch = next(iter(td))

        td = td.map(serialize_to_example)
        td = td.map(make_dataset_deserialize_fn(td))

        batch_d = next(iter(td))

        self.assertAllEqual(batch[0], batch_d[0])
        self.assertAllEqual(batch[1], batch_d[1])
        self.assertAllEqual(batch[2], batch_d[2])
示例#4
0
    def test_serialize_deserialize_shape3(self):
        td = self.td
        batch = next(iter(td))

        td = td.map(serialize_to_example)
        td = td.map(
            make_dataset_deserialize_fn(td, set_shape=True,
                                        set_dimension=True))

        self.assertEqual(td.element_spec[0],
                         tf.TensorSpec(shape=(28, 28, 3), dtype=tf.uint8))
        self.assertEqual(td.element_spec[1],
                         tf.TensorSpec(shape=(), dtype=tf.int64, name=None))

        batch_d = next(iter(td))
        self.assertAllEqual(batch[0], batch_d[0])
        self.assertAllEqual(batch[1], batch_d[1])
示例#5
0
    def test_serialize_deserialize_variable_shape0(self):
        td = self.td
        td = td.map(random_size)
        batch = next(iter(td))

        td = td.map(serialize_to_example)
        td = td.map(
            make_dataset_deserialize_fn(td,
                                        set_shape=False,
                                        set_dimension=False))

        self.assertEqual(td.element_spec[0],
                         tf.TensorSpec(shape=None, dtype=tf.float32))
        self.assertEqual(td.element_spec[1],
                         tf.TensorSpec(shape=None, dtype=tf.int64, name=None))

        batch_d = next(iter(td))
        self.assertAllEqual(batch[0], batch_d[0])
        self.assertAllEqual(batch[1], batch_d[1])
示例#6
0
    def test_serialize_deserialize_variable_shape2(self):
        td = self.td

        s0 = tf.cast(
            tf.random.uniform([16, 16, 3], maxval=255, dtype=tf.int32),
            tf.uint8)
        s1 = tf.random.uniform((), maxval=255, dtype=tf.int64)
        td_v = tf.data.Dataset.from_tensors((s0, s1))
        td = td_v.concatenate(td)

        td = td.map(serialize_to_example)
        td = td.map(
            make_dataset_deserialize_fn(td,
                                        set_shape=True,
                                        set_dimension=False))

        it = iter(td)
        x, y = next(it)
        with pytest.raises(InvalidArgumentError):
            x, y = next(it)