def test_tree_set_overlap(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[2] = 3.0 assert np.isclose(tree.sum(), 3.0) assert np.isclose(tree.sum(2, 3), 3.0) assert np.isclose(tree.sum(2, -1), 3.0) assert np.isclose(tree.sum(2, 4), 3.0) assert np.isclose(tree.sum(1, 2), 0.0)
def test_prefixsum_idx(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert tree.find_prefixsum_idx(0.0) == 2 assert tree.find_prefixsum_idx(0.5) == 2 assert tree.find_prefixsum_idx(0.99) == 2 assert tree.find_prefixsum_idx(1.01) == 3 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(4.00) == 3
def test_prefixsum_idx2(): tree = SumSegmentTree(4) tree[0] = 0.5 tree[1] = 1.0 tree[2] = 1.0 tree[3] = 3.0 assert tree.find_prefixsum_idx(0.00) == 0 assert tree.find_prefixsum_idx(0.55) == 1 assert tree.find_prefixsum_idx(0.99) == 1 assert tree.find_prefixsum_idx(1.51) == 2 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(5.50) == 3
def test_prefixsum_idx(): """ test Segment Tree data structure """ tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert tree.find_prefixsum_idx(0.0) == 2 assert tree.find_prefixsum_idx(0.5) == 2 assert tree.find_prefixsum_idx(0.99) == 2 assert tree.find_prefixsum_idx(1.01) == 3 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(4.00) == 3
def __init__(self, size, frame_history_len, alpha, lander=False): """This is a memory efficient implementation of the replay buffer. The sepecific memory optimizations use here are: - only store each frame once rather than k times even if every observation normally consists of k last frames - store frames as np.uint8 (actually it is most time-performance to cast them back to float32 on GPU to minimize memory transfer time) - store frame_t and frame_(t+1) in the same buffer. For the tipical use case in Atari Deep RL buffer with 1M frames the total memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes Warning! Assumes that returning frame of zeros at the beginning of the episode, when there is less frames than `frame_history_len`, is acceptable. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. frame_history_len: int Number of memories to be retried for each observation. """ self.lander = lander self.size = size self.frame_history_len = frame_history_len self.next_idx = 0 self.num_in_buffer = 0 self.obs = None self.action = None self.reward = None self.done = None assert alpha >= 0 assert alpha <= 1 self.alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0
def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0
def __init__(self, limit, alpha, transition_small_epsilon=1e-6, demo_epsilon=0.2, nb_rollout_steps=100): super(PrioritizedMemory, self).__init__(limit, nb_rollout_steps) assert alpha > 0 self._alpha = alpha self._transition_small_epsilon = transition_small_epsilon self._demo_epsilon = demo_epsilon it_capacity = 1 while it_capacity < self.maxsize: it_capacity *= 2 # Size must be power of 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0
def __init__(self, size, alpha, epsilon, timesteps, initial_p, final_p): super(DoublePrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha self._epsilon = epsilon self._beta_schedule = LinearSchedule(timesteps, initial_p=initial_p, final_p=final_p) it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._it_sum2 = SumSegmentTree(it_capacity) self._it_min2 = MinSegmentTree(it_capacity) self._max_priority2 = 1.0
def __init__(self, max_steps, num_processes, gamma, prio_alpha, obs_shape, action_space, recurrent_hidden_state_size, device): self.max_steps = max_steps self.num_processes = num_processes self.gamma = gamma self.device = device # stored episode data self.obs = torch.zeros(max_steps, *obs_shape) self.recurrent_hidden_states = torch.zeros( max_steps, recurrent_hidden_state_size) self.returns = torch.zeros(max_steps, 1) if action_space.__class__.__name__ == 'Discrete': self.actions = torch.zeros(max_steps, 1).long() else: self.actions = torch.zeros(max_steps, action_space.shape[0]) self.masks = torch.ones(max_steps, 1) self.next_idx = 0 self.num_steps = 0 # store (full) episode stats self.episode_step_count = 0 self.episode_rewards = deque() self.episode_steps = deque() # currently running (accumulating) episodes self.running_episodes = [[] for _ in range(num_processes)] if prio_alpha > 0: """ Sampling priority is enabled if prio_alpha > 0 Priority algorithm ripped from OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py """ self.prio_alpha = prio_alpha tree_capacity = 1 << math.ceil(math.log2(self.max_steps)) self.prio_sum_tree = SumSegmentTree(tree_capacity) self.prio_min_tree = MinSegmentTree(tree_capacity) self.prio_max = 1.0 else: self.prio_alpha = 0
def __init__(self, size, alpha): """ Create Prioritized Replay buffer. See Also ReplayBuffer.__init__ :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization) """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0
def test_tree_set(): tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 assert np.isclose(tree.sum(), 4.0) assert np.isclose(tree.sum(0, 2), 0.0) assert np.isclose(tree.sum(0, 3), 1.0) assert np.isclose(tree.sum(2, 3), 1.0) assert np.isclose(tree.sum(2, -1), 1.0) assert np.isclose(tree.sum(2, 4), 4.0)
def __init__(self, size, alpha1, alpha2=1.0, candidates_size=5, env_id='PongNoFrameskip-v4'): """Create a Double Prioritized State Recycled ReplayBuffer :param size: int Max number of transitions to store in the buffer. :param alpha1: float The rate of the prioritization of sampling. :param alpha2: float The rate of the prioritization of replacement. :param candidates_size: int The number of the candidates chosen in replacement. :param env_id: str The name of the gym [atari] environment. """ super().__init__(size) assert alpha1 >= 0 self._alpha1 = alpha1 assert alpha2 >= 0 self._alpha2 = alpha2 assert candidates_size > 0 self.candidates_size = candidates_size self.env_id = env_id it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._it_inverse_sum = SumSegmentTree(it_capacity) self._max_priority = 1.0
def __init__(self, buffer_shapes, size_in_transitions, T, sample_transitions, alpha, env_name): """Create Prioritized Replay buffer. """ super(PrioritizedReplayBuffer, self).__init__(buffer_shapes, size_in_transitions, T, sample_transitions) assert alpha >= 0 self._alpha = alpha it_capacity = 1 self.size_in_transitions = size_in_transitions while it_capacity < size_in_transitions: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self.T = T self.buffers['td'] = np.zeros([self.size, self.T]) self.buffers['e'] = np.zeros([self.size, self.T]) self.env_name = env_name
class ReplayStorage: def __init__(self, max_steps, num_processes, gamma, prio_alpha, obs_shape, action_space, recurrent_hidden_state_size, device): self.max_steps = max_steps self.num_processes = num_processes self.gamma = gamma self.device = device # stored episode data self.obs = torch.zeros(max_steps, *obs_shape) self.recurrent_hidden_states = torch.zeros( max_steps, recurrent_hidden_state_size) self.returns = torch.zeros(max_steps, 1) if action_space.__class__.__name__ == 'Discrete': self.actions = torch.zeros(max_steps, 1).long() else: self.actions = torch.zeros(max_steps, action_space.shape[0]) self.masks = torch.ones(max_steps, 1) self.next_idx = 0 self.num_steps = 0 # store (full) episode stats self.episode_step_count = 0 self.episode_rewards = deque() self.episode_steps = deque() # currently running (accumulating) episodes self.running_episodes = [[] for _ in range(num_processes)] if prio_alpha > 0: """ Sampling priority is enabled if prio_alpha > 0 Priority algorithm ripped from OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py """ self.prio_alpha = prio_alpha tree_capacity = 1 << math.ceil(math.log2(self.max_steps)) self.prio_sum_tree = SumSegmentTree(tree_capacity) self.prio_min_tree = MinSegmentTree(tree_capacity) self.prio_max = 1.0 else: self.prio_alpha = 0 def _process_rewards(self, trajectory): has_positive = False reward_sum = 0. r = 0. for t in trajectory[::-1]: reward = t['reward'] reward_sum += reward if reward > (0. + 1e-5): has_positive = True r = reward + self.gamma * r t['return'] = r return has_positive, reward_sum def _add_trajectory(self, trajectory): has_positive, reward_sum = self._process_rewards(trajectory) if not has_positive: return trajectory_len = len(trajectory) prev_idx = self.next_idx for transition in trajectory: self.obs[self.next_idx].copy_(transition['obs']) self.recurrent_hidden_states[self.next_idx].copy_( transition['rhs']) self.actions[self.next_idx].copy_(transition['action']) self.returns[self.next_idx].copy_(transition['return']) self.masks[self.next_idx] = 1.0 prev_idx = self.next_idx if self.prio_alpha: self.prio_sum_tree[ self.next_idx] = self.prio_max**self.prio_alpha self.prio_min_tree[ self.next_idx] = self.prio_max**self.prio_alpha self.next_idx = (self.next_idx + 1) % self.max_steps self.num_steps = min(self.max_steps, self.num_steps + 1) self.masks[prev_idx] = 0.0 # update stats of stored full trajectories (episodes) while self.episode_step_count + trajectory_len > self.max_steps: steps_popped = self.episode_steps.popleft() self.episode_rewards.popleft() self.episode_step_count -= steps_popped self.episode_step_count += trajectory_len self.episode_steps.append(trajectory_len) self.episode_rewards.append(reward_sum) def _sample_proportional(self, sample_size): res = [] for _ in range(sample_size): mass = random.random() * self.prio_sum_tree.sum( 0, self.num_steps - 1) idx = self.prio_sum_tree.find_prefixsum_idx(mass) res.append(idx) return res def insert(self, obs, rhs, actions, rewards, dones): for n in range(self.num_processes): self.running_episodes[n].append( dict(obs=obs[n].clone(), rhs=rhs[n].clone(), action=actions[n].clone(), reward=rewards[n].clone())) for n, done in enumerate(dones): if done: self._add_trajectory(self.running_episodes[n]) self.running_episodes[n] = [] def update_priorities(self, indices, priorities): if not self.prio_alpha: return """Update priorities of sampled transitions. sets priority of transition at index indices[i] in buffer to priorities[i]. Parameters ---------- indices: [int] List of indices of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled indices. """ assert len(indices) == len(priorities) for idx, priority in zip(indices, priorities): priority = max(priority, 1e-6) assert priority > 0 assert 0 <= idx < self.num_steps self.prio_sum_tree[idx] = priority**self.prio_alpha self.prio_min_tree[idx] = priority**self.prio_alpha self.prio_max = max(self.prio_max, priority) def feed_forward_generator(self, batch_size, num_batches=None, beta=0.): """Generate batches of sampled experiences. Parameters ---------- batch_size: int Size of each sampled batch num_batches: int Number of batches to sample beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) """ batch_count = 0 sample_size = num_batches * batch_size or self.num_steps if self.prio_alpha > 0: indices = self._sample_proportional(sample_size) if beta > 0: # compute importance sampling weights to correct for the # bias introduced by sampling in a non-uniform manner weights = [] p_min = self.prio_min_tree.min() / self.prio_sum_tree.sum() max_weight = (p_min * self.num_steps)**(-beta) for i in indices: p_sample = self.prio_sum_tree[i] / self.prio_sum_tree.sum() weight = (p_sample * self.num_steps)**(-beta) weights.append(weight / max_weight) weights = torch.tensor(weights, dtype=torch.float32).unsqueeze(1) else: weights = torch.ones((len(indices), 1), dtype=torch.float32) else: if sample_size * 3 < self.num_steps: indices = random.sample(range(self.num_steps), sample_size) else: indices = np.random.permutation(self.num_steps)[:sample_size] weights = None for si in range(0, len(indices), batch_size): indices_batch = indices[si:min(len(indices), si + batch_size)] if len(indices_batch) < batch_size: return weights_batch = None if weights is None else \ weights[si:min(len(indices), si + batch_size)].to(self.device) obs_batch = self.obs[indices_batch].to(self.device) recurrent_hidden_states_batch = self.recurrent_hidden_states[ indices_batch].to(self.device) actions_batch = self.actions[indices_batch].to(self.device) returns_batch = self.returns[indices_batch].to(self.device) masks_batch = self.masks[indices_batch].to(self.device) yield obs_batch, recurrent_hidden_states_batch, actions_batch, returns_batch, \ masks_batch, weights_batch, indices_batch batch_count += 1 if num_batches and batch_count >= num_batches: return
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage)) ** (-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage)) ** (-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority ** self._alpha self._it_min[idx] = priority ** self._alpha self._max_priority = max(self._max_priority, priority) def save(self, path): """Save the priority memory in case of crash Parameters ---------- path: str The network path inside the yaml file of the model where the model is being saved """ info = { "alpha": self._alpha, "it_sum": self._it_sum, "it_min": self._it_min, "max_priority": self._max_priority, "next_idx": self._next_idx, "storage": self._storage, "maxsize": self._maxsize } with open(path + "/adaptive_memory.info", "wb") as file: p = pickle.Pickler(file) p.fast = True p.dump(info) p.memo.clear() def load(self, path): """ Load the parameters of a saved off memory file Parameters ---------- path: str The path of where the saved off file exists """ restore_path = path + "/adaptive_memory.info" if os.path.exists(restore_path): with open(restore_path, "rb") as file: p = pickle.Unpickler(file) info = p.load() p.memo.clear() self._alpha = info["alpha"] self._it_sum = info["it_sum"] self._it_min = info["it_min"] self._max_priority = info["max_priority"] self._next_idx = info["next_idx"] self._storage = info["storage"] self._maxsize = info["maxsize"]
class DoublePrioritizedStateRecycledReplayBuffer(ReplayBuffer): def __init__(self, size, alpha1, alpha2=1.0, candidates_size=5, env_id='PongNoFrameskip-v4'): """Create a Double Prioritized State Recycled ReplayBuffer :param size: int Max number of transitions to store in the buffer. :param alpha1: float The rate of the prioritization of sampling. :param alpha2: float The rate of the prioritization of replacement. :param candidates_size: int The number of the candidates chosen in replacement. :param env_id: str The name of the gym [atari] environment. """ super().__init__(size) assert alpha1 >= 0 self._alpha1 = alpha1 assert alpha2 >= 0 self._alpha2 = alpha2 assert candidates_size > 0 self.candidates_size = candidates_size self.env_id = env_id it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._it_inverse_sum = SumSegmentTree(it_capacity) self._max_priority = 1.0 def not_full(self): return self._next_idx >= len(self._storage) def replacement_candidates(self, candidates_size=None): candidates_idxes = self._replacement_candidate_proportional( candidates_size) candidates = [self._storage[idx] for idx in candidates_idxes] return candidates_idxes, candidates def state_recycle(self, idxes, data, td_errors, max_priority_set): for i in range(len(idxes)): self._storage[idxes[i]] = data[i] if max_priority_set: self._it_sum[idxes[i]] = self._max_priority**self._alpha1 self._it_min[idxes[i]] = self._max_priority**self._alpha1 self._it_inverse_sum[ idxes[i]] = self._max_priority**-self._alpha2 self._next_idx = idxes[np.argmin(td_errors)] def add(self, obs_t, action, reward, obs_tp1, done, env_clone_state=None, timestamp=None, idx=None): data = (obs_t, action, reward, obs_tp1, done, env_clone_state, timestamp) if not idx: idx = self._next_idx if self.not_full(): self._storage.append(data) else: self._storage[idx] = data self._next_idx = (self._next_idx + 1) % self._maxsize self._it_sum[idx] = self._max_priority**self._alpha1 self._it_min[idx] = self._max_priority**self._alpha1 self._it_inverse_sum[idx] = self._max_priority**-self._alpha2 def _sample_proportional(self, batch_size): res = [] p_total = self._it_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def _replacement_candidate_proportional(self, candidates_size): res = [] if not candidates_size: candidates_size = self.candidates_size p_total = self._it_inverse_sum.sum(0, len(self._storage) - 1) every_range_len = p_total / candidates_size for i in range(candidates_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_inverse_sum.find_prefixsum_idx(mass) res.append(idx) return res def _encode_sample(self, idxes): obses_t, actions, rewards, obses_tp1, dones, env_states, timestamps = [], [], [], [], [], [], [] for i in idxes: data = self._storage[i] obs_t, action, reward, obs_tp1, done, env_state, timestamp = data obses_t.append(np.array(obs_t, copy=False)) actions.append(np.array(action, copy=False)) rewards.append(reward) obses_tp1.append(np.array(obs_tp1, copy=False)) dones.append(done) env_states.append(np.array(env_state, copy=False)) timestamps.append(np.array(timestamp, copy=False)) return np.array(obses_t), np.array(actions), np.array(rewards), \ np.array(obses_tp1), np.array(dones), np.array(env_states), np.array(timestamps) def sample(self, batch_size, beta=1.0): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha1 self._it_min[idx] = priority**self._alpha1 self._it_inverse_sum[idx] = priority**-self._alpha2 self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(object): def __init__(self, size, frame_history_len, alpha, lander=False): """This is a memory efficient implementation of the replay buffer. The sepecific memory optimizations use here are: - only store each frame once rather than k times even if every observation normally consists of k last frames - store frames as np.uint8 (actually it is most time-performance to cast them back to float32 on GPU to minimize memory transfer time) - store frame_t and frame_(t+1) in the same buffer. For the tipical use case in Atari Deep RL buffer with 1M frames the total memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes Warning! Assumes that returning frame of zeros at the beginning of the episode, when there is less frames than `frame_history_len`, is acceptable. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. frame_history_len: int Number of memories to be retried for each observation. """ self.lander = lander self.size = size self.frame_history_len = frame_history_len self.next_idx = 0 self.num_in_buffer = 0 self.obs = None self.action = None self.reward = None self.done = None assert alpha >= 0 assert alpha <= 1 self.alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def can_sample(self, batch_size): """Returns true if `batch_size` different transitions can be sampled from the buffer.""" return batch_size + 1 <= self.num_in_buffer def _encode_sample(self, idxes): obs_batch = np.concatenate( [self._encode_observation(idx)[None] for idx in idxes], 0) act_batch = self.action[idxes] rew_batch = self.reward[idxes] next_obs_batch = np.concatenate( [self._encode_observation(idx + 1)[None] for idx in idxes], 0) done_mask = np.array([1.0 if self.done[idx] else 0.0 for idx in idxes], dtype=np.float32) return obs_batch, act_batch, rew_batch, next_obs_batch, done_mask def sample(self, batch_size, beta): """Sample `batch_size` different transitions. < Now proportional & returns importance weights> i-th sample transition is the following: when observing `obs_batch[i]`, action `act_batch[i]` was taken, after which reward `rew_batch[i]` was received and subsequent observation next_obs_batch[i] was observed, unless the epsiode was done which is represented by `done_mask[i]` which is equal to 1 if episode has ended as a result of that action. Parameters ---------- batch_size: int How many transitions to sample. Returns ------- obs_batch: np.array Array of shape (batch_size, img_h, img_w, img_c * frame_history_len) and dtype np.uint8 act_batch: np.array Array of shape (batch_size,) and dtype np.int32 rew_batch: np.array Array of shape (batch_size,) and dtype np.float32 next_obs_batch: np.array Array of shape (batch_size, img_h, img_w, img_c * frame_history_len) and dtype np.uint8 done_mask: np.array Array of shape (batch_size,) and dtype np.float32 """ def proportional(batch_size): res = [] p_total = self._it_sum.sum(0, self.num_in_buffer - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res assert beta > 0 assert self.can_sample(batch_size) idxes = proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * self.num_in_buffer)**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * self.num_in_buffer)**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded = self._encode_sample(idxes) return list(encoded) + [weights, idxes] def encode_recent_observation(self): """Return the most recent `frame_history_len` frames. Returns ------- observation: np.array Array of shape (img_h, img_w, img_c * frame_history_len) and dtype np.uint8, where observation[:, :, i*img_c:(i+1)*img_c] encodes frame at time `t - frame_history_len + i` """ assert self.num_in_buffer > 0 return self._encode_observation((self.next_idx - 1) % self.size) def update_priorities(self, idxes, priorities): """***copied from baseline code*** Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ # ERROR AREA assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < self.num_in_buffer self._it_sum[idx] = priority**self.alpha self._it_min[idx] = priority**self.alpha self._max_priority = max(self._max_priority, priority) def _encode_observation(self, idx): end_idx = idx + 1 # make noninclusive start_idx = end_idx - self.frame_history_len # this checks if we are using low-dimensional observations, such as RAM # state, in which case we just directly return the latest RAM. if len(self.obs.shape) == 2: return self.obs[end_idx - 1] # if there weren't enough frames ever in the buffer for context if start_idx < 0 and self.num_in_buffer != self.size: start_idx = 0 for idx in range(start_idx, end_idx - 1): if self.done[idx % self.size]: start_idx = idx + 1 missing_context = self.frame_history_len - (end_idx - start_idx) # if zero padding is needed for missing context # or we are on the boundry of the buffer if start_idx < 0 or missing_context > 0: frames = [ np.zeros_like(self.obs[0]) for _ in range(missing_context) ] for idx in range(start_idx, end_idx): frames.append(self.obs[idx % self.size]) return np.concatenate(frames, 2) else: # this optimization has potential to saves about 30% compute time \o/ img_h, img_w = self.obs.shape[1], self.obs.shape[2] return self.obs[start_idx:end_idx].transpose(1, 2, 0, 3).reshape( img_h, img_w, -1) def store_frame(self, frame): """Store a single frame in the buffer at the next available index, overwriting old frames if necessary. Parameters ---------- frame: np.array Array of shape (img_h, img_w, img_c) and dtype np.uint8 the frame to be stored Returns ------- idx: int Index at which the frame is stored. To be used for `store_effect` later. """ if self.obs is None: self.obs = np.empty([self.size] + list(frame.shape), dtype=np.float32 if self.lander else np.uint8) self.action = np.empty([self.size], dtype=np.int32) self.reward = np.empty([self.size], dtype=np.float32) self.done = np.empty([self.size], dtype=np.bool) ret = self.next_idx self.obs[ret] = frame self._it_sum[ret] = self._max_priority**self.alpha self._it_min[ret] = self._max_priority**self.alpha self.next_idx = (self.next_idx + 1) % self.size self.num_in_buffer = min(self.size, self.num_in_buffer + 1) return ret def store_effect(self, idx, action, reward, done): """Store effects of action taken after obeserving frame stored at index idx. The reason `store_frame` and `store_effect` is broken up into two functions is so that once can call `encode_recent_observation` in between. Paramters --------- idx: int Index in buffer of recently observed frame (returned by `store_frame`). action: int Action that was performed upon observing this frame. reward: float Reward that was received when the actions was performed. done: bool True if episode was finished after performing that action. """ self.action[idx] = action self.reward[idx] = reward self.done[idx] = done
class DoublePrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha, epsilon, timesteps, initial_p, final_p): super(DoublePrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha self._epsilon = epsilon self._beta_schedule = LinearSchedule(timesteps, initial_p=initial_p, final_p=final_p) it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._it_sum2 = SumSegmentTree(it_capacity) self._it_min2 = MinSegmentTree(it_capacity) self._max_priority2 = 1.0 def add(self, *args, **kwargs): idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha self._it_sum2[idx] = self._max_priority2**self._alpha self._it_min2[idx] = self._max_priority2**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def _sample_proportional2(self, batch_size): res = [] for _ in range(batch_size): mass = random.random() * self._it_sum2.sum(0, len(self._storage) - 1) idx = self._it_sum2.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, time_step): beta = self._beta_schedule.value(time_step) assert beta > 0 idxes = self._sample_proportional(batch_size) self.idxes = idxes # keep to update priorities later weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return encoded_sample + (weights, ) def sample_qmap(self, batch_size, time_step, n_steps=1): beta = self._beta_schedule.value(time_step) assert beta > 0 idxes = self._sample_proportional2(batch_size) self.idxes2 = idxes # keep to update priorities later weights = [] p_min = self._it_min2.min() / self._it_sum2.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum2[idx] / self._it_sum2.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_qmap_sample(idxes, n_steps) return encoded_sample + (weights, ) def update_priorities(self, td_errors): priorities = np.abs(td_errors) + self._epsilon idxes = self.idxes assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) def update_priorities_qmap(self, td_errors): priorities = np.abs(td_errors) + self._epsilon idxes = self.idxes2 assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum2[idx] = priority**self._alpha self._it_min2[idx] = priority**self._alpha self._max_priority2 = max(self._max_priority2, priority)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority ** self._alpha self._it_min[idx] = self._max_priority ** self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage)) ** (-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage)) ** (-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority ** self._alpha self._it_min[idx] = priority ** self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer_NStep(ReplayBuffer_NStep): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer_NStep, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super().add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """ Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) for importance weights, which is used to determine how much to scale the gradients of high error samples DOWN, INCREASE correction as training progresses Returns ------- obs_batch: np.array batch of observations """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """ Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """ Create Prioritized Replay buffer. See Also ReplayBuffer.__init__ :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization) """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, obs_t, action, reward, obs_tp1, done): """ add a new transition to the buffer :param obs_t: (Any) the last observation :param action: ([float]) the action :param reward: (float) the reward of the transition :param obs_tp1: (Any) the current observation :param done: (bool) is the episode done """ idx = self._next_idx super().add(obs_t, action, reward, obs_tp1, done) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta=0): """ Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. :param batch_size: (int) How many transitions to sample. :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction) :return: - obs_batch: (numpy Any) batch of observations - act_batch: (numpy float) batch of actions executed given obs_batch - rew_batch: (numpy float) rewards received as results of executing act_batch - next_obs_batch: (numpy Any) next set of observations seen after executing act_batch - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """ Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. :param idxes: ([int]) List of idxes of sampled transitions :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha >= 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def add(self, *args, **kwargs): """See ReplayBuffer.store_effect""" idx = self._next_idx super(PrioritizedReplayBuffer, self).add(*args, **kwargs) self._it_sum[idx] = self._max_priority**self._alpha self._it_min[idx] = self._max_priority**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority)
class PrioritizedMemory(Memory): def __init__(self, limit, alpha, transition_small_epsilon=1e-6, demo_epsilon=0.2, nb_rollout_steps=100): super(PrioritizedMemory, self).__init__(limit, nb_rollout_steps) assert alpha > 0 self._alpha = alpha self._transition_small_epsilon = transition_small_epsilon self._demo_epsilon = demo_epsilon it_capacity = 1 while it_capacity < self.maxsize: it_capacity *= 2 # Size must be power of 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 def append(self, *args, **kwargs): with self.condition: idx = self._next_idx if not super().append(*args, **kwargs): return self._it_sum[idx] = self._max_priority self._it_min[idx] = self._max_priority def append_demonstration(self, *args, **kwargs): with self.lock: idx = self._next_idx if not super().append(*args, **kwargs, count=False): return self._it_sum[idx] = self._max_priority self._it_min[idx] = self._max_priority self.num_demonstrations += 1 def _sample_proportional(self, batch_size, pretrain): with self.lock: res = [] if pretrain: res = np.random.random_integers(low=0, high=self.nb_entries - 1, size=batch_size) return res for _ in range(batch_size): while True: mass = np.random.uniform( 0, self._it_sum.sum(0, len(self.storage) - 1)) idx = self._it_sum.find_prefixsum_idx(mass) if idx not in res: res.append(idx) break return res def sample_prioritized(self, batch_size, beta, pretrain=False): with self.lock: idxes = self._sample_proportional(batch_size, pretrain) demos = [i < self.num_demonstrations for i in idxes] weights = [] p_sum = self._it_sum.sum() for idx in idxes: p_sample = self._it_sum[idx] / p_sum weight = ((1.0 / p_sample) * (1.0 / len(self.storage)))**beta weights.append(weight) weights = np.array(weights) / np.max(weights) encoded_sample = self._get_batches_for_idxes(idxes) encoded_sample['weights'] = array_min2d(weights) encoded_sample['idxes'] = idxes encoded_sample['demos'] = array_min2d(demos) return encoded_sample def sample_rollout(self, batch_size, nsteps, beta, gamma, pretrain=False): with self.lock: batches = self.sample_prioritized(batch_size, beta, pretrain) n_step_batches = { storable_element: [] for storable_element in self.storable_elements } n_step_batches["step_reached"] = [] idxes = batches["idxes"] for idx in idxes: local_idxes = list(range(idx, min(idx + nsteps, len(self)))) transitions = self._get_batches_for_idxes(local_idxes) summed_reward = 0 count = 0 terminal = 0.0 terminals = transitions['terminals1'] r = transitions['rewards'] for i in range(len(r)): summed_reward += (gamma**i) * r[i] count = i if terminals[i]: terminal = 1.0 break n_step_batches["step_reached"].append(count) n_step_batches["obs1"].append(transitions["obs1"][count]) n_step_batches["terminals1"].append(terminal) n_step_batches["rewards"].append(summed_reward) n_step_batches["states1"].append(transitions["states1"][count]) n_step_batches["aux1"].append(transitions["aux1"][count]) n_step_batches["actions"].append(transitions["actions"][0]) n_step_batches['demos'] = batches['demos'] n_step_batches = { k: array_min2d(v) for k, v in n_step_batches.items() } n_step_batches['weights'] = batches['weights'] n_step_batches['idxes'] = idxes n_step_batches['weights'] = batches['weights'] return batches, n_step_batches, sum(batches['demos']) / batch_size def update_priorities(self, idxes, td_errors, actor_losses=0.0): with self.lock: priorities = td_errors + \ (actor_losses ** 2) + self._transition_small_epsilon for i in range(len(priorities)): if idxes[i] < self.num_demonstrations: priorities[i] += np.max(priorities) * self._demo_epsilon assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self.storage) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority**self._alpha)