def test_independent_mode(self): """Test the lockstep mode by adding batches from multiple policies.""" self.batch_id = 0 num_batches = 3 buffer_size = 15 num_policies = 2 # Test lockstep mode with different policy ids using MultiAgentBatches buffer = MultiAgentPrioritizedReplayBuffer( capacity=buffer_size, replay_mode="independent", learning_starts=0, num_shards=1, ) self._add_multi_agent_batch_to_buffer( buffer, num_policies=num_policies, num_batches=num_batches ) # Sample 4 SampleBatches from only one policy and put it into a # MultiAgentBatch for _id in range(num_policies): for __id in buffer.sample(4, policy_id=_id).policy_batches[_id][ "policy_id" ]: assert __id == _id # Sample without specifying the policy should yield approx. the same # number of batches from each policy num_sampled_dict = {_id: 0 for _id in range(num_policies)} num_samples = 200 for i in range(num_samples): num_items = np.random.randint(1, 5) for _id, batch in buffer.sample(num_items=num_items).policy_batches.items(): num_sampled_dict[_id] += 1 assert len(batch) == num_items assert np.allclose( np.array(list(num_sampled_dict.values())), len(num_sampled_dict) * [200], atol=0.1, )
def test_policy_id_of_multi_agent_batches_independent(self): """Test if indepent sampling yields a MultiAgentBatch with the correct policy id.""" self.batch_id = 0 # Test lockstep mode with different policy ids using MultiAgentBatches buffer = MultiAgentPrioritizedReplayBuffer( capacity=10, replay_mode="independent", learning_starts=0, num_shards=1 ) self._add_multi_agent_batch_to_buffer(buffer, num_policies=1, num_batches=1) mabatch = buffer.sample(1) assert list(mabatch.policy_batches.keys())[0] == 0
def test_update_priorities(self): num_batches = 5 buffer_size = 15 # Buffer needs to be in independent mode, lockstep is not supported buffer = MultiAgentPrioritizedReplayBuffer( capacity=buffer_size, prioritized_replay_alpha=self.alpha, prioritized_replay_beta=self.beta, replay_mode="independent", replay_sequence_length=2, learning_starts=0, num_shards=1, ) # Insert n samples for i in range(num_batches): data = self._generate_data() buffer.add(data, weight=1.0) assert len(buffer) == i + 1 # Fetch records, their indices and weights. mabatch = buffer.sample(3) assert type(mabatch) == MultiAgentBatch samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID] weights = samplebatch["weights"] indices = samplebatch["batch_indexes"] check(weights, np.ones(shape=(6,))) assert 6 == len(indices) assert len(buffer) == num_batches policy_buffer = buffer.replay_buffers[DEFAULT_POLICY_ID] assert policy_buffer._next_idx == num_batches # Update weight of indices 0, 2, 3, 4, like in our # PrioritizedReplayBuffer tests priority_dict = { DEFAULT_POLICY_ID: ( np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01]), ) } buffer.update_priorities(priority_dict) # Expect to sample almost only index 1 # (which still has a weight of 1.0). for _ in range(10): mabatch = buffer.sample(1000) assert type(mabatch) == MultiAgentBatch samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID] assert type(mabatch) == MultiAgentBatch indices = samplebatch["batch_indexes"] self.assertTrue(1900 < np.sum(indices) < 2200) # Test get_state/set_state. state = buffer.get_state() new_buffer = MultiAgentPrioritizedReplayBuffer( capacity=buffer_size, prioritized_replay_alpha=self.alpha, prioritized_replay_beta=self.beta, replay_mode="independent", learning_starts=0, num_shards=1, ) new_buffer.set_state(state) batch = new_buffer.sample(1000).policy_batches[DEFAULT_POLICY_ID] indices = batch["batch_indexes"] self.assertTrue(1900 < np.sum(indices) < 2200)