예제 #1
0
 def __init__(self,
              limit,
              ob_shape,
              ac_shape,
              alpha,
              beta,
              demos_eps=0.1,
              ranked=False,
              max_priority=1.0):
     """`alpha` determines how much prioritization is used
     0: none, equivalent to uniform sampling
     1: full prioritization
     `beta` (defined in `__init__`) represents to what degree importance weights are used.
     """
     super(PrioritizedReplayBuffer, self).__init__(limit, ob_shape,
                                                   ac_shape)
     assert 0. <= alpha <= 1.
     assert beta > 0, "beta must be positive"
     self.alpha = alpha
     self.beta = beta
     self.max_priority = max_priority
     # Calculate the segment tree capacity suited to the user-specified limit
     self.st_cap = segment_tree_capacity(limit)
     # Create segment tree objects as data collection structure for priorities.
     # It provides an efficient way of calculating a cumulative sum of priorities
     self.sum_st = SumSegmentTree(
         self.st_cap)  # with `operator.add` operation
     self.min_st = MinSegmentTree(self.st_cap)  # with `min` operation
     # Whether it is the ranked version or not
     self.ranked = ranked
     if self.ranked:
         # Create a dict that will contain all the (index, priority) pairs
         self.i_p = {}
     # Define the priority bonus assigned to demos stored in the replay buffer (if stored)
     self.demos_eps = demos_eps
예제 #2
0
 def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0):
     """Reuse of the 'PrioritizedReplayBuffer' constructor w/:
         - `alpha` arbitrarily set to 1. (unused)
         - `beta` arbitrarily set to 1. (unused)
         - `ranked` set to True (necessary to have access to the ranks)
     """
     super(UnrealReplayBuffer, self).__init__(limit, ob_shape, ac_shape, 1.,
                                              1, True, max_priority)
     # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones
     self.b_sum_st = SumSegmentTree(
         self.st_cap)  # with `operator.add` operation
     self.g_sum_st = SumSegmentTree(
         self.st_cap)  # with `operator.add` operation
    def __init__(self, max_steps, num_processes, gamma, prio_alpha, obs_shape,
                 action_space, recurrent_hidden_state_size, device):
        self.max_steps = int(max_steps)
        self.num_processes = num_processes
        self.gamma = gamma
        self.device = device

        # stored episode data
        self.obs = torch.zeros(self.max_steps, *obs_shape)
        self.recurrent_hidden_states = torch.zeros(
            self.max_steps, recurrent_hidden_state_size)
        self.returns = torch.zeros(self.max_steps, 1)
        if action_space.__class__.__name__ == 'Discrete':
            self.actions = torch.zeros(self.max_steps, 1).long()
        else:
            self.actions = torch.zeros(self.max_steps, action_space.shape[0])
        self.masks = torch.ones(self.max_steps, 1)
        self.next_idx = 0
        self.num_steps = 0

        # store (full) episode stats
        self.episode_step_count = 0
        self.episode_rewards = deque()
        self.episode_steps = deque()

        # currently running (accumulating) episodes
        self.running_episodes = [[] for _ in range(num_processes)]

        if prio_alpha > 0:
            """
            Sampling priority is enabled if prio_alpha > 0
            Priority algorithm ripped from OpenAI Baselines
            https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
            """
            self.prio_alpha = prio_alpha
            tree_capacity = 1 << math.ceil(math.log2(self.max_steps))
            self.prio_sum_tree = SumSegmentTree(tree_capacity)
            self.prio_min_tree = MinSegmentTree(tree_capacity)
            self.prio_max = 1.0
        else:
            self.prio_alpha = 0
class ReplayStorage:
    def __init__(self, max_steps, num_processes, gamma, prio_alpha, obs_shape,
                 action_space, recurrent_hidden_state_size, device):
        self.max_steps = int(max_steps)
        self.num_processes = num_processes
        self.gamma = gamma
        self.device = device

        # stored episode data
        self.obs = torch.zeros(self.max_steps, *obs_shape)
        self.recurrent_hidden_states = torch.zeros(
            self.max_steps, recurrent_hidden_state_size)
        self.returns = torch.zeros(self.max_steps, 1)
        if action_space.__class__.__name__ == 'Discrete':
            self.actions = torch.zeros(self.max_steps, 1).long()
        else:
            self.actions = torch.zeros(self.max_steps, action_space.shape[0])
        self.masks = torch.ones(self.max_steps, 1)
        self.next_idx = 0
        self.num_steps = 0

        # store (full) episode stats
        self.episode_step_count = 0
        self.episode_rewards = deque()
        self.episode_steps = deque()

        # currently running (accumulating) episodes
        self.running_episodes = [[] for _ in range(num_processes)]

        if prio_alpha > 0:
            """
            Sampling priority is enabled if prio_alpha > 0
            Priority algorithm ripped from OpenAI Baselines
            https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
            """
            self.prio_alpha = prio_alpha
            tree_capacity = 1 << math.ceil(math.log2(self.max_steps))
            self.prio_sum_tree = SumSegmentTree(tree_capacity)
            self.prio_min_tree = MinSegmentTree(tree_capacity)
            self.prio_max = 1.0
        else:
            self.prio_alpha = 0

    def _process_rewards(self, trajectory):
        has_positive = False
        reward_sum = 0.
        r = 0.
        for t in trajectory[::-1]:
            reward = t['reward']
            reward_sum += reward
            if reward > (0. + 1e-5):
                has_positive = True
            r = reward + self.gamma * r
            t['return'] = r
        return has_positive, reward_sum

    def _add_trajectory(self, trajectory):
        has_positive, reward_sum = self._process_rewards(trajectory)
        if not has_positive:
            return
        trajectory_len = len(trajectory)
        prev_idx = self.next_idx
        for transition in trajectory:
            self.obs[self.next_idx].copy_(transition['obs'])
            self.recurrent_hidden_states[self.next_idx].copy_(
                transition['rhs'])
            self.actions[self.next_idx].copy_(transition['action'])
            self.returns[self.next_idx].copy_(transition['return'])
            self.masks[self.next_idx] = 1.0
            prev_idx = self.next_idx
            if self.prio_alpha:
                self.prio_sum_tree[
                    self.next_idx] = self.prio_max**self.prio_alpha
                self.prio_min_tree[
                    self.next_idx] = self.prio_max**self.prio_alpha
            self.next_idx = (self.next_idx + 1) % self.max_steps
            self.num_steps = min(self.max_steps, self.num_steps + 1)
        self.masks[prev_idx] = 0.0

        # update stats of stored full trajectories (episodes)
        while self.episode_step_count + trajectory_len > self.max_steps:
            steps_popped = self.episode_steps.popleft()
            self.episode_rewards.popleft()
            self.episode_step_count -= steps_popped
        self.episode_step_count += trajectory_len
        self.episode_steps.append(trajectory_len)
        self.episode_rewards.append(reward_sum)

    def _sample_proportional(self, sample_size):
        res = []
        for _ in range(sample_size):
            mass = random.random() * self.prio_sum_tree.sum(
                0, self.num_steps - 1)
            idx = self.prio_sum_tree.find_prefixsum_idx(mass)
            res.append(idx)
        return res

    def insert(self, obs, rhs, actions, rewards, dones):
        for n in range(self.num_processes):
            self.running_episodes[n].append(
                dict(obs=obs[n].clone(),
                     rhs=rhs[n].clone(),
                     action=actions[n].clone(),
                     reward=rewards[n].clone()))
        for n, done in enumerate(dones):
            if done:
                self._add_trajectory(self.running_episodes[n])
                self.running_episodes[n] = []

    def update_priorities(self, indices, priorities):
        if not self.prio_alpha:
            return
        """Update priorities of sampled transitions.
        sets priority of transition at index indices[i] in buffer
        to priorities[i].
        Parameters
        ----------
        indices: [int]
            List of indices of sampled transitions
        priorities: [float]
            List of updated priorities corresponding to
            transitions at the sampled indices.
        """
        assert len(indices) == len(priorities)
        for idx, priority in zip(indices, priorities):
            priority = max(priority, 1e-6)
            assert priority > 0
            assert 0 <= idx < self.num_steps
            self.prio_sum_tree[idx] = priority**self.prio_alpha
            self.prio_min_tree[idx] = priority**self.prio_alpha

            self.prio_max = max(self.prio_max, priority)

    def feed_forward_generator(self, batch_size, num_batches=None, beta=0.):
        """Generate batches of sampled experiences.

        Parameters
        ----------
        batch_size: int
            Size of each sampled batch
        num_batches: int
            Number of batches to sample
        beta: float
            To what degree to use importance weights
            (0 - no corrections, 1 - full correction)
        """

        batch_count = 0
        sample_size = num_batches * batch_size or self.num_steps

        if self.prio_alpha > 0:
            indices = self._sample_proportional(sample_size)
            if beta > 0:
                # compute importance sampling weights to correct for the
                # bias introduced by sampling in a non-uniform manner
                weights = []
                p_min = self.prio_min_tree.min() / self.prio_sum_tree.sum()
                max_weight = (p_min * self.num_steps)**(-beta)
                for i in indices:
                    p_sample = self.prio_sum_tree[i] / self.prio_sum_tree.sum()
                    weight = (p_sample * self.num_steps)**(-beta)
                    weights.append(weight / max_weight)
                weights = torch.tensor(weights,
                                       dtype=torch.float32).unsqueeze(1)
            else:
                weights = torch.ones((len(indices), 1), dtype=torch.float32)
        else:
            if sample_size * 3 < self.num_steps:
                indices = random.sample(range(self.num_steps), sample_size)
            else:
                indices = np.random.permutation(self.num_steps)[:sample_size]
            weights = None

        for si in range(0, len(indices), batch_size):
            indices_batch = indices[si:min(len(indices), si + batch_size)]
            if len(indices_batch) < batch_size:
                return

            weights_batch = None if weights is None else \
                weights[si:min(len(indices), si + batch_size)].to(self.device)

            obs_batch = self.obs[indices_batch].to(self.device)
            recurrent_hidden_states_batch = self.recurrent_hidden_states[
                indices_batch].to(self.device)
            actions_batch = self.actions[indices_batch].to(self.device)
            returns_batch = self.returns[indices_batch].to(self.device)
            masks_batch = self.masks[indices_batch].to(self.device)

            yield obs_batch, recurrent_hidden_states_batch, actions_batch, returns_batch, \
                  masks_batch, weights_batch, indices_batch

            batch_count += 1
            if num_batches and batch_count >= num_batches:
                return
예제 #5
0
class UnrealReplayBuffer(PrioritizedReplayBuffer):
    """'Reinforcement Learning w/ unsupervised Auxiliary Tasks' replay buffer implementation
    Reference: https://arxiv.org/pdf/1611.05397.pdf
    """
    def __init__(self, limit, ob_shape, ac_shape, max_priority=1.0):
        """Reuse of the 'PrioritizedReplayBuffer' constructor w/:
            - `alpha` arbitrarily set to 1. (unused)
            - `beta` arbitrarily set to 1. (unused)
            - `ranked` set to True (necessary to have access to the ranks)
        """
        super(UnrealReplayBuffer, self).__init__(limit, ob_shape, ac_shape, 1.,
                                                 1, True, max_priority)
        # Create two extra `SumSegmentTree` objects: one for 'bad' transitions, one for 'good' ones
        self.b_sum_st = SumSegmentTree(
            self.st_cap)  # with `operator.add` operation
        self.g_sum_st = SumSegmentTree(
            self.st_cap)  # with `operator.add` operation

    def _sample_unreal(self, batch_size):
        """Sample uniformly from which virtual sub-buffer to pick: bad or good transitions,
        then sample uniformly a transition from the previously picked virtual sub-buffer.
        Implemented w/ segment tree.
        Since `b_sum_st` and `g_sum_st` contain priorities in {0,1}, sampling according to
        priorities samples a transition uniformly from the transitions having priority 1
        in the previously sampled virtual sub-buffer (bad or good transitions).
        Note: the priorities were used (in `update_priorities`) to determine in which
        virtual sub-buffer a transition belongs.
        """
        factor = 5  # hyperparam?
        assert factor > 1
        assert self.num_entries
        transition_idxs = []
        # Sum the priorities (in {0,1}) of the transitions currently in the buffers
        # Since values are in {0,1}, the sums correspond to cardinalities (#b and #g)
        b_p_total = self.b_sum_st.sum(end=self.num_entries)
        g_p_total = self.g_sum_st.sum(end=self.num_entries)
        # `start` is 0 by default, `end` is length - 1 (classic python)
        # Ensure no repeats once the number of memory entries is high enough
        no_repeats = self.num_entries > factor * batch_size
        for _ in range(batch_size):
            while True:
                # Sample a value uniformly from the unit interval
                unit_u_sample = random.random()  # ~U[0,1]
                # Sample a number in {0,1} to decide whether to sample a b/g transition
                is_g = np.random.randint(2)
                p_total = g_p_total if is_g else b_p_total
                # Scale the sampled value to the total sum of priorities
                u_sample = unit_u_sample * p_total
                # Retrieve the transition index associated with `u_sample`
                # i.e. in which block the sample landed
                b_o_g_sum_st = self.g_sum_st if is_g else self.b_sum_st
                transition_idx = b_o_g_sum_st.find_prefixsum_idx(u_sample)
                if not no_repeats or transition_idx not in transition_idxs:
                    transition_idxs.append(transition_idx)
                    break
        return np.array(transition_idxs)

    def sample(self, batch_size):
        return super()._sample(batch_size, self._sample_unreal)

    def lookahead_sample(self, batch_size, n, gamma):
        return super().lookahead_sample(batch_size, n, gamma)

    def append(self, *args, **kwargs):
        super().append(*args, **kwargs)
        idx = self.latest_entry_idx
        # Add newly added elements to 'good' and 'bad' virtual sub-buffer
        self.b_sum_st[idx] = 1
        self.g_sum_st[idx] = 1

    def update_priorities(self, idxs, priorities):
        # Update priorities via the legacy method, used w/ ranked approach
        idxs, priorities = super().update_priorities(idxs, priorities)
        # Register whether a transition is b/g in the UNREAL-specific sum trees
        for idx, priority in zipsame(idxs, priorities):
            # Decide whether the transition to be added is good or bad
            # Get the rank from the priority
            # Note: UnrealReplayBuffer inherits from PER w/ 'ranked' set to True
            if idx < self.num_demos:
                # When the transition is from the demos, always set it as 'good' regardless
                self.b_sum_st[idx] = 0
                self.g_sum_st[idx] = 1
            else:
                rank = (1. / priority) - 1
                thres = floor(.5 * self.num_entries)
                is_g = rank < thres
                is_g *= 1  # HAXX: multiply by 1 to cast the bool into an int
                # Fill the good and bad sum segment trees w/ the obtained value
                self.b_sum_st[idx] = 1 - is_g
                self.g_sum_st[idx] = is_g
        if debug:
            # Verify updates
            # Compute the cardinalities of virtual sub-buffers
            b_num_entries = self.b_sum_st.sum(end=self.num_entries)
            g_num_entries = self.g_sum_st.sum(end=self.num_entries)
            print("[num entries]    b: {}    | g: {}".format(
                b_num_entries, g_num_entries))
            print("total num entries: {}".format(self.num_entries))

    def __repr__(self):
        fmt = "UnrealReplayBuffer(limit={}, ob_shape={}, ac_shape={}, max_priority={})"
        return fmt.format(self.limit, self.ob_shape, self.ac_shape,
                          self.max_priority)
예제 #6
0
class PrioritizedReplayBuffer(ReplayBuffer):
    """'Prioritized Experience Replay' replay buffer implementation
    Reference: https://arxiv.org/pdf/1511.05952.pdf
    """
    def __init__(self,
                 limit,
                 ob_shape,
                 ac_shape,
                 alpha,
                 beta,
                 demos_eps=0.1,
                 ranked=False,
                 max_priority=1.0):
        """`alpha` determines how much prioritization is used
        0: none, equivalent to uniform sampling
        1: full prioritization
        `beta` (defined in `__init__`) represents to what degree importance weights are used.
        """
        super(PrioritizedReplayBuffer, self).__init__(limit, ob_shape,
                                                      ac_shape)
        assert 0. <= alpha <= 1.
        assert beta > 0, "beta must be positive"
        self.alpha = alpha
        self.beta = beta
        self.max_priority = max_priority
        # Calculate the segment tree capacity suited to the user-specified limit
        self.st_cap = segment_tree_capacity(limit)
        # Create segment tree objects as data collection structure for priorities.
        # It provides an efficient way of calculating a cumulative sum of priorities
        self.sum_st = SumSegmentTree(
            self.st_cap)  # with `operator.add` operation
        self.min_st = MinSegmentTree(self.st_cap)  # with `min` operation
        # Whether it is the ranked version or not
        self.ranked = ranked
        if self.ranked:
            # Create a dict that will contain all the (index, priority) pairs
            self.i_p = {}
        # Define the priority bonus assigned to demos stored in the replay buffer (if stored)
        self.demos_eps = demos_eps

    def _sample_w_priorities(self, batch_size):
        """Sample in proportion to priorities, implemented w/ segment tree.
        This function samples a batch of transitions indices, directly from the priorities.
        Segment trees enable the emulation of a categorical sampling process by
        relying on what they shine at: computing cumulative sums.
        Imagine a stack of blocks.
        Each block (transition) has a height equal to its priority.
        The total height therefore is the sum of all priorities, `p_total`.
        `u_sample` is sampled from a U[0,1]. `u_sample * p_total` consequently is
        a value uniformly sampled from U[0,p_total], a height on the stacked blocks.
        `find_prefixsum_idx` returns the highest index (block id) such that the sum of
        preceeding priorities (block heights) is <= to the uniformly sampled height.
        The process is equivalent to sampling from a categorical distribution over
        the transitions (it might even be how some library implement categorical sampling).
        Since the height is sampled uniformly, the prob of landing in a block is proportional
        to the height of said block. The height being the priority value, the higher the
        priority, the higher the prob of being selected.
        """
        assert self.num_entries

        transition_idxs = []
        # Sum the priorities of the transitions currently in the buffer
        p_total = self.sum_st.sum(end=self.num_entries - 1)
        # `start` is 0 by default, `end` is length - 1
        # Divide equally into `batch_size` ranges (appendix B.2.1)
        p_pieces = p_total / batch_size
        # Sample `batch_size` samples independently, each from within the associated range
        # which is referred to as 'stratified sampling'
        for i in range(batch_size):
            # Sample a value uniformly from the unit interval
            unit_u_sample = random.random()  # ~U[0,1]
            # Scale and shift the sampled value to be within the range of cummulative priorities
            u_sample = (unit_u_sample * p_pieces) + (i * p_pieces)
            # Retrieve the transition index associated with `u_sample`
            # i.e. in which block the sample landed
            transition_idx = self.sum_st.find_prefixsum_idx(u_sample)
            transition_idxs.append(transition_idx)
        return np.array(transition_idxs)

    def _sample(self, batch_size, sampling_fn):
        """Sample from the replay buffer according to assigned priorities
        while using importance weights to offset the biasing effect of non-uniform sampling.
        `beta` (defined in `__init__`) represents to what degree importance weights are used.
        """
        # Sample transition idxs according to the samplign function
        idxs = sampling_fn(batch_size=batch_size)

        # Initialize importance weights
        iws = []
        # Create var for lowest sampling prob among transitions currently in the buffer,
        # equal to lowest priority divided by the sum of all priorities
        lowest_prob = self.min_st.min(end=self.num_entries) / self.sum_st.sum(
            end=self.num_entries)
        # Create for maximum weight var for weight scaling purposes (eq in 3.4. PER paper)
        max_weight = (self.num_entries * lowest_prob)**(-self.beta)

        # Create a weight for every selected transition
        for idx in idxs:
            # Compute the probability assigned to the transition
            prob_transition = self.sum_st[idx] / self.sum_st.sum(
                end=self.num_entries)
            # Compute the transition weight
            weight_transition = (self.num_entries *
                                 prob_transition)**(-self.beta)
            iws.append(weight_transition / max_weight)

        # Collect batch of transitions w/ iws and indices
        weighted_transitions = super().batchify(idxs)
        weighted_transitions['iws'] = array_min2d(np.array(iws))
        weighted_transitions['idxs'] = np.array(idxs)
        return weighted_transitions

    def sample(self, batch_size):
        return self._sample(batch_size, self._sample_w_priorities)

    def sample_uniform(self, batch_size):
        return super().sample(batch_size=batch_size)

    def lookahead_sample(self, batch_size, n, gamma):
        """Sample from the replay buffer according to assigned priorities.
        This function is for n-step TD backups, where n > 1
        """
        assert n > 1
        # Sample a batch of transitions
        transitions = self.sample(batch_size=batch_size)
        # Expand each transition w/ a n-step TD lookahead
        lookahead_batch = super().lookahead(transitions=transitions,
                                            n=n,
                                            gamma=gamma)
        # Add iws and indices to the dict
        lookahead_batch['iws'] = transitions['iws']
        lookahead_batch['idxs'] = transitions['idxs']
        return lookahead_batch

    def append(self, *args, **kwargs):
        super().append(*args, **kwargs)
        idx = self.latest_entry_idx
        # Assign highest priority value to newly added elements (line 6 alg PER paper)
        self.sum_st[idx] = self.max_priority**self.alpha
        self.min_st[idx] = self.max_priority**self.alpha

    def update_priorities(self, idxs, priorities):
        """Update priorities according to the PER paper, i.e. by updating
        only the priority of sampled transitions. A priority priorities[i] is
        assigned to the transition at index indices[i].
        Note: not in use in the vanilla setting, but here if needed in extensions.
        """
        global debug
        if self.ranked:
            # Override the priorities to be 1 / (rank(priority) + 1)
            # Add new index, priority pairs to the list
            self.i_p.update({i: p for i, p in zipsame(idxs, priorities)})
            # Rank the indices by priorities
            i_sorted_by_p = sorted(self.i_p.items(),
                                   key=lambda t: t[1],
                                   reverse=True)
            # Create the index, rank dict
            i_r = {i: i_sorted_by_p.index((i, p)) for i, p in self.i_p.items()}
            # Unpack indices and ranks
            _idxs, ranks = zipsame(*i_r.items())
            # Override the indices and priorities
            idxs = list(_idxs)
            priorities = [1. / (rank + 1)
                          for rank in ranks]  # start ranks at 1
            if debug:
                # Verify that the priorities have been properly overridden
                for idx, priority in zipsame(idxs, priorities):
                    print("index: {}    | priority: {}".format(idx, priority))

        assert len(idxs) == len(
            priorities), "the two arrays must be the same length"
        for idx, priority in zipsame(idxs, priorities):
            assert priority > 0, "priorities must be positive"
            assert 0 <= idx < self.num_entries, "no element in buffer associated w/ index"
            if idx < self.num_demos:
                # Add a priority bonus when replaying a demo
                priority += self.demos_eps
            self.sum_st[idx] = priority**self.alpha
            self.min_st[idx] = priority**self.alpha
            # Update max priority currently in the buffer
            self.max_priority = max(priority, self.max_priority)

        if self.ranked:
            # Return indices and associated overriden priorities
            # Note: returned values are only used in the UNREAL priority update function
            return idxs, priorities

    def __repr__(self):
        fmt = "PrioritizedReplayBuffer(limit={}, ob_shape={}, ac_shape={}, alpha={}, beta={}, "
        fmt += "demos_eps={}, ranked={}, max_priority={})"
        return fmt.format(self.limit, self.ob_shape, self.ac_shape, self.alpha,
                          self.beta, self.demos_eps, self.ranked,
                          self.max_priority)