def collect_episode(self): """Gather fragments from all in-progress episodes. Returns: EpisodeBatch: A batch of the episode fragments. """ for i, frag in enumerate(self._fragments): assert frag.env is self._envs[i] if len(frag.rewards) > 0: complete_frag = frag.to_batch() self._complete_fragments.append(complete_frag) self._fragments[i] = InProgressEpisode(frag.env, frag.last_obs, frag.episode_info) assert len(self._complete_fragments) > 0 result = EpisodeBatch.concatenate(*self._complete_fragments) self._complete_fragments = [] return result
def obtain_samples(self, num_samples, agent_update, env_updates=None): """Sample the policy for new episodes. Args: num_samples (int): Number of steps the the sampler should collect. agent_update (object): Value which will be passed into the `agent_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_updates (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: EpisodeBatch: Batch of gathered episodes. """ self.update_workers(agent_update, env_updates) completed_samples = 0 batches = [] # TODO: can we replace the while, so all processes are scheduled beforehand? while completed_samples < num_samples: pids = [w.rollout.remote() for w in self.workers] results = [ray.get(pid) for pid in pids] for episode_batch in results: num_returned_samples = episode_batch.lengths.sum() completed_samples += num_returned_samples batches.append(episode_batch) # Note: EpisodeBatch takes care of concatenating - is this a performance issue? samples = EpisodeBatch.concatenate(*batches) self.total_env_steps += sum(samples.lengths) return samples
def obtain_exact_episodes(self, n_eps_per_worker, agent_update, env_update=None): """Sample an exact number of episodes per worker. Args: n_eps_per_worker (int): Exact number of episodes to gather for each worker. agent_update (object): Value which will be passed into the `agent_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: EpisodeBatch: Batch of gathered episodes. Always in worker order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc. Raises: AssertionError: On internal errors. """ self._agent_version += 1 updated_workers = set() agent_ups = self._factory.prepare_worker_messages( agent_update, cloudpickle.dumps) env_ups = self._factory.prepare_worker_messages(env_update) episodes = defaultdict(list) with click.progressbar(length=self._factory.n_workers, label='Sampling') as pbar: while any( len(episodes[i]) < n_eps_per_worker for i in range(self._factory.n_workers)): self._push_updates(updated_workers, agent_ups, env_ups) tag, contents = self._to_sampler.get() if tag == 'episode': batch, version, worker_n = contents if version == self._agent_version: if len(episodes[worker_n]) < n_eps_per_worker: episodes[worker_n].append(batch) if len(episodes[worker_n]) == n_eps_per_worker: pbar.update(1) try: self._to_worker[worker_n].put_nowait( ('stop', ())) except queue.Full: pass else: raise AssertionError( 'Unknown tag {} with contents {}'.format( tag, contents)) for q in self._to_worker: try: q.put_nowait(('stop', ())) except queue.Full: pass ordered_episodes = list( itertools.chain( *[episodes[i] for i in range(self._factory.n_workers)])) samples = EpisodeBatch.concatenate(*ordered_episodes) self.total_env_steps += sum(samples.lengths) return samples
def obtain_samples(self, itr, num_samples, agent_update, env_update=None): """Collect at least a given number transitions (timesteps). Args: itr(int): The current iteration number. Using this argument is deprecated. num_samples (int): Minimum number of transitions / timesteps to sample. agent_update (object): Value which will be passed into the `agent_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: EpisodeBatch: The batch of collected episodes. Raises: AssertionError: On internal errors. """ del itr batches = [] completed_samples = 0 self._agent_version += 1 updated_workers = set() agent_ups = self._factory.prepare_worker_messages( agent_update, cloudpickle.dumps) env_ups = self._factory.prepare_worker_messages(env_update) with click.progressbar(length=num_samples, label='Sampling') as pbar: while completed_samples < num_samples: self._push_updates(updated_workers, agent_ups, env_ups) for _ in range(self._factory.n_workers): try: tag, contents = self._to_sampler.get_nowait() if tag == 'episode': batch, version, worker_n = contents del worker_n if version == self._agent_version: batches.append(batch) num_returned_samples = batch.lengths.sum() completed_samples += num_returned_samples pbar.update(num_returned_samples) else: # Receiving episodes from previous iterations # is normal. Potentially, we could gather them # here, if an off-policy method wants them. pass else: raise AssertionError( 'Unknown tag {} with contents {}'.format( tag, contents)) except queue.Empty: pass for q in self._to_worker: try: q.put_nowait(('stop', ())) except queue.Full: pass samples = EpisodeBatch.concatenate(*batches) self.total_env_steps += sum(samples.lengths) return samples
def log_multitask_performance(itr, batch, discount, name_map=None, use_wandb=True): r"""Log performance of episodes from multiple tasks. Args: itr (int): Iteration number to be logged. batch (EpisodeBatch): Batch of episodes. The episodes should have either the "task_name" or "task_id" `env_infos`. If the "task_name" is not present, then `name_map` is required, and should map from task id's to task names. discount (float): Discount used in computing returns. name_map (dict[int, str] or None): Mapping from task id"s to task names. Optional if the "task_name" environment info is present. Note that if provided, all tasks listed in this map will be logged, even if there are no episodes present for them. Returns: numpy.ndarray: Undiscounted returns averaged across all tasks. Has shape :math:`(N \bullet [T])`. """ eps_by_name = defaultdict(list) for eps in batch.split(): task_name = "__unnamed_task__" if "task_name" in eps.env_infos: task_name = eps.env_infos["task_name"][0] elif "task_id" in eps.env_infos: name_map = {} if name_map is None else name_map task_id = eps.env_infos["task_id"][0] task_name = name_map.get(task_id, "Task #{}".format(task_id)) eps_by_name[task_name].append(eps) if name_map is None: task_names = eps_by_name.keys() else: task_names = name_map.values() consolidated_log = dict() for task_name in task_names: if task_name in eps_by_name: episodes = eps_by_name[task_name] # analyze statistics per task _, task_log = log_performance( itr, EpisodeBatch.concatenate(*episodes), discount, prefix=task_name, # specific to task use_wandb=use_wandb ) if task_log is not None: consolidated_log.update(task_log) else: with tabular.prefix(task_name + "/"): tabular.record("Iteration", itr) tabular.record("NumEpisodes", 0) tabular.record("AverageDiscountedReturn", np.nan) tabular.record("AverageReturn", np.nan) tabular.record("StdReturn", np.nan) tabular.record("MaxReturn", np.nan) tabular.record("MinReturn", np.nan) tabular.record("TerminationRate", np.nan) tabular.record("SuccessRate", np.nan) undiscounted_returns, average_log = log_performance( itr, batch, discount=discount, prefix="Average", use_wandb=use_wandb ) if average_log is not None: consolidated_log.update(average_log) return undiscounted_returns, consolidated_log
def log_multitask_performance( itr, batch, discount, name_map=None, log_per_task=False, ): r"""Log performance of episodes from multiple tasks. Args: itr (int): Iteration number to be logged. batch (EpisodeBatch): Batch of episodes. The episodes should have either the "task_name" or "task_id" `env_infos`. If the "task_name" is not present, then `name_map` is required, and should map from task id's to task names. discount (float): Discount used in computing returns. name_map (dict[int, str] or None): Mapping from task id"s to task names. Optional if the "task_name" environment info is present. Note that if provided, all tasks listed in this map will be logged, even if there are no episodes present for them. Returns: numpy.ndarray: Undiscounted returns averaged across all tasks. Has shape :math:`(N \bullet [T])`. """ # Create log_dict with the averages undiscounted_returns, consolidated_log = log_performance(itr, batch, discount=discount, prefix="Average") # Add results by task to task _dict, if requested if log_per_task: eps_by_name = defaultdict(list) for eps in batch.split(): task_name = "__unnamed_task__" if "task_name" in eps.env_infos: task_name = eps.env_infos["task_name"][0] elif "task_id" in eps.env_infos: name_map = {} if name_map is None else name_map task_id = eps.env_infos["task_id"][0] task_name = name_map.get(task_id, "Task #{}".format(task_id)) eps_by_name[task_name].append(eps) if name_map is None: task_names = eps_by_name.keys() else: task_names = name_map.values() for task_name in task_names: if task_name in eps_by_name: episodes = eps_by_name[task_name] # analyze statistics per task _, task_log = log_performance( itr, EpisodeBatch.concatenate(*episodes), discount, prefix=task_name, # specific to task ) consolidated_log.update(task_log) return undiscounted_returns, consolidated_log
def obtain_exact_episodes(self, n_eps_per_worker, agent_update, env_update=None): """Sample an exact number of episodes per worker. Args: n_eps_per_worker (int): Exact number of episodes to gather for each worker. agent_update (object): Value which will be passed into the `agent_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: EpisodeBatch: Batch of gathered episodes. Always in worker order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc. """ active_workers = [] episodes = defaultdict(list) # update the policy params of each worker before sampling # for the current iteration idle_worker_ids = [] updating_workers = self._update_workers(agent_update, env_update) with click.progressbar(length=self._worker_factory.n_workers, label='Sampling') as pbar: while any( len(episodes[i]) < n_eps_per_worker for i in range(self._worker_factory.n_workers)): # if there are workers still being updated, check # which ones are still updating and take the workers that # are done updating, and start collecting episodes on # those workers. if updating_workers: updated, updating_workers = ray.wait(updating_workers, num_returns=1, timeout=0.1) upd = [ray.get(up) for up in updated] idle_worker_ids.extend(upd) # if there are idle workers, use them to collect episodes # mark the newly busy workers as active while idle_worker_ids: idle_worker_id = idle_worker_ids.pop() worker = self._all_workers[idle_worker_id] active_workers.append(worker.rollout.remote()) # check which workers are done/not done collecting a sample # if any are done, send them to process the collected episode # if they are not, keep checking if they are done ready, not_ready = ray.wait(active_workers, num_returns=1, timeout=0.001) active_workers = not_ready for result in ready: ready_worker_id, episode_batch = ray.get(result) episodes[ready_worker_id].append(episode_batch) if len(episodes[ready_worker_id]) < n_eps_per_worker: idle_worker_ids.append(ready_worker_id) pbar.update(1) ordered_episodes = list( itertools.chain( *[episodes[i] for i in range(self._worker_factory.n_workers)])) return EpisodeBatch.concatenate(*ordered_episodes)
def obtain_samples(self, itr, num_samples, agent_update, env_update=None): """Sample the policy for new episodes. Args: itr (int): Iteration number. num_samples (int): Number of steps the the sampler should collect. agent_update (object): Value which will be passed into the `agent_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: EpisodeBatch: Batch of gathered episodes. """ active_workers = [] completed_samples = 0 batches = [] # update the policy params of each worker before sampling # for the current iteration idle_worker_ids = [] updating_workers = self._update_workers(agent_update, env_update) with click.progressbar(length=num_samples, label='Sampling') as pbar: while completed_samples < num_samples: # if there are workers still being updated, check # which ones are still updating and take the workers that # are done updating, and start collecting episodes on those # workers. if updating_workers: updated, updating_workers = ray.wait(updating_workers, num_returns=1, timeout=0.1) upd = [ray.get(up) for up in updated] idle_worker_ids.extend(upd) # if there are idle workers, use them to collect episodes and # mark the newly busy workers as active while idle_worker_ids: idle_worker_id = idle_worker_ids.pop() worker = self._all_workers[idle_worker_id] active_workers.append(worker.rollout.remote()) # check which workers are done/not done collecting a sample # if any are done, send them to process the collected # episode if they are not, keep checking if they are done ready, not_ready = ray.wait(active_workers, num_returns=1, timeout=0.001) active_workers = not_ready for result in ready: ready_worker_id, episode_batch = ray.get(result) idle_worker_ids.append(ready_worker_id) num_returned_samples = episode_batch.lengths.sum() completed_samples += num_returned_samples batches.append(episode_batch) pbar.update(num_returned_samples) return EpisodeBatch.concatenate(*batches)
def obtain_exact_episodes( self, n_eps_per_worker, agent_update, collect_hook_data=False, env_updates=None ): """Sample an exact number of episodes per worker. Args: n_eps_per_worker (int): Exact number of episodes to gather for each worker. agent_update (object): Value which will be passed into the `agent_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. env_update (object): Value which will be passed into the `env_update_fn` before sampling episodes. If a list is passed in, it must have length exactly `factory.n_workers`, and will be spread across the workers. Returns: EpisodeBatch: Batch of gathered episodes. Always in worker order. In other words, first all episodes from worker 0, then all episodes from worker 1, etc. """ self.update_workers_eval(agent_update, env_updates, eval_mode=True) # adjust n_eps_per_worker to account for the number of workers available per env assert n_eps_per_worker % self.workers_per_env == 0, \ "Number of eps per worker should be a multiple of workers per env" n_eps_per_worker = int(n_eps_per_worker / self.workers_per_env) # only include hook data if hook has been attached. if(hasattr(agent_update, "collect_hook_data")): collect_hook_data = True episodes = defaultdict(list) data_to_export = defaultdict(dict) def update_eval_results(results): for worker_id, (episode_batch, hook_data) in enumerate(results): episodes[worker_id].append(episode_batch) if hook_data is not None: # Loop through hooks, allow multiple hooks for hook, data in hook_data.items(): # For each hook save data for each task separately # allowing for several episodes to be saved at once if worker_id not in data_to_export[hook]: data_to_export[hook][worker_id] = data else: data_to_export[hook][worker_id] = torch.cat( [data_to_export[hook][worker_id], data], dim=0 ) for _ in range(n_eps_per_worker): episode_results = [ worker.rollout_eval(collect_hook_data=collect_hook_data) for worker in self.eval_workers ] update_eval_results(episode_results) # Note: do they need to be ordered? ordered_episodes = list(chain( *[episodes[i] for i in range(len(self.eval_workers))] )) samples = EpisodeBatch.concatenate(*ordered_episodes) # concat self.total_env_steps += sum(samples.lengths) return samples, data_to_export