예제 #1
0
 def __init__(self, limit, ob_shape, ac_shape, alpha, beta, demos_eps=0.1,
              ranked=False, max_priority=1.0):
     """`alpha` determines how much prioritization is used
     0: none, equivalent to uniform sampling
     1: full prioritization
     `beta` (defined in `__init__`) represents to what degree importance weights are used.
     """
     super(PrioritizedRB, self).__init__(limit, ob_shape, ac_shape)
     assert 0. <= alpha <= 1.
     assert beta > 0, "beta must be positive"
     self.alpha = alpha
     self.beta = beta
     self.max_priority = max_priority
     # Calculate the segment tree capacity suited to the user-specified limit
     self.st_cap = segment_tree_capacity(limit)
     # Create segment tree objects as data collection structure for priorities.
     # It provides an efficient way of calculating a cumulative sum of priorities
     self.sum_st = SumSegmentTree(self.st_cap)  # with `operator.add` operation
     self.min_st = MinSegmentTree(self.st_cap)  # with `min` operation
     # Whether it is the ranked version or not
     self.ranked = ranked
     if self.ranked:
         # Create a dict that will contain all the (index, priority) pairs
         self.i_p = {}
     # Define the priority bonus assigned to demos stored in the replay buffer (if stored)
     self.demos_eps = demos_eps
예제 #2
0
 def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0):
     """Reuse of the 'PrioritizedRB' constructor w/:
         - `alpha` arbitrarily set to 1. (unused)
         - `beta` arbitrarily set to 1. (unused)
         - `ranked` set to True (necessary to have access to the ranks)
     """
     super(UnrealRB, self).__init__(limit, ob_shape, ac_shape, 1., 1, True, max_priority)
     # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones
     self.b_sum_st = SumSegmentTree(self.st_cap)  # with `operator.add` operation
     self.g_sum_st = SumSegmentTree(self.st_cap)  # with `operator.add` operation
예제 #3
0
def test_prefixsum_idx():
    tree = SumSegmentTree(4)

    tree[2] = 1.0
    tree[3] = 3.0

    assert tree.find_prefixsum_idx(0.0) == 2
    assert tree.find_prefixsum_idx(0.5) == 2
    assert tree.find_prefixsum_idx(0.99) == 2
    assert tree.find_prefixsum_idx(1.01) == 3
    assert tree.find_prefixsum_idx(3.00) == 3
    assert tree.find_prefixsum_idx(4.00) == 3
예제 #4
0
def test_prefixsum_idx2():
    tree = SumSegmentTree(4)

    tree[0] = 0.5
    tree[1] = 1.0
    tree[2] = 1.0
    tree[3] = 3.0

    assert tree.find_prefixsum_idx(0.00) == 0
    assert tree.find_prefixsum_idx(0.55) == 1
    assert tree.find_prefixsum_idx(0.99) == 1
    assert tree.find_prefixsum_idx(1.51) == 2
    assert tree.find_prefixsum_idx(3.00) == 3
    assert tree.find_prefixsum_idx(5.50) == 3
예제 #5
0
def test_tree_set():
    tree = SumSegmentTree(4)

    tree[2] = 1.0
    tree[3] = 3.0

    assert np.isclose(tree.sum(), 4.0)
    assert np.isclose(tree.sum(0, 2), 0.0)
    assert np.isclose(tree.sum(0, 3), 1.0)
    assert np.isclose(tree.sum(2, 3), 1.0)
    assert np.isclose(tree.sum(2, -1), 1.0)
    assert np.isclose(tree.sum(2, 4), 4.0)
예제 #6
0
def test_tree_set_overlap():
    tree = SumSegmentTree(4)

    tree[2] = 1.0
    tree[2] = 3.0

    assert np.isclose(tree.sum(), 3.0)
    assert np.isclose(tree.sum(2, 3), 3.0)
    assert np.isclose(tree.sum(2, -1), 3.0)
    assert np.isclose(tree.sum(2, 4), 3.0)
    assert np.isclose(tree.sum(1, 2), 0.0)
예제 #7
0
class UnrealRB(PrioritizedRB):
    """'Reinforcement Learning w/ unsupervised Auxiliary Tasks' replay buffer implementation
    Reference: https://arxiv.org/pdf/1611.05397.pdf
    """

    def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0):
        """Reuse of the 'PrioritizedRB' constructor w/:
            - `alpha` arbitrarily set to 1. (unused)
            - `beta` arbitrarily set to 1. (unused)
            - `ranked` set to True (necessary to have access to the ranks)
        """
        super(UnrealRB, self).__init__(limit, ob_shape, ac_shape, 1., 1, True, max_priority)
        # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones
        self.b_sum_st = SumSegmentTree(self.st_cap)  # with `operator.add` operation
        self.g_sum_st = SumSegmentTree(self.st_cap)  # with `operator.add` operation

    def _sample_unreal(self, batch_size):
        """Sample uniformly from which virtual sub-buffer to pick: bad or good transitions,
        then sample uniformly a transition from the previously picked virtual sub-buffer.
        Implemented w/ segment tree.
        Since `b_sum_st` and `g_sum_st` contain priorities in {0,1}, sampling according to
        priorities samples a transition uniformly from the transitions having priority 1
        in the previously sampled virtual sub-buffer (bad or good transitions).
        Note: the priorities were used (in `update_priorities`) to determine in which
        virtual sub-buffer a transition belongs.
        """
        factor = 5  # hyperparam?
        assert factor > 1
        assert self.num_entries
        transition_idxs = []
        # Sum the priorities (in {0,1}) of the transitions currently in the buffers
        # Since values are in {0,1}, the sums correspond to cardinalities (#b and #g)
        b_p_total = self.b_sum_st.sum(end=self.num_entries)
        g_p_total = self.g_sum_st.sum(end=self.num_entries)
        # `start` is 0 by default, `end` is length - 1 (classic python)
        # Ensure no repeats once the number of memory entries is high enough
        no_repeats = self.num_entries > factor * batch_size
        for _ in range(batch_size):
            while True:
                # Sample a value uniformly from the unit interval
                unit_u_sample = random.random()  # ~U[0,1]
                # Sample a number in {0,1} to decide whether to sample a b/g transition
                is_g = np.random.randint(2)
                p_total = g_p_total if is_g else b_p_total
                # Scale the sampled value to the total sum of priorities
                u_sample = unit_u_sample * p_total
                # Retrieve the transition index associated with `u_sample`
                # i.e. in which block the sample landed
                b_o_g_sum_st = self.g_sum_st if is_g else self.b_sum_st
                transition_idx = b_o_g_sum_st.find_prefixsum_idx(u_sample)
                if not no_repeats or transition_idx not in transition_idxs:
                    transition_idxs.append(transition_idx)
                    break
        return np.array(transition_idxs)

    def sample(self, batch_size):
        return super()._sample(batch_size, self._sample_unreal)

    def n_step_lookahead_sample(self, batch_size, n, gamma):
        return super().n_step_lookahead_sample(batch_size, n, gamma)

    def append(self, *args, **kwargs):
        super().append(*args, **kwargs)
        idx = self.latest_entry_idx
        # Add newly added elements to 'good' and 'bad' virtual sub-buffer
        self.b_sum_st[idx] = 1
        self.g_sum_st[idx] = 1

    def update_priorities(self, idxs, priorities):
        # Update priorities via the legacy method, used w/ ranked approach
        idxs, priorities = super().update_priorities(idxs, priorities)
        # Register whether a transition is b/g in the UNREAL-specific sum trees
        for idx, priority in zipsame(idxs, priorities):
            # Decide whether the transition to be added is good or bad
            # Get the rank from the priority
            # Note: UnrealRB inherits from PER w/ 'ranked' set to True
            if idx < self.num_demos:
                # When the transition is from the demos, always set it as 'good' regardless
                self.b_sum_st[idx] = 0
                self.g_sum_st[idx] = 1
            else:
                rank = (1. / priority) - 1
                thres = floor(.5 * self.num_entries)
                is_g = rank < thres
                is_g *= 1  # HAXX: multiply by 1 to cast the bool into an int
                # Fill the good and bad sum segment trees w/ the obtained value
                self.b_sum_st[idx] = 1 - is_g
                self.g_sum_st[idx] = is_g
        if debug:
            # Verify updates
            # Compute the cardinalities of virtual sub-buffers
            b_num_entries = self.b_sum_st.sum(end=self.num_entries)
            g_num_entries = self.g_sum_st.sum(end=self.num_entries)
            print("[num entries]    b: {}    | g: {}".format(b_num_entries, g_num_entries))
            print("total num entries: {}".format(self.num_entries))

    def __repr__(self):
        fmt = "UnrealRB(limit={}, ob_shape={}, ac_shape={}, max_priority={})"
        return fmt.format(self.limit, self.ob_shape, self.ac_shape, self.max_priority)
예제 #8
0
class PrioritizedRB(RB):
    """'Prioritized Experience Replay' replay buffer implementation
    Reference: https://arxiv.org/pdf/1511.05952.pdf
    """

    def __init__(self, limit, ob_shape, ac_shape, alpha, beta, demos_eps=0.1,
                 ranked=False, max_priority=1.0):
        """`alpha` determines how much prioritization is used
        0: none, equivalent to uniform sampling
        1: full prioritization
        `beta` (defined in `__init__`) represents to what degree importance weights are used.
        """
        super(PrioritizedRB, self).__init__(limit, ob_shape, ac_shape)
        assert 0. <= alpha <= 1.
        assert beta > 0, "beta must be positive"
        self.alpha = alpha
        self.beta = beta
        self.max_priority = max_priority
        # Calculate the segment tree capacity suited to the user-specified limit
        self.st_cap = segment_tree_capacity(limit)
        # Create segment tree objects as data collection structure for priorities.
        # It provides an efficient way of calculating a cumulative sum of priorities
        self.sum_st = SumSegmentTree(self.st_cap)  # with `operator.add` operation
        self.min_st = MinSegmentTree(self.st_cap)  # with `min` operation
        # Whether it is the ranked version or not
        self.ranked = ranked
        if self.ranked:
            # Create a dict that will contain all the (index, priority) pairs
            self.i_p = {}
        # Define the priority bonus assigned to demos stored in the replay buffer (if stored)
        self.demos_eps = demos_eps

    def _sample_w_priorities(self, batch_size):
        """Sample in proportion to priorities, implemented w/ segment tree.
        This function samples a batch of transitions indices, directly from the priorities.
        Segment trees enable the emulation of a categorical sampling process by
        relying on what they shine at: computing cumulative sums.
        Imagine a stack of blocks.
        Each block (transition) has a height equal to its priority.
        The total height therefore is the sum of all priorities, `p_total`.
        `u_sample` is sampled from a U[0,1]. `u_sample * p_total` consequently is
        a value uniformly sampled from U[0,p_total], a height on the stacked blocks.
        `find_prefixsum_idx` returns the highest index (block id) such that the sum of
        preceeding priorities (block heights) is <= to the uniformly sampled height.
        The process is equivalent to sampling from a categorical distribution over
        the transitions (it might even be how some library implement categorical sampling).
        Since the height is sampled uniformly, the prob of landing in a block is proportional
        to the height of said block. The height being the priority value, the higher the
        priority, the higher the prob of being selected.
        """
        assert self.num_entries

        transition_idxs = []
        # Sum the priorities of the transitions currently in the buffer
        p_total = self.sum_st.sum(end=self.num_entries - 1)
        # `start` is 0 by default, `end` is length - 1
        # Divide equally into `batch_size` ranges (appendix B.2.1)
        p_pieces = p_total / batch_size
        # Sample `batch_size` samples independently, each from within the associated range
        # which is referred to as 'stratified sampling'
        for i in range(batch_size):
            # Sample a value uniformly from the unit interval
            unit_u_sample = random.random()  # ~U[0,1]
            # Scale and shift the sampled value to be within the range of cummulative priorities
            u_sample = (unit_u_sample * p_pieces) + (i * p_pieces)
            # Retrieve the transition index associated with `u_sample`
            # i.e. in which block the sample landed
            transition_idx = self.sum_st.find_prefixsum_idx(u_sample)
            transition_idxs.append(transition_idx)
        return np.array(transition_idxs)

    def _sample(self, batch_size, sampling_fn):
        """Sample from the replay buffer according to assigned priorities
        while using importance weights to offset the biasing effect of non-uniform sampling.
        `beta` (defined in `__init__`) represents to what degree importance weights are used.
        """
        # Sample transition idxs according to the samplign function
        idxs = sampling_fn(batch_size=batch_size)

        # Initialize importance weights
        iws = []
        # Create var for lowest sampling prob among transitions currently in the buffer,
        # equal to lowest priority divided by the sum of all priorities
        lowest_prob = self.min_st.min(end=self.num_entries) / self.sum_st.sum(end=self.num_entries)
        # Create for maximum weight var for weight scaling purposes (eq in 3.4. PER paper)
        max_weight = (self.num_entries * lowest_prob) ** (-self.beta)

        # Create a weight for every selected transition
        for idx in idxs:
            # Compute the probability assigned to the transition
            prob_transition = self.sum_st[idx] / self.sum_st.sum(end=self.num_entries)
            # Compute the transition weight
            weight_transition = (self.num_entries * prob_transition) ** (-self.beta)
            iws.append(weight_transition / max_weight)

        # Collect batch of transitions w/ iws and indices
        weighted_transitions = super().batchify(idxs)
        weighted_transitions['iws'] = array_min2d(np.array(iws))
        weighted_transitions['idxs'] = np.array(idxs)
        return weighted_transitions

    def sample(self, batch_size):
        return self._sample(batch_size, self._sample_w_priorities)

    def sample_uniform(self, batch_size):
        return super().sample(batch_size=batch_size)

    def n_step_lookahead_sample(self, batch_size, n, gamma):
        """Sample from the replay buffer according to assigned priorities.
        This function is for n-step TD backups, where n > 1
        """
        assert n > 1
        # Sample a batch of transitions
        transitions = self.sample(batch_size=batch_size)
        # Expand each transition w/ a n-step TD lookahead
        lookahead_batch = super().lookahead(transitions=transitions, n=n, gamma=gamma)
        # Add iws and indices to the dict
        lookahead_batch['iws'] = transitions['iws']
        lookahead_batch['idxs'] = transitions['idxs']
        return lookahead_batch

    def append(self, *args, **kwargs):
        super().append(*args, **kwargs)
        idx = self.latest_entry_idx
        # Assign highest priority value to newly added elements (line 6 alg PER paper)
        self.sum_st[idx] = self.max_priority ** self.alpha
        self.min_st[idx] = self.max_priority ** self.alpha

    def update_priorities(self, idxs, priorities):
        """Update priorities according to the PER paper, i.e. by updating
        only the priority of sampled transitions. A priority priorities[i] is
        assigned to the transition at index indices[i].
        Note: not in use in the vanilla setting, but here if needed in extensions.
        """
        global debug
        if self.ranked:
            # Override the priorities to be 1 / (rank(priority) + 1)
            # Add new index, priority pairs to the list
            self.i_p.update({i: p for i, p in zipsame(idxs, priorities)})
            # Rank the indices by priorities
            i_sorted_by_p = sorted(self.i_p.items(), key=lambda t: t[1], reverse=True)
            # Create the index, rank dict
            i_r = {i: i_sorted_by_p.index((i, p)) for i, p in self.i_p.items()}
            # Unpack indices and ranks
            _idxs, ranks = zipsame(*i_r.items())
            # Override the indices and priorities
            idxs = list(_idxs)
            priorities = [1. / (rank + 1) for rank in ranks]  # start ranks at 1
            if debug:
                # Verify that the priorities have been properly overridden
                for idx, priority in zipsame(idxs, priorities):
                    print("index: {}    | priority: {}".format(idx, priority))

        assert len(idxs) == len(priorities), "the two arrays must be the same length"
        for idx, priority in zipsame(idxs, priorities):
            assert priority > 0, "priorities must be positive"
            assert 0 <= idx < self.num_entries, "no element in buffer associated w/ index"
            if idx < self.num_demos:
                # Add a priority bonus when replaying a demo
                priority += self.demos_eps
            self.sum_st[idx] = priority ** self.alpha
            self.min_st[idx] = priority ** self.alpha
            # Update max priority currently in the buffer
            self.max_priority = max(priority, self.max_priority)

        if self.ranked:
            # Return indices and associated overriden priorities
            # Note: returned values are only used in the UNREAL priority update function
            return idxs, priorities

    def __repr__(self):
        fmt = "PrioritizedRB(limit={}, ob_shape={}, ac_shape={}, alpha={}, beta={}, "
        fmt += "demos_eps={}, ranked={}, max_priority={})"
        return fmt.format(self.limit, self.ob_shape, self.ac_shape, self.alpha, self.beta,
                          self.demos_eps, self.ranked, self.max_priority)