def test_set_get_state(self): num_policies = 2 buffer_size = 15 num_batches = 1 buffer = MultiAgentReplayBuffer( capacity=buffer_size, replay_mode="independent", learning_starts=0, num_shards=1, ) self._add_multi_agent_batch_to_buffer( buffer, num_policies=num_policies, num_batches=num_batches ) state = buffer.get_state() another_buffer = MultiAgentReplayBuffer( capacity=buffer_size, replay_mode="independent", learning_starts=0, num_shards=1, ) another_buffer.set_state(state) # State is equal to set of states of underlying buffers for _id, _buffer in buffer.replay_buffers.items(): assert _buffer.get_state() == another_buffer.replay_buffers[_id].get_state() assert buffer._num_added == another_buffer._num_added
def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ self.last_added_batches = state["last_added_batches"] MultiAgentReplayBuffer.set_state(state)