class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """ Create Prioritized Replay buffer. See Also ReplayBuffer.__init__ :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization) """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer :param obs_t: (Any) the last observation :param action: ([float]) the action :param reward: (float) the reward of the transition :param obs_tp1: (Any) the current observation :param done: (bool) is the episode done """ idx = self._next_idx super().add(obs_t, action, reward, obs_tp1, done) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta=0): """ Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. :param batch_size: (int) How many transitions to sample. :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction) :return: - obs_batch: (np.ndarray) batch of observations - act_batch: (numpy float) batch of actions executed given obs_batch - rew_batch: (numpy float) rewards received as results of executing act_batch - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """ Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. :param idxes: ([int]) List of idxes of sampled transitions :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class TwoWayPrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done): idx = self._next_idx super().add(obs_t, action, reward, obs_tp1, done) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def extend(self, obs_t, action, reward, obs_tp1, done): idx = self._next_idx super().extend(obs_t, action, reward, obs_tp1, done) while idx != self._next_idx: self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha idx = (idx + 1) % self._maxsize def _sample_proportional(self, batch_size): mass = [] total = self._it_sum.sum(0, len(self._storage) - 1) # TODO(szymon): should we ensure no repeats? mass = np.random.random(size=batch_size) * total idx = self._it_sum.find_prefixsum_idx(mass) return idx def sample(self, batch_size: int, beta: float = 0, env: Optional[VecNormalize] = None): assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) p_sample = self._it_sum[idxes] / self._it_sum.sum() weights = (p_sample * len(self._storage))**(-beta) / max_weight encoded_sample = self._encode_sample(idxes, env=env) return tuple(list(encoded_sample) + [weights, idxes]) def _sample_invproportional(self, batch_size): mass = [] total = self._it_sum.sum(len(self._storage) - 1, 0) # TODO(szymon): should we ensure no repeats? mass = np.random.random(size=batch_size) * total idx = self._it_sum.find_prefixsum_idx(mass) return idx def invsample(self, batch_size: int, beta: float = 0, env: Optional[VecNormalize] = None): assert beta > 0 idxes = self._sample_invproportional(batch_size) encoded_sample = self._encode_sample(idxes, env=env) return list(encoded_sample) def update_priorities(self, idxes, priorities): assert len(idxes) == len(priorities) assert np.min(priorities) > 0 assert np.min(idxes) >= 0 assert np.max(idxes) < len(self.storage) self._it_sum[idxes] = priorities**self._alpha self._it_min[idxes] = priorities**self._alpha self._max_priority = max(self._max_priority, np.max(priorities))
def test_max_interval_tree(): """ test Segment Tree data structure """ tree = MinSegmentTree(4) tree[0] = 1.0 tree[2] = 0.5 tree[3] = 3.0 assert np.isclose(tree.min(), 0.5) assert np.isclose(tree.min(0, 2), 1.0) assert np.isclose(tree.min(0, 3), 0.5) assert np.isclose(tree.min(0, -1), 0.5) assert np.isclose(tree.min(2, 4), 0.5) assert np.isclose(tree.min(3, 4), 3.0) tree[2] = 0.7 assert np.isclose(tree.min(), 0.7) assert np.isclose(tree.min(0, 2), 1.0) assert np.isclose(tree.min(0, 3), 0.7) assert np.isclose(tree.min(0, -1), 0.7) assert np.isclose(tree.min(2, 4), 0.7) assert np.isclose(tree.min(3, 4), 3.0) tree[2] = 4.0 assert np.isclose(tree.min(), 1.0) assert np.isclose(tree.min(0, 2), 1.0) assert np.isclose(tree.min(0, 3), 1.0) assert np.isclose(tree.min(0, -1), 1.0) assert np.isclose(tree.min(2, 4), 3.0) assert np.isclose(tree.min(2, 3), 4.0) assert np.isclose(tree.min(2, -1), 4.0) assert np.isclose(tree.min(3, 4), 3.0)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """ Create Prioritized Replay buffer. See Also ReplayBuffer.__init__ :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization) """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer :param obs_t: (Any) the last observation :param action: ([float]) the action :param reward: (float) the reward of the transition :param obs_tp1: (Any) the current observation :param done: (bool) is the episode done """ idx = self._next_idx super().add(obs_t, action, reward, obs_tp1, done) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def extend(self, obs_t, action, reward, obs_tp1, done): """ add a new batch of transitions to the buffer :param obs_t: (Union[Tuple[Union[np.ndarray, int]], np.ndarray]) the last batch of observations :param action: (Union[Tuple[Union[np.ndarray, int]]], np.ndarray]) the batch of actions :param reward: (Union[Tuple[float], np.ndarray]) the batch of the rewards of the transition :param obs_tp1: (Union[Tuple[Union[np.ndarray, int]], np.ndarray]) the current batch of observations :param done: (Union[Tuple[bool], np.ndarray]) terminal status of the batch Note: uses the same names as .add to keep compatibility with named argument passing but expects iterables and arrays with more than 1 dimensions """ idx = self._next_idx super().extend(obs_t, action, reward, obs_tp1, done) while idx != self._next_idx: self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha idx = (idx + 1) % self._maxsize def _sample_proportional(self, batch_size): mass = [] total = self._it_sum.sum(0, len(self._storage) - 1) # TODO(szymon): should we ensure no repeats? mass = np.random.random(size=batch_size) * total idx = self._it_sum.find_prefixsum_idx(mass) return idx def sample(self, batch_size: int, beta: float = 0, env: Optional[VecNormalize] = None): """ Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. :param batch_size: (int) How many transitions to sample. :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction) :param env: (Optional[VecNormalize]) associated gym VecEnv to normalize the observations/rewards when sampling :return: - obs_batch: (np.ndarray) batch of observations - act_batch: (numpy float) batch of actions executed given obs_batch - rew_batch: (numpy float) rewards received as results of executing act_batch - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) p_sample = self._it_sum[idxes] / self._it_sum.sum() weights = (p_sample * len(self._storage))**(-beta) / max_weight encoded_sample = self._encode_sample(idxes, env=env) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """ Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. :param idxes: ([int]) List of idxes of sampled transitions :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) assert np.min(priorities) > 0 assert np.min(idxes) >= 0 assert np.max(idxes) < len(self.storage) self._it_sum[idxes] = priorities**self._alpha self._it_min[idxes] = priorities**self._alpha self._max_priority = max(self._max_priority, np.max(priorities))
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch R_batch: np.array returns received as results of executing act_batch weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ idxes = self._sample_proportional(batch_size) if beta > 0: weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) else: weights = np.ones_like(idxes, dtype=np.float32) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): priority = max(priority, 1e-6) assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)