def __init__(self, size, keys, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size, keys) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0
def __init__(self, limit, pre_load_data, alpha=.4, start_beta=1., end_beta=1., steps_annealed=1, **kwargs): super(PartitionedMemory, self).__init__(**kwargs) #The capacity of the replay buffer self.limit = limit #Transitions are stored in individual PartitionedRingBuffers. self.actions = PartitionedRingBuffer(limit) self.rewards = PartitionedRingBuffer(limit) self.terminals = PartitionedRingBuffer(limit) self.observations = PartitionedRingBuffer(limit) self.exps = PartitionedRingBuffer(limit) assert alpha >= 0 #how aggressively to sample based on TD error self.alpha = alpha #how aggressively to compensate for that sampling. self.start_beta = start_beta self.end_beta = end_beta self.steps_annealed = steps_annealed #SegmentTrees need a leaf count that is a power of 2 tree_capacity = 1 while tree_capacity < self.limit: tree_capacity *= 2 #Create SegmentTrees with this capacity self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) self.max_priority = 1. #unpack the expert transitions (assumes order recorded by the rl.utils.record_demo_data() method) demo_obs, demo_acts, demo_rews, demo_ts, demo_exps = [], [], [], [], [] self.pre_load_data = pre_load_data for demo in self.pre_load_data: demo_obs.append(demo[0]) demo_acts.append(demo[1]) demo_rews.append(demo[2]) demo_ts.append(demo[3]) demo_exps.append(1) #pre-load the demonstration data self.observations.load(demo_obs) self.actions.load(demo_acts) self.rewards.load(demo_rews) self.terminals.load(demo_ts) self.exps.load(demo_exps) self.permanent_idx = self.observations.permanent_idx assert self.permanent_idx == self.rewards.permanent_idx self.next_index = 0 for idx in range(self.permanent_idx): self.sum_tree[idx] = (self.max_priority ** self.alpha) self.min_tree[idx] = (self.max_priority ** self.alpha)
def __init__(self, size, T_max, learn_start): self._storage = [] self._maxsize = size self._next_idx = 0 # it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._sumTree = SumSegmentTree(it_capacity) self._minTree = MinSegmentTree(it_capacity) self._max_priority = 1.0 # self.e = 0.01 self.alpha = 0.5 # tradeoff between taking only experience with high-priority samples self.beta = 0.4 # Importance Sampling, from 0.4 -> 1.0 over the course of training self.beta_increment = (1 - self.beta) / (T_max - learn_start) self.abs_error_clipUpper = 1.0 self.NORMALIZE_BY_BATCH = False # In openAI baseline, normalize by whole
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, keys, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size, keys) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def add_batch(self, datas, batch_size): for i in range(batch_size): self.add({k: datas[k][i] for k in self.keys}) def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return encoded_sample, weights, idxes def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(): """ PrioritizedReplayBuffer From OpenAI Baseline """ def __init__(self, size, T_max, learn_start): self._storage = [] self._maxsize = size self._next_idx = 0 # it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._sumTree = SumSegmentTree(it_capacity) self._minTree = MinSegmentTree(it_capacity) self._max_priority = 1.0 # self.e = 0.01 self.alpha = 0.5 # tradeoff between taking only experience with high-priority samples self.beta = 0.4 # Importance Sampling, from 0.4 -> 1.0 over the course of training self.beta_increment = (1 - self.beta) / (T_max - learn_start) self.abs_error_clipUpper = 1.0 self.NORMALIZE_BY_BATCH = False # In openAI baseline, normalize by whole def __len__(self): return len(self._storage) def push(self, state, action, next_state, reward): idx = self._next_idx # # Setting maximum priority for new transitions. Total priority will be updated if next_state is not None: data = Transition(state.cpu(), action.cpu(), next_state.cpu(), reward.cpu()) else: data = Transition(state.cpu(), action.cpu(), None, reward.cpu()) # if self._next_idx >= len(self._storage): self._storage += data, else: self._storage[self._next_idx] = data self._next_idx = (self._next_idx + 1) % self._maxsize # self._sumTree[idx] = self._max_priority**self.alpha self._minTree[idx] = self._max_priority**self.alpha def sample(self, batch_size): # indices = self._sample_proportional(batch_size) indices = [] batch_sample = [] weights = [] # Increase the beta each time we sample a new mini-batch until it reaches 1.0 self.beta = min(self.beta + self.beta_increment, 1.0) # total_priority = self._sumTree.sum(0, len(self._storage) - 1) priority_segment = total_priority / batch_size # min_priority = self._minTree.min() / self._sumTree.sum() max_weight_ALL_memory = (min_priority * len(self._storage))**(-self.beta) # for i in range(batch_size): mass = (i + random.random()) * priority_segment index = self._sumTree.find_prefixsum_idx(mass) # P(j) --> stochastic priority stochastic_p = self._sumTree[index] / total_priority this_weight_IS = (stochastic_p * len(self._storage))**(-self.beta) """ Importance Sampling Weight: [ 1 1 ]^(beta) | --- * -----------| [ N prob_min ] """ this_weight_IS /= max_weight_ALL_memory # Append to list weights += this_weight_IS, batch_sample += self._storage[index], indices += index, # return batch_sample, indices, weights def update_priority_on_tree(self, tree_idx, abs_TD_errors): assert (len(tree_idx) == len(abs_TD_errors)) abs_TD_errors = np.nan_to_num(abs_TD_errors) + self.e abs_TD_errors = abs_TD_errors.tolist() # for index, priority in zip(tree_idx, abs_TD_errors): assert (priority > 0) assert (0 <= index <= len(self._storage)) self._sumTree[index] = priority**self.alpha self._minTree[index] = priority**self.alpha # self._max_priority = max(self._max_priority, priority)
class PartitionedMemory(Memory): def __init__(self, limit, pre_load_data, alpha=.4, start_beta=1., end_beta=1., steps_annealed=1, **kwargs): super(PartitionedMemory, self).__init__(**kwargs) #The capacity of the replay buffer self.limit = limit #Transitions are stored in individual PartitionedRingBuffers. self.actions = PartitionedRingBuffer(limit) self.rewards = PartitionedRingBuffer(limit) self.terminals = PartitionedRingBuffer(limit) self.observations = PartitionedRingBuffer(limit) self.exps = PartitionedRingBuffer(limit) assert alpha >= 0 #how aggressively to sample based on TD error self.alpha = alpha #how aggressively to compensate for that sampling. self.start_beta = start_beta self.end_beta = end_beta self.steps_annealed = steps_annealed #SegmentTrees need a leaf count that is a power of 2 tree_capacity = 1 while tree_capacity < self.limit: tree_capacity *= 2 #Create SegmentTrees with this capacity self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) self.max_priority = 1. #unpack the expert transitions (assumes order recorded by the rl.utils.record_demo_data() method) demo_obs, demo_acts, demo_rews, demo_ts, demo_exps = [], [], [], [], [] self.pre_load_data = pre_load_data for demo in self.pre_load_data: demo_obs.append(demo[0]) demo_acts.append(demo[1]) demo_rews.append(demo[2]) demo_ts.append(demo[3]) demo_exps.append(1) #pre-load the demonstration data self.observations.load(demo_obs) self.actions.load(demo_acts) self.rewards.load(demo_rews) self.terminals.load(demo_ts) self.exps.load(demo_exps) self.permanent_idx = self.observations.permanent_idx assert self.permanent_idx == self.rewards.permanent_idx self.next_index = 0 for idx in range(self.permanent_idx): self.sum_tree[idx] = (self.max_priority ** self.alpha) self.min_tree[idx] = (self.max_priority ** self.alpha) def append(self, observation, action, reward, terminal, expert_or_not, training=True): #super() call adds to the deques that hold the most recent info, which is fed to the agent #on agent.forward() super(PartitionedMemory, self).append(observation, action, reward, terminal, training=training) if training: self.observations.append(observation) self.actions.append(action) self.rewards.append(reward) self.terminals.append(terminal) self.exps.append(expert_or_not) #The priority of each new transition is set to the maximum self.sum_tree[self.next_index + self.permanent_idx] = self.max_priority ** self.alpha self.min_tree[self.next_index + self.permanent_idx] = self.max_priority ** self.alpha #shift tree pointer index to keep it in sync with RingBuffers self.next_index = ((self.next_index + 1) % (self.limit - self.permanent_idx)) def sample_proportional(self, batch_size): """ Outputs a list of idxs to sample, based on their priorities. This function is public in this memory (vs. private in Sequential and Prioritized), because DQfD needs to be able to sample the same idxs twice (single step and n-step). """ idxs = list() for _ in range(batch_size): mass = random.random() * self.sum_tree.sum(0, self.limit - 1) idx = self.sum_tree.find_prefixsum_idx(mass) idxs.append(idx) return idxs def sample(self, step, batch_size, n_step, gamma): current_beta = self.calculate_beta(step) # Sample from the memory. idxs = self.sample_proportional(batch_size) experiences, ntse, isW, _idx = self._sample(idxs, batch_size, current_beta) experiences_n, ntse, isW, _idx = self._sample(idxs, batch_size, current_beta, n_step, gamma) return idxs, [[i.state0, i.action, i.reward, i.state1, i.terminal1, i.expert, j.reward, j.state1, j.terminal1, k] for i,j,k in zip(experiences, experiences_n, ntse)], isW def _sample(self, idxs, batch_size, beta=1., nstep=1, gamma=1): """ Gathers transition data from the ring buffers. The PartitionedMemory separates generating the idxs and returning their transitions, allowing this method to be called multiple times with the same idxs. """ #importance sampling weights are a stability measure importance_weights = list() #The lowest-priority experience defines the maximum importance sampling weight prob_min = self.min_tree.min() / self.sum_tree.sum() max_importance_weight = (prob_min * self.nb_entries) ** (-beta) obs_t, act_t, rews, obs_t1, dones, nste = [], [], [], [], [], [] experiences = list() for idx in idxs: while idx < self.window_length + 1: idx += 1 while idx + nstep > self.nb_entries and self.nb_entries < self.limit: # We are fine with nstep spilling back to the beginning of the buffer # once it has been filled. idx -= 1 terminal0 = self.terminals[idx - 2] while terminal0: # Skip this transition because the environment was reset here. Select a new, random # transition and use this instead. This may cause the batch to contain the same # transition twice. idx = sample_batch_indexes(self.window_length + 1, self.nb_entries - nstep, size=1)[0] terminal0 = self.terminals[idx - 2] assert self.window_length + 1 <= idx < self.nb_entries #probability of sampling transition is the priority of the transition over the sum of all priorities prob_sample = self.sum_tree[idx] / self.sum_tree.sum() importance_weight = (prob_sample * self.nb_entries) ** (-beta) #normalize weights according to the maximum value importance_weights.append(importance_weight/max_importance_weight) #assemble the initial state from the ringbuffer. state0 = [self.observations[idx - 1]] for offset in range(0, self.window_length - 1): current_idx = idx - 2 - offset assert current_idx >= 1 current_terminal = self.terminals[current_idx - 1] if current_terminal and not self.ignore_episode_boundaries: # The previously handled observation was terminal, don't add the current one. # Otherwise we would leak into a different episode. break state0.insert(0, self.observations[current_idx]) while len(state0) < self.window_length: state0.insert(0, zeroed_observation(state0[0])) action = self.actions[idx - 1] # N-step TD reward = 0 nstep = nstep for i in range(nstep): reward += (gamma**i) * self.rewards[idx + i - 1] if self.terminals[idx + i - 1]: #episode terminated before length of n-step rollout. nstep = i break nste.append(nstep) terminal1 = self.terminals[idx + nstep - 1] expert1 = self.exps[idx + nstep - 1] # We assemble the second state in a similar way. state1 = [self.observations[idx + nstep - 1]] for offset in range(0, self.window_length - 1): current_idx = idx + nstep - 1 - offset assert current_idx >= 1 current_terminal = self.terminals[current_idx - 1] if current_terminal and not self.ignore_episode_boundaries: # The previously handled observation was terminal, don't add the current one. # Otherwise we would leak into a different episode. break state1.insert(0, self.observations[current_idx]) while len(state1) < self.window_length: state1.insert(0, zeroed_observation(state0[0])) assert len(state0) == self.window_length assert len(state1) == len(state0) experiences.append(Experience(state0=state0, action=action, reward=reward, state1=state1, terminal1=terminal1, expert= expert1)) assert len(experiences) == batch_size return list(experiences), nste, importance_weights, idxs def update_priorities(self, idxs, priorities): #adjust priorities based on new TD error for i, idx in enumerate(idxs): assert 0 <= idx < self.limit #expert transition priorities receive an extra boost if idx < self.permanent_idx: priority = (priorities[i] ** self.alpha) + .999 else: priority = (priorities[i] ** self.alpha) self.sum_tree[idx] = priority self.min_tree[idx] = priority self.max_priority = max(self.max_priority, priority) def calculate_beta(self, current_step): a = float(self.end_beta - self.start_beta) / float(self.steps_annealed) b = float(self.start_beta) current_beta = min(self.end_beta, a * float(current_step) + b) return current_beta def get_config(self): config = super(PrioritizedMemory, self).get_config() config['alpha'] = self.alpha config['start_beta'] = self.start_beta config['end_beta'] = self.end_beta config['beta_steps_annealed'] = self.steps_annealed config['pre_load_data'] = self.pre_load_data @property def nb_entries(self): """Return number of observations # Returns Number of observations """ return len(self.observations)