class SyncUniformExperienceReplayer(ExperienceReplayer): """ For synchronous off-policy training. Example algorithms: DDPG, SAC """ def __init__(self, experience_spec, batch_size): self._experience_spec = experience_spec self._buffer = ReplayBuffer(experience_spec, batch_size) self._data_iter = None @tf.function def observe(self, exp): """Store one batch of experience into replay buffer. Args: exp (Experience): input experience to be stored. For the sync driver, `exp` has the shape (`env_batch_size`, ...) with `num_envs`==1 and `unroll_length`==1. """ outer_rank = get_outer_rank(exp, self._experience_spec) if outer_rank == 1: self._buffer.add_batch(exp, exp.env_id) elif outer_rank == 3: # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...] for q in tf.range(tf.shape(exp.step_type)[0]): for t in tf.range(tf.shape(exp.step_type)[1]): bat = tf.nest.map_structure(lambda x: x[q, t, ...], exp) self._buffer.add_batch(bat, bat.env_id) else: raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank) def replay(self, sample_batch_size, mini_batch_length): """Get a random batch. Args: sample_batch_size (int): number of sequences mini_batch_length (int): the length of each sequence Returns: Experience: experience batch in batch major (B, T, ...) """ return self._buffer.get_batch(sample_batch_size, mini_batch_length) def replay_all(self): return self._buffer.gather_all() def clear(self): self._buffer.clear() @property def batch_size(self): return self._buffer.num_environments
class CyclicOneTimeExperienceReplayer(SyncUniformExperienceReplayer): """ A one-time experience replayer that stores a total of T + 1 timesteps so that every T timesteps of rollout, the last step of the previous rollout plus the new T timesteps are stored and used in training. This is to ensure every timestep is used in computing loss. Every replay_all() is assumed to follow an observe() sequentially. Not thread-safe. Example algorithms: IMPALA, PPO2 """ def __init__(self, experience_spec, batch_size, num_actors, rollout_length, learn_queue_cap): # assign max_length of buffer as rollout_length + 1 to ensure # all timesteps are used in training. self._experience_spec = experience_spec self._buffer = ReplayBuffer(experience_spec, batch_size, max_length=rollout_length + 1) self._data_iter = None assert num_actors > 0 assert batch_size % num_actors == 0 assert learn_queue_cap > 0 assert num_actors % learn_queue_cap == 0 # replay_buffer can contain exp from all actors, but only stores or # retrieves from a fixed number of actors as many as learn_queue_cap # every observe or replay_all. # # We store the env_ids corresponding to the observe, and later # retrieve the corresponding experience from buffer using these ids. shape = (batch_size // num_actors * learn_queue_cap, 1) self._env_id_shape = shape self._env_id = tf.Variable(tf.zeros(shape, dtype=tf.int32), shape=shape, dtype=tf.int32) # number of exp's buffered, to check every observe is replayed once. self._buffered = tf.Variable(tf.zeros((), dtype=tf.int32)) @tf.function def observe(self, exp): """Store one batch of experience into replay buffer. Args: exp (Experience): input experience to be stored. """ super().observe(exp) tf.debugging.assert_equal(self._buffered, 0) # get batch env_ids from only the first TimeStep self._env_id.assign(tf.reshape(exp.env_id[:, 0], self._env_id_shape)) self._buffered.assign_add(1) @tf.function def replay_all(self): # Only replays the last gathered batch of environments, # which corresponds to one actor in the case of IMPALA. tf.debugging.assert_equal(self._buffered, 1) self._buffered.assign_add(-1) return self._buffer.gather_all(self._env_id) def clear(self): # No need to clear, as new batch of timesteps overwrites the oldest # timesteps in the buffer. pass def replay(self, unused_sample_batch_size, unused_mini_batch_length): """Get a random batch. Args: unused_sample_batch_size (int): number of sequences unused_mini_batch_length (int): the length of each sequence Not to be used """ raise Exception('Should not be called, use replay_all instead.')
def test_replay_buffer(self): dim = 20 max_length = 4 num_envs = 8 data_spec = DataItem( env_id=tf.TensorSpec(shape=(), dtype=tf.int32), x=tf.TensorSpec(shape=(dim, ), dtype=tf.float32), t=tf.TensorSpec(shape=(), dtype=tf.int32)) replay_buffer = ReplayBuffer( data_spec=data_spec, num_environments=num_envs, max_length=max_length) def _get_batch(env_ids, t, x): batch_size = len(env_ids) x = (x * tf.expand_dims(tf.range(batch_size, dtype=tf.float32), 1) * tf.expand_dims(tf.range(dim, dtype=tf.float32), 0)) return DataItem( env_id=tf.constant(env_ids), x=x, t=t * tf.ones((batch_size, ), tf.int32)) batch1 = _get_batch([0, 4, 7], t=0, x=0.1) replay_buffer.add_batch(batch1, batch1.env_id) self.assertArrayEqual(replay_buffer._current_size, [1, 0, 0, 0, 1, 0, 0, 1]) self.assertArrayEqual(replay_buffer._current_pos, [1, 0, 0, 0, 1, 0, 0, 1]) with self.assertRaises(tf.errors.InvalidArgumentError): replay_buffer.get_batch(8, 1) batch2 = _get_batch([1, 2, 3, 5, 6], t=0, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) self.assertArrayEqual(replay_buffer._current_size, [1, 1, 1, 1, 1, 1, 1, 1]) self.assertArrayEqual(replay_buffer._current_pos, [1, 1, 1, 1, 1, 1, 1, 1]) batch = replay_buffer.gather_all() self.assertEqual(batch.t.shape, [8, 1]) with self.assertRaises(tf.errors.InvalidArgumentError): replay_buffer.get_batch(8, 2) replay_buffer.get_batch(13, 1) batch = replay_buffer.get_batch(8, 1) # squeeze the time dimension batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1), batch) bat1 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch1.env_id, axis=0), batch) bat2 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch) self.assertArrayEqual(bat1.env_id, batch1.env_id) self.assertArrayEqual(bat1.x, batch1.x) self.assertArrayEqual(bat1.t, batch1.t) self.assertArrayEqual(bat2.env_id, batch2.env_id) self.assertArrayEqual(bat2.x, batch2.x) self.assertArrayEqual(bat2.t, batch2.t) for t in range(1, 10): batch3 = _get_batch([0, 4, 7], t=t, x=0.3) j = (t + 1) % max_length s = min(t + 1, max_length) replay_buffer.add_batch(batch3, batch3.env_id) self.assertArrayEqual(replay_buffer._current_size, [s, 1, 1, 1, s, 1, 1, s]) self.assertArrayEqual(replay_buffer._current_pos, [j, 1, 1, 1, j, 1, 1, j]) batch2 = _get_batch([1, 2, 3, 5, 6], t=1, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) batch = replay_buffer.get_batch(8, 1) # squeeze the time dimension batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1), batch) bat3 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch) bat2 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch) self.assertArrayEqual(bat3.env_id, batch3.env_id) self.assertArrayEqual(bat3.x, batch3.x) self.assertArrayEqual(bat2.env_id, batch2.env_id) self.assertArrayEqual(bat2.x, batch2.x) batch = replay_buffer.get_batch(8, 2) t2 = [] t3 = [] for t in range(2): batch_t = tf.nest.map_structure(lambda b: b[:, t], batch) bat3 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch_t) bat2 = tf.nest.map_structure( lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch_t) t2.append(bat2.t) self.assertArrayEqual(bat3.env_id, batch3.env_id) self.assertArrayEqual(bat3.x, batch3.x) self.assertArrayEqual(bat2.env_id, batch2.env_id) self.assertArrayEqual(bat2.x, batch2.x) t3.append(bat3.t) # Test time consistency self.assertArrayEqual(t2[0] + 1, t2[1]) self.assertArrayEqual(t3[0] + 1, t3[1]) batch = replay_buffer.get_batch(128, 2) self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(batch.t.shape, [128, 2]) batch = replay_buffer.get_batch(10, 2) self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(batch.t.shape, [10, 2]) batch = replay_buffer.get_batch(4, 2) self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(batch.t.shape, [4, 2]) # Test gather_all() with self.assertRaises(tf.errors.InvalidArgumentError): replay_buffer.gather_all() for t in range(2, 10): batch4 = _get_batch([1, 2, 3, 5, 6], t=t, x=0.4) replay_buffer.add_batch(batch4, batch4.env_id) batch = replay_buffer.gather_all() self.assertEqual(batch.t.shape, [8, 4])
class SyncExperienceReplayer(ExperienceReplayer): """ For synchronous off-policy training. Example algorithms: DDPG, SAC """ def __init__(self, experience_spec, batch_size, max_length, num_earliest_frames_ignored=0, prioritized_sampling=False, name="SyncExperienceReplayer"): """Create a ReplayBuffer. Args: data_experience_specspec (nested TensorSpec): spec describing a single item that can be stored in the replayer. batch_size (int): number of environments. max_length (int): The maximum number of items that can be stored for a single environment. num_earliest_frames_ignored (int): ignore the earlist so many frames when sample from the buffer. This is typically required when FrameStack is used. prioritized_sampling (bool): Use prioritized sampling if this is True. """ super().__init__() self._experience_spec = experience_spec self._buffer = ReplayBuffer( experience_spec, batch_size, max_length=max_length, prioritized_sampling=prioritized_sampling, num_earliest_frames_ignored=num_earliest_frames_ignored, name=name) self._data_iter = None def observe(self, exp): """ For the sync driver, `exp` has the shape (`env_batch_size`, ...) with `num_envs`==1 and `unroll_length`==1. """ outer_rank = alf.nest.utils.get_outer_rank(exp, self._experience_spec) if outer_rank == 1: self._buffer.add_batch(exp, exp.env_id) elif outer_rank == 3: # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...] for q in range(exp.step_type.shape[0]): for t in range(exp.step_type.shape[1]): bat = alf.nest.map_structure(lambda x: x[q, t, ...], exp) self._buffer.add_batch(bat, bat.env_id) else: raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank) def replay(self, sample_batch_size, mini_batch_length): """Get a random batch. Args: sample_batch_size (int): number of sequences mini_batch_length (int): the length of each sequence Returns: tuple: - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...] - BatchInfo: Information about the batch. Its shapes are [batch_size]. - env_ids: environment id for each sequence - positions: starting position in the replay buffer for each sequence. - importance_weights: importance weight divided by the average of all non-zero importance weights in the buffer. """ return self._buffer.get_batch(sample_batch_size, mini_batch_length) def replay_all(self): return self._buffer.gather_all() def clear(self): self._buffer.clear() def update_priority(self, env_ids, positions, priorities): """Update the priorities for the given experiences. Args: env_ids (Tensor): 1-D int64 Tensor. positions (Tensor): 1-D int64 Tensor with same shape as ``env_ids``. This position should be obtained the BatchInfo returned by ``get_batch()`` """ self._buffer.update_priority(env_ids, positions, priorities) @property def batch_size(self): return self._buffer.num_environments @property def total_size(self): return self._buffer.total_size @property def replay_buffer(self): return self._buffer
def test_replay_buffer(self, allow_multiprocess, with_replacement): replay_buffer = ReplayBuffer(data_spec=self.data_spec, num_environments=self.num_envs, max_length=self.max_length, allow_multiprocess=allow_multiprocess) batch1 = get_batch([0, 4, 7], self.dim, t=0, x=0.1) replay_buffer.add_batch(batch1, batch1.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([1, 0, 0, 0, 1, 0, 0, 1])) self.assertEqual(replay_buffer._current_pos, torch.tensor([1, 0, 0, 0, 1, 0, 0, 1])) self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 1) batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=0, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([1, 1, 1, 1, 1, 1, 1, 1])) self.assertEqual(replay_buffer._current_pos, torch.tensor([1, 1, 1, 1, 1, 1, 1, 1])) batch = replay_buffer.gather_all() self.assertEqual(list(batch.t.shape), [8, 1]) # test that RingBuffer detaches gradients of inputs self.assertFalse(batch.x.requires_grad) self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 2) replay_buffer.get_batch(13, 1)[0] batch = replay_buffer.get_batch(8, 1)[0] # squeeze the time dimension batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch) bat1 = alf.nest.map_structure(lambda bat: bat[batch1.env_id], batch) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch) self.assertEqual(bat1.env_id, batch1.env_id) self.assertEqual(bat1.x, batch1.x) self.assertEqual(bat1.t, batch1.t) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) self.assertEqual(bat2.t, batch2.t) for t in range(1, 10): batch3 = get_batch([0, 4, 7], self.dim, t=t, x=0.3) j = t + 1 s = min(t + 1, self.max_length) replay_buffer.add_batch(batch3, batch3.env_id) self.assertEqual(replay_buffer._current_size, torch.tensor([s, 1, 1, 1, s, 1, 1, s])) self.assertEqual(replay_buffer._current_pos, torch.tensor([j, 1, 1, 1, j, 1, 1, j])) batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=1, x=0.2) replay_buffer.add_batch(batch2, batch2.env_id) batch = replay_buffer.get_batch(8, 1)[0] # squeeze the time dimension batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch) bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch) self.assertEqual(bat3.env_id, batch3.env_id) self.assertEqual(bat3.x, batch3.x) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) batch = replay_buffer.get_batch(8, 2)[0] t2 = [] t3 = [] for t in range(2): batch_t = alf.nest.map_structure(lambda b: b[:, t], batch) bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch_t) bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch_t) t2.append(bat2.t) self.assertEqual(bat3.env_id, batch3.env_id) self.assertEqual(bat3.x, batch3.x) self.assertEqual(bat2.env_id, batch2.env_id) self.assertEqual(bat2.x, batch2.x) t3.append(bat3.t) # Test time consistency self.assertEqual(t2[0] + 1, t2[1]) self.assertEqual(t3[0] + 1, t3[1]) batch = replay_buffer.get_batch(128, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [128, 2]) batch = replay_buffer.get_batch(10, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [10, 2]) batch = replay_buffer.get_batch(4, 2)[0] self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1]) self.assertEqual(list(batch.t.shape), [4, 2]) # Test gather_all() # Exception because the size of all the environments are not same self.assertRaises(AssertionError, replay_buffer.gather_all) for t in range(2, 10): batch4 = get_batch([1, 2, 3, 5, 6], self.dim, t=t, x=0.4) replay_buffer.add_batch(batch4, batch4.env_id) batch = replay_buffer.gather_all() self.assertEqual(list(batch.t.shape), [8, 4]) # Test clear() replay_buffer.clear() self.assertEqual(replay_buffer.total_size, 0)