Ejemplo n.º 1
0
    def collect_episode(self):
        """Gather fragments from all in-progress episodes.

        Returns:
            EpisodeBatch: A batch of the episode fragments.

        """
        for i, frag in enumerate(self._fragments):
            assert frag.env is self._envs[i]
            if len(frag.rewards) > 0:
                complete_frag = frag.to_batch()
                self._complete_fragments.append(complete_frag)
                self._fragments[i] = InProgressEpisode(frag.env, frag.last_obs,
                                                       frag.episode_info)
        assert len(self._complete_fragments) > 0
        result = EpisodeBatch.concatenate(*self._complete_fragments)
        self._complete_fragments = []
        return result
Ejemplo n.º 2
0
    def obtain_samples(self, num_samples, agent_update, env_updates=None):
        """Sample the policy for new episodes.

        Args:
            num_samples (int): Number of steps the the sampler should collect.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_updates (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes.

        """

        self.update_workers(agent_update, env_updates)
        completed_samples = 0
        batches = []

        # TODO: can we replace the while, so all processes are scheduled beforehand?
        while completed_samples < num_samples:
            pids = [w.rollout.remote() for w in self.workers]
            results = [ray.get(pid) for pid in pids]
            for episode_batch in results:
                num_returned_samples = episode_batch.lengths.sum()
                completed_samples += num_returned_samples
                batches.append(episode_batch)

        # Note: EpisodeBatch takes care of concatenating - is this a performance issue?
        samples = EpisodeBatch.concatenate(*batches)
        self.total_env_steps += sum(samples.lengths)
        return samples
Ejemplo n.º 3
0
    def obtain_exact_episodes(self,
                              n_eps_per_worker,
                              agent_update,
                              env_update=None):
        """Sample an exact number of episodes per worker.

        Args:
            n_eps_per_worker (int): Exact number of episodes to gather for
                each worker.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes. Always in worker
                order. In other words, first all episodes from worker 0,
                then all episodes from worker 1, etc.

        Raises:
            AssertionError: On internal errors.

        """
        self._agent_version += 1
        updated_workers = set()
        agent_ups = self._factory.prepare_worker_messages(
            agent_update, cloudpickle.dumps)
        env_ups = self._factory.prepare_worker_messages(env_update)
        episodes = defaultdict(list)

        with click.progressbar(length=self._factory.n_workers,
                               label='Sampling') as pbar:
            while any(
                    len(episodes[i]) < n_eps_per_worker
                    for i in range(self._factory.n_workers)):
                self._push_updates(updated_workers, agent_ups, env_ups)
                tag, contents = self._to_sampler.get()

                if tag == 'episode':
                    batch, version, worker_n = contents

                    if version == self._agent_version:
                        if len(episodes[worker_n]) < n_eps_per_worker:
                            episodes[worker_n].append(batch)

                        if len(episodes[worker_n]) == n_eps_per_worker:
                            pbar.update(1)
                            try:
                                self._to_worker[worker_n].put_nowait(
                                    ('stop', ()))
                            except queue.Full:
                                pass
                else:
                    raise AssertionError(
                        'Unknown tag {} with contents {}'.format(
                            tag, contents))

            for q in self._to_worker:
                try:
                    q.put_nowait(('stop', ()))
                except queue.Full:
                    pass

        ordered_episodes = list(
            itertools.chain(
                *[episodes[i] for i in range(self._factory.n_workers)]))
        samples = EpisodeBatch.concatenate(*ordered_episodes)
        self.total_env_steps += sum(samples.lengths)
        return samples
Ejemplo n.º 4
0
    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Collect at least a given number transitions (timesteps).

        Args:
            itr(int): The current iteration number. Using this argument is
                deprecated.
            num_samples (int): Minimum number of transitions / timesteps to
                sample.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: The batch of collected episodes.

        Raises:
            AssertionError: On internal errors.

        """
        del itr
        batches = []
        completed_samples = 0
        self._agent_version += 1
        updated_workers = set()
        agent_ups = self._factory.prepare_worker_messages(
            agent_update, cloudpickle.dumps)
        env_ups = self._factory.prepare_worker_messages(env_update)

        with click.progressbar(length=num_samples, label='Sampling') as pbar:
            while completed_samples < num_samples:
                self._push_updates(updated_workers, agent_ups, env_ups)
                for _ in range(self._factory.n_workers):
                    try:
                        tag, contents = self._to_sampler.get_nowait()
                        if tag == 'episode':
                            batch, version, worker_n = contents
                            del worker_n
                            if version == self._agent_version:
                                batches.append(batch)
                                num_returned_samples = batch.lengths.sum()
                                completed_samples += num_returned_samples
                                pbar.update(num_returned_samples)
                            else:
                                # Receiving episodes from previous iterations
                                # is normal.  Potentially, we could gather them
                                # here, if an off-policy method wants them.
                                pass
                        else:
                            raise AssertionError(
                                'Unknown tag {} with contents {}'.format(
                                    tag, contents))
                    except queue.Empty:
                        pass
            for q in self._to_worker:
                try:
                    q.put_nowait(('stop', ()))
                except queue.Full:
                    pass

        samples = EpisodeBatch.concatenate(*batches)
        self.total_env_steps += sum(samples.lengths)
        return samples
Ejemplo n.º 5
0
def log_multitask_performance(itr, batch, discount, name_map=None, use_wandb=True):
    r"""Log performance of episodes from multiple tasks.

    Args:
        itr (int): Iteration number to be logged.
        batch (EpisodeBatch): Batch of episodes. The episodes should have
            either the "task_name" or "task_id" `env_infos`. If the "task_name"
            is not present, then `name_map` is required, and should map from
            task id's to task names.
        discount (float): Discount used in computing returns.
        name_map (dict[int, str] or None): Mapping from task id"s to task
            names. Optional if the "task_name" environment info is present.
            Note that if provided, all tasks listed in this map will be logged,
            even if there are no episodes present for them.

    Returns:
        numpy.ndarray: Undiscounted returns averaged across all tasks. Has
            shape :math:`(N \bullet [T])`.

    """
    eps_by_name = defaultdict(list)

    for eps in batch.split():
        task_name = "__unnamed_task__"
        if "task_name" in eps.env_infos:
            task_name = eps.env_infos["task_name"][0]
        elif "task_id" in eps.env_infos:
            name_map = {} if name_map is None else name_map
            task_id = eps.env_infos["task_id"][0]
            task_name = name_map.get(task_id, "Task #{}".format(task_id))
        eps_by_name[task_name].append(eps)

    if name_map is None:
        task_names = eps_by_name.keys()
    else:
        task_names = name_map.values()

    consolidated_log = dict()
    for task_name in task_names:
        if task_name in eps_by_name:
            episodes = eps_by_name[task_name]
            # analyze statistics per task
            _, task_log = log_performance(
                itr,
                EpisodeBatch.concatenate(*episodes),
                discount,
                prefix=task_name,  # specific to task
                use_wandb=use_wandb
            )
            if task_log is not None:
                consolidated_log.update(task_log)
        else:
            with tabular.prefix(task_name + "/"):
                tabular.record("Iteration", itr)
                tabular.record("NumEpisodes", 0)
                tabular.record("AverageDiscountedReturn", np.nan)
                tabular.record("AverageReturn", np.nan)
                tabular.record("StdReturn", np.nan)
                tabular.record("MaxReturn", np.nan)
                tabular.record("MinReturn", np.nan)
                tabular.record("TerminationRate", np.nan)
                tabular.record("SuccessRate", np.nan)

    undiscounted_returns, average_log = log_performance(
        itr, batch, discount=discount, prefix="Average", use_wandb=use_wandb
    )
    if average_log is not None:
        consolidated_log.update(average_log)

    return undiscounted_returns, consolidated_log
Ejemplo n.º 6
0
def log_multitask_performance(
    itr,
    batch,
    discount,
    name_map=None,
    log_per_task=False,
):
    r"""Log performance of episodes from multiple tasks.

    Args:
        itr (int): Iteration number to be logged.
        batch (EpisodeBatch): Batch of episodes. The episodes should have
            either the "task_name" or "task_id" `env_infos`. If the "task_name"
            is not present, then `name_map` is required, and should map from
            task id's to task names.
        discount (float): Discount used in computing returns.
        name_map (dict[int, str] or None): Mapping from task id"s to task
            names. Optional if the "task_name" environment info is present.
            Note that if provided, all tasks listed in this map will be logged,
            even if there are no episodes present for them.

    Returns:
        numpy.ndarray: Undiscounted returns averaged across all tasks. Has
            shape :math:`(N \bullet [T])`.

    """

    # Create log_dict with the averages
    undiscounted_returns, consolidated_log = log_performance(itr,
                                                             batch,
                                                             discount=discount,
                                                             prefix="Average")

    # Add results by task to task _dict, if requested
    if log_per_task:
        eps_by_name = defaultdict(list)
        for eps in batch.split():
            task_name = "__unnamed_task__"
            if "task_name" in eps.env_infos:
                task_name = eps.env_infos["task_name"][0]
            elif "task_id" in eps.env_infos:
                name_map = {} if name_map is None else name_map
                task_id = eps.env_infos["task_id"][0]
                task_name = name_map.get(task_id, "Task #{}".format(task_id))
            eps_by_name[task_name].append(eps)

        if name_map is None:
            task_names = eps_by_name.keys()
        else:
            task_names = name_map.values()

        for task_name in task_names:
            if task_name in eps_by_name:
                episodes = eps_by_name[task_name]
                # analyze statistics per task
                _, task_log = log_performance(
                    itr,
                    EpisodeBatch.concatenate(*episodes),
                    discount,
                    prefix=task_name,  # specific to task
                )
                consolidated_log.update(task_log)

    return undiscounted_returns, consolidated_log
Ejemplo n.º 7
0
    def obtain_exact_episodes(self,
                              n_eps_per_worker,
                              agent_update,
                              env_update=None):
        """Sample an exact number of episodes per worker.

        Args:
            n_eps_per_worker (int): Exact number of episodes to gather for
                each worker.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes. Always in worker
                order. In other words, first all episodes from worker 0, then
                all episodes from worker 1, etc.

        """
        active_workers = []
        episodes = defaultdict(list)

        # update the policy params of each worker before sampling
        # for the current iteration
        idle_worker_ids = []
        updating_workers = self._update_workers(agent_update, env_update)

        with click.progressbar(length=self._worker_factory.n_workers,
                               label='Sampling') as pbar:
            while any(
                    len(episodes[i]) < n_eps_per_worker
                    for i in range(self._worker_factory.n_workers)):
                # if there are workers still being updated, check
                # which ones are still updating and take the workers that
                # are done updating, and start collecting episodes on
                # those workers.
                if updating_workers:
                    updated, updating_workers = ray.wait(updating_workers,
                                                         num_returns=1,
                                                         timeout=0.1)
                    upd = [ray.get(up) for up in updated]
                    idle_worker_ids.extend(upd)

                # if there are idle workers, use them to collect episodes
                # mark the newly busy workers as active
                while idle_worker_ids:
                    idle_worker_id = idle_worker_ids.pop()
                    worker = self._all_workers[idle_worker_id]
                    active_workers.append(worker.rollout.remote())

                # check which workers are done/not done collecting a sample
                # if any are done, send them to process the collected episode
                # if they are not, keep checking if they are done
                ready, not_ready = ray.wait(active_workers,
                                            num_returns=1,
                                            timeout=0.001)
                active_workers = not_ready
                for result in ready:
                    ready_worker_id, episode_batch = ray.get(result)
                    episodes[ready_worker_id].append(episode_batch)

                    if len(episodes[ready_worker_id]) < n_eps_per_worker:
                        idle_worker_ids.append(ready_worker_id)

                    pbar.update(1)

        ordered_episodes = list(
            itertools.chain(
                *[episodes[i] for i in range(self._worker_factory.n_workers)]))

        return EpisodeBatch.concatenate(*ordered_episodes)
Ejemplo n.º 8
0
    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Sample the policy for new episodes.

        Args:
            itr (int): Iteration number.
            num_samples (int): Number of steps the the sampler should collect.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes.

        """
        active_workers = []
        completed_samples = 0
        batches = []

        # update the policy params of each worker before sampling
        # for the current iteration
        idle_worker_ids = []
        updating_workers = self._update_workers(agent_update, env_update)

        with click.progressbar(length=num_samples, label='Sampling') as pbar:
            while completed_samples < num_samples:
                # if there are workers still being updated, check
                # which ones are still updating and take the workers that
                # are done updating, and start collecting episodes on those
                # workers.
                if updating_workers:
                    updated, updating_workers = ray.wait(updating_workers,
                                                         num_returns=1,
                                                         timeout=0.1)
                    upd = [ray.get(up) for up in updated]
                    idle_worker_ids.extend(upd)

                # if there are idle workers, use them to collect episodes and
                # mark the newly busy workers as active
                while idle_worker_ids:
                    idle_worker_id = idle_worker_ids.pop()
                    worker = self._all_workers[idle_worker_id]
                    active_workers.append(worker.rollout.remote())

                # check which workers are done/not done collecting a sample
                # if any are done, send them to process the collected
                # episode if they are not, keep checking if they are done
                ready, not_ready = ray.wait(active_workers,
                                            num_returns=1,
                                            timeout=0.001)
                active_workers = not_ready
                for result in ready:
                    ready_worker_id, episode_batch = ray.get(result)
                    idle_worker_ids.append(ready_worker_id)
                    num_returned_samples = episode_batch.lengths.sum()
                    completed_samples += num_returned_samples
                    batches.append(episode_batch)
                    pbar.update(num_returned_samples)

        return EpisodeBatch.concatenate(*batches)
Ejemplo n.º 9
0
    def obtain_exact_episodes(
        self,
        n_eps_per_worker,
        agent_update,
        collect_hook_data=False,
        env_updates=None
    ):
        """Sample an exact number of episodes per worker.

        Args:
            n_eps_per_worker (int): Exact number of episodes to gather for
                each worker.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes. Always in worker
                order. In other words, first all episodes from worker 0, then
                all episodes from worker 1, etc.

        """

        self.update_workers_eval(agent_update, env_updates, eval_mode=True)

        # adjust n_eps_per_worker to account for the number of workers available per env
        assert n_eps_per_worker % self.workers_per_env == 0, \
            "Number of eps per worker should be a multiple of workers per env"
        n_eps_per_worker = int(n_eps_per_worker / self.workers_per_env)

        # only include hook data if hook has been attached.
        if(hasattr(agent_update, "collect_hook_data")):
            collect_hook_data = True

        episodes = defaultdict(list)
        data_to_export = defaultdict(dict)

        def update_eval_results(results):
            for worker_id, (episode_batch, hook_data) in enumerate(results):
                episodes[worker_id].append(episode_batch)
                if hook_data is not None:
                    # Loop through hooks, allow multiple hooks
                    for hook, data in hook_data.items():
                        # For each hook save data for each task separately
                        # allowing for several episodes to be saved at once
                        if worker_id not in data_to_export[hook]:
                            data_to_export[hook][worker_id] = data
                        else:
                            data_to_export[hook][worker_id] = torch.cat(
                                [data_to_export[hook][worker_id], data], dim=0
                            )

        for _ in range(n_eps_per_worker):
            episode_results = [
                worker.rollout_eval(collect_hook_data=collect_hook_data)
                for worker in self.eval_workers
            ]
            update_eval_results(episode_results)

        # Note: do they need to be ordered?
        ordered_episodes = list(chain(
            *[episodes[i] for i in range(len(self.eval_workers))]
        ))

        samples = EpisodeBatch.concatenate(*ordered_episodes)  # concat
        self.total_env_steps += sum(samples.lengths)
        return samples, data_to_export