Exemplo n.º 1
0
    def test_set_get_state(self):
        num_policies = 2
        buffer_size = 15
        num_batches = 1

        buffer = MultiAgentReplayBuffer(
            capacity=buffer_size,
            replay_mode="independent",
            learning_starts=0,
            num_shards=1,
        )

        self._add_multi_agent_batch_to_buffer(
            buffer, num_policies=num_policies, num_batches=num_batches
        )

        state = buffer.get_state()

        another_buffer = MultiAgentReplayBuffer(
            capacity=buffer_size,
            replay_mode="independent",
            learning_starts=0,
            num_shards=1,
        )

        another_buffer.set_state(state)

        # State is equal to set of states of underlying buffers
        for _id, _buffer in buffer.replay_buffers.items():
            assert _buffer.get_state() == another_buffer.replay_buffers[_id].get_state()

        assert buffer._num_added == another_buffer._num_added
Exemplo n.º 2
0
    def get_state(self) -> Dict[str, Any]:
        """Returns all local state.

        Returns:
            The serializable local state.
        """
        data = {
            "last_added_batches": self.last_added_batches,
        }
        parent = MultiAgentReplayBuffer.get_state(self)
        parent.update(data)
        return parent