class PrioritizedReplayBuffer(ReplayBuffer): """Prioritized Replay buffer. Attributes: max_priority (float): max priority tree_ptr (int): next index of tree alpha (float): alpha parameter for prioritized replay buffer sum_tree (SumSegmentTree): sum tree for prior min_tree (MinSegmentTree): min tree for min prior to get max weight """ def __init__( self, obs_dim: int, size: int, batch_size: int = 32, alpha: float = 0.6, n_step: int = 1, gamma: float = 0.99, ): """Initialization.""" assert alpha >= 0 super(PrioritizedReplayBuffer, self).__init__(obs_dim, size, batch_size, n_step, gamma) self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha # capacity must be positive and a power of 2. tree_capacity = 1 while tree_capacity < self.max_size: tree_capacity *= 2 self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) def store( self, obs: np.ndarray, act: int, rew: float, next_obs: np.ndarray, done: bool, ) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]: """Store experience and priority.""" transition = super().store(obs, act, rew, next_obs, done) if transition: self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha self.min_tree[self.tree_ptr] = self.max_priority**self.alpha self.tree_ptr = (self.tree_ptr + 1) % self.max_size return transition def sample_batch(self, beta: float = 0.4) -> Dict[str, np.ndarray]: """Sample a batch of experiences.""" assert len(self) >= self.batch_size assert beta > 0 indices = self._sample_proportional() obs = self.obs_buf[indices] next_obs = self.next_obs_buf[indices] acts = self.acts_buf[indices] rews = self.rews_buf[indices] done = self.done_buf[indices] weights = np.array([self._calculate_weight(i, beta) for i in indices]) return dict( obs=obs, next_obs=next_obs, acts=acts, rews=rews, done=done, weights=weights, indices=indices, ) def update_priorities(self, indices: List[int], priorities: np.ndarray): """Update priorities of sampled transitions.""" assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): assert priority > 0 assert 0 <= idx < len(self) self.sum_tree[idx] = priority**self.alpha self.min_tree[idx] = priority**self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional(self) -> List[int]: """Sample indices based on proportions.""" indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / self.batch_size for i in range(self.batch_size): a = segment * i b = segment * (i + 1) upperbound = random.uniform(a, b) idx = self.sum_tree.retrieve(upperbound) indices.append(idx) return indices def _calculate_weight(self, idx: int, beta: float): """Calculate the weight of the experience at idx.""" # get max weight p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self))**(-beta) # calculate weights p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self))**(-beta) weight = weight / max_weight return weight
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, buffer_size, input_dim, batch_size, alpha): super(PrioritizedReplayBuffer, self).__init__(buffer_size, input_dim, batch_size) # For PER. Parameter settings. self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha tree_capacity = 1 while tree_capacity < self.buffer_size: tree_capacity *= 2 self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) def store(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: int): super().store(state, action, reward, next_state, done) self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha self.min_tree[self.tree_ptr] = self.max_priority**self.alpha self.tree_ptr = (self.tree_ptr + 1) % self.buffer_size def batch_load(self, beta): # indices를 받아오는 부분은 병렬처리!!, 그리고 같은 함수에서 weight도 받을 수 있다. indices = self._sample_proportional_indices() weights = np.array( [self._calculate_weight(idx, beta) for idx in indices]) return dict(states=self.state_buffer[indices], actions=self.action_buffer[indices], rewards=self.reward_buffer[indices], next_states=self.next_state_buffer[indices], dones=self.done_buffer[indices], weights=weights, indices=indices) def update_priorities(self, indices, priorities): # 이 부분도 병렬 처리 할 수 있는 구간. for idx, priority in zip(indices, priorities): self.sum_tree[idx] = priority**self.alpha self.min_tree[idx] = priority**self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional_indices(self): indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / self.batch_size # multiprocessing 등을 활용해서 병렬처리 하자 for i in range(self.batch_size): a = segment * i b = segment * (i + 1) sample = np.random.uniform(a, b) idx = self.sum_tree.retrieve(sample) # sample의 tree에서의 idx를 리턴 indices.append(idx) return indices def _calculate_weight(self, idx, beta): # 이 부분은 batch 당 weight 구할 때 한번만 하면 될듯. p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self))**(-beta) p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self))**(-beta) weight /= max_weight return weight
class ReplayBuffer: """Fixed-size buffer to store experience tuples.""" def __init__(self, action_size, buffer_size, batch_size, alpha): """Initialize a ReplayBuffer object. Params ====== action_size (int): dimension of each action buffer_size (int): maximum size of buffer batch_size (int): size of each training batch alpha (float): alpha PER value """ self.max_priority = 1.0 self.alpha = alpha # capacity must be positive and a power of 2. self.tree_capacity = 1 while self.tree_capacity < buffer_size: self.tree_capacity *= 2 self.sum_tree = SumSegmentTree(self.tree_capacity) self.min_tree = MinSegmentTree(self.tree_capacity) self.action_size = action_size self.memory = [] self.batch_size = batch_size self.experience = namedtuple( "Experience", field_names=["state", "action", "reward", "next_state", "done"]) def add(self, t, state, action, reward, next_state, done): """Add a new experience to memory.""" e = self.experience(state, action, reward, next_state, done) idx = t % self.tree_capacity if t >= self.tree_capacity: self.memory[idx] = e else: self.memory.append(e) # insert experience index in priority tree self.sum_tree[idx] = self.max_priority**self.alpha self.min_tree[idx] = self.max_priority**self.alpha def sample(self, beta): """Sampling a batch of relevant experiences from memory.""" indices = self.relevant_sample_indx() idxs = np.vstack(indices).astype(np.int) states = torch.from_numpy( np.vstack([self.memory[i].state for i in indices])).float().to(device) actions = torch.from_numpy( np.vstack([self.memory[i].action for i in indices])).long().to(device) rewards = torch.from_numpy( np.vstack([self.memory[i].reward for i in indices])).float().to(device) next_states = torch.from_numpy( np.vstack([self.memory[i].next_state for i in indices])).float().to(device) dones = torch.from_numpy( np.vstack([self.memory[i].done for i in indices]).astype(np.uint8)).float().to(device) weights = torch.from_numpy( np.array([self.isw(i, beta) for i in indices])).float().to(device) return (idxs, states, actions, rewards, next_states, dones, weights) def relevant_sample_indx(self): """Selecting most informative sample indices.""" indices = [] p_total = self.sum_tree.sum(0, len(self) - 1) segment = p_total / self.batch_size for i in range(self.batch_size): a = segment * i b = segment * (i + 1) upperbound = random.uniform(a, b) idx = self.sum_tree.retrieve(upperbound) indices.append(idx) return indices def update_priorities(self, indices, priorities): """Update priorities of sampled transitions.""" assert indices.shape[0] == priorities.shape[0] for idx, priority in zip(indices.flatten(), priorities.flatten()): assert priority > 0 assert 0 <= idx < len(self) self.sum_tree[idx] = priority**self.alpha self.min_tree[idx] = priority**self.alpha self.max_priority = max(self.max_priority, priority) def isw(self, idx, beta): """Compute Importance Sample Weight.""" # get max weight p_min = self.min_tree.min() / self.sum_tree.sum() max_weight = (p_min * len(self))**(-beta) # calculate weights p_sample = self.sum_tree[idx] / self.sum_tree.sum() weight = (p_sample * len(self))**(-beta) is_weight = weight / max_weight return is_weight def __len__(self): """Return the current size of internal memory.""" return len(self.memory)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, action_size, buffer_size, batch_size, seed, alpha=0.6): super(PrioritizedReplayBuffer, self).__init__(action_size, buffer_size, batch_size, seed) #capacity must be positive and a power of 2 tree_capacity = 1 while tree_capacity < self.buffer_size: tree_capacity *= 2 self.sum_tree = SumSegmentTree(tree_capacity) self.min_tree = MinSegmentTree(tree_capacity) self.max_priority, self.tree_ptr = 1.0, 0 self.alpha = alpha def add(self, state, action, reward, next_state, done): self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha self.min_tree[self.tree_ptr] = self.max_priority**self.alpha super().add(state, action, reward, next_state, done) self.tree_ptr = (self.tree_ptr + 1) % self.buffer_size # if self.tree_ptr == self.buffer_size-1: # for i in range(0, self.buffer_size-1): # self.sum_tree[i] = self.sum_tree[i+1] # self.min_tree[i] = self.min_tree[i+1] # self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha # self.min_tree[self.tree_ptr] = self.max_priority**self.alpha # else: # def sample(self, beta=0.4): indices = self._sample_proportional() indices = [index for index in indices if index<len(self.memory)] states = torch.from_numpy(np.vstack([self.memory[index].state for index in indices])).float().to(device) actions = torch.from_numpy(np.vstack([self.memory[index].action for index in indices])).long().to(device) rewards = torch.from_numpy(np.vstack([self.memory[index].reward for index in indices])).float().to(device) next_states = torch.from_numpy(np.vstack([self.memory[index].next_state for index in indices])).float().to(device) dones = torch.from_numpy(np.vstack([self.memory[index].done for index in indices]).astype(np.uint8)).float().to(device) weights = torch.from_numpy(np.vstack([self._cal_weight(index, beta) for index in indices])).float().to(device) return (states, actions, rewards, next_states, dones, weights, indices) def update_priority(self, indices, loss_for_prior): for idx, priority in zip(indices, loss_for_prior): self.sum_tree[idx] = priority ** self.alpha self.min_tree[idx] = priority ** self.alpha self.max_priority = max(self.max_priority, priority) def _sample_proportional(self): indices = [] p_total = self.sum_tree.sum() #sum(0, len(self.memory)-1) segment = p_total / self.batch_size for i in range(self.batch_size): start = segment * i end = start + segment upper = random.uniform(start, end) index = self.sum_tree.retrieve(upper) indices.append(index) return indices def _cal_weight(self, index, beta): sum_priority = self.sum_tree.sum() min_priority = self.min_tree.min() current_priority = self.sum_tree[index] # max_w = (len(self.memory) * (min_priority/sum_priority)) ** (-beta) # current_w = (len(self.memory) * (current_priority/sum_priority)) ** (-beta) # return current_w / max_w return (min_priority / current_priority) ** beta