class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, max_size, alpha): """Create Prioritized Replay buffer. Parameters ---------- max_size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0: no prioritization, 1: full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(max_size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < max_size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, state, action, next_state, reward, done): """See ReplayBuffer.add""" idx = self.ptr super().add(state, action, next_state, reward, done) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0: no corrections, 1: full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple( list(encoded_sample) + [ torch.as_tensor( weights, device=self.device, dtype=torch.float32), idxes ]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. Sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha # update priority self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha self.it_capacity = 1 while self.it_capacity < size*2: # We use double the soft capacity of the PER for the segment trees to allow for any overflow over the soft capacity limit before samples are removed self.it_capacity *= 2 self._it_sum = SumSegmentTree(self.it_capacity) self._it_min = MinSegmentTree(self.it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): idx = self._next_idx assert idx < self.it_capacity, "Number of samples in replay memory exceeds capacity of segment trees. Please increase capacity of segment trees or increase the frequency at which samples are removed from the replay memory" super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha def remove(self, num_samples): super().remove(num_samples) self._it_sum.remove_items(num_samples) self._it_min.remove_items(num_samples) def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. gammas: np.array product of gammas for N-step returns weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage)) ** (-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage)) ** (-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority ** self._alpha self._it_min[idx] = priority ** self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): """Prioritized Replay buffer. Attributes: max_priority (float): max priority tree_ptr (int): next index of tree alpha (float): alpha parameter for prioritized replay buffer sum_tree (SumSegmentTree): sum tree for prior min_tree (MinSegmentTree): min tree for min prior to get max weight """ def __init__(self, obs_dim: int, size: int, batch_size: int, alpha: float = 0.6): """Initialization.""" assert alpha >= 0 super(PrioritizedReplayBuffer, self).__init__(obs_dim, size, batch_size) self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha # capacity must be positive and a power of 2. tree_capacity = 1 while tree_capacity < self.max_size: tree_capacity *= 2 self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) def store(self, obs: np.ndarray, act: int, rew: float, next_obs: np.ndarray, done: bool): """Store experience and priority.""" super().store(obs, act, rew, next_obs, done) self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha self.min_tree[self.tree_ptr] = self.max_priority**self.alpha self.tree_ptr = (self.tree_ptr + 1) % self.max_size def sample_batch(self, beta: float = 0.4) -> Dict[str, np.ndarray]: """Sample a batch of experiences.""" assert len(self) >= self.batch_size assert beta > 0 indices = self._sample_proportional() obs = self.obs_buf[indices] next_obs = self.next_obs_buf[indices] acts = self.acts_buf[indices] rews = self.rews_buf[indices] done = self.done_buf[indices] weights = np.array([self._calculate_weight(i, beta) for i in indices]) return dict( obs=obs, next_obs=next_obs, acts=acts, rews=rews, done=done, weights=weights, indices=indices, ) def update_priorities(self, indices: List[int], priorities: np.ndarray): """Update priorities of sampled transitions.""" assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): assert priority > 0 assert 0 <= idx < len(self) self.sum_tree[idx] = priority**self.alpha self.min_tree[idx] = priority**self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional(self) -> List[int]: """Sample indices based on proportions.""" indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / self.batch_size for i in range(self.batch_size): a = segment * i b = segment * (i + 1) upperbound = random.uniform(a, b) idx = self.sum_tree.retrieve(upperbound) indices.append(idx) return indices def _calculate_weight(self, idx: int, beta: float): """Calculate the weight of the experience at idx.""" # get max weight p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self))**(-beta) # calculate weights p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self))**(-beta) weight = weight / max_weight return weight
class PrioritizedReplayBufferTorch(ReplayBufferTorch): def __init__(self, size, alpha, device): #print(self.__mro__) super().__init__(size, device) assert alpha >= 0 self._alpha = alpha it_capacity = 2 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done): idx = self._next_idx data = (obs_t, action, reward, obs_tp1, done) if self._next_idx >= len(self._storage): self._storage.append(data) else: self._storage[self._next_idx] = data self._next_idx = (self._next_idx + 1) % self._maxsize self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = torch.tensor(weights, dtype=torch.float32, device=self.device) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): assert len(idxes) == len(priorities) assert all(0 <= x < len(self._storage) for x in idxes) assert (priorities > 0).all() self._max_priority = max(self._max_priority, max(priorities)) for idx, priority in zip(idxes, priorities): #assert priority > 0 #assert 0 <= idx < len(self._storage) #print(priority) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, args, buffer_id): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(args, buffer_id) assert self.args.alpha > 0 self._alpha = args.replay_alpha self._beta = args.replay_beta it_capacity = 1 while it_capacity < self.args.size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done, weight): """See ReplayBuffer.store_effect""" idx = self._next_idx super(PrioritizedReplayBuffer, self).add(obs_t, action, reward, obs_tp1, done, weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return np.array(res, dtype=np.int32) def sample_idxes(self, batch_size): return self._sample_proportional(batch_size) def sample_with_weights_and_idxes(self, idxes): weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-self._beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-self._beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return list(encoded_sample) + [weights, idxes] def sample(self, batch_size): idxes = self.sample_idxes(batch_size) return self.sample_with_weights_and_idxes(idxes) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, frame_history_len, alpha, num_branches, non_pixel_dimension, add_non_pixel=False): """ ---------- alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) """ super(PrioritizedReplayBuffer, self).__init__(size, frame_history_len, non_pixel_dimension, add_non_pixel) self.num_branches = num_branches assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array Array of shape (batch_size, img_c * frame_history_len, img_h, img_w) and dtype np.uint8 act_batch: np.array Array of shape (batch_size,) and dtype np.int32 rew_batch: np.array Array of shape (batch_size,) and dtype np.float32 next_obs_batch: np.array Array of shape (batch_size, img_c * frame_history_len, img_h, img_w) and dtype np.uint8 done_mask: np.array Array of shape (batch_size,) and dtype np.float32 weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * self.num_in_buffer)**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * self.num_in_buffer)**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def store_frame(self, frame, non_pixel_feature): """Store a single frame in the buffer at the next available index, overwriting old frames if necessary. Parameters ---------- frame: np.array Array of shape (img_h, img_w, img_c) and dtype np.uint8 the frame to be stored Returns ------- idx: int Index at which the frame is stored. To be used for `store_effect` later. """ # if observation is an image... if len(frame.shape) > 1: frame = frame.transpose(2, 0, 1) if self.obs is None: self.obs = np.empty([self.size] + list(frame.shape), dtype=np.uint8) self.action = np.empty([self.size, self.num_branches], dtype=np.int32) self.reward = np.empty([self.size], dtype=np.float32) self.done = np.empty([self.size], dtype=np.bool) if self.add_non_pixel: self.non_pixel_obs = np.empty( [self.size, self.non_pixel_dimension], dtype=np.float32) self.obs[self.next_idx] = frame if self.add_non_pixel: self.non_pixel_obs[self.next_idx] = non_pixel_feature ret = self.next_idx self.next_idx = (self.next_idx + 1) % self.size self.num_in_buffer = min(self.size, self.num_in_buffer + 1) return ret def store_effect(self, idx, action, reward, done): self.action[idx] = action self.reward[idx] = reward self.done[idx] = done self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): mass = random.random() * self._it_sum.sum(0, self.num_in_buffer - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < self.num_in_buffer self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayMemory(ReplayMemory): def __init__(self, alpha=0.6, capacity=100000, replace=False, tuple_class=Transition): super().__init__(capacity, replace, tuple_class) assert alpha >= 0 self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha # capacity must be positive and a power of 2. # tree_capacity = 1 # while tree_capacity < self.capacity: # tree_capacity *= 2 # Tree capacity has to be a power of 2 m = np.ceil(np.log(self.capacity) / np.log(2)) tree_capacity = np.power(2, m).astype(int) self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) def add(self, record): super().add(record) self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha self.min_tree[self.tree_ptr] = self.max_priority**self.alpha self.tree_ptr = (self.tree_ptr + 1) % self.capacity def sample(self, batch_size, beta: float = 0.4) -> Dict[str, np.ndarray]: """Sample a batch of experiences.""" assert len(self) >= batch_size assert beta > 0 indices = self._sample_proportional(batch_size) weights = np.array([self._calculate_weight(i, beta) for i in indices]) result = self._reformat(indices) result['indices'] = indices result['weights'] = weights return result def update_priorities(self, indices: List[int], priorities: np.ndarray): """Update priorities of sampled transitions.""" assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): assert priority > 0 assert 0 <= idx < len(self) self.sum_tree[idx] = priority**self.alpha self.min_tree[idx] = priority**self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional(self, batch_size) -> List[int]: """Sample indices based on proportions.""" indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / batch_size for i in range(batch_size): a = segment * i b = segment * (i + 1) upperbound = random.uniform(a, b) idx = self.sum_tree.retrieve(upperbound) indices.append(idx) return indices def _calculate_weight(self, idx: int, beta: float): """Calculate the weight of the experience at idx.""" # get max weight p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self))**(-beta) # calculate weights p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self))**(-beta) weight = weight / max_weight return weight