Exemplo n.º 1
0
class RankBasedReplay(ExperienceReplay):
    def __init__(self, size, alpha):
        super(RankBasedReplay, self).__init__(size)
        assert alpha >= 0
        self.alpha = alpha

        self.sorted = []
        self.heap = Heap(self.maxsize)
        self.max_priority = 1.0  # will change as priorities are updated according to TD error

        self.N_list, self.range_list = load_quantiles(
        )  # gets ranges of equal probability of zipf distribution for a few values of N
        self.range_idx = 0  # index into N_list of the ranges we're currently using
        self.priority_sums = [
            sum([i**(-alpha) for i in range(1, N + 1)]) for N in self.N_list
        ]  # normalizing factors for priority distributions
        self.min_priorities = [
            N**(-alpha) / self.priority_sums[i]
            for i, N in enumerate(self.N_list)
        ]  # minimum possible priorities given N

    def add(self, experience):
        if self.next_idx >= len(
                self.buffer):  # increase size of buffer if there's still room
            self.buffer.append([experience,
                                self.next_idx])  # index is into the heap
            self.heap.insert(
                HeapItem(self.max_priority**self.alpha,
                         self.next_idx))  # index is into buffer
            self.sorted.append(
                self.next_idx
            )  # while growing, highest priority (newest) is ranked last until we resort

        else:  # overwrite old experience
            self.buffer[self.next_idx][0] = experience
            heap_idx = self.buffer[self.next_idx][1]
            self.heap[heap_idx].value = self.max_priority**self.alpha

        self.next_idx = (self.next_idx + 1) % self.maxsize

        # update set of ranges we're using
        if self.range_idx < len(self.N_list) - 1 and len(
                self.buffer) >= self.N_list[self.range_idx + 1]:
            self.range_idx += 1

    # a rank is uniformly sampled from each of a set of precomputed ranges
    def _sample_by_rank(self, batch_size):
        if len(
                self.buffer
        ) < batch_size:  # return all indices if there are fewer than batch_size of them
            return list(range(1, len(self.buffer) + 1))

        ranks = []
        ranges = self.range_list[self.range_idx]  # precomputed ranges
        for _range in ranges:  # for each range
            ranks.append(self.np_random.randint(
                _range[0], _range[1] + 1))  # random int in closed interval
        return ranks

    # sample batch of experiences along with their weights and indices
    def sample(self, batch_size, beta):
        assert beta > 0
        ranks = self._sample_by_rank(batch_size)

        p_min = self.min_priorities[
            self.range_idx]  # minimum possible priority for a transition
        max_weight = (p_min * len(self.buffer))**(
            -beta)  # (p_uniform/p_min)^beta is maximum possible IS weight

        # get IS weights for sampled experience
        weights = []
        for rank in ranks:
            p_sample = rank**(-self.alpha) / self.priority_sums[
                self.range_idx]  # normalize sampled priority
            weight = (p_sample * len(self.buffer))**(
                -beta)  # (p_uniform/p_sample)^beta. IS weight
            weights.append(
                weight / max_weight
            )  # weights normalized by max so that they only scale the update downwards
        weights = np.array(weights)

        heap_idxs = [self.sorted[rank - 1] for rank in ranks]
        buffer_idxs = [self.heap[heap_idx].index for heap_idx in heap_idxs]
        encoded_sample = self.encode_samples(
            buffer_idxs,
            ranked_priority=True)  # collect experience at given indices
        return tuple(list(encoded_sample) + [weights, heap_idxs])

    # set the priorities of experiences at given indices
    def update_priorities(self, heap_idxs, priorities):
        assert len(heap_idxs) == len(priorities)
        for idx, priority in zip(heap_idxs, priorities):
            assert priority > 0
            assert 0 <= idx < len(self.heap)

            self.heap[idx].value = priority**self.alpha
            self.max_priority = max(self.max_priority, priority)

    # re-heapify. to be called periodically
    def sort(self):
        self.heap.build_heap()
        for i in range(len(self.heap)):
            buffer_idx = self.heap[i].index
            self.buffer[buffer_idx][1] = i  # update buffer's indices into heap
        self.sorted = self.heap.get_k_largest(len(self.heap))