def __init__(self, limit, ob_shape, ac_shape, alpha, beta, demos_eps=0.1, ranked=False, max_priority=1.0): """`alpha` determines how much prioritization is used 0: none, equivalent to uniform sampling 1: full prioritization `beta` (defined in `__init__`) represents to what degree importance weights are used. """ super(PrioritizedRB, self).__init__(limit, ob_shape, ac_shape) assert 0. <= alpha <= 1. assert beta > 0, "beta must be positive" self.alpha = alpha self.beta = beta self.max_priority = max_priority # Calculate the segment tree capacity suited to the user-specified limit self.st_cap = segment_tree_capacity(limit) # Create segment tree objects as data collection structure for priorities. # It provides an efficient way of calculating a cumulative sum of priorities self.sum_st = SumSegmentTree(self.st_cap) # with `operator.add` operation self.min_st = MinSegmentTree(self.st_cap) # with `min` operation # Whether it is the ranked version or not self.ranked = ranked if self.ranked: # Create a dict that will contain all the (index, priority) pairs self.i_p = {} # Define the priority bonus assigned to demos stored in the replay buffer (if stored) self.demos_eps = demos_eps
def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0): """Reuse of the 'PrioritizedRB' constructor w/: - `alpha` arbitrarily set to 1. (unused) - `beta` arbitrarily set to 1. (unused) - `ranked` set to True (necessary to have access to the ranks) """ super(UnrealRB, self).__init__(limit, ob_shape, ac_shape, 1., 1, True, max_priority) # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones self.b_sum_st = SumSegmentTree(self.st_cap) # with `operator.add` operation self.g_sum_st = SumSegmentTree(self.st_cap) # with `operator.add` operation
def test_prefixsum_idx(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert tree.find_prefixsum_idx(0.0) == 2 assert tree.find_prefixsum_idx(0.5) == 2 assert tree.find_prefixsum_idx(0.99) == 2 assert tree.find_prefixsum_idx(1.01) == 3 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(4.00) == 3
def test_prefixsum_idx2(): tree = SumSegmentTree(4) tree[0] = 0.5 tree[1] = 1.0 tree[2] = 1.0 tree[3] = 3.0 assert tree.find_prefixsum_idx(0.00) == 0 assert tree.find_prefixsum_idx(0.55) == 1 assert tree.find_prefixsum_idx(0.99) == 1 assert tree.find_prefixsum_idx(1.51) == 2 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(5.50) == 3
def test_tree_set(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert np.isclose(tree.sum(), 4.0) assert np.isclose(tree.sum(0, 2), 0.0) assert np.isclose(tree.sum(0, 3), 1.0) assert np.isclose(tree.sum(2, 3), 1.0) assert np.isclose(tree.sum(2, -1), 1.0) assert np.isclose(tree.sum(2, 4), 4.0)
def test_tree_set_overlap(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[2] = 3.0 assert np.isclose(tree.sum(), 3.0) assert np.isclose(tree.sum(2, 3), 3.0) assert np.isclose(tree.sum(2, -1), 3.0) assert np.isclose(tree.sum(2, 4), 3.0) assert np.isclose(tree.sum(1, 2), 0.0)
class UnrealRB(PrioritizedRB): """'Reinforcement Learning w/ unsupervised Auxiliary Tasks' replay buffer implementation Reference: https://arxiv.org/pdf/1611.05397.pdf """ def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0): """Reuse of the 'PrioritizedRB' constructor w/: - `alpha` arbitrarily set to 1. (unused) - `beta` arbitrarily set to 1. (unused) - `ranked` set to True (necessary to have access to the ranks) """ super(UnrealRB, self).__init__(limit, ob_shape, ac_shape, 1., 1, True, max_priority) # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones self.b_sum_st = SumSegmentTree(self.st_cap) # with `operator.add` operation self.g_sum_st = SumSegmentTree(self.st_cap) # with `operator.add` operation def _sample_unreal(self, batch_size): """Sample uniformly from which virtual sub-buffer to pick: bad or good transitions, then sample uniformly a transition from the previously picked virtual sub-buffer. Implemented w/ segment tree. Since `b_sum_st` and `g_sum_st` contain priorities in {0,1}, sampling according to priorities samples a transition uniformly from the transitions having priority 1 in the previously sampled virtual sub-buffer (bad or good transitions). Note: the priorities were used (in `update_priorities`) to determine in which virtual sub-buffer a transition belongs. """ factor = 5 # hyperparam? assert factor > 1 assert self.num_entries transition_idxs = [] # Sum the priorities (in {0,1}) of the transitions currently in the buffers # Since values are in {0,1}, the sums correspond to cardinalities (#b and #g) b_p_total = self.b_sum_st.sum(end=self.num_entries) g_p_total = self.g_sum_st.sum(end=self.num_entries) # `start` is 0 by default, `end` is length - 1 (classic python) # Ensure no repeats once the number of memory entries is high enough no_repeats = self.num_entries > factor * batch_size for _ in range(batch_size): while True: # Sample a value uniformly from the unit interval unit_u_sample = random.random() # ~U[0,1] # Sample a number in {0,1} to decide whether to sample a b/g transition is_g = np.random.randint(2) p_total = g_p_total if is_g else b_p_total # Scale the sampled value to the total sum of priorities u_sample = unit_u_sample * p_total # Retrieve the transition index associated with `u_sample` # i.e. in which block the sample landed b_o_g_sum_st = self.g_sum_st if is_g else self.b_sum_st transition_idx = b_o_g_sum_st.find_prefixsum_idx(u_sample) if not no_repeats or transition_idx not in transition_idxs: transition_idxs.append(transition_idx) break return np.array(transition_idxs) def sample(self, batch_size): return super()._sample(batch_size, self._sample_unreal) def n_step_lookahead_sample(self, batch_size, n, gamma): return super().n_step_lookahead_sample(batch_size, n, gamma) def append(self, *args, **kwargs): super().append(*args, **kwargs) idx = self.latest_entry_idx # Add newly added elements to 'good' and 'bad' virtual sub-buffer self.b_sum_st[idx] = 1 self.g_sum_st[idx] = 1 def update_priorities(self, idxs, priorities): # Update priorities via the legacy method, used w/ ranked approach idxs, priorities = super().update_priorities(idxs, priorities) # Register whether a transition is b/g in the UNREAL-specific sum trees for idx, priority in zipsame(idxs, priorities): # Decide whether the transition to be added is good or bad # Get the rank from the priority # Note: UnrealRB inherits from PER w/ 'ranked' set to True if idx < self.num_demos: # When the transition is from the demos, always set it as 'good' regardless self.b_sum_st[idx] = 0 self.g_sum_st[idx] = 1 else: rank = (1. / priority) - 1 thres = floor(.5 * self.num_entries) is_g = rank < thres is_g *= 1 # HAXX: multiply by 1 to cast the bool into an int # Fill the good and bad sum segment trees w/ the obtained value self.b_sum_st[idx] = 1 - is_g self.g_sum_st[idx] = is_g if debug: # Verify updates # Compute the cardinalities of virtual sub-buffers b_num_entries = self.b_sum_st.sum(end=self.num_entries) g_num_entries = self.g_sum_st.sum(end=self.num_entries) print("[num entries] b: {} | g: {}".format(b_num_entries, g_num_entries)) print("total num entries: {}".format(self.num_entries)) def __repr__(self): fmt = "UnrealRB(limit={}, ob_shape={}, ac_shape={}, max_priority={})" return fmt.format(self.limit, self.ob_shape, self.ac_shape, self.max_priority)
class PrioritizedRB(RB): """'Prioritized Experience Replay' replay buffer implementation Reference: https://arxiv.org/pdf/1511.05952.pdf """ def __init__(self, limit, ob_shape, ac_shape, alpha, beta, demos_eps=0.1, ranked=False, max_priority=1.0): """`alpha` determines how much prioritization is used 0: none, equivalent to uniform sampling 1: full prioritization `beta` (defined in `__init__`) represents to what degree importance weights are used. """ super(PrioritizedRB, self).__init__(limit, ob_shape, ac_shape) assert 0. <= alpha <= 1. assert beta > 0, "beta must be positive" self.alpha = alpha self.beta = beta self.max_priority = max_priority # Calculate the segment tree capacity suited to the user-specified limit self.st_cap = segment_tree_capacity(limit) # Create segment tree objects as data collection structure for priorities. # It provides an efficient way of calculating a cumulative sum of priorities self.sum_st = SumSegmentTree(self.st_cap) # with `operator.add` operation self.min_st = MinSegmentTree(self.st_cap) # with `min` operation # Whether it is the ranked version or not self.ranked = ranked if self.ranked: # Create a dict that will contain all the (index, priority) pairs self.i_p = {} # Define the priority bonus assigned to demos stored in the replay buffer (if stored) self.demos_eps = demos_eps def _sample_w_priorities(self, batch_size): """Sample in proportion to priorities, implemented w/ segment tree. This function samples a batch of transitions indices, directly from the priorities. Segment trees enable the emulation of a categorical sampling process by relying on what they shine at: computing cumulative sums. Imagine a stack of blocks. Each block (transition) has a height equal to its priority. The total height therefore is the sum of all priorities, `p_total`. `u_sample` is sampled from a U[0,1]. `u_sample * p_total` consequently is a value uniformly sampled from U[0,p_total], a height on the stacked blocks. `find_prefixsum_idx` returns the highest index (block id) such that the sum of preceeding priorities (block heights) is <= to the uniformly sampled height. The process is equivalent to sampling from a categorical distribution over the transitions (it might even be how some library implement categorical sampling). Since the height is sampled uniformly, the prob of landing in a block is proportional to the height of said block. The height being the priority value, the higher the priority, the higher the prob of being selected. """ assert self.num_entries transition_idxs = [] # Sum the priorities of the transitions currently in the buffer p_total = self.sum_st.sum(end=self.num_entries - 1) # `start` is 0 by default, `end` is length - 1 # Divide equally into `batch_size` ranges (appendix B.2.1) p_pieces = p_total / batch_size # Sample `batch_size` samples independently, each from within the associated range # which is referred to as 'stratified sampling' for i in range(batch_size): # Sample a value uniformly from the unit interval unit_u_sample = random.random() # ~U[0,1] # Scale and shift the sampled value to be within the range of cummulative priorities u_sample = (unit_u_sample * p_pieces) + (i * p_pieces) # Retrieve the transition index associated with `u_sample` # i.e. in which block the sample landed transition_idx = self.sum_st.find_prefixsum_idx(u_sample) transition_idxs.append(transition_idx) return np.array(transition_idxs) def _sample(self, batch_size, sampling_fn): """Sample from the replay buffer according to assigned priorities while using importance weights to offset the biasing effect of non-uniform sampling. `beta` (defined in `__init__`) represents to what degree importance weights are used. """ # Sample transition idxs according to the samplign function idxs = sampling_fn(batch_size=batch_size) # Initialize importance weights iws = [] # Create var for lowest sampling prob among transitions currently in the buffer, # equal to lowest priority divided by the sum of all priorities lowest_prob = self.min_st.min(end=self.num_entries) / self.sum_st.sum(end=self.num_entries) # Create for maximum weight var for weight scaling purposes (eq in 3.4. PER paper) max_weight = (self.num_entries * lowest_prob) ** (-self.beta) # Create a weight for every selected transition for idx in idxs: # Compute the probability assigned to the transition prob_transition = self.sum_st[idx] / self.sum_st.sum(end=self.num_entries) # Compute the transition weight weight_transition = (self.num_entries * prob_transition) ** (-self.beta) iws.append(weight_transition / max_weight) # Collect batch of transitions w/ iws and indices weighted_transitions = super().batchify(idxs) weighted_transitions['iws'] = array_min2d(np.array(iws)) weighted_transitions['idxs'] = np.array(idxs) return weighted_transitions def sample(self, batch_size): return self._sample(batch_size, self._sample_w_priorities) def sample_uniform(self, batch_size): return super().sample(batch_size=batch_size) def n_step_lookahead_sample(self, batch_size, n, gamma): """Sample from the replay buffer according to assigned priorities. This function is for n-step TD backups, where n > 1 """ assert n > 1 # Sample a batch of transitions transitions = self.sample(batch_size=batch_size) # Expand each transition w/ a n-step TD lookahead lookahead_batch = super().lookahead(transitions=transitions, n=n, gamma=gamma) # Add iws and indices to the dict lookahead_batch['iws'] = transitions['iws'] lookahead_batch['idxs'] = transitions['idxs'] return lookahead_batch def append(self, *args, **kwargs): super().append(*args, **kwargs) idx = self.latest_entry_idx # Assign highest priority value to newly added elements (line 6 alg PER paper) self.sum_st[idx] = self.max_priority ** self.alpha self.min_st[idx] = self.max_priority ** self.alpha def update_priorities(self, idxs, priorities): """Update priorities according to the PER paper, i.e. by updating only the priority of sampled transitions. A priority priorities[i] is assigned to the transition at index indices[i]. Note: not in use in the vanilla setting, but here if needed in extensions. """ global debug if self.ranked: # Override the priorities to be 1 / (rank(priority) + 1) # Add new index, priority pairs to the list self.i_p.update({i: p for i, p in zipsame(idxs, priorities)}) # Rank the indices by priorities i_sorted_by_p = sorted(self.i_p.items(), key=lambda t: t[1], reverse=True) # Create the index, rank dict i_r = {i: i_sorted_by_p.index((i, p)) for i, p in self.i_p.items()} # Unpack indices and ranks _idxs, ranks = zipsame(*i_r.items()) # Override the indices and priorities idxs = list(_idxs) priorities = [1. / (rank + 1) for rank in ranks] # start ranks at 1 if debug: # Verify that the priorities have been properly overridden for idx, priority in zipsame(idxs, priorities): print("index: {} | priority: {}".format(idx, priority)) assert len(idxs) == len(priorities), "the two arrays must be the same length" for idx, priority in zipsame(idxs, priorities): assert priority > 0, "priorities must be positive" assert 0 <= idx < self.num_entries, "no element in buffer associated w/ index" if idx < self.num_demos: # Add a priority bonus when replaying a demo priority += self.demos_eps self.sum_st[idx] = priority ** self.alpha self.min_st[idx] = priority ** self.alpha # Update max priority currently in the buffer self.max_priority = max(priority, self.max_priority) if self.ranked: # Return indices and associated overriden priorities # Note: returned values are only used in the UNREAL priority update function return idxs, priorities def __repr__(self): fmt = "PrioritizedRB(limit={}, ob_shape={}, ac_shape={}, alpha={}, beta={}, " fmt += "demos_eps={}, ranked={}, max_priority={})" return fmt.format(self.limit, self.ob_shape, self.ac_shape, self.alpha, self.beta, self.demos_eps, self.ranked, self.max_priority)