def __init__(self, limit, ob_shape, ac_shape, alpha, beta, demos_eps=0.1, ranked=False, max_priority=1.0): """`alpha` determines how much prioritization is used 0: none, equivalent to uniform sampling 1: full prioritization `beta` (defined in `__init__`) represents to what degree importance weights are used. """ super(PrioritizedReplayBuffer, self).__init__(limit, ob_shape, ac_shape) assert 0. <= alpha <= 1. assert beta > 0, "beta must be positive" self.alpha = alpha self.beta = beta self.max_priority = max_priority # Calculate the segment tree capacity suited to the user-specified limit self.st_cap = segment_tree_capacity(limit) # Create segment tree objects as data collection structure for priorities. # It provides an efficient way of calculating a cumulative sum of priorities self.sum_st = SumSegmentTree( self.st_cap) # with `operator.add` operation self.min_st = MinSegmentTree(self.st_cap) # with `min` operation # Whether it is the ranked version or not self.ranked = ranked if self.ranked: # Create a dict that will contain all the (index, priority) pairs self.i_p = {} # Define the priority bonus assigned to demos stored in the replay buffer (if stored) self.demos_eps = demos_eps
def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0): """Reuse of the 'PrioritizedReplayBuffer' constructor w/: - `alpha` arbitrarily set to 1. (unused) - `beta` arbitrarily set to 1. (unused) - `ranked` set to True (necessary to have access to the ranks) """ super(UnrealReplayBuffer, self).__init__(limit, ob_shape, ac_shape, 1., 1, True, max_priority) # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones self.b_sum_st = SumSegmentTree( self.st_cap) # with `operator.add` operation self.g_sum_st = SumSegmentTree( self.st_cap) # with `operator.add` operation
def __init__(self, max_steps, num_processes, gamma, prio_alpha, obs_shape, action_space, recurrent_hidden_state_size, device): self.max_steps = int(max_steps) self.num_processes = num_processes self.gamma = gamma self.device = device # stored episode data self.obs = torch.zeros(self.max_steps, *obs_shape) self.recurrent_hidden_states = torch.zeros( self.max_steps, recurrent_hidden_state_size) self.returns = torch.zeros(self.max_steps, 1) if action_space.__class__.__name__ == 'Discrete': self.actions = torch.zeros(self.max_steps, 1).long() else: self.actions = torch.zeros(self.max_steps, action_space.shape[0]) self.masks = torch.ones(self.max_steps, 1) self.next_idx = 0 self.num_steps = 0 # store (full) episode stats self.episode_step_count = 0 self.episode_rewards = deque() self.episode_steps = deque() # currently running (accumulating) episodes self.running_episodes = [[] for _ in range(num_processes)] if prio_alpha > 0: """ Sampling priority is enabled if prio_alpha > 0 Priority algorithm ripped from OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py """ self.prio_alpha = prio_alpha tree_capacity = 1 << math.ceil(math.log2(self.max_steps)) self.prio_sum_tree = SumSegmentTree(tree_capacity) self.prio_min_tree = MinSegmentTree(tree_capacity) self.prio_max = 1.0 else: self.prio_alpha = 0
class ReplayStorage: def __init__(self, max_steps, num_processes, gamma, prio_alpha, obs_shape, action_space, recurrent_hidden_state_size, device): self.max_steps = int(max_steps) self.num_processes = num_processes self.gamma = gamma self.device = device # stored episode data self.obs = torch.zeros(self.max_steps, *obs_shape) self.recurrent_hidden_states = torch.zeros( self.max_steps, recurrent_hidden_state_size) self.returns = torch.zeros(self.max_steps, 1) if action_space.__class__.__name__ == 'Discrete': self.actions = torch.zeros(self.max_steps, 1).long() else: self.actions = torch.zeros(self.max_steps, action_space.shape[0]) self.masks = torch.ones(self.max_steps, 1) self.next_idx = 0 self.num_steps = 0 # store (full) episode stats self.episode_step_count = 0 self.episode_rewards = deque() self.episode_steps = deque() # currently running (accumulating) episodes self.running_episodes = [[] for _ in range(num_processes)] if prio_alpha > 0: """ Sampling priority is enabled if prio_alpha > 0 Priority algorithm ripped from OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py """ self.prio_alpha = prio_alpha tree_capacity = 1 << math.ceil(math.log2(self.max_steps)) self.prio_sum_tree = SumSegmentTree(tree_capacity) self.prio_min_tree = MinSegmentTree(tree_capacity) self.prio_max = 1.0 else: self.prio_alpha = 0 def _process_rewards(self, trajectory): has_positive = False reward_sum = 0. r = 0. for t in trajectory[::-1]: reward = t['reward'] reward_sum += reward if reward > (0. + 1e-5): has_positive = True r = reward + self.gamma * r t['return'] = r return has_positive, reward_sum def _add_trajectory(self, trajectory): has_positive, reward_sum = self._process_rewards(trajectory) if not has_positive: return trajectory_len = len(trajectory) prev_idx = self.next_idx for transition in trajectory: self.obs[self.next_idx].copy_(transition['obs']) self.recurrent_hidden_states[self.next_idx].copy_( transition['rhs']) self.actions[self.next_idx].copy_(transition['action']) self.returns[self.next_idx].copy_(transition['return']) self.masks[self.next_idx] = 1.0 prev_idx = self.next_idx if self.prio_alpha: self.prio_sum_tree[ self.next_idx] = self.prio_max**self.prio_alpha self.prio_min_tree[ self.next_idx] = self.prio_max**self.prio_alpha self.next_idx = (self.next_idx + 1) % self.max_steps self.num_steps = min(self.max_steps, self.num_steps + 1) self.masks[prev_idx] = 0.0 # update stats of stored full trajectories (episodes) while self.episode_step_count + trajectory_len > self.max_steps: steps_popped = self.episode_steps.popleft() self.episode_rewards.popleft() self.episode_step_count -= steps_popped self.episode_step_count += trajectory_len self.episode_steps.append(trajectory_len) self.episode_rewards.append(reward_sum) def _sample_proportional(self, sample_size): res = [] for _ in range(sample_size): mass = random.random() * self.prio_sum_tree.sum( 0, self.num_steps - 1) idx = self.prio_sum_tree.find_prefixsum_idx(mass) res.append(idx) return res def insert(self, obs, rhs, actions, rewards, dones): for n in range(self.num_processes): self.running_episodes[n].append( dict(obs=obs[n].clone(), rhs=rhs[n].clone(), action=actions[n].clone(), reward=rewards[n].clone())) for n, done in enumerate(dones): if done: self._add_trajectory(self.running_episodes[n]) self.running_episodes[n] = [] def update_priorities(self, indices, priorities): if not self.prio_alpha: return """Update priorities of sampled transitions. sets priority of transition at index indices[i] in buffer to priorities[i]. Parameters ---------- indices: [int] List of indices of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled indices. """ assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): priority = max(priority, 1e-6) assert priority > 0 assert 0 <= idx < self.num_steps self.prio_sum_tree[idx] = priority**self.prio_alpha self.prio_min_tree[idx] = priority**self.prio_alpha self.prio_max = max(self.prio_max, priority) def feed_forward_generator(self, batch_size, num_batches=None, beta=0.): """Generate batches of sampled experiences. Parameters ---------- batch_size: int Size of each sampled batch num_batches: int Number of batches to sample beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) """ batch_count = 0 sample_size = num_batches * batch_size or self.num_steps if self.prio_alpha > 0: indices = self._sample_proportional(sample_size) if beta > 0: # compute importance sampling weights to correct for the # bias introduced by sampling in a non-uniform manner weights = [] p_min = self.prio_min_tree.min() / self.prio_sum_tree.sum() max_weight = (p_min * self.num_steps)**(-beta) for i in indices: p_sample = self.prio_sum_tree[i] / self.prio_sum_tree.sum() weight = (p_sample * self.num_steps)**(-beta) weights.append(weight / max_weight) weights = torch.tensor(weights, dtype=torch.float32).unsqueeze(1) else: weights = torch.ones((len(indices), 1), dtype=torch.float32) else: if sample_size * 3 < self.num_steps: indices = random.sample(range(self.num_steps), sample_size) else: indices = np.random.permutation(self.num_steps)[:sample_size] weights = None for si in range(0, len(indices), batch_size): indices_batch = indices[si:min(len(indices), si + batch_size)] if len(indices_batch) < batch_size: return weights_batch = None if weights is None else \ weights[si:min(len(indices), si + batch_size)].to(self.device) obs_batch = self.obs[indices_batch].to(self.device) recurrent_hidden_states_batch = self.recurrent_hidden_states[ indices_batch].to(self.device) actions_batch = self.actions[indices_batch].to(self.device) returns_batch = self.returns[indices_batch].to(self.device) masks_batch = self.masks[indices_batch].to(self.device) yield obs_batch, recurrent_hidden_states_batch, actions_batch, returns_batch, \ masks_batch, weights_batch, indices_batch batch_count += 1 if num_batches and batch_count >= num_batches: return
class UnrealReplayBuffer(PrioritizedReplayBuffer): """'Reinforcement Learning w/ unsupervised Auxiliary Tasks' replay buffer implementation Reference: https://arxiv.org/pdf/1611.05397.pdf """ def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0): """Reuse of the 'PrioritizedReplayBuffer' constructor w/: - `alpha` arbitrarily set to 1. (unused) - `beta` arbitrarily set to 1. (unused) - `ranked` set to True (necessary to have access to the ranks) """ super(UnrealReplayBuffer, self).__init__(limit, ob_shape, ac_shape, 1., 1, True, max_priority) # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones self.b_sum_st = SumSegmentTree( self.st_cap) # with `operator.add` operation self.g_sum_st = SumSegmentTree( self.st_cap) # with `operator.add` operation def _sample_unreal(self, batch_size): """Sample uniformly from which virtual sub-buffer to pick: bad or good transitions, then sample uniformly a transition from the previously picked virtual sub-buffer. Implemented w/ segment tree. Since `b_sum_st` and `g_sum_st` contain priorities in {0,1}, sampling according to priorities samples a transition uniformly from the transitions having priority 1 in the previously sampled virtual sub-buffer (bad or good transitions). Note: the priorities were used (in `update_priorities`) to determine in which virtual sub-buffer a transition belongs. """ factor = 5 # hyperparam? assert factor > 1 assert self.num_entries transition_idxs = [] # Sum the priorities (in {0,1}) of the transitions currently in the buffers # Since values are in {0,1}, the sums correspond to cardinalities (#b and #g) b_p_total = self.b_sum_st.sum(end=self.num_entries) g_p_total = self.g_sum_st.sum(end=self.num_entries) # `start` is 0 by default, `end` is length - 1 (classic python) # Ensure no repeats once the number of memory entries is high enough no_repeats = self.num_entries > factor * batch_size for _ in range(batch_size): while True: # Sample a value uniformly from the unit interval unit_u_sample = random.random() # ~U[0,1] # Sample a number in {0,1} to decide whether to sample a b/g transition is_g = np.random.randint(2) p_total = g_p_total if is_g else b_p_total # Scale the sampled value to the total sum of priorities u_sample = unit_u_sample * p_total # Retrieve the transition index associated with `u_sample` # i.e. in which block the sample landed b_o_g_sum_st = self.g_sum_st if is_g else self.b_sum_st transition_idx = b_o_g_sum_st.find_prefixsum_idx(u_sample) if not no_repeats or transition_idx not in transition_idxs: transition_idxs.append(transition_idx) break return np.array(transition_idxs) def sample(self, batch_size): return super()._sample(batch_size, self._sample_unreal) def lookahead_sample(self, batch_size, n, gamma): return super().lookahead_sample(batch_size, n, gamma) def append(self, *args, **kwargs): super().append(*args, **kwargs) idx = self.latest_entry_idx # Add newly added elements to 'good' and 'bad' virtual sub-buffer self.b_sum_st[idx] = 1 self.g_sum_st[idx] = 1 def update_priorities(self, idxs, priorities): # Update priorities via the legacy method, used w/ ranked approach idxs, priorities = super().update_priorities(idxs, priorities) # Register whether a transition is b/g in the UNREAL-specific sum trees for idx, priority in zipsame(idxs, priorities): # Decide whether the transition to be added is good or bad # Get the rank from the priority # Note: UnrealReplayBuffer inherits from PER w/ 'ranked' set to True if idx < self.num_demos: # When the transition is from the demos, always set it as 'good' regardless self.b_sum_st[idx] = 0 self.g_sum_st[idx] = 1 else: rank = (1. / priority) - 1 thres = floor(.5 * self.num_entries) is_g = rank < thres is_g *= 1 # HAXX: multiply by 1 to cast the bool into an int # Fill the good and bad sum segment trees w/ the obtained value self.b_sum_st[idx] = 1 - is_g self.g_sum_st[idx] = is_g if debug: # Verify updates # Compute the cardinalities of virtual sub-buffers b_num_entries = self.b_sum_st.sum(end=self.num_entries) g_num_entries = self.g_sum_st.sum(end=self.num_entries) print("[num entries] b: {} | g: {}".format( b_num_entries, g_num_entries)) print("total num entries: {}".format(self.num_entries)) def __repr__(self): fmt = "UnrealReplayBuffer(limit={}, ob_shape={}, ac_shape={}, max_priority={})" return fmt.format(self.limit, self.ob_shape, self.ac_shape, self.max_priority)
class PrioritizedReplayBuffer(ReplayBuffer): """'Prioritized Experience Replay' replay buffer implementation Reference: https://arxiv.org/pdf/1511.05952.pdf """ def __init__(self, limit, ob_shape, ac_shape, alpha, beta, demos_eps=0.1, ranked=False, max_priority=1.0): """`alpha` determines how much prioritization is used 0: none, equivalent to uniform sampling 1: full prioritization `beta` (defined in `__init__`) represents to what degree importance weights are used. """ super(PrioritizedReplayBuffer, self).__init__(limit, ob_shape, ac_shape) assert 0. <= alpha <= 1. assert beta > 0, "beta must be positive" self.alpha = alpha self.beta = beta self.max_priority = max_priority # Calculate the segment tree capacity suited to the user-specified limit self.st_cap = segment_tree_capacity(limit) # Create segment tree objects as data collection structure for priorities. # It provides an efficient way of calculating a cumulative sum of priorities self.sum_st = SumSegmentTree( self.st_cap) # with `operator.add` operation self.min_st = MinSegmentTree(self.st_cap) # with `min` operation # Whether it is the ranked version or not self.ranked = ranked if self.ranked: # Create a dict that will contain all the (index, priority) pairs self.i_p = {} # Define the priority bonus assigned to demos stored in the replay buffer (if stored) self.demos_eps = demos_eps def _sample_w_priorities(self, batch_size): """Sample in proportion to priorities, implemented w/ segment tree. This function samples a batch of transitions indices, directly from the priorities. Segment trees enable the emulation of a categorical sampling process by relying on what they shine at: computing cumulative sums. Imagine a stack of blocks. Each block (transition) has a height equal to its priority. The total height therefore is the sum of all priorities, `p_total`. `u_sample` is sampled from a U[0,1]. `u_sample * p_total` consequently is a value uniformly sampled from U[0,p_total], a height on the stacked blocks. `find_prefixsum_idx` returns the highest index (block id) such that the sum of preceeding priorities (block heights) is <= to the uniformly sampled height. The process is equivalent to sampling from a categorical distribution over the transitions (it might even be how some library implement categorical sampling). Since the height is sampled uniformly, the prob of landing in a block is proportional to the height of said block. The height being the priority value, the higher the priority, the higher the prob of being selected. """ assert self.num_entries transition_idxs = [] # Sum the priorities of the transitions currently in the buffer p_total = self.sum_st.sum(end=self.num_entries - 1) # `start` is 0 by default, `end` is length - 1 # Divide equally into `batch_size` ranges (appendix B.2.1) p_pieces = p_total / batch_size # Sample `batch_size` samples independently, each from within the associated range # which is referred to as 'stratified sampling' for i in range(batch_size): # Sample a value uniformly from the unit interval unit_u_sample = random.random() # ~U[0,1] # Scale and shift the sampled value to be within the range of cummulative priorities u_sample = (unit_u_sample * p_pieces) + (i * p_pieces) # Retrieve the transition index associated with `u_sample` # i.e. in which block the sample landed transition_idx = self.sum_st.find_prefixsum_idx(u_sample) transition_idxs.append(transition_idx) return np.array(transition_idxs) def _sample(self, batch_size, sampling_fn): """Sample from the replay buffer according to assigned priorities while using importance weights to offset the biasing effect of non-uniform sampling. `beta` (defined in `__init__`) represents to what degree importance weights are used. """ # Sample transition idxs according to the samplign function idxs = sampling_fn(batch_size=batch_size) # Initialize importance weights iws = [] # Create var for lowest sampling prob among transitions currently in the buffer, # equal to lowest priority divided by the sum of all priorities lowest_prob = self.min_st.min(end=self.num_entries) / self.sum_st.sum( end=self.num_entries) # Create for maximum weight var for weight scaling purposes (eq in 3.4. PER paper) max_weight = (self.num_entries * lowest_prob)**(-self.beta) # Create a weight for every selected transition for idx in idxs: # Compute the probability assigned to the transition prob_transition = self.sum_st[idx] / self.sum_st.sum( end=self.num_entries) # Compute the transition weight weight_transition = (self.num_entries * prob_transition)**(-self.beta) iws.append(weight_transition / max_weight) # Collect batch of transitions w/ iws and indices weighted_transitions = super().batchify(idxs) weighted_transitions['iws'] = array_min2d(np.array(iws)) weighted_transitions['idxs'] = np.array(idxs) return weighted_transitions def sample(self, batch_size): return self._sample(batch_size, self._sample_w_priorities) def sample_uniform(self, batch_size): return super().sample(batch_size=batch_size) def lookahead_sample(self, batch_size, n, gamma): """Sample from the replay buffer according to assigned priorities. This function is for n-step TD backups, where n > 1 """ assert n > 1 # Sample a batch of transitions transitions = self.sample(batch_size=batch_size) # Expand each transition w/ a n-step TD lookahead lookahead_batch = super().lookahead(transitions=transitions, n=n, gamma=gamma) # Add iws and indices to the dict lookahead_batch['iws'] = transitions['iws'] lookahead_batch['idxs'] = transitions['idxs'] return lookahead_batch def append(self, *args, **kwargs): super().append(*args, **kwargs) idx = self.latest_entry_idx # Assign highest priority value to newly added elements (line 6 alg PER paper) self.sum_st[idx] = self.max_priority**self.alpha self.min_st[idx] = self.max_priority**self.alpha def update_priorities(self, idxs, priorities): """Update priorities according to the PER paper, i.e. by updating only the priority of sampled transitions. A priority priorities[i] is assigned to the transition at index indices[i]. Note: not in use in the vanilla setting, but here if needed in extensions. """ global debug if self.ranked: # Override the priorities to be 1 / (rank(priority) + 1) # Add new index, priority pairs to the list self.i_p.update({i: p for i, p in zipsame(idxs, priorities)}) # Rank the indices by priorities i_sorted_by_p = sorted(self.i_p.items(), key=lambda t: t[1], reverse=True) # Create the index, rank dict i_r = {i: i_sorted_by_p.index((i, p)) for i, p in self.i_p.items()} # Unpack indices and ranks _idxs, ranks = zipsame(*i_r.items()) # Override the indices and priorities idxs = list(_idxs) priorities = [1. / (rank + 1) for rank in ranks] # start ranks at 1 if debug: # Verify that the priorities have been properly overridden for idx, priority in zipsame(idxs, priorities): print("index: {} | priority: {}".format(idx, priority)) assert len(idxs) == len( priorities), "the two arrays must be the same length" for idx, priority in zipsame(idxs, priorities): assert priority > 0, "priorities must be positive" assert 0 <= idx < self.num_entries, "no element in buffer associated w/ index" if idx < self.num_demos: # Add a priority bonus when replaying a demo priority += self.demos_eps self.sum_st[idx] = priority**self.alpha self.min_st[idx] = priority**self.alpha # Update max priority currently in the buffer self.max_priority = max(priority, self.max_priority) if self.ranked: # Return indices and associated overriden priorities # Note: returned values are only used in the UNREAL priority update function return idxs, priorities def __repr__(self): fmt = "PrioritizedReplayBuffer(limit={}, ob_shape={}, ac_shape={}, alpha={}, beta={}, " fmt += "demos_eps={}, ranked={}, max_priority={})" return fmt.format(self.limit, self.ob_shape, self.ac_shape, self.alpha, self.beta, self.demos_eps, self.ranked, self.max_priority)