def add(self, batch: SampleBatchType, **kwargs) -> None: """Adds a batch of experiences. Args: batch: SampleBatch to add to this buffer's storage. """ # Update add counts. self._num_add_calls += 1 # Update our timesteps counts. self._num_timesteps_added += batch.count self._num_timesteps_added_wrap += batch.count if self._num_timesteps_added < self.capacity: ReplayBuffer.add(self, batch) else: # Eviction of older samples has already started (buffer is "full") self._eviction_started = True idx = random.randint(0, self._num_add_calls - 1) if idx < self.capacity: self._num_evicted += 1 self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 self._storage[idx] = batch assert batch.count > 0, batch warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)
def test_episodes_unit(self): """Tests adding, sampling, and eviction of episodes.""" buffer = ReplayBuffer(capacity=18, storage_unit="episodes") batches = [ SampleBatch({ SampleBatch.T: [0, 1, 2, 3], SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 4 * [np.random.rand()], SampleBatch.DONES: [False, False, False, True], SampleBatch.SEQ_LENS: [4], SampleBatch.EPS_ID: 4 * [i], }) for i in range(3) ] batches.append( SampleBatch({ SampleBatch.T: [0, 1, 0, 1], SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 4 * [np.random.rand()], SampleBatch.DONES: [False, True, False, True], SampleBatch.SEQ_LENS: [2, 2], SampleBatch.EPS_ID: [3, 3, 4, 4], })) for batch in batches: buffer.add(batch) num_sampled_dict = {_id: 0 for _id in range(5)} num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) _id = sample[SampleBatch.EPS_ID][0] assert len(sample[SampleBatch.SEQ_LENS]) == 1 num_sampled_dict[_id] += 1 # All episodes, even though in different batches should be sampled # equally often assert np.allclose( np.array(list(num_sampled_dict.values())) / num_samples, [1 / 5, 1 / 5, 1 / 5, 1 / 5, 1 / 5], atol=0.1, ) # Episode 6 is not entirely inside this batch, it should not be added # to the buffer buffer.add( SampleBatch({ SampleBatch.T: [0, 1, 0, 1], SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 4 * [np.random.rand()], SampleBatch.DONES: [False, True, False, False], SampleBatch.SEQ_LENS: [2, 2], SampleBatch.EPS_ID: [5, 5, 6, 6], })) num_sampled_dict = {_id: 0 for _id in range(7)} num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) _id = sample[SampleBatch.EPS_ID][0] assert len(sample[SampleBatch.SEQ_LENS]) == 1 num_sampled_dict[_id] += 1 # Episode 7 should be dropped for not ending inside the batch assert np.allclose( np.array(list(num_sampled_dict.values())) / num_samples, [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 0], atol=0.1, ) # Add another batch to evict the first batch buffer.add( SampleBatch({ SampleBatch.T: [0, 1, 2, 3], SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 4 * [np.random.rand()], SampleBatch.DONES: [False, False, False, True], SampleBatch.SEQ_LENS: [4], SampleBatch.EPS_ID: 4 * [7], })) # After adding 1 more batch, eviction has started with 24 # timesteps added in total, 2 of which were discarded assert len(buffer) == 6 assert buffer._num_timesteps_added == 4 * 6 - 2 assert buffer._num_timesteps_added_wrap == 4 assert buffer._next_idx == 1 assert buffer._eviction_started is True num_sampled_dict = {_id: 0 for _id in range(8)} num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) _id = sample[SampleBatch.EPS_ID][0] assert len(sample[SampleBatch.SEQ_LENS]) == 1 num_sampled_dict[_id] += 1 assert np.allclose( np.array(list(num_sampled_dict.values())) / num_samples, [0, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 0, 1 / 6], atol=0.1, )
def test_sequences_unit(self): """Tests adding, sampling and eviction of sequences.""" buffer = ReplayBuffer(capacity=10, storage_unit="sequences") batches = [ SampleBatch({ SampleBatch.T: i * [np.random.random((4, ))], SampleBatch.ACTIONS: i * [np.random.choice([0, 1])], SampleBatch.REWARDS: i * [np.random.rand()], SampleBatch.DONES: i * [np.random.choice([False, True])], SampleBatch.SEQ_LENS: [i], "batch_id": i * [i], }) for i in range(1, 4) ] batches.append( SampleBatch({ SampleBatch.T: 4 * [np.random.random((4, ))], SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 4 * [np.random.rand()], SampleBatch.DONES: 4 * [np.random.choice([False, True])], SampleBatch.SEQ_LENS: [2, 2], "batch_id": 4 * [4], })) for batch in batches: buffer.add(batch) num_sampled_dict = {_id: 0 for _id in range(1, 5)} num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) _id = sample["batch_id"][0] assert len(sample[SampleBatch.SEQ_LENS]) == 1 num_sampled_dict[_id] += 1 # Out of five sequences, we want to sequences from the last batch to # be sampled twice as often, because they are stored separately assert np.allclose( np.array(list(num_sampled_dict.values())) / num_samples, [1 / 5, 1 / 5, 1 / 5, 2 / 5], atol=0.1, ) # Add another batch to evict buffer.add( SampleBatch({ SampleBatch.T: 5 * [np.random.random((4, ))], SampleBatch.ACTIONS: 5 * [np.random.choice([0, 1])], SampleBatch.REWARDS: 5 * [np.random.rand()], SampleBatch.DONES: 5 * [np.random.choice([False, True])], SampleBatch.SEQ_LENS: [5], "batch_id": 5 * [5], })) # After adding 1 more batch, eviction has started with 15 # timesteps added in total assert len(buffer) == 5 assert buffer._num_timesteps_added == sum(range(1, 6)) assert buffer._num_timesteps_added_wrap == 5 assert buffer._next_idx == 1 assert buffer._eviction_started is True # The first batch should now not be sampled anymore, other batches # should be sampled as before num_sampled_dict = {_id: 0 for _id in range(2, 6)} num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) _id = sample["batch_id"][0] assert len(sample[SampleBatch.SEQ_LENS]) == 1 num_sampled_dict[_id] += 1 assert np.allclose( np.array(list(num_sampled_dict.values())) / num_samples, [1 / 5, 1 / 5, 2 / 5, 1 / 5], atol=0.1, )