class GenericLearner(threading.Thread): def __init__(self, local_evaluator): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.daemon = True self.weights_updated = False def run(self): while True: self.step() def step(self): with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: with self.grad_timer: td_error = self.local_evaluator.compute_apply(replay) self.outqueue.put((ra, replay, td_error)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True
class LearnerThread(threading.Thread): """Background thread that updates the local model from replay data. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_evaluator): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.daemon = True self.weights_updated = False def run(self): while True: self.step() def step(self): with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: with self.grad_timer: td_error = self.local_evaluator.compute_apply( replay)["td_error"] self.outqueue.put((ra, replay, td_error, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True
class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. This is for use with AsyncSamplesOptimizer. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter, learner_queue_size, learner_queue_timeout): """Initialize the learner thread. Arguments: local_worker (RolloutWorker): process local rollout worker holding policies this thread will call learn_on_batch() on minibatch_buffer_size (int): max number of train batches to store in the minibatching buffer num_sgd_iter (int): number of passes to learn on per train batch learner_queue_size (int): max size of queue of inbound train batches to this thread learner_queue_timeout (int): raise an exception if the queue has been empty for this long in seconds """ threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer(inqueue=self.inqueue, size=minibatch_buffer_size, timeout=learner_queue_timeout, num_passes=num_sgd_iter) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.stats = {} self.stopped = False def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats = get_learner_stats(fetches) self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize())
class LearnerThread(threading.Thread): """Background thread that updates the local model from replay data. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.overall_timer = TimerStat() self.daemon = True self.weights_updated = False self.stopped = False self.stats = {} def run(self): # Switch on eager mode if configured. if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]: tf1.enable_eager_execution() while not self.stopped: self.step() def step(self): with self.overall_timer: with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: grad_out = self.local_worker.learn_on_batch(replay) for pid, info in grad_out.items(): td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error")) # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing # (may get manipulated if they are part of a tensor). replay.policy_batches[pid].set_get_interceptor(None) prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error) self.stats[pid] = get_learner_stats(info) self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True self.overall_timer.push_units_processed(replay and replay.count or 0)
class LearnerThread(threading.Thread): """Background thread that updates the local model from replay data. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.overall_timer = TimerStat() self.daemon = True self.weights_updated = False self.stopped = False self.stats = {} def run(self): while not self.stopped: self.step() def step(self): with self.overall_timer: with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: grad_out = self.local_worker.learn_on_batch(replay) for pid, info in grad_out.items(): td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error")) prio_dict[pid] = (replay.policy_batches[pid].data.get( "batch_indexes"), td_error) self.stats[pid] = get_learner_stats(info) self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True self.overall_timer.push_units_processed(replay and replay.count or 0)
class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. This is for use with AsyncSamplesOptimizer. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter, learner_queue_size, learner_queue_timeout): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer(inqueue=self.inqueue, size=minibatch_buffer_size, timeout=learner_queue_timeout, num_passes=num_sgd_iter) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.stats = {} self.stopped = False def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats = get_learner_stats(fetches) self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize())
class LearnerThread(threading.Thread): """Background thread that updates the local model from replay data. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_evaluator): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.daemon = True self.weights_updated = False self.stopped = False self.stats = {} def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: grad_out = self.local_evaluator.compute_apply(replay) for pid, info in grad_out.items(): prio_dict[pid] = ( replay.policy_batches[pid].data.get("batch_indexes"), info.get("td_error")) if "stats" in info: self.stats[pid] = info["stats"] self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True
class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer( self.inqueue, minibatch_buffer_size, num_sgd_iter) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.stats = {} self.stopped = False def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_evaluator.learn_on_batch(batch) self.weights_updated = True self.stats = fetches.get("stats", {}) self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize())
class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer(self.inqueue, minibatch_buffer_size, num_sgd_iter) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.stats = {} self.stopped = False def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_evaluator.compute_apply(batch) self.weights_updated = True self.stats = fetches.get("stats", {}) self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize())
class LearnerThread(threading.Thread): """Background thread that updates the local model from replay data. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_evaluator): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.daemon = True self.weights_updated = False self.stopped = False def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: grad_out = self.local_evaluator.compute_apply(replay) for pid, info in grad_out.items(): prio_dict[pid] = ( replay.policy_batches[pid]["batch_indexes"], info["td_error"]) # send `replay` back also so that it gets released by the original # thread: https://github.com/ray-project/ray/issues/2610 self.outqueue.put((ra, replay, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True
class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_evaluator): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_evaluator = local_evaluator self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.daemon = True self.weights_updated = 0 self.stats = {} def run(self): while True: self.step() def step(self): with self.queue_timer: ra, batch = self.inqueue.get() if batch is not None: with self.grad_timer: fetches = self.local_evaluator.compute_apply(batch) self.weights_updated += 1 if "stats" in fetches: self.stats = fetches["stats"] self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize())
class LearnerThread(threading.Thread): """Background thread that updates the local model from replay data. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.overall_timer = TimerStat() self.daemon = True self.weights_updated = False self.stopped = False self.learner_info = {} def run(self): # Switch on eager mode if configured. if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]: tf1.enable_eager_execution() while not self.stopped: self.step() def step(self): with self.overall_timer: with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the # final results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same # structure no matter the setup (multi-GPU, multi-agent, # minibatch SGD, tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch( replay) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results( results, pid) td_error = results["td_error"] # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing # (may get manipulated if they are part of a tensor). replay.policy_batches[pid].set_get_interceptor(None) prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error) self.learner_info = learner_info_builder.finalize() self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True self.overall_timer.push_units_processed(replay and replay.count or 0)
class ReservoirReplayBuffer: @DeveloperAPI def __init__(self, size: int): """Create Prioritized Replay buffer. Args: size (int): Max number of timesteps to store in the reservoir buffer. """ self._storage = [] self._maxsize = size self._next_idx = 0 self._hit_count = np.zeros(size) self._eviction_started = False self._num_timesteps_added = 0 self._num_timesteps_added_wrap = 0 self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 self._add_calls = 0 def __len__(self): return len(self._storage) def _sample_reservior_buffer_next_index(self): uniform_idx = np.random.randint(0, self._add_calls + 1) if uniform_idx < self._maxsize: self._next_idx = uniform_idx # Reservoir Behavior else: self._next_idx = None @DeveloperAPI def add(self, item: SampleBatchType): warn_replay_buffer_size(item=item, num_items=self._maxsize / item.count) assert item.count == 1, item if self._next_idx is not None: self._num_timesteps_added += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item if self._num_timesteps_added >= self._maxsize: self._eviction_started = True self._sample_reservior_buffer_next_index() else: self._next_idx += 1 if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 else: assert self._add_calls >= self._maxsize self._sample_reservior_buffer_next_index() self._add_calls += 1 def _encode_sample(self, idxes: List[int]) -> SampleBatchType: out = SampleBatch.concat_samples([self._storage[i] for i in idxes]) out.decompress_if_needed() return out @DeveloperAPI def sample(self, num_items: int) -> SampleBatchType: """Sample a batch of experiences. Args: num_items (int): Number of items to sample from this buffer. Returns: SampleBatchType: concatenated batch of items. """ idxes = [ random.randint(0, len(self._storage) - 1) for _ in range(num_items) ] self._num_timesteps_sampled += num_items out_batch = self._encode_sample(idxes) # assert out_batch.count == 128 return out_batch @DeveloperAPI def stats(self, debug=False): data = { "added_count": self._num_timesteps_added, "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data
class ReplayBuffer: @DeveloperAPI def __init__(self, capacity: int = 10000, size: Optional[int] = DEPRECATED_VALUE): """Initializes a Replaybuffer instance. Args: capacity (int): Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. """ # Deprecated args. if size != DEPRECATED_VALUE: deprecation_warning("ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False) capacity = size # The actual storage (list of SampleBatches). self._storage = [] self.capacity = capacity # The next index to override in the buffer. self._next_idx = 0 self._hit_count = np.zeros(self.capacity) # Whether we have already hit our capacity (and have therefore # started to evict older samples). self._eviction_started = False self._num_timesteps_added = 0 self._num_timesteps_added_wrap = 0 self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self) -> int: return len(self._storage) @DeveloperAPI def add(self, item: SampleBatchType, weight: float) -> None: assert item.count > 0, item warn_replay_capacity(item=item, num_items=self.capacity / item.count) self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item # Wrap around storage as a circular buffer once we hit capacity. if self._num_timesteps_added_wrap >= self.capacity: self._eviction_started = True self._num_timesteps_added_wrap = 0 self._next_idx = 0 else: self._next_idx += 1 if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes: List[int]) -> SampleBatchType: out = SampleBatch.concat_samples([self._storage[i] for i in idxes]) out.decompress_if_needed() return out @DeveloperAPI def sample(self, num_items: int) -> SampleBatchType: """Sample a batch of experiences. Args: num_items (int): Number of items to sample from this buffer. Returns: SampleBatchType: concatenated batch of items. """ idxes = [ random.randint(0, len(self._storage) - 1) for _ in range(num_items) ] self._num_sampled += num_items return self._encode_sample(idxes) @DeveloperAPI def stats(self, debug=False) -> dict: data = { "added_count": self._num_timesteps_added, "added_count_wrapped": self._num_timesteps_added_wrap, "eviction_started": self._eviction_started, "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data @DeveloperAPI def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: Dict[str, Any]: The serializable local state. """ state = {"_storage": self._storage, "_next_idx": self._next_idx} state.update(self.stats(debug=False)) return state @DeveloperAPI def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state (Dict[str, Any]): The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ # The actual storage. self._storage = state["_storage"] self._next_idx = state["_next_idx"] # Stats and counts. self._num_timesteps_added = state["added_count"] self._num_timesteps_added_wrap = state["added_count_wrapped"] self._eviction_started = state["eviction_started"] self._num_timesteps_sampled = state["sampled_count"] self._est_size_bytes = state["est_size_bytes"]
class ReplayBuffer: @DeveloperAPI def __init__(self, size): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. """ self._storage = [] self._maxsize = size self._next_idx = 0 self._hit_count = np.zeros(size) self._eviction_started = False self._num_added = 0 self._num_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self): return len(self._storage) @DeveloperAPI def add(self, obs_t, action, reward, obs_tp1, done, weight): data = (obs_t, action, reward, obs_tp1, done) self._num_added += 1 if self._next_idx >= len(self._storage): self._storage.append(data) self._est_size_bytes += sum(sys.getsizeof(d) for d in data) else: self._storage[self._next_idx] = data if self._next_idx + 1 >= self._maxsize: self._eviction_started = True self._next_idx = (self._next_idx + 1) % self._maxsize if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes): obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] for i in idxes: data = self._storage[i] obs_t, action, reward, obs_tp1, done = data obses_t.append(np.array(unpack_if_needed(obs_t), copy=False)) actions.append(np.array(action, copy=False)) rewards.append(reward) obses_tp1.append(np.array(unpack_if_needed(obs_tp1), copy=False)) dones.append(done) self._hit_count[i] += 1 return (np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)) @DeveloperAPI def sample_idxes(self, batch_size): return np.random.randint(0, len(self._storage), batch_size) @DeveloperAPI def sample_with_idxes(self, idxes): self._num_sampled += len(idxes) return self._encode_sample(idxes) @DeveloperAPI def sample(self, batch_size): """Sample a batch of experiences. Parameters ---------- batch_size: int How many transitions to sample. Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. """ idxes = [ random.randint(0, len(self._storage) - 1) for _ in range(batch_size) ] self._num_sampled += batch_size return self._encode_sample(idxes) @DeveloperAPI def stats(self, debug=False): data = { "added_count": self._num_added, "sampled_count": self._num_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data
class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter, learner_queue_size, learner_queue_timeout): """Initialize the learner thread. Arguments: local_worker (RolloutWorker): process local rollout worker holding policies this thread will call learn_on_batch() on minibatch_buffer_size (int): max number of train batches to store in the minibatching buffer num_sgd_iter (int): number of passes to learn on per train batch learner_queue_size (int): max size of queue of inbound train batches to this thread learner_queue_timeout (int): raise an exception if the queue has been empty for this long in seconds """ threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer(inqueue=self.inqueue, size=minibatch_buffer_size, timeout=learner_queue_timeout, num_passes=num_sgd_iter, init_num_passes=num_sgd_iter) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.stats = {} self.stopped = False self.num_steps = 0 def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats = get_learner_stats(fetches) self.num_steps += 1 self.outqueue.put((batch.count, self.stats)) self.learner_queue_size.push(self.inqueue.qsize()) def add_learner_metrics(self, result): """Add internal metrics to a trainer result dict.""" def timer_to_ms(timer): return round(1000 * timer.mean, 3) result["info"].update({ "learner_queue": self.learner_queue_size.stats(), "learner": copy.deepcopy(self.stats), "timing_breakdown": { "learner_grad_time_ms": timer_to_ms(self.grad_timer), "learner_load_time_ms": timer_to_ms(self.load_timer), "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer), "learner_dequeue_time_ms": timer_to_ms(self.queue_timer), } }) return result
class LearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker: RolloutWorker, minibatch_buffer_size: int, num_sgd_iter: int, learner_queue_size: int, learner_queue_timeout: int): """Initialize the learner thread. Args: local_worker (RolloutWorker): process local rollout worker holding policies this thread will call learn_on_batch() on minibatch_buffer_size (int): max number of train batches to store in the minibatching buffer num_sgd_iter (int): number of passes to learn on per train batch learner_queue_size (int): max size of queue of inbound train batches to this thread learner_queue_timeout (int): raise an exception if the queue has been empty for this long in seconds """ threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer( inqueue=self.inqueue, size=minibatch_buffer_size, timeout=learner_queue_timeout, num_passes=num_sgd_iter, init_num_passes=num_sgd_iter) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.learner_info = {} self.stopped = False self.num_steps = 0 def run(self) -> None: # Switch on eager mode if configured. if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]: tf1.enable_eager_execution() while not self.stopped: self.step() def step(self) -> Optional[_NextValueNotReady]: with self.queue_timer: try: batch, _ = self.minibatch_buffer.get() except queue.Empty: return _NextValueNotReady() with self.grad_timer: # Use LearnerInfoBuilder as a unified way to build the final # results dict from `learn_on_loaded_batch` call(s). # This makes sure results dicts always have the same structure # no matter the setup (multi-GPU, multi-agent, minibatch SGD, # tf vs torch). learner_info_builder = LearnerInfoBuilder(num_devices=1) multi_agent_results = self.local_worker.learn_on_batch(batch) for pid, results in multi_agent_results.items(): learner_info_builder.add_learn_on_batch_results(results, pid) self.learner_info = learner_info_builder.finalize() learner_stats = { pid: info[LEARNER_STATS_KEY] for pid, info in self.learner_info.items() } self.weights_updated = True self.num_steps += 1 self.outqueue.put((batch.count, learner_stats)) self.learner_queue_size.push(self.inqueue.qsize()) def add_learner_metrics(self, result: Dict) -> Dict: """Add internal metrics to a trainer result dict.""" def timer_to_ms(timer): return round(1000 * timer.mean, 3) result["info"].update({ "learner_queue": self.learner_queue_size.stats(), LEARNER_INFO: copy.deepcopy(self.learner_info), "timing_breakdown": { "learner_grad_time_ms": timer_to_ms(self.grad_timer), "learner_load_time_ms": timer_to_ms(self.load_timer), "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer), "learner_dequeue_time_ms": timer_to_ms(self.queue_timer), } }) return result
class CustomValueReplayBuffer(object): """Holds custom keys/values in each batch. The normal rllib Replay Buffer is hard coded as to what it can hold. """ def __init__(self, size, keys_to_types_dict=None, can_pack_list=None): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. """ if keys_to_types_dict is not None: self.keys_to_types_dict = keys_to_types_dict else: # Default Values self.keys_to_types_dict = { "obs": np.array, "actions": np.array, "rewards": float, "new_obs": np.array, "dones": bool } for k in self.keys_to_types_dict: if self.keys_to_types_dict[k] == np.array: self.keys_to_types_dict[k] = lambda x: np.array(x, copy=False) self.expected_keys = sorted(self.keys_to_types_dict.keys()) if can_pack_list is not None: self.can_pack_list = set(can_pack_list) else: self.can_pack_list = {"obs", "new_obs"} self._storage = [] self._maxsize = size self._next_idx = 0 self._hit_count = np.zeros(size) self._num_added = 0 self._num_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self): return len(self._storage) def add(self, **kwargs): assert len(kwargs) == len(self.expected_keys) and sorted( kwargs.keys()) == self.expected_keys for k in kwargs.keys(): if k in self.can_pack_list: kwargs[k] = pack_if_needed(kwargs[k]) data = [kwargs[k] for k in self.expected_keys] if len(self._storage) < self._maxsize: self._storage.append(data) self._est_size_bytes += sum(sys.getsizeof(d) for d in data) else: idx = np.random.randint(0, self._num_added + 1) if idx < self._maxsize: self._storage[idx] = data self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 self._num_added += 1 def _encode_sample(self, idxes): batch = {k: [] for k in self.expected_keys} for i in idxes: data = self._storage[i] for data_item, k in zip(data, self.expected_keys): if k in self.can_pack_list: data_item = unpack_if_needed(data_item) data_item = self.keys_to_types_dict[k](data_item) batch[k].append(data_item) self._hit_count[i] += 1 return batch def sample(self, batch_size): """Sample a batch of experiences. Parameters ---------- batch_size: int How many transitions to sample. Returns ------- batch in dictionary form """ idxes = [ random.randint(0, len(self._storage) - 1) for _ in range(batch_size) ] self._num_sampled += batch_size return self._encode_sample(idxes) def stats(self, debug=False): data = { "added_count": self._num_added, "sampled_count": self._num_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data def clear(self): self.__init__(size=self._maxsize, keys_to_types_dict=self.keys_to_types_dict, can_pack_list=self.can_pack_list) def get_single_epoch_batch_generator(self, batch_size): def gen(): batch_idx = 0 while batch_idx < int(len(self) / batch_size): sample = self.sample(batch_size) batch_idx += 1 yield sample return gen
class SyncLearnerThread(threading.Thread): """Background thread that updates the local model from sample trajectories. PZH: We wish to mock the behavior of a normal PPO pipeline. This mean that, allowing multiple SGD epochs and mini-batching within each SGD epochs. These requirements are not supported by current APPO implementation. The learner thread communicates with the main thread through Queues. This is needed since Ray operations can only be run on the main thread. In addition, moving heavyweight gradient ops session runs off the main thread improves overall throughput. """ def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter, learner_queue_size, learner_queue_timeout, num_gpus, sgd_batch_size): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer( inqueue=self.inqueue, size=1, # size=minibatch_buffer_size, timeout=learner_queue_timeout, # num_sgd_iter=num_sgd_iter, num_sgd_iter=1, init_num_passes=1) self.queue_timer = TimerStat() self.grad_timer = TimerStat() self.load_timer = TimerStat() self.load_wait_timer = TimerStat() self.daemon = True self.weights_updated = False self.stats = {"train_timesteps": 0} self.stopped = False self.num_steps = 0 self.num_sgd_iter = num_sgd_iter # ===== Copied the initialization in multi_gpu_optimizer if not num_gpus: self.devices = ["/cpu:0"] else: self.devices = [ "/gpu:{}".format(i) for i in range(int(math.ceil(num_gpus))) ] self.batch_size = int(sgd_batch_size / len(self.devices)) * len( self.devices) assert self.batch_size % len(self.devices) == 0 assert self.batch_size >= len(self.devices), "batch size too small" self.per_device_batch_size = int(self.batch_size / len(self.devices)) self.policies = dict( local_worker.foreach_trainable_policy(lambda p, i: (i, p))) self.optimizers = {} with local_worker.tf_sess.graph.as_default(): with local_worker.tf_sess.as_default(): for policy_id, policy in self.policies.items(): with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE): if policy._state_inputs: rnn_inputs = policy._state_inputs + [ policy._seq_lens ] else: rnn_inputs = [] self.optimizers[policy_id] = ( LocalSyncParallelOptimizer( policy._optimizer, self.devices, [v for _, v in policy._loss_inputs], rnn_inputs, self.per_device_batch_size, policy.copy)) self.sess = local_worker.tf_sess self.sess.run(tf.global_variables_initializer()) def run(self): while not self.stopped: self.step() def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) # TODO maybe we should do the normalization here num_loaded_tuples = {} with self.load_timer: for policy_id, batch in batch.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] policy._debug_vars() tuples = policy._get_loss_inputs_dict(batch, shuffle=True) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) # Not support multiagent recording now. self.stats.update(fetches["default_policy"]) self.stats["train_timesteps"] += tuples_per_device self.num_steps += 1 self.stats["update_steps"] = self.num_steps self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True if self.minibatch_buffer.is_empty(): # Send signal to optimizer self.outqueue.put(None)
class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._prio_change_stats = WindowStat("reprio", 1000) def add(self, obs_t, action, reward, obs_tp1, done, weight): """See ReplayBuffer.store_effect""" idx = self._next_idx super(PrioritizedReplayBuffer, self).add(obs_t, action, reward, obs_tp1, done, weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta > 0 self._num_sampled += batch_size idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) def stats(self, debug=False): parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) return parent
class ReplayBuffer(object): def __init__(self, size): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. """ self._storage = [] self._maxsize = size self._next_idx = 0 self._hit_count = np.zeros(size) self._eviction_started = False self._num_added = 0 self._num_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self): return len(self._storage) def add(self, obs_t, action, reward, obs_tp1, done, weight): data = (obs_t, action, reward, obs_tp1, done) self._num_added += 1 if self._next_idx >= len(self._storage): self._storage.append(data) self._est_size_bytes += sum(sys.getsizeof(d) for d in data) else: self._storage[self._next_idx] = data if self._next_idx + 1 >= self._maxsize: self._eviction_started = True self._next_idx = (self._next_idx + 1) % self._maxsize if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes): obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] for i in idxes: data = self._storage[i] obs_t, action, reward, obs_tp1, done = data obses_t.append(np.array(unpack_if_needed(obs_t), copy=False)) actions.append(np.array(action, copy=False)) rewards.append(reward) obses_tp1.append(np.array(unpack_if_needed(obs_tp1), copy=False)) dones.append(done) self._hit_count[i] += 1 return (np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)) def sample(self, batch_size): """Sample a batch of experiences. Parameters ---------- batch_size: int How many transitions to sample. Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. """ idxes = [ random.randint(0, len(self._storage) - 1) for _ in range(batch_size) ] self._num_sampled += batch_size return self._encode_sample(idxes) def stats(self, debug=False): data = { "added_count": self._num_added, "sampled_count": self._num_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data
class ListReplayBuffer: """Replay buffer as a list of tuples. Returns a SampleBatch object when queried for samples. Args: size: max number of transitions to store in the buffer. When the buffer overflows, the old memories are dropped. Attributes: fields (:obj:`tuple` of :obj:`ReplayField`): storage fields specification """ # pylint:disable=too-many-instance-attributes def __init__(self, size: int): self._storage = [] self._maxsize = size self._next_idx = 0 self._hit_count = np.zeros(size) self._eviction_started = False self._num_added = 0 self._num_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 self.fields = ( ReplayField(SampleBatch.CUR_OBS), ReplayField(SampleBatch.ACTIONS), ReplayField(SampleBatch.REWARDS), ReplayField(SampleBatch.NEXT_OBS), ReplayField(SampleBatch.DONES), ) def __len__(self): return len(self._storage) def add_fields(self, *fields: ReplayField): """Add fields to the replay buffer and build the corresponding storage.""" new_names = {f.name for f in fields} assert len(new_names) == len(fields), "Field names must be unique" conflicts = new_names.intersection({f.name for f in self.fields}) assert not conflicts, f"{conflicts} are already in buffer" self.fields = self.fields + fields def add(self, row: dict): # pylint:disable=arguments-differ """Add a row from a SampleBatch to storage. Args: row: sample batch row as returned by SampleBatch.rows """ data = tuple(row[f.name] for f in self.fields) self._num_added += 1 if self._next_idx >= len(self._storage): self._storage.append(data) self._est_size_bytes += sum(sys.getsizeof(d) for d in data) else: self._storage[self._next_idx] = data if self._next_idx + 1 >= self._maxsize: self._eviction_started = True self._next_idx = (self._next_idx + 1) % self._maxsize if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes): sample = [] for i in idxes: sample.append(self._storage[i]) self._hit_count[i] += 1 obses_t, actions, rewards, obses_tp1, dones, *extras = zip(*sample) obses_t = [np.array(unpack_if_needed(o), copy=False) for o in obses_t] actions = [np.array(a, copy=False) for a in actions] obses_tp1 = [ np.array(unpack_if_needed(o), copy=False) for o in obses_tp1 ] return tuple( map(np.array, [obses_t, actions, rewards, obses_tp1, dones] + extras)) def sample(self, batch_size: int) -> SampleBatch: """Sample a batch of experiences. Args: batch_size: How many transitions to sample Returns: A sample batch of roughly decorrelated transitions """ idxes = random.choices(range(len(self._storage)), k=batch_size) return self.sample_with_idxes(idxes) def sample_with_idxes(self, idxes: np.ndarray) -> SampleBatch: """Sample a batch of experiences corresponding to the given indexes.""" self._num_sampled += len(idxes) data = self._encode_sample(idxes) return SampleBatch(dict(zip([f.name for f in self.fields], data))) def all_samples(self) -> SampleBatch: """All transitions stored in buffer.""" return self.sample_with_idxes(range(len(self))) def stats(self, debug=False): """Returns a dictionary of usage statistics.""" data = { "added_count": self._num_added, "sampled_count": self._num_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data
class PrioritizedReplayBuffer(ReplayBuffer): @DeveloperAPI def __init__(self, size, alpha): """Create Prioritized Replay buffer. Parameters ---------- size: int Max number of transitions to store in the buffer. When the buffer overflows the old memories are dropped. alpha: float how much prioritization is used (0 - no prioritization, 1 - full prioritization) See Also -------- ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._prio_change_stats = WindowStat("reprio", 1000) @DeveloperAPI def add(self, obs_t, action, reward, obs_tp1, done, weight): """See ReplayBuffer.store_effect""" idx = self._next_idx super(PrioritizedReplayBuffer, self).add(obs_t, action, reward, obs_tp1, done, weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha def _sample_proportional(self, batch_size): res = [] for _ in range(batch_size): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res @DeveloperAPI def sample_idxes(self, batch_size): return self._sample_proportional(batch_size) @DeveloperAPI def sample_with_idxes(self, idxes, beta): assert beta > 0 self._num_sampled += len(idxes) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) @DeveloperAPI def sample(self, batch_size, beta): """Sample a batch of experiences. compared to ReplayBuffer.sample it also returns importance weights and idxes of sampled experiences. Parameters ---------- batch_size: int How many transitions to sample. beta: float To what degree to use importance weights (0 - no corrections, 1 - full correction) Returns ------- obs_batch: np.array batch of observations act_batch: np.array batch of actions executed given obs_batch rew_batch: np.array rewards received as results of executing act_batch next_obs_batch: np.array next set of observations seen after executing act_batch done_mask: np.array done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode and 0 otherwise. weights: np.array Array of shape (batch_size,) and dtype np.float32 denoting importance weight of each sampled transition idxes: np.array Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ assert beta >= 0.0 self._num_sampled += batch_size idxes = self._sample_proportional(batch_size) weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) weights.append(weight / max_weight) weights = np.array(weights) encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) @DeveloperAPI def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) @DeveloperAPI def stats(self, debug=False): parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) return parent
class ReplayBuffer: @DeveloperAPI def __init__(self, size: int): """Create Prioritized Replay buffer. Args: size (int): Max number of items to store in the FIFO buffer. """ self._storage = [] self._maxsize = size self._next_idx = 0 self._hit_count = np.zeros(size) self._eviction_started = False self._num_added = 0 self._num_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 def __len__(self): return len(self._storage) @DeveloperAPI def add(self, item: SampleBatchType, weight: float): assert item.count > 0, item self._num_added += 1 if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item if self._next_idx + 1 >= self._maxsize: self._eviction_started = True self._next_idx = (self._next_idx + 1) % self._maxsize if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes: List[int]) -> SampleBatchType: out = SampleBatch.concat_samples([self._storage[i] for i in idxes]) out.decompress_if_needed() return out @DeveloperAPI def sample(self, num_items: int) -> SampleBatchType: """Sample a batch of experiences. Args: num_items (int): Number of items to sample from this buffer. Returns: SampleBatchType: concatenated batch of items. """ idxes = [ random.randint(0, len(self._storage) - 1) for _ in range(num_items) ] self._num_sampled += num_items return self._encode_sample(idxes) @DeveloperAPI def stats(self, debug=False): data = { "added_count": self._num_added, "sampled_count": self._num_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } if debug: data.update(self._evicted_hit_stats.stats()) return data
class PrioritizedReplayBuffer(ReplayBuffer): @DeveloperAPI def __init__(self, size: int, alpha: float): """Create Prioritized Replay buffer. Args: size (int): Max number of items to store in the FIFO buffer. alpha (float): how much prioritization is used (0 - no prioritization, 1 - full prioritization). See also: ReplayBuffer.__init__() """ super(PrioritizedReplayBuffer, self).__init__(size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < size: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._prio_change_stats = WindowStat("reprio", 1000) @DeveloperAPI def add(self, item: SampleBatchType, weight: float): idx = self._next_idx super(PrioritizedReplayBuffer, self).add(item, weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha def _sample_proportional(self, num_items: int): res = [] for _ in range(num_items): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res @DeveloperAPI def sample(self, num_items: int, beta: float) -> SampleBatchType: """Sample a batch of experiences and return priority weights, indices. Args: num_items (int): Number of items to sample from this buffer. beta (float): To what degree to use importance weights (0 - no corrections, 1 - full correction). Returns: SampleBatchType: Concatenated batch of items including "weights" and "batch_indexes" fields denoting IS of each sampled transition and original idxes in buffer of sampled experiences. """ assert beta >= 0.0 self._num_sampled += num_items idxes = self._sample_proportional(num_items) weights = [] batch_indexes = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) count = self._storage[idx].count weights.extend([weight / max_weight] * count) batch_indexes.extend([idx] * count) batch = self._encode_sample(idxes) # Note: prioritization is not supported in lockstep replay mode. if isinstance(batch, SampleBatch): assert len(weights) == batch.count assert len(batch_indexes) == batch.count batch["weights"] = np.array(weights) batch["batch_indexes"] = np.array(batch_indexes) return batch @DeveloperAPI def update_priorities(self, idxes, priorities): """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) @DeveloperAPI def stats(self, debug=False): parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) return parent
class PrioritizedReplayBuffer(ReplayBuffer): @DeveloperAPI def __init__(self, capacity: int = 10000, alpha: float = 1.0, size: Optional[int] = DEPRECATED_VALUE): """Initializes a PrioritizedReplayBuffer instance. Args: capacity (int): Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. alpha (float): How much prioritization is used (0.0=no prioritization, 1.0=full prioritization). """ super(PrioritizedReplayBuffer, self).__init__(capacity, size) assert alpha > 0 self._alpha = alpha it_capacity = 1 while it_capacity < self.capacity: it_capacity *= 2 self._it_sum = SumSegmentTree(it_capacity) self._it_min = MinSegmentTree(it_capacity) self._max_priority = 1.0 self._prio_change_stats = WindowStat("reprio", 1000) @DeveloperAPI @override(ReplayBuffer) def add(self, item: SampleBatchType, weight: float) -> None: idx = self._next_idx super(PrioritizedReplayBuffer, self).add(item, weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha def _sample_proportional(self, num_items: int) -> List[int]: res = [] for _ in range(num_items): # TODO(szymon): should we ensure no repeats? mass = random.random() * self._it_sum.sum(0, len(self._storage)) idx = self._it_sum.find_prefixsum_idx(mass) res.append(idx) return res @DeveloperAPI @override(ReplayBuffer) def sample(self, num_items: int, beta: float) -> SampleBatchType: """Sample a batch of experiences and return priority weights, indices. Args: num_items (int): Number of items to sample from this buffer. beta (float): To what degree to use importance weights (0 - no corrections, 1 - full correction). Returns: SampleBatchType: Concatenated batch of items including "weights" and "batch_indexes" fields denoting IS of each sampled transition and original idxes in buffer of sampled experiences. """ assert beta >= 0.0 idxes = self._sample_proportional(num_items) weights = [] batch_indexes = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() weight = (p_sample * len(self._storage))**(-beta) count = self._storage[idx].count # If zero-padded, count will not be the actual batch size of the # data. if isinstance(self._storage[idx], SampleBatch) and \ self._storage[idx].zero_padded: actual_size = self._storage[idx].max_seq_len else: actual_size = count weights.extend([weight / max_weight] * actual_size) batch_indexes.extend([idx] * actual_size) self._num_timesteps_sampled += count batch = self._encode_sample(idxes) # Note: prioritization is not supported in lockstep replay mode. if isinstance(batch, SampleBatch): batch["weights"] = np.array(weights) batch["batch_indexes"] = np.array(batch_indexes) return batch @DeveloperAPI def update_priorities(self, idxes: List[int], priorities: List[float]) -> None: """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer to priorities[i]. Parameters ---------- idxes: [int] List of idxes of sampled transitions priorities: [float] List of updated priorities corresponding to transitions at the sampled idxes denoted by variable `idxes`. """ # Making sure we don't pass in e.g. a torch tensor. assert isinstance(idxes, (list, np.ndarray)), \ "ERROR: `idxes` is not a list or np.ndarray, but " \ "{}!".format(type(idxes).__name__) assert len(idxes) == len(priorities) for idx, priority in zip(idxes, priorities): assert priority > 0 assert 0 <= idx < len(self._storage) delta = priority**self._alpha - self._it_sum[idx] self._prio_change_stats.push(delta) self._it_sum[idx] = priority**self._alpha self._it_min[idx] = priority**self._alpha self._max_priority = max(self._max_priority, priority) @DeveloperAPI @override(ReplayBuffer) def stats(self, debug: bool = False) -> Dict: parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) return parent @DeveloperAPI @override(ReplayBuffer) def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: Dict[str, Any]: The serializable local state. """ # Get parent state. state = super().get_state() # Add prio weights. state.update({ "sum_segment_tree": self._it_sum.get_state(), "min_segment_tree": self._it_min.get_state(), "max_priority": self._max_priority, }) return state @DeveloperAPI @override(ReplayBuffer) def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state (Dict[str, Any]): The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ super().set_state(state) self._it_sum.set_state(state["sum_segment_tree"]) self._it_min.set_state(state["min_segment_tree"]) self._max_priority = state["max_priority"]