class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array([weights], dtype=np.float32) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
def test_tree_set_overlap(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[2] = 3.0 assert np.isclose(tree.sum(), 3.0) assert np.isclose(tree.sum(2, 3), 3.0) assert np.isclose(tree.sum(2, -1), 3.0) assert np.isclose(tree.sum(2, 4), 3.0) assert np.isclose(tree.sum(1, 2), 0.0)
def test_tree_set(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert np.isclose(tree.sum(), 4.0) assert np.isclose(tree.sum(0, 2), 0.0) assert np.isclose(tree.sum(0, 3), 1.0) assert np.isclose(tree.sum(2, 3), 1.0) assert np.isclose(tree.sum(2, -1), 1.0) assert np.isclose(tree.sum(2, 4), 4.0)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights, dtype=np.float32) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """ Create Prioritized Replay buffer. See Also ReplayBuffer.__init__ :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization) """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer :param obs_t: (Any) the last observation :param action: ([float]) the action :param reward: (float) the reward of the transition :param obs_tp1: (Any) the current observation :param done: (bool) is the episode done """ idx = self._next_idx super().add(obs_t, action, reward, obs_tp1, done) self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta=0): """ Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. :param batch_size: (int) How many transitions to sample. :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction) :return: - obs_batch: (np.ndarray) batch of observations - act_batch: (numpy float) batch of actions executed given obs_batch - rew_batch: (numpy float) rewards received as results of executing act_batch - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage)) ** (-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage)) ** (-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """ Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. :param idxes: ([int]) List of idxes of sampled transitions :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority ** self._alpha self._it_min[idx] = priority ** self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, capacity, alpha): """Create Prioritized Replay buffer. Parameters ---------- capacity: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(capacity) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < capacity: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def push(self, *args, **kwargs): idx = self._position super().push(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- transitions: [Transition] batch of transitions weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) transitions = self._retrieve_sample(idxes) return transitions, (weights, idxes) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): """Prioritized Replay buffer. Attributes: max_priority (float): max priority tree_ptr (int): next index of tree alpha (float): alpha parameter for prioritized replay buffer sum_tree (SumSegmentTree): sum tree for prior min_tree (MinSegmentTree): min tree for min prior to get max weight """ def __init__( self, obs_dim: int, size: int, batch_size: int = 32, alpha: float = 0.6 ): """Initialization.""" assert alpha >= 0 super(PrioritizedReplayBuffer, self).__init__(obs_dim, size, batch_size) self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha # capacity must be positive and a power of 2. tree_capacity = 1 while tree_capacity < self.max_size: tree_capacity *= 2 self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) def store( self, obs: np.ndarray, act: int, rew: float, next_obs: np.ndarray, done: bool ): """Store experience and priority.""" super().store(obs, act, rew, next_obs, done) self.sum_tree[self.tree_ptr] = self.max_priority ** self.alpha self.min_tree[self.tree_ptr] = self.max_priority ** self.alpha self.tree_ptr = (self.tree_ptr + 1) % self.max_size def sample_batch(self, beta: float = 0.4) -> Dict[str, np.ndarray]: """Sample a batch of experiences.""" assert len(self) >= self.batch_size assert beta > 0 indices = self._sample_proportional() obs = self.obs_buf[indices] next_obs = self.next_obs_buf[indices] acts = self.acts_buf[indices] rews = self.rews_buf[indices] done = self.done_buf[indices] weights = np.array([self._calculate_weight(i, beta) for i in indices]) return dict( obs=obs, next_obs=next_obs, acts=acts, rews=rews, done=done, weights=weights, indices=indices, ) def update_priorities(self, indices: List[int], priorities: np.ndarray): """Update priorities of sampled transitions.""" assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): assert priority > 0 assert 0 <= idx < len(self) self.sum_tree[idx] = priority ** self.alpha self.min_tree[idx] = priority ** self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional(self) -> List[int]: """Sample indices based on proportions.""" indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / self.batch_size for i in range(self.batch_size): a = segment * i b = segment * (i + 1) upperbound = random.uniform(a, b) idx = self.sum_tree.retrieve(upperbound) indices.append(idx) return indices def _calculate_weight(self, idx: int, beta: float): """Calculate the weight of the experience at idx.""" # get max weight p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self)) ** (-beta) # calculate weights p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self)) ** (-beta) weight = weight / max_weight return weight
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """ Create Prioritized Replay buffer. See Also ReplayBuffer.__init__ :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization) """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer :param obs_t: (Any) the last observation :param action: ([float]) the action :param reward: (float) the reward of the transition :param obs_tp1: (Any) the current observation :param done: (bool) is the episode done """ idx = self._next_idx super().add(obs_t, action, reward, obs_tp1, done) self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha def extend(self, obs_t, action, reward, obs_tp1, done): """ add a new batch of transitions to the buffer :param obs_t: (Union[Tuple[Union[np.ndarray, int]], np.ndarray]) the last batch of observations :param action: (Union[Tuple[Union[np.ndarray, int]]], np.ndarray]) the batch of actions :param reward: (Union[Tuple[float], np.ndarray]) the batch of the rewards of the transition :param obs_tp1: (Union[Tuple[Union[np.ndarray, int]], np.ndarray]) the current batch of observations :param done: (Union[Tuple[bool], np.ndarray]) terminal status of the batch Note: uses the same names as .add to keep compatibility with named argument passing but expects iterables and arrays with more than 1 dimensions """ idx = self._next_idx super().extend(obs_t, action, reward, obs_tp1, done) while idx != self._next_idx: self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha idx = (idx + 1) % self._maxsize def _sample_proportional(self, batch_size): mass = [] total = self._it_sum.sum(0, len(self._storage) - 1) # TODO(szymon): should we ensure no repeats? mass = np.random.random(size=batch_size) * total idx = self._it_sum.find_prefixsum_idx(mass) return idx def sample(self, batch_size: int, beta: float = 0, env: Optional[VecNormalize] = None): """ Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. :param batch_size: (int) How many transitions to sample. :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction) :param env: (Optional[VecNormalize]) associated gym VecEnv to normalize the observations/rewards when sampling :return: - obs_batch: (np.ndarray) batch of observations - act_batch: (numpy float) batch of actions executed given obs_batch - rew_batch: (numpy float) rewards received as results of executing act_batch - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage)) ** (-beta) p_sample = self._it_sum[idxes] / self._it_sum.sum() weights = (p_sample * len(self._storage)) ** (-beta) / max_weight encoded_sample = self._encode_sample(idxes, env=env) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """ Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. :param idxes: ([int]) List of idxes of sampled transitions :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) assert np.min(priorities) > 0 assert np.min(idxes) >= 0 assert np.max(idxes) < len(self.storage) self._it_sum[idxes] = priorities ** self._alpha self._it_min[idxes] = priorities ** self._alpha self._max_priority = max(self._max_priority, np.max(priorities))