class PrioritizedReplayBuffer(ReplayBuffer): """Prioritized Replay buffer. Attributes: max_priority (float): max priority tree_ptr (int): next index of tree alpha (float): alpha parameter for prioritized replay buffer sum_tree (SumSegmentTree): sum tree for prior min_tree (MinSegmentTree): min tree for min prior to get max weight """ def __init__(self, obs_dim: int, size: int, batch_size: int, alpha: float = 0.6): """Initialization.""" assert alpha >= 0 super(PrioritizedReplayBuffer, self).__init__(obs_dim, size, batch_size) self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha # capacity must be positive and a power of 2. tree_capacity = 1 while tree_capacity < self.max_size: tree_capacity *= 2 self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) def store(self, obs: np.ndarray, act: int, rew: float, next_obs: np.ndarray, done: bool): """Store experience and priority.""" super().store(obs, act, rew, next_obs, done) self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha self.min_tree[self.tree_ptr] = self.max_priority**self.alpha self.tree_ptr = (self.tree_ptr + 1) % self.max_size def sample_batch(self, beta: float = 0.4) -> Dict[str, np.ndarray]: """Sample a batch of experiences.""" assert len(self) >= self.batch_size assert beta > 0 indices = self._sample_proportional() obs = self.obs_buf[indices] next_obs = self.next_obs_buf[indices] acts = self.acts_buf[indices] rews = self.rews_buf[indices] done = self.done_buf[indices] weights = np.array([self._calculate_weight(i, beta) for i in indices]) return dict( obs=obs, next_obs=next_obs, acts=acts, rews=rews, done=done, weights=weights, indices=indices, ) def update_priorities(self, indices: List[int], priorities: np.ndarray): """Update priorities of sampled transitions.""" assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): assert priority > 0 assert 0 <= idx < len(self) self.sum_tree[idx] = priority**self.alpha self.min_tree[idx] = priority**self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional(self) -> List[int]: """Sample indices based on proportions.""" indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / self.batch_size for i in range(self.batch_size): a = segment * i b = segment * (i + 1) upperbound = random.uniform(a, b) idx = self.sum_tree.retrieve(upperbound) indices.append(idx) return indices def _calculate_weight(self, idx: int, beta: float): """Calculate the weight of the experience at idx.""" # get max weight p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self))**(-beta) # calculate weights p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self))**(-beta) weight = weight / max_weight return weight
class PrioritizedReplayMemory(ReplayMemory): def __init__(self, alpha=0.6, capacity=100000, replace=False, tuple_class=Transition): super().__init__(capacity, replace, tuple_class) assert alpha >= 0 self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha # capacity must be positive and a power of 2. # tree_capacity = 1 # while tree_capacity < self.capacity: # tree_capacity *= 2 # Tree capacity has to be a power of 2 m = np.ceil(np.log(self.capacity) / np.log(2)) tree_capacity = np.power(2, m).astype(int) self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) def add(self, record): super().add(record) self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha self.min_tree[self.tree_ptr] = self.max_priority**self.alpha self.tree_ptr = (self.tree_ptr + 1) % self.capacity def sample(self, batch_size, beta: float = 0.4) -> Dict[str, np.ndarray]: """Sample a batch of experiences.""" assert len(self) >= batch_size assert beta > 0 indices = self._sample_proportional(batch_size) weights = np.array([self._calculate_weight(i, beta) for i in indices]) result = self._reformat(indices) result['indices'] = indices result['weights'] = weights return result def update_priorities(self, indices: List[int], priorities: np.ndarray): """Update priorities of sampled transitions.""" assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): assert priority > 0 assert 0 <= idx < len(self) self.sum_tree[idx] = priority**self.alpha self.min_tree[idx] = priority**self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional(self, batch_size) -> List[int]: """Sample indices based on proportions.""" indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / batch_size for i in range(batch_size): a = segment * i b = segment * (i + 1) upperbound = random.uniform(a, b) idx = self.sum_tree.retrieve(upperbound) indices.append(idx) return indices def _calculate_weight(self, idx: int, beta: float): """Calculate the weight of the experience at idx.""" # get max weight p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self))**(-beta) # calculate weights p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self))**(-beta) weight = weight / max_weight return weight