コード例 #1
0
    def get_state(self) -> Dict[str, Any]:
        """Returns all local state.

        Returns:
            The serializable local state.
        """
        data = {
            "last_added_batches": self.last_added_batches,
        }
        parent = MultiAgentPrioritizedReplayBuffer.get_state(self)
        parent.update(data)
        return parent
コード例 #2
0
    def test_update_priorities(self):
        num_batches = 5
        buffer_size = 15

        # Buffer needs to be in independent mode, lockstep is not supported
        buffer = MultiAgentPrioritizedReplayBuffer(
            capacity=buffer_size,
            prioritized_replay_alpha=self.alpha,
            prioritized_replay_beta=self.beta,
            replay_mode="independent",
            replay_sequence_length=2,
            learning_starts=0,
            num_shards=1,
        )

        # Insert n samples
        for i in range(num_batches):
            data = self._generate_data()
            buffer.add(data, weight=1.0)
            assert len(buffer) == i + 1

        # Fetch records, their indices and weights.
        mabatch = buffer.sample(3)
        assert type(mabatch) == MultiAgentBatch
        samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID]

        weights = samplebatch["weights"]
        indices = samplebatch["batch_indexes"]
        check(weights, np.ones(shape=(6,)))
        assert 6 == len(indices)
        assert len(buffer) == num_batches
        policy_buffer = buffer.replay_buffers[DEFAULT_POLICY_ID]
        assert policy_buffer._next_idx == num_batches
        # Update weight of indices 0, 2, 3, 4, like in our
        # PrioritizedReplayBuffer tests
        priority_dict = {
            DEFAULT_POLICY_ID: (
                np.array([0, 2, 3, 4]),
                np.array([0.01, 0.01, 0.01, 0.01]),
            )
        }

        buffer.update_priorities(priority_dict)

        # Expect to sample almost only index 1
        # (which still has a weight of 1.0).
        for _ in range(10):
            mabatch = buffer.sample(1000)
            assert type(mabatch) == MultiAgentBatch
            samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID]
            assert type(mabatch) == MultiAgentBatch
            indices = samplebatch["batch_indexes"]
            self.assertTrue(1900 < np.sum(indices) < 2200)
        # Test get_state/set_state.
        state = buffer.get_state()
        new_buffer = MultiAgentPrioritizedReplayBuffer(
            capacity=buffer_size,
            prioritized_replay_alpha=self.alpha,
            prioritized_replay_beta=self.beta,
            replay_mode="independent",
            learning_starts=0,
            num_shards=1,
        )
        new_buffer.set_state(state)
        batch = new_buffer.sample(1000).policy_batches[DEFAULT_POLICY_ID]
        indices = batch["batch_indexes"]
        self.assertTrue(1900 < np.sum(indices) < 2200)