Ejemplo n.º 1
0
    def test_mixin_sampling_sequences(self):
        """Test sampling of sequences."""
        # 50% replay ratio.
        buffer = MultiAgentMixInReplayBuffer(
            capacity=100, storage_unit="sequences", replay_ratio=0.5, learning_starts=0
        )

        # If we insert and replay n times, expect roughly return batches of
        # len 6 (replay_ratio=0.5 -> 50% replayed samples -> 2 new and 2
        # old sequences with an average length of 1.5 each.
        results = []
        batch = self._generate_episodes()
        for _ in range(400):
            buffer.add(batch)
            sample = buffer.sample(10)
            assert type(sample) == MultiAgentBatch
            results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
        self.assertAlmostEqual(np.mean(results), 2 * len(batch), delta=0.1)
    def test_mixin_sampling_episodes(self):
        """Test sampling of episodes."""
        # 50% replay ratio.
        buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity,
                                             storage_unit="episodes",
                                             replay_ratio=0.5)

        # If we insert and replay n times, expect roughly return batches of
        # len 5 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1
        # old sample, each of length two on average in each returned value).
        results = []
        batch = self._generate_episodes()
        for _ in range(20):
            buffer.add(batch)
            sample = buffer.sample(2)
            assert type(sample) == MultiAgentBatch
            results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
        # One sample in the episode does not belong the the episode on thus
        # gets dropped. Full episodes are of length two.
        self.assertAlmostEqual(np.mean(results), 2 * (len(batch) - 1))
    def test_mixin_sampling_timesteps(self):
        """Test different mixin ratios with timesteps."""
        # 33% replay ratio.
        buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity,
                                             storage_unit="timesteps",
                                             replay_ratio=0.333)
        # Expect exactly 0 samples to be returned (buffer empty).
        sample = buffer.sample(10)
        assert len(sample.policy_batches) == 0

        batch = self._generate_single_timesteps()
        # If we insert-2x and replay n times, expect roughly return batches of
        # len 5 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1
        # old sample on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add(batch)
            buffer.add(batch)
            sample = buffer.sample(3)
            assert type(sample) == MultiAgentBatch
            results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
        self.assertAlmostEqual(np.mean(results), 3.0, delta=0.2)

        # If we insert-1x and replay n times, expect roughly return batches of
        # len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5
        # old
        # samples on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add(batch)
            sample = buffer.sample(5)
            assert type(sample) == MultiAgentBatch
            results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
        self.assertAlmostEqual(np.mean(results), 1.5, delta=0.2)

        # 90% replay ratio.
        buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity,
                                             replay_ratio=0.9)

        # If we insert and replay n times, expect roughly return batches of
        # len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old
        # samples on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add(batch)
            sample = buffer.sample(10)
            assert type(sample) == MultiAgentBatch
            results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
        self.assertAlmostEqual(np.mean(results), 10.0, delta=0.2)

        # 0% replay ratio -> Only new samples.
        buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity,
                                             replay_ratio=0.0)
        # Add a new batch.
        batch = self._generate_single_timesteps()
        buffer.add(batch)
        # Expect exactly 1 batch to be returned.
        sample = buffer.sample(1)
        assert type(sample) == MultiAgentBatch
        self.assertTrue(len(sample) == 1)
        # Expect exactly 0 sample to be returned (nothing new to be returned;
        # no replay allowed (replay_ratio=0.0)).
        sample = buffer.sample(1)
        assert type(sample) == MultiAgentBatch
        assert len(sample.policy_batches) == 0
        # If we insert and replay n times, expect roughly return batches of
        # len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
        # on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add(batch)
            sample = buffer.sample(1)
            assert type(sample) == MultiAgentBatch
            results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
        self.assertAlmostEqual(np.mean(results), 1.0, delta=0.2)

        # 100% replay ratio -> Only new samples.
        buffer = MultiAgentMixInReplayBuffer(capacity=self.capacity,
                                             replay_ratio=1.0)
        # Expect exactly 0 samples to be returned (buffer empty).
        sample = buffer.sample(1)
        assert len(sample.policy_batches) == 0
        # Add a new batch.
        batch = self._generate_single_timesteps()
        buffer.add(batch)
        # Expect exactly 1 sample to be returned (the new batch).
        sample = buffer.sample(1)
        assert type(sample) == MultiAgentBatch
        self.assertTrue(len(sample) == 1)
        # Another replay -> Expect exactly 1 sample to be returned.
        sample = buffer.sample(1)
        assert type(sample) == MultiAgentBatch
        self.assertTrue(len(sample) == 1)
        # If we replay n times, expect roughly return batches of
        # len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
        # on average in each returned value).
        results = []
        for _ in range(100):
            sample = buffer.sample(1)
            assert type(sample) == MultiAgentBatch
            results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
        self.assertAlmostEqual(np.mean(results), 1.0)