def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ data = { "last_added_batches": self.last_added_batches, } parent = MultiAgentPrioritizedReplayBuffer.get_state(self) parent.update(data) return parent
def test_update_priorities(self): num_batches = 5 buffer_size = 15 # Buffer needs to be in independent mode, lockstep is not supported buffer = MultiAgentPrioritizedReplayBuffer( capacity=buffer_size, prioritized_replay_alpha=self.alpha, prioritized_replay_beta=self.beta, replay_mode="independent", replay_sequence_length=2, learning_starts=0, num_shards=1, ) # Insert n samples for i in range(num_batches): data = self._generate_data() buffer.add(data, weight=1.0) assert len(buffer) == i + 1 # Fetch records, their indices and weights. mabatch = buffer.sample(3) assert type(mabatch) == MultiAgentBatch samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID] weights = samplebatch["weights"] indices = samplebatch["batch_indexes"] check(weights, np.ones(shape=(6,))) assert 6 == len(indices) assert len(buffer) == num_batches policy_buffer = buffer.replay_buffers[DEFAULT_POLICY_ID] assert policy_buffer._next_idx == num_batches # Update weight of indices 0, 2, 3, 4, like in our # PrioritizedReplayBuffer tests priority_dict = { DEFAULT_POLICY_ID: ( np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01]), ) } buffer.update_priorities(priority_dict) # Expect to sample almost only index 1 # (which still has a weight of 1.0). for _ in range(10): mabatch = buffer.sample(1000) assert type(mabatch) == MultiAgentBatch samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID] assert type(mabatch) == MultiAgentBatch indices = samplebatch["batch_indexes"] self.assertTrue(1900 < np.sum(indices) < 2200) # Test get_state/set_state. state = buffer.get_state() new_buffer = MultiAgentPrioritizedReplayBuffer( capacity=buffer_size, prioritized_replay_alpha=self.alpha, prioritized_replay_beta=self.beta, replay_mode="independent", learning_starts=0, num_shards=1, ) new_buffer.set_state(state) batch = new_buffer.sample(1000).policy_batches[DEFAULT_POLICY_ID] indices = batch["batch_indexes"] self.assertTrue(1900 < np.sum(indices) < 2200)