def init_priority_tree(self): """Organized here for clean inheritance.""" self.priority_tree = SumTree( T=self.T, B=self.B, off_backward=self.off_backward, off_forward=self.off_forward, default_value=self.default_priority**self.alpha, )
def init_priority_tree(self): self.priority_tree = SumTree( T=self.T, B=self.B, off_backward=self.n_step_return, off_forward=0, default_value=1, enable_input_priorities=True, input_priority_shift=self.n_step_return - 1, )
def init_priority_tree(self): off_backward = math.ceil((1 + self.off_backward + self.batch_T) / self.rnn_state_interval) # +1 in case interval aligned? TODO: check self.priority_tree = SumTree( T=self.T // self.rnn_state_interval, B=self.B, off_backward=off_backward, off_forward=math.ceil(self.off_forward / self.rnn_state_interval), default_value=self.default_priority ** self.alpha, )
class PrioritizedReplay(object): def __init__(self, alpha=0.6, beta=0.4, default_priority=1, unique=False, **kwargs): super().__init__(**kwargs) save__init__args(locals()) self.init_priority_tree() def init_priority_tree(self): """Organized here for clean inheritance.""" self.priority_tree = SumTree( T=self.T, B=self.B, off_backward=self.off_backward, off_forward=self.off_forward, default_value=self.default_priority**self.alpha, ) def set_beta(self, beta): self.beta = beta def append_samples(self, samples): T, idxs = super().append_samples(samples) self.priority_tree.advance(T) # Progress priority_tree cursor. return T, idxs def sample_batch(self, batch_B): (T_idxs, B_idxs), priorities = self.priority_tree.sample(batch_B, unique=self.unique) batch = self.extract_batch(T_idxs, B_idxs) is_weights = (1. / (priorities + EPS))**self.beta # Unnormalized. is_weights /= max(is_weights) # Normalize. is_weights = torchify_buffer(is_weights).float() return SamplesFromReplayPri(*batch, is_weights=is_weights) def update_batch_priorities(self, priorities): priorities = numpify_buffer(priorities) self.priority_tree.update_batch_priorities(priorities**self.alpha)
class PrioritizedSequenceReplay(object): def __init__(self, alpha=0.6, beta=0.4, default_priority=1, unique=False, **kwargs): """Fix the SampleFromReplay length here, so priority tree can track where not to sample (else would have to temporarily subtract from tree every time sampling).""" super().__init__(**kwargs) save__init__args(locals()) assert self.batch_T is not None # Must assign. self.init_priority_tree() def init_priority_tree(self): off_backward = math.ceil((1 + self.off_backward + self.batch_T) / self.rnn_state_interval) # +1 in case interval aligned? TODO: check self.priority_tree = SumTree( T=self.T // self.rnn_state_interval, B=self.B, off_backward=off_backward, off_forward=math.ceil(self.off_forward / self.rnn_state_interval), default_value=self.default_priority ** self.alpha, ) def set_beta(self, beta): self.beta = beta def append_samples(self, samples): t, rsi = self.t, self.rnn_state_interval T, idxs = super().append_samples(samples) if rsi <= 1: # All or no rnn states stored. self.priority_tree.advance(T) else: # Some rnn states stored. n = self.t // rsi - t // rsi if self.t < t: # Wrapped. n += self.T // rsi self.priority_tree.advance(n) return T, idxs def sample_batch(self, batch_B): (tree_T_idxs, B_idxs), priorities = self.priority_tree.sample( batch_B, unique=self.unique) if self.rnn_state_interval > 1: T_idxs = tree_T_idxs * self.rnn_state_interval batch = self.extract_batch(T_idxs, B_idxs, self.batch_T) is_weights = (1. / priorities) ** self.beta is_weights /= max(is_weights) # Normalize. is_weights = torchify_buffer(is_weights).float() return SamplesFromReplayPri(*batch, is_weights=is_weights) def update_batch_priorities(self, priorities): priorities = numpify_buffer(priorities) self.priority_tree.update_batch_priorities(priorities ** self.alpha)
class RlWithUlPrioritizedReplayBuffer(BaseReplayBuffer): """Replay prioritized by empirical n-step returns: pri = 1 + alpha * return ** beta.""" def __init__(self, example, size, B, replay_T, discount, n_step_return, alpha, beta): self.T = T = math.ceil(size / B) self.B = B self.size = T * B self.t = 0 # cursor self.replay_T = replay_T self.discount = discount self.n_step_return = n_step_return self.alpha = alpha self.beta = beta self.samples = buffer_from_example(example, (T, B), share_memory=self.async_) if n_step_return > 1: self.samples_return_ = buffer_from_example(example.reward, (T, B)) self.samples_done_n = buffer_from_example(example.done, (T, B)) else: self.samples_return_ = self.samples.reward self.samples_done_n = self.samples.done self._buffer_full = False self.init_priority_tree() def append_samples(self, samples): T, B = get_leading_dims(samples, n_dim=2) assert B == self.B t = self.t if t + T > self.T: # Wrap. idxs = np.arange(t, t + T) % self.T else: idxs = slice(t, t + T) self.samples[idxs] = samples new_returns = self.compute_returns(T) if not self._buffer_full and t + T >= self.T: self._buffer_full = True self.t = (t + T) % self.T priorities = 1 + self.alpha * new_returns**self.beta self.priority_tree.advance(T, priorities=priorities) return T, idxs def sample_batch(self, batch_B): T_idxs, B_idxs = self.sample_idxs(batch_B) return self.extract_batch(T_idxs, B_idxs, self.replay_T) def compute_returns(self, T): """Compute the n-step returns using the new rewards just written into the buffer, but before the buffer cursor is advanced. Input ``T`` is the number of new timesteps which were just written. Does nothing if `n-step==1`. e.g. if 2-step return, t-1 is first return written here, using reward at t-1 and new reward at t (up through t-1+T from t+T).] Use ABSOLUTE VALUE of rewards...it's all good signal for prioritization. """ t, s, nm1 = self.t, self.samples, self.n_step_return - 1 if self.n_step_return == 1: idxs = np.arange(t - nm1, t + T) % self.T return_ = np.abs(s.reward[idxs]) return return_ # return = reward, done_n = done if t - nm1 >= 0 and t + T <= self.T: # No wrap (operate in-place). reward = np.abs(s.reward[t - nm1:t + T]) done = s.done[t - nm1:t + T] return_dest = self.samples_return_[t - nm1:t - nm1 + T] done_n_dest = self.samples_done_n[t - nm1:t - nm1 + T] discount_return_n_step(reward, done, n_step=self.n_step_return, discount=self.discount, return_dest=return_dest, done_n_dest=done_n_dest) return return_dest.copy() else: # Wrap (copies); Let it (wrongly) wrap at first call. idxs = np.arange(t - nm1, t + T) % self.T reward = np.abs(s.reward[idxs]) done = s.done[idxs] dest_idxs = idxs[:-nm1] return_, done_n = discount_return_n_step(reward, done, n_step=self.n_step_return, discount=self.discount) self.samples_return_[dest_idxs] = return_ self.samples_done_n[dest_idxs] = done_n return return_ def init_priority_tree(self): self.priority_tree = SumTree( T=self.T, B=self.B, off_backward=self.n_step_return, off_forward=0, default_value=1, enable_input_priorities=True, input_priority_shift=self.n_step_return - 1, ) def sample_idxs(self, batch_B): (T_idxs, B_idxs), priorities = self.priority_tree.sample(batch_B, unique=True) return T_idxs, B_idxs def extract_batch(self, T_idxs, B_idxs, T): s = self.samples batch = SamplesFromReplay( observation=self.extract_observation(T_idxs, B_idxs, T), action=buffer_func(s.action, extract_sequences, T_idxs, B_idxs, T), reward=extract_sequences(s.reward, T_idxs, B_idxs, T), done=extract_sequences(s.done, T_idxs, B_idxs, T), ) return torchify_buffer(batch) def extract_observation(self, T_idxs, B_idxs, T): return buffer_func(self.samples.observation, extract_sequences, T_idxs, B_idxs, T)