Exemple #1
0
    def add(self, batch: SampleBatchType, **kwargs) -> None:
        """Adds a batch of experiences.

        Args:
            batch: SampleBatch to add to this buffer's storage.
        """
        # Update add counts.
        self._num_add_calls += 1
        # Update our timesteps counts.
        self._num_timesteps_added += batch.count
        self._num_timesteps_added_wrap += batch.count

        if self._num_timesteps_added < self.capacity:
            ReplayBuffer.add(self, batch)
        else:
            # Eviction of older samples has already started (buffer is "full")
            self._eviction_started = True
            idx = random.randint(0, self._num_add_calls - 1)
            if idx < self.capacity:
                self._num_evicted += 1
                self._evicted_hit_stats.push(self._hit_count[idx])
                self._hit_count[idx] = 0
                self._storage[idx] = batch

                assert batch.count > 0, batch
                warn_replay_capacity(item=batch,
                                     num_items=self.capacity / batch.count)
Exemple #2
0
    def test_episodes_unit(self):
        """Tests adding, sampling, and eviction of episodes."""
        buffer = ReplayBuffer(capacity=18, storage_unit="episodes")

        batches = [
            SampleBatch({
                SampleBatch.T: [0, 1, 2, 3],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, False, False, True],
                SampleBatch.SEQ_LENS: [4],
                SampleBatch.EPS_ID: 4 * [i],
            }) for i in range(3)
        ]

        batches.append(
            SampleBatch({
                SampleBatch.T: [0, 1, 0, 1],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, True, False, True],
                SampleBatch.SEQ_LENS: [2, 2],
                SampleBatch.EPS_ID: [3, 3, 4, 4],
            }))

        for batch in batches:
            buffer.add(batch)

        num_sampled_dict = {_id: 0 for _id in range(5)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample[SampleBatch.EPS_ID][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        # All episodes, even though in different batches should be sampled
        # equally often
        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 5, 1 / 5, 1 / 5, 1 / 5, 1 / 5],
            atol=0.1,
        )

        # Episode 6 is not entirely inside this batch, it should not be added
        # to the buffer
        buffer.add(
            SampleBatch({
                SampleBatch.T: [0, 1, 0, 1],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, True, False, False],
                SampleBatch.SEQ_LENS: [2, 2],
                SampleBatch.EPS_ID: [5, 5, 6, 6],
            }))

        num_sampled_dict = {_id: 0 for _id in range(7)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample[SampleBatch.EPS_ID][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        # Episode 7 should be dropped for not ending inside the batch
        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 0],
            atol=0.1,
        )

        # Add another batch to evict the first batch
        buffer.add(
            SampleBatch({
                SampleBatch.T: [0, 1, 2, 3],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, False, False, True],
                SampleBatch.SEQ_LENS: [4],
                SampleBatch.EPS_ID: 4 * [7],
            }))

        # After adding 1 more batch, eviction has started with 24
        # timesteps added in total, 2 of which were discarded
        assert len(buffer) == 6
        assert buffer._num_timesteps_added == 4 * 6 - 2
        assert buffer._num_timesteps_added_wrap == 4
        assert buffer._next_idx == 1
        assert buffer._eviction_started is True

        num_sampled_dict = {_id: 0 for _id in range(8)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample[SampleBatch.EPS_ID][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [0, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 0, 1 / 6],
            atol=0.1,
        )
Exemple #3
0
    def test_sequences_unit(self):
        """Tests adding, sampling and eviction of sequences."""
        buffer = ReplayBuffer(capacity=10, storage_unit="sequences")

        batches = [
            SampleBatch({
                SampleBatch.T:
                i * [np.random.random((4, ))],
                SampleBatch.ACTIONS:
                i * [np.random.choice([0, 1])],
                SampleBatch.REWARDS:
                i * [np.random.rand()],
                SampleBatch.DONES:
                i * [np.random.choice([False, True])],
                SampleBatch.SEQ_LENS: [i],
                "batch_id":
                i * [i],
            }) for i in range(1, 4)
        ]

        batches.append(
            SampleBatch({
                SampleBatch.T:
                4 * [np.random.random((4, ))],
                SampleBatch.ACTIONS:
                4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS:
                4 * [np.random.rand()],
                SampleBatch.DONES:
                4 * [np.random.choice([False, True])],
                SampleBatch.SEQ_LENS: [2, 2],
                "batch_id":
                4 * [4],
            }))

        for batch in batches:
            buffer.add(batch)

        num_sampled_dict = {_id: 0 for _id in range(1, 5)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample["batch_id"][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        # Out of five sequences, we want to sequences from the last batch to
        # be sampled twice as often, because they are stored separately
        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 5, 1 / 5, 1 / 5, 2 / 5],
            atol=0.1,
        )

        # Add another batch to evict
        buffer.add(
            SampleBatch({
                SampleBatch.T:
                5 * [np.random.random((4, ))],
                SampleBatch.ACTIONS:
                5 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS:
                5 * [np.random.rand()],
                SampleBatch.DONES:
                5 * [np.random.choice([False, True])],
                SampleBatch.SEQ_LENS: [5],
                "batch_id":
                5 * [5],
            }))

        # After adding 1 more batch, eviction has started with 15
        # timesteps added in total
        assert len(buffer) == 5
        assert buffer._num_timesteps_added == sum(range(1, 6))
        assert buffer._num_timesteps_added_wrap == 5
        assert buffer._next_idx == 1
        assert buffer._eviction_started is True

        # The first batch should now not be sampled anymore, other batches
        # should be sampled as before
        num_sampled_dict = {_id: 0 for _id in range(2, 6)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample["batch_id"][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 5, 1 / 5, 2 / 5, 1 / 5],
            atol=0.1,
        )