Пример #1
0
 def init_priority_tree(self):
     """Organized here for clean inheritance."""
     self.priority_tree = SumTree(
         T=self.T,
         B=self.B,
         off_backward=self.off_backward,
         off_forward=self.off_forward,
         default_value=self.default_priority**self.alpha,
     )
Пример #2
0
 def init_priority_tree(self):
     self.priority_tree = SumTree(
         T=self.T,
         B=self.B,
         off_backward=self.n_step_return,
         off_forward=0,
         default_value=1,
         enable_input_priorities=True,
         input_priority_shift=self.n_step_return - 1,
     )
Пример #3
0
 def init_priority_tree(self):
     off_backward = math.ceil((1 + self.off_backward + self.batch_T) /
         self.rnn_state_interval)  # +1 in case interval aligned? TODO: check
     self.priority_tree = SumTree(
         T=self.T // self.rnn_state_interval,
         B=self.B,
         off_backward=off_backward,
         off_forward=math.ceil(self.off_forward / self.rnn_state_interval),
         default_value=self.default_priority ** self.alpha,
     )
Пример #4
0
class PrioritizedReplay(object):
    def __init__(self,
                 alpha=0.6,
                 beta=0.4,
                 default_priority=1,
                 unique=False,
                 **kwargs):
        super().__init__(**kwargs)
        save__init__args(locals())
        self.init_priority_tree()

    def init_priority_tree(self):
        """Organized here for clean inheritance."""
        self.priority_tree = SumTree(
            T=self.T,
            B=self.B,
            off_backward=self.off_backward,
            off_forward=self.off_forward,
            default_value=self.default_priority**self.alpha,
        )

    def set_beta(self, beta):
        self.beta = beta

    def append_samples(self, samples):
        T, idxs = super().append_samples(samples)
        self.priority_tree.advance(T)  # Progress priority_tree cursor.
        return T, idxs

    def sample_batch(self, batch_B):
        (T_idxs,
         B_idxs), priorities = self.priority_tree.sample(batch_B,
                                                         unique=self.unique)
        batch = self.extract_batch(T_idxs, B_idxs)
        is_weights = (1. / (priorities + EPS))**self.beta  # Unnormalized.
        is_weights /= max(is_weights)  # Normalize.
        is_weights = torchify_buffer(is_weights).float()
        return SamplesFromReplayPri(*batch, is_weights=is_weights)

    def update_batch_priorities(self, priorities):
        priorities = numpify_buffer(priorities)
        self.priority_tree.update_batch_priorities(priorities**self.alpha)
Пример #5
0
class PrioritizedSequenceReplay(object):

    def __init__(self, alpha=0.6, beta=0.4, default_priority=1, unique=False,
            **kwargs):
        """Fix the SampleFromReplay length here, so priority tree can
        track where not to sample (else would have to temporarily subtract
        from tree every time sampling)."""
        super().__init__(**kwargs)
        save__init__args(locals())
        assert self.batch_T is not None  # Must assign.
        self.init_priority_tree()

    def init_priority_tree(self):
        off_backward = math.ceil((1 + self.off_backward + self.batch_T) /
            self.rnn_state_interval)  # +1 in case interval aligned? TODO: check
        self.priority_tree = SumTree(
            T=self.T // self.rnn_state_interval,
            B=self.B,
            off_backward=off_backward,
            off_forward=math.ceil(self.off_forward / self.rnn_state_interval),
            default_value=self.default_priority ** self.alpha,
        )

    def set_beta(self, beta):
        self.beta = beta

    def append_samples(self, samples):
        t, rsi = self.t, self.rnn_state_interval
        T, idxs = super().append_samples(samples)
        if rsi <= 1:  # All or no rnn states stored.
            self.priority_tree.advance(T)
        else:  # Some rnn states stored.
            n = self.t // rsi - t // rsi
            if self.t < t:  # Wrapped.
                n += self.T // rsi
            self.priority_tree.advance(n)
        return T, idxs

    def sample_batch(self, batch_B):
        (tree_T_idxs, B_idxs), priorities = self.priority_tree.sample(
            batch_B, unique=self.unique)
        if self.rnn_state_interval > 1:
            T_idxs = tree_T_idxs * self.rnn_state_interval
        batch = self.extract_batch(T_idxs, B_idxs, self.batch_T)
        is_weights = (1. / priorities) ** self.beta
        is_weights /= max(is_weights)  # Normalize.
        is_weights = torchify_buffer(is_weights).float()
        return SamplesFromReplayPri(*batch, is_weights=is_weights)

    def update_batch_priorities(self, priorities):
        priorities = numpify_buffer(priorities)
        self.priority_tree.update_batch_priorities(priorities ** self.alpha)
Пример #6
0
class RlWithUlPrioritizedReplayBuffer(BaseReplayBuffer):
    """Replay prioritized by empirical n-step returns: pri = 1 + alpha * return ** beta."""
    def __init__(self, example, size, B, replay_T, discount, n_step_return,
                 alpha, beta):
        self.T = T = math.ceil(size / B)
        self.B = B
        self.size = T * B
        self.t = 0  # cursor
        self.replay_T = replay_T
        self.discount = discount
        self.n_step_return = n_step_return
        self.alpha = alpha
        self.beta = beta
        self.samples = buffer_from_example(example, (T, B),
                                           share_memory=self.async_)
        if n_step_return > 1:
            self.samples_return_ = buffer_from_example(example.reward, (T, B))
            self.samples_done_n = buffer_from_example(example.done, (T, B))
        else:
            self.samples_return_ = self.samples.reward
            self.samples_done_n = self.samples.done
        self._buffer_full = False
        self.init_priority_tree()

    def append_samples(self, samples):
        T, B = get_leading_dims(samples, n_dim=2)
        assert B == self.B
        t = self.t
        if t + T > self.T:  # Wrap.
            idxs = np.arange(t, t + T) % self.T
        else:
            idxs = slice(t, t + T)
        self.samples[idxs] = samples
        new_returns = self.compute_returns(T)
        if not self._buffer_full and t + T >= self.T:
            self._buffer_full = True
        self.t = (t + T) % self.T
        priorities = 1 + self.alpha * new_returns**self.beta
        self.priority_tree.advance(T, priorities=priorities)
        return T, idxs

    def sample_batch(self, batch_B):
        T_idxs, B_idxs = self.sample_idxs(batch_B)
        return self.extract_batch(T_idxs, B_idxs, self.replay_T)

    def compute_returns(self, T):
        """Compute the n-step returns using the new rewards just written into
        the buffer, but before the buffer cursor is advanced.  Input ``T`` is
        the number of new timesteps which were just written.
        Does nothing if `n-step==1`. e.g. if 2-step return, t-1
        is first return written here, using reward at t-1 and new reward at t
        (up through t-1+T from t+T).]

        Use ABSOLUTE VALUE of rewards...it's all good signal for prioritization.
        """
        t, s, nm1 = self.t, self.samples, self.n_step_return - 1
        if self.n_step_return == 1:
            idxs = np.arange(t - nm1, t + T) % self.T
            return_ = np.abs(s.reward[idxs])
            return return_  # return = reward, done_n = done
        if t - nm1 >= 0 and t + T <= self.T:  # No wrap (operate in-place).
            reward = np.abs(s.reward[t - nm1:t + T])
            done = s.done[t - nm1:t + T]
            return_dest = self.samples_return_[t - nm1:t - nm1 + T]
            done_n_dest = self.samples_done_n[t - nm1:t - nm1 + T]
            discount_return_n_step(reward,
                                   done,
                                   n_step=self.n_step_return,
                                   discount=self.discount,
                                   return_dest=return_dest,
                                   done_n_dest=done_n_dest)
            return return_dest.copy()
        else:  # Wrap (copies); Let it (wrongly) wrap at first call.
            idxs = np.arange(t - nm1, t + T) % self.T
            reward = np.abs(s.reward[idxs])
            done = s.done[idxs]
            dest_idxs = idxs[:-nm1]
            return_, done_n = discount_return_n_step(reward,
                                                     done,
                                                     n_step=self.n_step_return,
                                                     discount=self.discount)
            self.samples_return_[dest_idxs] = return_
            self.samples_done_n[dest_idxs] = done_n
            return return_

    def init_priority_tree(self):
        self.priority_tree = SumTree(
            T=self.T,
            B=self.B,
            off_backward=self.n_step_return,
            off_forward=0,
            default_value=1,
            enable_input_priorities=True,
            input_priority_shift=self.n_step_return - 1,
        )

    def sample_idxs(self, batch_B):
        (T_idxs, B_idxs), priorities = self.priority_tree.sample(batch_B,
                                                                 unique=True)
        return T_idxs, B_idxs

    def extract_batch(self, T_idxs, B_idxs, T):
        s = self.samples
        batch = SamplesFromReplay(
            observation=self.extract_observation(T_idxs, B_idxs, T),
            action=buffer_func(s.action, extract_sequences, T_idxs, B_idxs, T),
            reward=extract_sequences(s.reward, T_idxs, B_idxs, T),
            done=extract_sequences(s.done, T_idxs, B_idxs, T),
        )
        return torchify_buffer(batch)

    def extract_observation(self, T_idxs, B_idxs, T):
        return buffer_func(self.samples.observation, extract_sequences, T_idxs,
                           B_idxs, T)