def set_state(self, state: Dict[str, Any]) -> None:
        """Restores all local state to the provided `state`.

        Args:
            state: The new state to set this buffer. Can be obtained by
                calling `self.get_state()`.
        """
        self.last_added_batches = state["last_added_batches"]
        MultiAgentPrioritizedReplayBuffer.set_state(state)
Exemple #2
0
    def test_update_priorities(self):
        num_batches = 5
        buffer_size = 15

        # Buffer needs to be in independent mode, lockstep is not supported
        buffer = MultiAgentPrioritizedReplayBuffer(
            capacity=buffer_size,
            prioritized_replay_alpha=self.alpha,
            prioritized_replay_beta=self.beta,
            replay_mode="independent",
            replay_sequence_length=2,
            learning_starts=0,
            num_shards=1,
        )

        # Insert n samples
        for i in range(num_batches):
            data = self._generate_data()
            buffer.add(data, weight=1.0)
            assert len(buffer) == i + 1

        # Fetch records, their indices and weights.
        mabatch = buffer.sample(3)
        assert type(mabatch) == MultiAgentBatch
        samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID]

        weights = samplebatch["weights"]
        indices = samplebatch["batch_indexes"]
        check(weights, np.ones(shape=(6,)))
        assert 6 == len(indices)
        assert len(buffer) == num_batches
        policy_buffer = buffer.replay_buffers[DEFAULT_POLICY_ID]
        assert policy_buffer._next_idx == num_batches
        # Update weight of indices 0, 2, 3, 4, like in our
        # PrioritizedReplayBuffer tests
        priority_dict = {
            DEFAULT_POLICY_ID: (
                np.array([0, 2, 3, 4]),
                np.array([0.01, 0.01, 0.01, 0.01]),
            )
        }

        buffer.update_priorities(priority_dict)

        # Expect to sample almost only index 1
        # (which still has a weight of 1.0).
        for _ in range(10):
            mabatch = buffer.sample(1000)
            assert type(mabatch) == MultiAgentBatch
            samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID]
            assert type(mabatch) == MultiAgentBatch
            indices = samplebatch["batch_indexes"]
            self.assertTrue(1900 < np.sum(indices) < 2200)
        # Test get_state/set_state.
        state = buffer.get_state()
        new_buffer = MultiAgentPrioritizedReplayBuffer(
            capacity=buffer_size,
            prioritized_replay_alpha=self.alpha,
            prioritized_replay_beta=self.beta,
            replay_mode="independent",
            learning_starts=0,
            num_shards=1,
        )
        new_buffer.set_state(state)
        batch = new_buffer.sample(1000).policy_batches[DEFAULT_POLICY_ID]
        indices = batch["batch_indexes"]
        self.assertTrue(1900 < np.sum(indices) < 2200)