def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def Sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and indexes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition indexes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 indexes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in indexes + 1: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(indexes) return tuple(list(encoded_sample) + [weights, indexes]) def update_priorities(self, indexes, priorities): """Update priorities of sampled transitions. sets priority of transition at index indexes[i] in buffer to priorities[i]. Parameters ---------- indexes: [int] List of indexes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled indexes denoted by variable `indexes`. """ assert len(indexes) == len(priorities) for idx, priority in zip(indexes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)