def test_mixin_sampling_sequences(self): """Test sampling of sequences.""" # 50% replay ratio. buffer = MultiAgentMixInReplayBuffer(capacity=100, storage_unit="sequences", replay_ratio=0.5) # If we insert and replay n times, expect roughly return batches of # len 6 (replay_ratio=0.5 -> 50% replayed samples -> 2 new and 2 # old sequences with an average length of 1.5 each. results = [] batch = self._generate_episodes() for _ in range(400): buffer.add(batch) sample = buffer.sample(10) assert type(sample) == MultiAgentBatch results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 2 * len(batch), delta=0.1)
def test_mixin_sampling_episodes(self): """Test sampling of episodes.""" # 50% replay ratio. buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, storage_unit="episodes", replay_ratio=0.5) # If we insert and replay n times, expect roughly return batches of # len 5 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 # old sample, each of length two on average in each returned value). results = [] batch = self._generate_episodes() for _ in range(20): buffer.add(batch) sample = buffer.sample(2) assert type(sample) == MultiAgentBatch results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) # One sample in the episode does not belong the the episode on thus # gets dropped. Full episodes are of length two. self.assertAlmostEqual(np.mean(results), 2 * (len(batch) - 1))
def test_mixin_sampling_timesteps(self): """Test different mixin ratios with timesteps.""" # 33% replay ratio. buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, storage_unit="timesteps", replay_ratio=0.333) # Expect exactly 0 samples to be returned (buffer empty). sample = buffer.sample(10) assert len(sample.policy_batches) == 0 batch = self._generate_single_timesteps() # If we insert-2x and replay n times, expect roughly return batches of # len 5 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 # old sample on average in each returned value). results = [] for _ in range(100): buffer.add(batch) buffer.add(batch) sample = buffer.sample(3) assert type(sample) == MultiAgentBatch results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 3.0, delta=0.2) # If we insert-1x and replay n times, expect roughly return batches of # len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 # old # samples on average in each returned value). results = [] for _ in range(100): buffer.add(batch) sample = buffer.sample(5) assert type(sample) == MultiAgentBatch results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 1.5, delta=0.2) # 90% replay ratio. buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, replay_ratio=0.9) # If we insert and replay n times, expect roughly return batches of # len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old # samples on average in each returned value). results = [] for _ in range(100): buffer.add(batch) sample = buffer.sample(10) assert type(sample) == MultiAgentBatch results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 10.0, delta=0.2) # 0% replay ratio -> Only new samples. buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, replay_ratio=0.0) # Add a new batch. batch = self._generate_single_timesteps() buffer.add(batch) # Expect exactly 1 batch to be returned. sample = buffer.sample(1) assert type(sample) == MultiAgentBatch self.assertTrue(len(sample) == 1) # Expect exactly 0 sample to be returned (nothing new to be returned; # no replay allowed (replay_ratio=0.0)). sample = buffer.sample(1) assert type(sample) == MultiAgentBatch assert len(sample.policy_batches) == 0 # If we insert and replay n times, expect roughly return batches of # len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples # on average in each returned value). results = [] for _ in range(100): buffer.add(batch) sample = buffer.sample(1) assert type(sample) == MultiAgentBatch results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 1.0, delta=0.2) # 100% replay ratio -> Only new samples. buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity, replay_ratio=1.0) # Expect exactly 0 samples to be returned (buffer empty). sample = buffer.sample(1) assert len(sample.policy_batches) == 0 # Add a new batch. batch = self._generate_single_timesteps() buffer.add(batch) # Expect exactly 1 sample to be returned (the new batch). sample = buffer.sample(1) assert type(sample) == MultiAgentBatch self.assertTrue(len(sample) == 1) # Another replay -> Expect exactly 1 sample to be returned. sample = buffer.sample(1) assert type(sample) == MultiAgentBatch self.assertTrue(len(sample) == 1) # If we replay n times, expect roughly return batches of # len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples # on average in each returned value). results = [] for _ in range(100): sample = buffer.sample(1) assert type(sample) == MultiAgentBatch results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 1.0)