def test_data_buffer(self): dim = 20 capacity = 256 data_spec = (TensorSpec(shape=()), TensorSpec(shape=(dim // 3 - 1, )), TensorSpec(shape=(dim - dim // 3, ))) data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity) def _get_batch(batch_size): x = torch.randn(batch_size, dim, requires_grad=True) x = (x[:, 0], x[:, 1:dim // 3], x[..., dim // 3:]) return x data_buffer.add_batch(_get_batch(100)) self.assertEqual(int(data_buffer.current_size), 100) batch = _get_batch(1000) # test that the created batch has gradients self.assertTrue(batch[0].requires_grad) data_buffer.add_batch(batch) ret = data_buffer.get_batch(2) # test that DataBuffer detaches gradients of inputs self.assertFalse(ret[0].requires_grad) self.assertEqual(int(data_buffer.current_size), capacity) ret = data_buffer.get_batch_by_indices(torch.arange(capacity)) self.assertEqual(ret[0], batch[0][-capacity:]) self.assertEqual(ret[1], batch[1][-capacity:]) self.assertEqual(ret[2], batch[2][-capacity:]) batch = _get_batch(100) data_buffer.add_batch(batch) ret = data_buffer.get_batch_by_indices( torch.arange(data_buffer.current_size - 100, data_buffer.current_size)) self.assertEqual(ret[0], batch[0]) self.assertEqual(ret[1], batch[1]) self.assertEqual(ret[2], batch[2][-capacity:]) # Test checkpoint working with tempfile.TemporaryDirectory() as checkpoint_directory: checkpoint = Checkpointer(checkpoint_directory, data_buffer=data_buffer) checkpoint.save(10) data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity) checkpoint = Checkpointer(checkpoint_directory, data_buffer=data_buffer) global_step = checkpoint.load() self.assertEqual(global_step, 10) ret = data_buffer.get_batch_by_indices( torch.arange(data_buffer.current_size - 100, data_buffer.current_size)) self.assertEqual(ret[0], batch[0]) self.assertEqual(ret[1], batch[1]) self.assertEqual(ret[2], batch[2][-capacity:]) data_buffer.clear() self.assertEqual(int(data_buffer.current_size), 0)
def test_data_buffer(self): dim = 20 capacity = 256 data_spec = (tf.TensorSpec(shape=(), dtype=tf.float32), tf.TensorSpec(shape=(dim // 3 - 1, ), dtype=tf.float32), tf.TensorSpec(shape=(dim - dim // 3, ), dtype=tf.float32)) data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity) def _get_batch(batch_size): x = tf.random.normal(shape=(batch_size, dim)) x = (x[:, 0], x[:, 1:dim // 3], x[..., dim // 3:]) return x data_buffer.add_batch(_get_batch(100)) self.assertEqual(int(data_buffer.current_size), 100) batch = _get_batch(1000) data_buffer.add_batch(batch) self.assertEqual(int(data_buffer.current_size), capacity) ret = data_buffer.get_batch_by_indices(tf.range(capacity)) self.assertArrayEqual(ret[0], batch[0][-capacity:]) self.assertArrayEqual(ret[1], batch[1][-capacity:]) self.assertArrayEqual(ret[2], batch[2][-capacity:]) batch = _get_batch(100) data_buffer.add_batch(batch) ret = data_buffer.get_batch_by_indices( tf.range(data_buffer.current_size - 100, data_buffer.current_size)) self.assertArrayEqual(ret[0], batch[0]) self.assertArrayEqual(ret[1], batch[1]) self.assertArrayEqual(ret[2], batch[2][-capacity:]) # Test checkpoint working with tempfile.TemporaryDirectory() as checkpoint_directory: checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") checkpoint = tf.train.Checkpoint(data_buffer=data_buffer) checkpoint.save(file_prefix=checkpoint_prefix) data_buffer = DataBuffer(data_spec=data_spec, capacity=capacity) checkpoint = tf.train.Checkpoint(data_buffer=data_buffer) status = checkpoint.restore( tf.train.latest_checkpoint(checkpoint_directory)) status.assert_consumed() ret = data_buffer.get_batch_by_indices( tf.range(data_buffer.current_size - 100, data_buffer.current_size)) self.assertArrayEqual(ret[0], batch[0]) self.assertArrayEqual(ret[1], batch[1]) self.assertArrayEqual(ret[2], batch[2][-capacity:])