def test_add(self):
        memory = PrioritizedReplayBuffer(
            size=2,
            alpha=self.alpha,
        )

        # Assert indices 0 before insert.
        self.assertEqual(len(memory), 0)
        self.assertEqual(memory._next_idx, 0)

        # Insert single record.
        data = self._generate_data()
        memory.add(*data, weight=0.5)
        self.assertTrue(len(memory) == 1)
        self.assertTrue(memory._next_idx == 1)

        # Insert single record.
        data = self._generate_data()
        memory.add(*data, weight=0.1)
        self.assertTrue(len(memory) == 2)
        self.assertTrue(memory._next_idx == 0)

        # Insert over capacity.
        data = self._generate_data()
        memory.add(*data, weight=1.0)
        self.assertTrue(len(memory) == 2)
        self.assertTrue(memory._next_idx == 1)
    def test_alpha_parameter(self):
        # Test sampling from a PR with a very small alpha (should behave just
        # like a regular ReplayBuffer).
        memory = PrioritizedReplayBuffer(size=self.capacity, alpha=0.01)

        # Insert n samples.
        num_records = 5
        for i in range(num_records):
            data = self._generate_data()
            memory.add(*data, weight=np.random.rand())
            self.assertTrue(len(memory) == i + 1)
            self.assertTrue(memory._next_idx == i + 1)

        # Fetch records, their indices and weights.
        _, _, _, _, _, weights, indices = \
            memory.sample(1000, beta=self.beta)
        counts = Counter()
        for i in indices:
            counts[i] += 1
        print(counts)
        # Expect an approximately uniform distribution of indices.
        for i in counts.values():
            self.assertTrue(100 < i < 300)
    def test_update_priorities(self):
        memory = PrioritizedReplayBuffer(size=self.capacity, alpha=self.alpha)

        # Insert n samples.
        num_records = 5
        for i in range(num_records):
            data = self._generate_data()
            memory.add(*data, weight=1.0)
            self.assertTrue(len(memory) == i + 1)
            self.assertTrue(memory._next_idx == i + 1)

        # Fetch records, their indices and weights.
        _, _, _, _, _, weights, indices = \
            memory.sample(3, beta=self.beta)
        check(weights, np.ones(shape=(3, )))
        self.assertEqual(3, len(indices))
        self.assertTrue(len(memory) == num_records)
        self.assertTrue(memory._next_idx == num_records)

        # Update weight of indices 0, 2, 3, 4 to very small.
        memory.update_priorities(np.array([0, 2, 3, 4]),
                                 np.array([0.01, 0.01, 0.01, 0.01]))
        # Expect to sample almost only index 1
        # (which still has a weight of 1.0).
        for _ in range(10):
            _, _, _, _, _, weights, indices = memory.sample(1000,
                                                            beta=self.beta)
            self.assertTrue(970 < np.sum(indices) < 1100)

        # Update weight of indices 0 and 1 to >> 0.01.
        # Expect to sample 0 and 1 equally (and some 2s, 3s, and 4s).
        for _ in range(10):
            rand = np.random.random() + 0.2
            memory.update_priorities(np.array([0, 1]), np.array([rand, rand]))
            _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
            # Expect biased to higher values due to some 2s, 3s, and 4s.
            # print(np.sum(indices))
            self.assertTrue(400 < np.sum(indices) < 800)

        # Update weights to be 1:2.
        # Expect to sample double as often index 1 over index 0
        # plus very few times indices 2, 3, or 4.
        for _ in range(10):
            rand = np.random.random() + 0.2
            memory.update_priorities(np.array([0, 1]),
                                     np.array([rand, rand * 2]))
            _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
            # print(np.sum(indices))
            self.assertTrue(600 < np.sum(indices) < 850)

        # Update weights to be 1:4.
        # Expect to sample quadruple as often index 1 over index 0
        # plus very few times indices 2, 3, or 4.
        for _ in range(10):
            rand = np.random.random() + 0.2
            memory.update_priorities(np.array([0, 1]),
                                     np.array([rand, rand * 4]))
            _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
            # print(np.sum(indices))
            self.assertTrue(750 < np.sum(indices) < 950)

        # Update weights to be 1:9.
        # Expect to sample 9 times as often index 1 over index 0.
        # plus very few times indices 2, 3, or 4.
        for _ in range(10):
            rand = np.random.random() + 0.2
            memory.update_priorities(np.array([0, 1]),
                                     np.array([rand, rand * 9]))
            _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
            # print(np.sum(indices))
            self.assertTrue(850 < np.sum(indices) < 1100)

        # Insert n more samples.
        num_records = 5
        for i in range(num_records):
            data = self._generate_data()
            memory.add(*data, weight=1.0)
            self.assertTrue(len(memory) == i + 6)
            self.assertTrue(memory._next_idx == (i + 6) % self.capacity)

        # Update all weights to be 1.0 to 10.0 and sample a >100 batch.
        memory.update_priorities(
            np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
            np.array([0.001, 0.1, 2., 8., 16., 32., 64., 128., 256., 512.]))
        counts = Counter()
        for _ in range(10):
            _, _, _, _, _, _, indices = memory.sample(np.random.randint(
                100, 600),
                                                      beta=self.beta)
            for i in indices:
                counts[i] += 1
        print(counts)
        # Expect an approximately correct distribution of indices.
        self.assertTrue(
            counts[9] >= counts[8] >= counts[7] >= counts[6] >= counts[5] >=
            counts[4] >= counts[3] >= counts[2] >= counts[1] >= counts[0])
    def test_add(self):
        memory = PrioritizedReplayBuffer(capacity=2, alpha=self.alpha)

        # Assert indices 0 before insert.
        self.assertEqual(len(memory), 0)
        self.assertEqual(memory._next_idx, 0)

        # Insert single record.
        data = self._generate_data()
        memory.add(data, weight=0.5)
        self.assertTrue(len(memory) == 1)
        self.assertTrue(memory._next_idx == 1)

        # Insert single record.
        data = self._generate_data()
        memory.add(data, weight=0.1)
        self.assertTrue(len(memory) == 2)
        self.assertTrue(memory._next_idx == 0)

        # Insert over capacity.
        data = self._generate_data()
        memory.add(data, weight=1.0)
        self.assertTrue(len(memory) == 2)
        self.assertTrue(memory._next_idx == 1)

        # Test get_state/set_state.
        state = memory.get_state()
        new_memory = PrioritizedReplayBuffer(capacity=2, alpha=self.alpha)
        new_memory.set_state(state)
        self.assertTrue(len(new_memory) == 2)
        self.assertTrue(new_memory._next_idx == 1)
    def test_sequence_size(self):
        # Seq-len=1.
        memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        for _ in range(200):
            memory.add(self._generate_data(), weight=None)
        assert len(memory._storage) == 100, len(memory._storage)
        assert memory.stats()["added_count"] == 200, memory.stats()
        # Test get_state/set_state.
        state = memory.get_state()
        new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        new_memory.set_state(state)
        assert len(new_memory._storage) == 100, len(new_memory._storage)
        assert new_memory.stats()["added_count"] == 200, new_memory.stats()

        # Seq-len=5.
        memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        for _ in range(40):
            memory.add(SampleBatch.concat_samples(
                [self._generate_data() for _ in range(5)]),
                       weight=None)
        assert len(memory._storage) == 20, len(memory._storage)
        assert memory.stats()["added_count"] == 200, memory.stats()
        # Test get_state/set_state.
        state = memory.get_state()
        new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
        new_memory.set_state(state)
        assert len(new_memory._storage) == 20, len(new_memory._storage)
        assert new_memory.stats()["added_count"] == 200, new_memory.stats()
    def test_alpha_parameter(self):
        # Test sampling from a PR with a very small alpha (should behave just
        # like a regular ReplayBuffer).
        memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)

        # Insert n samples.
        num_records = 5
        for i in range(num_records):
            data = self._generate_data()
            memory.add(data, weight=np.random.rand())
            self.assertTrue(len(memory) == i + 1)
            self.assertTrue(memory._next_idx == i + 1)
        # Test get_state/set_state.
        state = memory.get_state()
        new_memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
        new_memory.set_state(state)
        self.assertTrue(len(new_memory) == num_records)
        self.assertTrue(new_memory._next_idx == num_records)

        # Fetch records, their indices and weights.
        batch = memory.sample(1000, beta=self.beta)
        indices = batch["batch_indexes"]
        counts = Counter()
        for i in indices:
            counts[i] += 1
        print(counts)
        # Expect an approximately uniform distribution of indices.
        self.assertTrue(any(100 < i < 300 for i in counts.values()))
        # Test get_state/set_state.
        state = memory.get_state()
        new_memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
        new_memory.set_state(state)
        batch = new_memory.sample(1000, beta=self.beta)
        indices = batch["batch_indexes"]
        counts = Counter()
        for i in indices:
            counts[i] += 1
        self.assertTrue(any(100 < i < 300 for i in counts.values()))
    def test_sequence_size(self):
        # seq len 1
        memory = PrioritizedReplayBuffer(size=100, alpha=0.1)
        for _ in range(200):
            memory.add(self._generate_data(), weight=None)
        assert len(memory._storage) == 100, len(memory._storage)
        assert memory.stats()["added_count"] == 200, memory.stats()

        # seq len 5
        memory = PrioritizedReplayBuffer(size=100, alpha=0.1)
        for _ in range(40):
            memory.add(SampleBatch.concat_samples(
                [self._generate_data() for _ in range(5)]),
                       weight=None)
        assert len(memory._storage) == 20, len(memory._storage)
        assert memory.stats()["added_count"] == 200, memory.stats()