Exemplo n.º 1
0
    def test_data_buffer(self):
        dim = 20
        capacity = 256
        data_spec = (TensorSpec(shape=()), TensorSpec(shape=(dim // 3 - 1, )),
                     TensorSpec(shape=(dim - dim // 3, )))

        data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)

        def _get_batch(batch_size):
            x = torch.randn(batch_size, dim, requires_grad=True)
            x = (x[:, 0], x[:, 1:dim // 3], x[..., dim // 3:])
            return x

        data_buffer.add_batch(_get_batch(100))
        self.assertEqual(int(data_buffer.current_size), 100)
        batch = _get_batch(1000)
        # test that the created batch has gradients
        self.assertTrue(batch[0].requires_grad)
        data_buffer.add_batch(batch)
        ret = data_buffer.get_batch(2)
        # test that DataBuffer detaches gradients of inputs
        self.assertFalse(ret[0].requires_grad)
        self.assertEqual(int(data_buffer.current_size), capacity)
        ret = data_buffer.get_batch_by_indices(torch.arange(capacity))
        self.assertEqual(ret[0], batch[0][-capacity:])
        self.assertEqual(ret[1], batch[1][-capacity:])
        self.assertEqual(ret[2], batch[2][-capacity:])
        batch = _get_batch(100)
        data_buffer.add_batch(batch)
        ret = data_buffer.get_batch_by_indices(
            torch.arange(data_buffer.current_size - 100,
                         data_buffer.current_size))
        self.assertEqual(ret[0], batch[0])
        self.assertEqual(ret[1], batch[1])
        self.assertEqual(ret[2], batch[2][-capacity:])

        # Test checkpoint working
        with tempfile.TemporaryDirectory() as checkpoint_directory:
            checkpoint = Checkpointer(checkpoint_directory,
                                      data_buffer=data_buffer)
            checkpoint.save(10)
            data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)
            checkpoint = Checkpointer(checkpoint_directory,
                                      data_buffer=data_buffer)
            global_step = checkpoint.load()
            self.assertEqual(global_step, 10)

        ret = data_buffer.get_batch_by_indices(
            torch.arange(data_buffer.current_size - 100,
                         data_buffer.current_size))
        self.assertEqual(ret[0], batch[0])
        self.assertEqual(ret[1], batch[1])
        self.assertEqual(ret[2], batch[2][-capacity:])

        data_buffer.clear()
        self.assertEqual(int(data_buffer.current_size), 0)
Exemplo n.º 2
0
    def test_data_buffer(self):
        dim = 20
        capacity = 256
        data_spec = (tf.TensorSpec(shape=(), dtype=tf.float32),
                     tf.TensorSpec(shape=(dim // 3 - 1, ), dtype=tf.float32),
                     tf.TensorSpec(shape=(dim - dim // 3, ), dtype=tf.float32))

        data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)

        def _get_batch(batch_size):
            x = tf.random.normal(shape=(batch_size, dim))
            x = (x[:, 0], x[:, 1:dim // 3], x[..., dim // 3:])
            return x

        data_buffer.add_batch(_get_batch(100))
        self.assertEqual(int(data_buffer.current_size), 100)
        batch = _get_batch(1000)
        data_buffer.add_batch(batch)
        self.assertEqual(int(data_buffer.current_size), capacity)
        ret = data_buffer.get_batch_by_indices(tf.range(capacity))
        self.assertArrayEqual(ret[0], batch[0][-capacity:])
        self.assertArrayEqual(ret[1], batch[1][-capacity:])
        self.assertArrayEqual(ret[2], batch[2][-capacity:])
        batch = _get_batch(100)
        data_buffer.add_batch(batch)
        ret = data_buffer.get_batch_by_indices(
            tf.range(data_buffer.current_size - 100, data_buffer.current_size))
        self.assertArrayEqual(ret[0], batch[0])
        self.assertArrayEqual(ret[1], batch[1])
        self.assertArrayEqual(ret[2], batch[2][-capacity:])

        # Test checkpoint working
        with tempfile.TemporaryDirectory() as checkpoint_directory:
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            checkpoint = tf.train.Checkpoint(data_buffer=data_buffer)
            checkpoint.save(file_prefix=checkpoint_prefix)

            data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity)
            checkpoint = tf.train.Checkpoint(data_buffer=data_buffer)
            status = checkpoint.restore(
                tf.train.latest_checkpoint(checkpoint_directory))
            status.assert_consumed()

        ret = data_buffer.get_batch_by_indices(
            tf.range(data_buffer.current_size - 100, data_buffer.current_size))
        self.assertArrayEqual(ret[0], batch[0])
        self.assertArrayEqual(ret[1], batch[1])
        self.assertArrayEqual(ret[2], batch[2][-capacity:])