def test_set_get_state(self): num_policies = 2 buffer_size = 15 num_batches = 1 buffer = MultiAgentReplayBuffer( capacity=buffer_size, replay_mode="independent", learning_starts=0, num_shards=1, ) self._add_multi_agent_batch_to_buffer( buffer, num_policies=num_policies, num_batches=num_batches ) state = buffer.get_state() another_buffer = MultiAgentReplayBuffer( capacity=buffer_size, replay_mode="independent", learning_starts=0, num_shards=1, ) another_buffer.set_state(state) # State is equal to set of states of underlying buffers for _id, _buffer in buffer.replay_buffers.items(): assert _buffer.get_state() == another_buffer.replay_buffers[_id].get_state() assert buffer._num_added == another_buffer._num_added
def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ data = { "last_added_batches": self.last_added_batches, } parent = MultiAgentReplayBuffer.get_state(self) parent.update(data) return parent