class SyncExperienceReplayer(ExperienceReplayer): """ For synchronous off-policy training. Example algorithms: DDPG, SAC """ def __init__(self, experience_spec, batch_size, max_length, num_earliest_frames_ignored=0, prioritized_sampling=False, name="SyncExperienceReplayer"): """Create a ReplayBuffer. Args: data_experience_specspec (nested TensorSpec): spec describing a single item that can be stored in the replayer. batch_size (int): number of environments. max_length (int): The maximum number of items that can be stored for a single environment. num_earliest_frames_ignored (int): ignore the earlist so many frames when sample from the buffer. This is typically required when FrameStack is used. prioritized_sampling (bool): Use prioritized sampling if this is True. """ super().__init__() self._experience_spec = experience_spec self._buffer = ReplayBuffer( experience_spec, batch_size, max_length=max_length, prioritized_sampling=prioritized_sampling, num_earliest_frames_ignored=num_earliest_frames_ignored, name=name) self._data_iter = None def observe(self, exp): """ For the sync driver, `exp` has the shape (`env_batch_size`, ...) with `num_envs`==1 and `unroll_length`==1. """ outer_rank = alf.nest.utils.get_outer_rank(exp, self._experience_spec) if outer_rank == 1: self._buffer.add_batch(exp, exp.env_id) elif outer_rank == 3: # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...] for q in range(exp.step_type.shape[0]): for t in range(exp.step_type.shape[1]): bat = alf.nest.map_structure(lambda x: x[q, t, ...], exp) self._buffer.add_batch(bat, bat.env_id) else: raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank) def replay(self, sample_batch_size, mini_batch_length): """Get a random batch. Args: sample_batch_size (int): number of sequences mini_batch_length (int): the length of each sequence Returns: tuple: - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...] - BatchInfo: Information about the batch. Its shapes are [batch_size]. - env_ids: environment id for each sequence - positions: starting position in the replay buffer for each sequence. - importance_weights: importance weight divided by the average of all non-zero importance weights in the buffer. """ return self._buffer.get_batch(sample_batch_size, mini_batch_length) def replay_all(self): return self._buffer.gather_all() def clear(self): self._buffer.clear() def update_priority(self, env_ids, positions, priorities): """Update the priorities for the given experiences. Args: env_ids (Tensor): 1-D int64 Tensor. positions (Tensor): 1-D int64 Tensor with same shape as ``env_ids``. This position should be obtained the BatchInfo returned by ``get_batch()`` """ self._buffer.update_priority(env_ids, positions, priorities) @property def batch_size(self): return self._buffer.num_environments @property def total_size(self): return self._buffer.total_size @property def replay_buffer(self): return self._buffer
def test_prioritized_replay(self): replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=self.num_envs, max_length=self.max_length, prioritized_sampling=True) self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1) batch1 = get_batch([1], self.dim, x=0.25, t=0) replay_buffer.add_batch(batch1, batch1.env_id) batch, batch_info = replay_buffer.get_batch(1, 1) self.assertEqual(batch_info.env_ids, torch.tensor([1], dtype=torch.int64)) self.assertEqual(batch_info.importance_weights, 1.) self.assertEqual(batch_info.importance_weights, torch.tensor([1.])) self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 2) batch2 = get_batch([1], self.dim, x=0.5, t=1) replay_buffer.add_batch(batch1, batch1.env_id) batch, batch_info = replay_buffer.get_batch(4, 2) self.assertEqual(batch_info.env_ids, torch.tensor([1], dtype=torch.int64)) self.assertEqual(batch_info.importance_weights, torch.tensor([1.])) self.assertEqual(batch_info.importance_weights, torch.tensor([1.] * 4)) batch, batch_info = replay_buffer.get_batch(1000, 1) n0 = (replay_buffer.circular(batch_info.positions) == 0).sum() n1 = (replay_buffer.circular(batch_info.positions) == 1).sum() self.assertEqual(n0, 500) self.assertEqual(n1, 500) replay_buffer.update_priority(env_ids=torch.tensor([1, 1], dtype=torch.int64), positions=torch.tensor( [0, 1], dtype=torch.int64), priorities=torch.tensor([0.5, 1.5])) batch, batch_info = replay_buffer.get_batch(1000, 1) n0 = (replay_buffer.circular(batch_info.positions) == 0).sum() n1 = (replay_buffer.circular(batch_info.positions) == 1).sum() self.assertEqual(n0, 250) self.assertEqual(n1, 750) batch2 = get_batch([0, 2], self.dim, x=0.5, t=1) replay_buffer.add_batch(batch2, batch2.env_id) batch, batch_info = replay_buffer.get_batch(1000, 1) def _get(env_id, pos): flag = ((batch_info.env_ids == env_id) * (batch_info.positions == replay_buffer._pad(pos, env_id))) w = batch_info.importance_weights[torch.nonzero(flag, as_tuple=True)[0]] return flag.sum(), w n0, w0 = _get(0, 0) n1, w1 = _get(1, 0) n2, w2 = _get(1, 1) n3, w3 = _get(2, 0) self.assertEqual(n0, 300) self.assertEqual(n1, 100) self.assertEqual(n2, 300) self.assertEqual(n3, 300) self.assertTrue(torch.all(w0 == 1.2)) self.assertTrue(torch.all(w1 == 0.4)) self.assertTrue(torch.all(w2 == 1.2)) self.assertTrue(torch.all(w3 == 1.2)) replay_buffer.update_priority(env_ids=torch.tensor([1, 2], dtype=torch.int64), positions=torch.tensor( [1, 0], dtype=torch.int64), priorities=torch.tensor([1.0, 1.0])) batch, batch_info = replay_buffer.get_batch(1000, 1) n0, w0 = _get(0, 0) n1, w1 = _get(1, 0) n2, w2 = _get(1, 1) n3, w3 = _get(2, 0) self.assertEqual(n0, 375) self.assertEqual(n1, 125) self.assertEqual(n2, 250) self.assertEqual(n3, 250) self.assertTrue(torch.all(w0 == 1.5)) self.assertTrue(torch.all(w1 == 0.5)) self.assertTrue(torch.all(w2 == 1.0)) self.assertTrue(torch.all(w3 == 1.0))