コード例 #1
0
class PrioritizedReplayBuffer(ReplayBuffer):
    """Prioritized Replay buffer.
    
    Attributes:
        max_priority (float): max priority
        tree_ptr (int): next index of tree
        alpha (float): alpha parameter for prioritized replay buffer
        sum_tree (SumSegmentTree): sum tree for prior
        min_tree (MinSegmentTree): min tree for min prior to get max weight
        
    """
    def __init__(
        self,
        obs_dim: int,
        size: int,
        batch_size: int = 32,
        alpha: float = 0.6,
        n_step: int = 1,
        gamma: float = 0.99,
    ):
        """Initialization."""
        assert alpha >= 0

        super(PrioritizedReplayBuffer,
              self).__init__(obs_dim, size, batch_size, n_step, gamma)
        self.max_priority, self.tree_ptr = 1.0, 0
        self.alpha = alpha

        # capacity must be positive and a power of 2.
        tree_capacity = 1
        while tree_capacity < self.max_size:
            tree_capacity *= 2

        self.sum_tree = SumSegmentTree(tree_capacity)
        self.min_tree = MinSegmentTree(tree_capacity)

    def store(
        self,
        obs: np.ndarray,
        act: int,
        rew: float,
        next_obs: np.ndarray,
        done: bool,
    ) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]:
        """Store experience and priority."""
        transition = super().store(obs, act, rew, next_obs, done)

        if transition:
            self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha
            self.min_tree[self.tree_ptr] = self.max_priority**self.alpha
            self.tree_ptr = (self.tree_ptr + 1) % self.max_size

        return transition

    def sample_batch(self, beta: float = 0.4) -> Dict[str, np.ndarray]:
        """Sample a batch of experiences."""
        assert len(self) >= self.batch_size
        assert beta > 0

        indices = self._sample_proportional()

        obs = self.obs_buf[indices]
        next_obs = self.next_obs_buf[indices]
        acts = self.acts_buf[indices]
        rews = self.rews_buf[indices]
        done = self.done_buf[indices]
        weights = np.array([self._calculate_weight(i, beta) for i in indices])

        return dict(
            obs=obs,
            next_obs=next_obs,
            acts=acts,
            rews=rews,
            done=done,
            weights=weights,
            indices=indices,
        )

    def update_priorities(self, indices: List[int], priorities: np.ndarray):
        """Update priorities of sampled transitions."""
        assert len(indices) == len(priorities)

        for idx, priority in zip(indices, priorities):
            assert priority > 0
            assert 0 <= idx < len(self)

            self.sum_tree[idx] = priority**self.alpha
            self.min_tree[idx] = priority**self.alpha

            self.max_priority = max(self.max_priority, priority)

    def _sample_proportional(self) -> List[int]:
        """Sample indices based on proportions."""
        indices = []
        p_total = self.sum_tree.sum(0, len(self) - 1)
        segment = p_total / self.batch_size

        for i in range(self.batch_size):
            a = segment * i
            b = segment * (i + 1)
            upperbound = random.uniform(a, b)
            idx = self.sum_tree.retrieve(upperbound)
            indices.append(idx)

        return indices

    def _calculate_weight(self, idx: int, beta: float):
        """Calculate the weight of the experience at idx."""
        # get max weight
        p_min = self.min_tree.min() / self.sum_tree.sum()
        max_weight = (p_min * len(self))**(-beta)

        # calculate weights
        p_sample = self.sum_tree[idx] / self.sum_tree.sum()
        weight = (p_sample * len(self))**(-beta)
        weight = weight / max_weight

        return weight
コード例 #2
0
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, buffer_size, input_dim, batch_size, alpha):

        super(PrioritizedReplayBuffer, self).__init__(buffer_size, input_dim,
                                                      batch_size)

        # For PER. Parameter settings.
        self.max_priority, self.tree_ptr = 1.0, 0
        self.alpha = alpha

        tree_capacity = 1
        while tree_capacity < self.buffer_size:
            tree_capacity *= 2

        self.sum_tree = SumSegmentTree(tree_capacity)
        self.min_tree = MinSegmentTree(tree_capacity)

    def store(self, state: np.ndarray, action: int, reward: float,
              next_state: np.ndarray, done: int):

        super().store(state, action, reward, next_state, done)

        self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha
        self.min_tree[self.tree_ptr] = self.max_priority**self.alpha
        self.tree_ptr = (self.tree_ptr + 1) % self.buffer_size

    def batch_load(self, beta):

        # indices를 받아오는 부분은 병렬처리!!, 그리고 같은 함수에서 weight도 받을 수 있다.
        indices = self._sample_proportional_indices()

        weights = np.array(
            [self._calculate_weight(idx, beta) for idx in indices])

        return dict(states=self.state_buffer[indices],
                    actions=self.action_buffer[indices],
                    rewards=self.reward_buffer[indices],
                    next_states=self.next_state_buffer[indices],
                    dones=self.done_buffer[indices],
                    weights=weights,
                    indices=indices)

    def update_priorities(self, indices, priorities):

        # 이 부분도 병렬 처리 할 수 있는 구간.
        for idx, priority in zip(indices, priorities):

            self.sum_tree[idx] = priority**self.alpha
            self.min_tree[idx] = priority**self.alpha

            self.max_priority = max(self.max_priority, priority)

    def _sample_proportional_indices(self):

        indices = []
        p_total = self.sum_tree.sum(0, len(self) - 1)
        segment = p_total / self.batch_size

        # multiprocessing 등을 활용해서 병렬처리 하자
        for i in range(self.batch_size):
            a = segment * i
            b = segment * (i + 1)
            sample = np.random.uniform(a, b)
            idx = self.sum_tree.retrieve(sample)  # sample의 tree에서의 idx를 리턴
            indices.append(idx)

        return indices

    def _calculate_weight(self, idx, beta):

        # 이 부분은 batch 당 weight 구할 때 한번만 하면 될듯.
        p_min = self.min_tree.min() / self.sum_tree.sum()
        max_weight = (p_min * len(self))**(-beta)

        p_sample = self.sum_tree[idx] / self.sum_tree.sum()
        weight = (p_sample * len(self))**(-beta)
        weight /= max_weight
        return weight
コード例 #3
0
class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""
    def __init__(self, action_size, buffer_size, batch_size, alpha):
        """Initialize a ReplayBuffer object.

        Params
        ======
            action_size (int): dimension of each action
            buffer_size (int): maximum size of buffer
            batch_size (int): size of each training batch
            alpha (float): alpha PER value 
        """
        self.max_priority = 1.0
        self.alpha = alpha

        # capacity must be positive and a power of 2.
        self.tree_capacity = 1
        while self.tree_capacity < buffer_size:
            self.tree_capacity *= 2

        self.sum_tree = SumSegmentTree(self.tree_capacity)
        self.min_tree = MinSegmentTree(self.tree_capacity)

        self.action_size = action_size
        self.memory = []
        self.batch_size = batch_size
        self.experience = namedtuple(
            "Experience",
            field_names=["state", "action", "reward", "next_state", "done"])

    def add(self, t, state, action, reward, next_state, done):
        """Add a new experience to memory."""
        e = self.experience(state, action, reward, next_state, done)

        idx = t % self.tree_capacity
        if t >= self.tree_capacity:
            self.memory[idx] = e
        else:
            self.memory.append(e)

        # insert experience index in priority tree
        self.sum_tree[idx] = self.max_priority**self.alpha
        self.min_tree[idx] = self.max_priority**self.alpha

    def sample(self, beta):
        """Sampling a batch of relevant experiences from memory."""
        indices = self.relevant_sample_indx()

        idxs = np.vstack(indices).astype(np.int)
        states = torch.from_numpy(
            np.vstack([self.memory[i].state
                       for i in indices])).float().to(device)
        actions = torch.from_numpy(
            np.vstack([self.memory[i].action
                       for i in indices])).long().to(device)
        rewards = torch.from_numpy(
            np.vstack([self.memory[i].reward
                       for i in indices])).float().to(device)
        next_states = torch.from_numpy(
            np.vstack([self.memory[i].next_state
                       for i in indices])).float().to(device)
        dones = torch.from_numpy(
            np.vstack([self.memory[i].done
                       for i in indices]).astype(np.uint8)).float().to(device)
        weights = torch.from_numpy(
            np.array([self.isw(i, beta) for i in indices])).float().to(device)

        return (idxs, states, actions, rewards, next_states, dones, weights)

    def relevant_sample_indx(self):
        """Selecting most informative sample indices."""
        indices = []
        p_total = self.sum_tree.sum(0, len(self) - 1)
        segment = p_total / self.batch_size

        for i in range(self.batch_size):
            a = segment * i
            b = segment * (i + 1)
            upperbound = random.uniform(a, b)
            idx = self.sum_tree.retrieve(upperbound)
            indices.append(idx)

        return indices

    def update_priorities(self, indices, priorities):
        """Update priorities of sampled transitions."""
        assert indices.shape[0] == priorities.shape[0]

        for idx, priority in zip(indices.flatten(), priorities.flatten()):
            assert priority > 0
            assert 0 <= idx < len(self)

            self.sum_tree[idx] = priority**self.alpha
            self.min_tree[idx] = priority**self.alpha

            self.max_priority = max(self.max_priority, priority)

    def isw(self, idx, beta):
        """Compute Importance Sample Weight."""
        # get max weight
        p_min = self.min_tree.min() / self.sum_tree.sum()
        max_weight = (p_min * len(self))**(-beta)

        # calculate weights
        p_sample = self.sum_tree[idx] / self.sum_tree.sum()
        weight = (p_sample * len(self))**(-beta)
        is_weight = weight / max_weight

        return is_weight

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)
コード例 #4
0
class PrioritizedReplayBuffer(ReplayBuffer):
    def __init__(self, action_size, buffer_size, batch_size, seed, alpha=0.6):
        super(PrioritizedReplayBuffer, self).__init__(action_size, buffer_size, batch_size, seed)
        
        #capacity must be positive and a power of 2
        tree_capacity = 1
        while tree_capacity < self.buffer_size:
            tree_capacity *= 2
        self.sum_tree = SumSegmentTree(tree_capacity)
        self.min_tree = MinSegmentTree(tree_capacity)
        self.max_priority, self.tree_ptr = 1.0, 0
        self.alpha = alpha
        
    def add(self, state, action, reward, next_state, done):
        
        self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha
        self.min_tree[self.tree_ptr] = self.max_priority**self.alpha
        super().add(state, action, reward, next_state, done)
        self.tree_ptr = (self.tree_ptr + 1) % self.buffer_size
        
#         if self.tree_ptr == self.buffer_size-1:
#             for i in range(0, self.buffer_size-1):
#                 self.sum_tree[i] = self.sum_tree[i+1] 
#                 self.min_tree[i] = self.min_tree[i+1]
#             self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha
#             self.min_tree[self.tree_ptr] = self.max_priority**self.alpha
#         else:

#         
        
    def sample(self, beta=0.4):
        indices = self._sample_proportional()
        
        indices = [index for index in indices if index<len(self.memory)]
        states = torch.from_numpy(np.vstack([self.memory[index].state for index in indices])).float().to(device)
        actions = torch.from_numpy(np.vstack([self.memory[index].action for index in indices])).long().to(device)
        rewards = torch.from_numpy(np.vstack([self.memory[index].reward for index in indices])).float().to(device)
        next_states = torch.from_numpy(np.vstack([self.memory[index].next_state for index in indices])).float().to(device)
        dones = torch.from_numpy(np.vstack([self.memory[index].done for index in indices]).astype(np.uint8)).float().to(device)
        weights = torch.from_numpy(np.vstack([self._cal_weight(index, beta) for index in indices])).float().to(device)
         
        return (states, actions, rewards, next_states, dones, weights, indices)
        
    def update_priority(self, indices, loss_for_prior):
        for idx, priority in zip(indices, loss_for_prior):
            self.sum_tree[idx] = priority ** self.alpha
            self.min_tree[idx] = priority ** self.alpha
            
            self.max_priority = max(self.max_priority, priority)
        
    def _sample_proportional(self):
        indices = []
        p_total = self.sum_tree.sum() #sum(0, len(self.memory)-1)
        segment = p_total / self.batch_size
        
        for i in range(self.batch_size):
            start = segment * i
            end = start + segment
            upper = random.uniform(start, end)
            index = self.sum_tree.retrieve(upper)
            indices.append(index)
        return indices
    
    def _cal_weight(self, index, beta):
        sum_priority = self.sum_tree.sum()
        min_priority = self.min_tree.min()
        current_priority = self.sum_tree[index]
        
 
#         max_w = (len(self.memory) * (min_priority/sum_priority)) ** (-beta)
#         current_w = (len(self.memory) * (current_priority/sum_priority)) ** (-beta)
        
#         return current_w / max_w
        return (min_priority / current_priority) ** beta