def __init__(self, size, alpha): """Create Prioritized Proportional Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__() """ super(ProportionalReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0
def test_prefixsum_idx(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert tree.find_prefixsum_idx(0.0) == 2 assert tree.find_prefixsum_idx(0.5) == 2 assert tree.find_prefixsum_idx(0.99) == 2 assert tree.find_prefixsum_idx(1.01) == 3 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(4.00) == 3
def test_prefixsum_idx2(): tree = SumSegmentTree(4) tree[0] = 0.5 tree[1] = 1.0 tree[2] = 1.0 tree[3] = 3.0 assert tree.find_prefixsum_idx(0.00) == 0 assert tree.find_prefixsum_idx(0.55) == 1 assert tree.find_prefixsum_idx(0.99) == 1 assert tree.find_prefixsum_idx(1.51) == 2 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(5.50) == 3
def test_tree_set(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert np.isclose(tree.sum(), 4.0) assert np.isclose(tree.sum(0, 2), 0.0) assert np.isclose(tree.sum(0, 3), 1.0) assert np.isclose(tree.sum(2, 3), 1.0) assert np.isclose(tree.sum(2, -1), 1.0) assert np.isclose(tree.sum(2, 4), 4.0)
def test_tree_set_overlap(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[2] = 3.0 assert np.isclose(tree.sum(), 3.0) assert np.isclose(tree.sum(2, 3), 3.0) assert np.isclose(tree.sum(2, -1), 3.0) assert np.isclose(tree.sum(2, 4), 3.0) assert np.isclose(tree.sum(1, 2), 0.0)
class ProportionalReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Proportional Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__() """ super(ProportionalReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.add""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() for idx in idxes: p_sample = self._it_sum[idx] weight_normalized = (p_sample / p_min)**(-beta) weights.append(weight_normalized) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)