Exemple #1
0
 def __init__(
         self,
         agents,
         envs,
         *,  # After this require passing by keyword.
         worker_factory=None,
         max_episode_length=None,
         is_tf_worker=False,
         seed=get_seed(),
         n_workers=psutil.cpu_count(logical=False),
         worker_class=DefaultWorker,
         worker_args=None):
     # pylint: disable=super-init-not-called
     if not ray.is_initialized():
         ray.init(log_to_driver=False, ignore_reinit_error=True)
     if worker_factory is None and max_episode_length is None:
         raise TypeError('Must construct a sampler from WorkerFactory or'
                         'parameters (at least max_episode_length)')
     if isinstance(worker_factory, WorkerFactory):
         self._worker_factory = worker_factory
     else:
         self._worker_factory = WorkerFactory(
             max_episode_length=max_episode_length,
             is_tf_worker=is_tf_worker,
             seed=seed,
             n_workers=n_workers,
             worker_class=worker_class,
             worker_args=worker_args)
     self._sampler_worker = ray.remote(SamplerWorker)
     self._agents = agents
     self._envs = self._worker_factory.prepare_worker_messages(envs)
     self._all_workers = defaultdict(None)
     self._workers_started = False
     self.start_worker()
     self.total_env_steps = 0
    def __init__(
            self,
            agents,
            envs,
            *,  # After this require passing by keyword.
            worker_factory=None,
            max_episode_length=None,
            is_tf_worker=False,
            seed=seed,
            worker_class=FragmentWorker,
            worker_args=None):

        if worker_factory is None and max_episode_length is None:
            raise TypeError("Must construct a sampler from WorkerFactory or"
                            "parameters (at least max_episode_length)")
        if isinstance(worker_factory, WorkerFactory):
            self._factory = worker_factory
        else:
            self._factory = WorkerFactory(
                max_episode_length=max_episode_length,
                is_tf_worker=is_tf_worker,
                seed=seed,
                n_workers=1,
                worker_class=worker_class,
                worker_args=worker_args)

        self._agents = self._factory.prepare_worker_messages(agents)
        self._envs = [copy.deepcopy(env) for env in envs]
        self._num_envs = len(self._envs)
        self._workers = [
            self._factory(i) for i in range(self._factory.n_workers)
        ]
        self._workers[0].update_agent(self._agents[0])
        self._workers[0].update_env(self._envs)
        self.total_env_steps = 0
Exemple #3
0
    def __init__(
            self,
            agents,
            envs,
            *,  # After this require passing by keyword.
            worker_factory=None,
            max_episode_length=None,
            is_tf_worker=False,
            seed=get_seed(),
            n_workers=psutil.cpu_count(logical=False),
            worker_class=DefaultWorker,
            worker_args=None):
        # pylint: disable=super-init-not-called
        if worker_factory is None and max_episode_length is None:
            raise TypeError('Must construct a sampler from WorkerFactory or'
                            'parameters (at least max_episode_length)')
        if isinstance(worker_factory, WorkerFactory):
            self._factory = worker_factory
        else:
            self._factory = WorkerFactory(
                max_episode_length=max_episode_length,
                is_tf_worker=is_tf_worker,
                seed=seed,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args)

        self._agents = self._factory.prepare_worker_messages(
            agents, cloudpickle.dumps)
        self._envs = self._factory.prepare_worker_messages(envs)
        self._to_sampler = mp.Queue(2 * self._factory.n_workers)
        self._to_worker = [mp.Queue(1) for _ in range(self._factory.n_workers)]
        # If we crash from an exception, with full queues, we would rather not
        # hang forever, so we would like the process to close without flushing
        # the queues.
        # That's what cancel_join_thread does.
        for q in self._to_worker:
            q.cancel_join_thread()
        self._workers = [
            mp.Process(target=run_worker,
                       kwargs=dict(
                           factory=self._factory,
                           to_sampler=self._to_sampler,
                           to_worker=self._to_worker[worker_number],
                           worker_number=worker_number,
                           agent=self._agents[worker_number],
                           env=self._envs[worker_number],
                       ),
                       daemon=False)
            for worker_number in range(self._factory.n_workers)
        ]
        self._agent_version = 0
        for w in self._workers:
            w.start()
        self.total_env_steps = 0
Exemple #4
0
    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_path_length=None,
                     worker_class=DefaultWorker,
                     sampler_args=None,
                     worker_args=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_path_length (int): Maximum path length to be sampled by the
                sampler. Paths longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the Sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.
            worker_args (dict or None): Additional arguments that should be
                passed to the sampler.

        Raises:
            ValueError: If `max_path_length` isn't passed and the algorithm
                doesn't contain a `max_path_length` field, or if the algorithm
                doesn't have a policy field.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        if not hasattr(self._algo, 'policy'):
            raise ValueError('If the runner is used to construct a sampler, '
                             'the algorithm must have a `policy` field.')
        if max_path_length is None:
            if hasattr(self._algo, 'max_path_length'):
                max_path_length = self._algo.max_path_length
            else:
                raise ValueError('If `sampler_cls` is specified in '
                                 'runner.setup, the algorithm must have '
                                 'a `max_path_length` field.')
        if seed is None:
            seed = get_seed()
        if sampler_args is None:
            sampler_args = {}
        if worker_args is None:
            worker_args = {}

        if issubclass(sampler_cls, BaseSampler):
            return sampler_cls(self._algo, self._env, **sampler_args)
        else:
            return sampler_cls.from_worker_factory(WorkerFactory(
                seed=seed,
                max_path_length=max_path_length,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args),
                                                   agents=self._algo.policy,
                                                   envs=self._env)
Exemple #5
0
    def evaluate(self, algo, test_rollouts_per_task=None):
        """Evaluate the Meta-RL algorithm on the test tasks.

        Args:
            algo (garage.np.algos.MetaRLAlgorithm): The algorithm to evaluate.
            test_rollouts_per_task (int or None): Number of rollouts per task.

        """
        if test_rollouts_per_task is None:
            test_rollouts_per_task = self._n_test_rollouts
        adapted_trajectories = []
        logger.log('Sampling for adapation and meta-testing...')
        if self._test_sampler is None:
            self._test_sampler = self._sampler_class.from_worker_factory(
                WorkerFactory(seed=get_seed(),
                              max_path_length=self._max_path_length,
                              n_workers=1,
                              worker_class=self._worker_class,
                              worker_args=self._worker_args),
                agents=algo.get_exploration_policy(),
                envs=self._test_task_sampler.sample(1))
        for env_up in self._test_task_sampler.sample(self._n_test_tasks):
            policy = algo.get_exploration_policy()
            traj = self._trajectory_batch_class.concatenate(*[
                self._test_sampler.obtain_samples(self._eval_itr, 1, policy,
                                                  env_up)
                for _ in range(self._n_exploration_traj)
            ])
            adapted_policy = algo.adapt_policy(policy, traj)
            adapted_traj = self._test_sampler.obtain_samples(
                self._eval_itr, test_rollouts_per_task * self._max_path_length,
                adapted_policy)
            adapted_trajectories.append(adapted_traj)
        logger.log('Finished meta-testing...')

        if self._test_task_names is not None:
            name_map = dict(enumerate(self._test_task_names))
        else:
            name_map = None

        with tabular.prefix(self._prefix + '/' if self._prefix else ''):
            log_multitask_performance(
                self._eval_itr,
                self._trajectory_batch_class.concatenate(
                    *adapted_trajectories),
                getattr(algo, 'discount', 1.0),
                trajectory_class=self._trajectory_batch_class,
                name_map=name_map)
        self._eval_itr += 1

        if self._trajectory_batch_class == TrajectoryBatch:
            rewards = self._trajectory_batch_class.concatenate(
                *adapted_trajectories).rewards
        else:
            rewards = self._trajectory_batch_class.concatenate(
                *adapted_trajectories).env_rewards

        return sum(rewards) / len(rewards)
Exemple #6
0
    def __init__(
            self,
            agents,
            envs,
            *,  # After this require passing by keyword.
            worker_factory=None,
            max_episode_length=None,
            is_tf_worker=False,
            seed=get_seed(),
            n_workers=psutil.cpu_count(logical=False),
            worker_class=DefaultWorker,
            worker_args=None):
        # pylint: disable=super-init-not-called
        if worker_factory is None and max_episode_length is None:
            raise TypeError('Must construct a sampler from WorkerFactory or'
                            'parameters (at least max_episode_length)')
        if isinstance(worker_factory, WorkerFactory):
            self._factory = worker_factory
        else:
            self._factory = WorkerFactory(
                max_episode_length=max_episode_length,
                is_tf_worker=is_tf_worker,
                seed=seed,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args)

        self._agents = self._factory.prepare_worker_messages(agents)
        self._envs = self._factory.prepare_worker_messages(
            envs, preprocess=copy.deepcopy)
        self._workers = [
            self._factory(i) for i in range(self._factory.n_workers)
        ]
        for worker, agent, env in zip(self._workers, self._agents, self._envs):
            worker.update_agent(agent)
            worker.update_env(env)
        self.total_env_steps = 0
Exemple #7
0
    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_path_length=None,
                     worker_class=DefaultWorker,
                     sampler_args=None,
                     worker_args=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_path_length (int): Maximum path length to be sampled by the
                sampler. Paths longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the Sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.
            worker_args (dict or None): Additional arguments that should be
                passed to the sampler.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        if max_path_length is None:
            max_path_length = self._algo.max_path_length
        if seed is None:
            seed = get_seed()
        if sampler_args is None:
            sampler_args = {}
        if worker_args is None:
            worker_args = {}
        if issubclass(sampler_cls, BaseSampler):
            return sampler_cls(self._algo, self._env, **sampler_args)
        else:
            return sampler_cls.from_worker_factory(WorkerFactory(
                seed=seed,
                max_path_length=max_path_length,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args),
                                                   agents=self._algo.policy,
                                                   envs=self._env)
class SingleVecWorkSampler(LocalSampler):
    """Sampler class which contains 1 vectorized worker which contains all the envs.

    The sampler need to be created either from a worker factory or from
    parameters which can construct a worker factory. See the __init__ method
    of WorkerFactory for the detail of these parameters.

    Args:
        agents (Policy or List[Policy]): Agent(s) to use to sample episodes.
            If a list is passed in, it must have length exactly
            `worker_factory.n_workers`, and will be spread across the
            workers.
        envs (Environment or List[Environment]): Environment from which
            episodes are sampled. If a list is passed in, it must have length
            exactly `worker_factory.n_workers`, and will be spread across the
            workers.
        worker_factory (WorkerFactory): Pickleable factory for creating
            workers. Should be transmitted to other processes / nodes where
            work needs to be done, then workers should be constructed
            there. Either this param or params after this are required to
            construct a sampler.
        max_episode_length(int): Params used to construct a worker factory.
            The maximum length episodes which will be sampled.
        is_tf_worker (bool): Whether it is workers for TFTrainer.
        seed(int): The seed to use to initialize random number generators.
        worker_class(type): Class of the workers. Instances should implement
            the Worker interface.
        worker_args (dict or None): Additional arguments that should be passed
            to the worker.

    """
    def __init__(
            self,
            agents,
            envs,
            *,  # After this require passing by keyword.
            worker_factory=None,
            max_episode_length=None,
            is_tf_worker=False,
            seed=seed,
            worker_class=FragmentWorker,
            worker_args=None):

        if worker_factory is None and max_episode_length is None:
            raise TypeError("Must construct a sampler from WorkerFactory or"
                            "parameters (at least max_episode_length)")
        if isinstance(worker_factory, WorkerFactory):
            self._factory = worker_factory
        else:
            self._factory = WorkerFactory(
                max_episode_length=max_episode_length,
                is_tf_worker=is_tf_worker,
                seed=seed,
                n_workers=1,
                worker_class=worker_class,
                worker_args=worker_args)

        self._agents = self._factory.prepare_worker_messages(agents)
        self._envs = [copy.deepcopy(env) for env in envs]
        self._num_envs = len(self._envs)
        self._workers = [
            self._factory(i) for i in range(self._factory.n_workers)
        ]
        self._workers[0].update_agent(self._agents[0])
        self._workers[0].update_env(self._envs)
        self.total_env_steps = 0

    def _update_workers(self, agent_update, env_update):
        """Apply updates to the workers.

        Args:
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        """
        agent_updates = self._factory.prepare_worker_messages(agent_update)
        # copy updates for all envs
        env_updates = [
            copy.deepcopy(env_update) for _ in range(self._num_envs)
        ]
        # update the worker's agent, and all of the worker's envs
        self._workers[0].update_agent(agent_updates[0])
        self._workers[0].update_env(env_updates)
Exemple #9
0
    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_episode_length=None,
                     worker_class=None,
                     sampler_args=None,
                     worker_args=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_episode_length (int): Maximum episode length to be sampled by
                the sampler. Epsiodes longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the Sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.
            worker_args (dict or None): Additional arguments that should be
                passed to the sampler.

        Raises:
            ValueError: If `max_episode_length` isn't passed and the algorithm
                doesn't contain a `max_episode_length` field, or if the
                algorithm doesn't have a policy field.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        policy = getattr(self._algo, 'exploration_policy', None)
        if policy is None:
            policy = getattr(self._algo, 'policy', None)
        if policy is None:
            raise ValueError('If the trainer is used to construct a sampler, '
                             'the algorithm must have a `policy` or '
                             '`exploration_policy` field.')
        if max_episode_length is None:
            if hasattr(self._algo, 'max_episode_length'):
                max_episode_length = self._algo.max_episode_length
        if max_episode_length is None:
            raise ValueError('If `sampler_cls` is specified in trainer.setup, '
                             'the algorithm must specify `max_episode_length`')
        if worker_class is None:
            worker_class = getattr(self._algo, 'worker_cls', DefaultWorker)
        if seed is None:
            seed = get_seed()
        if sampler_args is None:
            sampler_args = {}
        if worker_args is None:
            worker_args = {}
        return sampler_cls.from_worker_factory(WorkerFactory(
            seed=seed,
            max_episode_length=max_episode_length,
            n_workers=n_workers,
            worker_class=worker_class,
            worker_args=worker_args),
                                               agents=policy,
                                               envs=self._env)
Exemple #10
0
class MultiprocessingSampler(Sampler):
    """Sampler that uses multiprocessing to distribute workers.

    The sampler need to be created either from a worker factory or from
    parameters which can construct a worker factory. See the __init__ method
    of WorkerFactory for the detail of these parameters.

    Args:
        agents (Policy or List[Policy]): Agent(s) to use to sample episodes.
            If a list is passed in, it must have length exactly
            `worker_factory.n_workers`, and will be spread across the
            workers.
        envs (Environment or List[Environment]): Environment from which
            episodes are sampled. If a list is passed in, it must have length
            exactly `worker_factory.n_workers`, and will be spread across the
            workers.
        worker_factory (WorkerFactory): Pickleable factory for creating
            workers. Should be transmitted to other processes / nodes where
            work needs to be done, then workers should be constructed
            there. Either this param or params after this are required to
            construct a sampler.
        max_episode_length(int): Params used to construct a worker factory.
            The maximum length episodes which will be sampled.
        is_tf_worker (bool): Whether it is workers for TFTrainer.
        seed(int): The seed to use to initialize random number generators.
        n_workers(int): The number of workers to use.
        worker_class(type): Class of the workers. Instances should implement
            the Worker interface.
        worker_args (dict or None): Additional arguments that should be passed
            to the worker.

    """
    def __init__(
            self,
            agents,
            envs,
            *,  # After this require passing by keyword.
            worker_factory=None,
            max_episode_length=None,
            is_tf_worker=False,
            seed=get_seed(),
            n_workers=psutil.cpu_count(logical=False),
            worker_class=DefaultWorker,
            worker_args=None):
        # pylint: disable=super-init-not-called
        if worker_factory is None and max_episode_length is None:
            raise TypeError('Must construct a sampler from WorkerFactory or'
                            'parameters (at least max_episode_length)')
        if isinstance(worker_factory, WorkerFactory):
            self._factory = worker_factory
        else:
            self._factory = WorkerFactory(
                max_episode_length=max_episode_length,
                is_tf_worker=is_tf_worker,
                seed=seed,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args)

        self._agents = self._factory.prepare_worker_messages(
            agents, cloudpickle.dumps)
        self._envs = self._factory.prepare_worker_messages(envs)
        self._to_sampler = mp.Queue(2 * self._factory.n_workers)
        self._to_worker = [mp.Queue(1) for _ in range(self._factory.n_workers)]
        # If we crash from an exception, with full queues, we would rather not
        # hang forever, so we would like the process to close without flushing
        # the queues.
        # That's what cancel_join_thread does.
        for q in self._to_worker:
            q.cancel_join_thread()
        self._workers = [
            mp.Process(target=run_worker,
                       kwargs=dict(
                           factory=self._factory,
                           to_sampler=self._to_sampler,
                           to_worker=self._to_worker[worker_number],
                           worker_number=worker_number,
                           agent=self._agents[worker_number],
                           env=self._envs[worker_number],
                       ),
                       daemon=False)
            for worker_number in range(self._factory.n_workers)
        ]
        self._agent_version = 0
        for w in self._workers:
            w.start()
        self.total_env_steps = 0

    @classmethod
    def from_worker_factory(cls, worker_factory, agents, envs):
        """Construct this sampler.

        Args:
            worker_factory (WorkerFactory): Pickleable factory for creating
                workers. Should be transmitted to other processes / nodes where
                work needs to be done, then workers should be constructed
                there.
            agents (Policy or List[Policy]): Agent(s) to use to sample
                episodes. If a list is passed in, it must have length exactly
                `worker_factory.n_workers`, and will be spread across the
                workers.
            envs(Environment or List[Environment]): Environment from which
                episodes are sampled. If a list is passed in, it must have
                length exactly `worker_factory.n_workers`, and will be spread
                across the workers.

        Returns:
            Sampler: An instance of `cls`.

        """
        return cls(agents, envs, worker_factory=worker_factory)

    def _push_updates(self, updated_workers, agent_updates, env_updates):
        """Apply updates to the workers and (re)start them.

        Args:
            updated_workers (set[int]): Set of workers that don't need to be
                updated. Successfully updated workers will be added to this
                set.
            agent_updates (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_updates (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        """
        for worker_number, q in enumerate(self._to_worker):
            if worker_number in updated_workers:
                try:
                    q.put_nowait(('continue', ()))
                except queue.Full:
                    pass
            else:
                try:
                    q.put_nowait(('start', (agent_updates[worker_number],
                                            env_updates[worker_number],
                                            self._agent_version)))
                    updated_workers.add(worker_number)
                except queue.Full:
                    pass

    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Collect at least a given number transitions (timesteps).

        Args:
            itr(int): The current iteration number. Using this argument is
                deprecated.
            num_samples (int): Minimum number of transitions / timesteps to
                sample.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: The batch of collected episodes.

        Raises:
            AssertionError: On internal errors.

        """
        del itr
        batches = []
        completed_samples = 0
        self._agent_version += 1
        updated_workers = set()
        agent_ups = self._factory.prepare_worker_messages(
            agent_update, cloudpickle.dumps)
        env_ups = self._factory.prepare_worker_messages(env_update)

        with click.progressbar(length=num_samples, label='Sampling') as pbar:
            while completed_samples < num_samples:
                self._push_updates(updated_workers, agent_ups, env_ups)
                for _ in range(self._factory.n_workers):
                    try:
                        tag, contents = self._to_sampler.get_nowait()
                        if tag == 'episode':
                            batch, version, worker_n = contents
                            del worker_n
                            if version == self._agent_version:
                                batches.append(batch)
                                num_returned_samples = batch.lengths.sum()
                                completed_samples += num_returned_samples
                                pbar.update(num_returned_samples)
                            else:
                                # Receiving episodes from previous iterations
                                # is normal.  Potentially, we could gather them
                                # here, if an off-policy method wants them.
                                pass
                        else:
                            raise AssertionError(
                                'Unknown tag {} with contents {}'.format(
                                    tag, contents))
                    except queue.Empty:
                        pass
            for q in self._to_worker:
                try:
                    q.put_nowait(('stop', ()))
                except queue.Full:
                    pass

        samples = EpisodeBatch.concatenate(*batches)
        self.total_env_steps += sum(samples.lengths)
        return samples

    def obtain_exact_episodes(self,
                              n_eps_per_worker,
                              agent_update,
                              env_update=None):
        """Sample an exact number of episodes per worker.

        Args:
            n_eps_per_worker (int): Exact number of episodes to gather for
                each worker.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes. Always in worker
                order. In other words, first all episodes from worker 0,
                then all episodes from worker 1, etc.

        Raises:
            AssertionError: On internal errors.

        """
        self._agent_version += 1
        updated_workers = set()
        agent_ups = self._factory.prepare_worker_messages(
            agent_update, cloudpickle.dumps)
        env_ups = self._factory.prepare_worker_messages(env_update)
        episodes = defaultdict(list)

        with click.progressbar(length=self._factory.n_workers,
                               label='Sampling') as pbar:
            while any(
                    len(episodes[i]) < n_eps_per_worker
                    for i in range(self._factory.n_workers)):
                self._push_updates(updated_workers, agent_ups, env_ups)
                tag, contents = self._to_sampler.get()

                if tag == 'episode':
                    batch, version, worker_n = contents

                    if version == self._agent_version:
                        if len(episodes[worker_n]) < n_eps_per_worker:
                            episodes[worker_n].append(batch)

                        if len(episodes[worker_n]) == n_eps_per_worker:
                            pbar.update(1)
                            try:
                                self._to_worker[worker_n].put_nowait(
                                    ('stop', ()))
                            except queue.Full:
                                pass
                else:
                    raise AssertionError(
                        'Unknown tag {} with contents {}'.format(
                            tag, contents))

            for q in self._to_worker:
                try:
                    q.put_nowait(('stop', ()))
                except queue.Full:
                    pass

        ordered_episodes = list(
            itertools.chain(
                *[episodes[i] for i in range(self._factory.n_workers)]))
        samples = EpisodeBatch.concatenate(*ordered_episodes)
        self.total_env_steps += sum(samples.lengths)
        return samples

    def shutdown_worker(self):
        """Shutdown the workers."""
        for (q, w) in zip(self._to_worker, self._workers):
            # Loop until either the exit message is accepted or the process has
            # closed.  These might cause us to block, but ensures that the
            # workers are closed.
            while True:
                try:
                    # Set a timeout in case the child process crashed.
                    q.put(('exit', ()), timeout=1)
                    break
                except queue.Full:
                    # If the child process has crashed, we're done here.
                    # Otherwise it should eventually accept our message.
                    if not w.is_alive():
                        break
            # If this hangs forever, most likely a queue needs
            # cancel_join_thread called on it, or a subprocess has tripped the
            # "closing dowel with TensorboardOutput blocks forever bug."
            w.join()
        for q in self._to_worker:
            q.close()
        self._to_sampler.close()

    def __getstate__(self):
        """Get the pickle state.

        Returns:
            dict: The pickled state.

        """
        return dict(
            factory=self._factory,
            agents=[cloudpickle.loads(agent) for agent in self._agents],
            envs=self._envs)

    def __setstate__(self, state):
        """Unpickle the state.

        Args:
            state (dict): Unpickled state.

        """
        self.__init__(state['agents'],
                      state['envs'],
                      worker_factory=state['factory'])
Exemple #11
0
class LocalSampler(Sampler):
    """Sampler that runs workers in the main process.

    This is probably the simplest possible sampler. It's called the "Local"
    sampler because it runs everything in the same process and thread as where
    it was called from.

    The sampler need to be created either from a worker factory or from
    parameters which can construct a worker factory. See the __init__ method
    of WorkerFactory for the detail of these parameters.

    Args:
        agents (Policy or List[Policy]): Agent(s) to use to sample episodes.
            If a list is passed in, it must have length exactly
            `worker_factory.n_workers`, and will be spread across the
            workers.
        envs (Environment or List[Environment]): Environment from which
            episodes are sampled. If a list is passed in, it must have length
            exactly `worker_factory.n_workers`, and will be spread across the
            workers.
        worker_factory (WorkerFactory): Pickleable factory for creating
            workers. Should be transmitted to other processes / nodes where
            work needs to be done, then workers should be constructed
            there. Either this param or params after this are required to
            construct a sampler.
        max_episode_length(int): Params used to construct a worker factory.
            The maximum length episodes which will be sampled.
        is_tf_worker (bool): Whether it is workers for TFTrainer.
        seed(int): The seed to use to initialize random number generators.
        n_workers(int): The number of workers to use.
        worker_class(type): Class of the workers. Instances should implement
            the Worker interface.
        worker_args (dict or None): Additional arguments that should be passed
            to the worker.

    """
    def __init__(
            self,
            agents,
            envs,
            *,  # After this require passing by keyword.
            worker_factory=None,
            max_episode_length=None,
            is_tf_worker=False,
            seed=get_seed(),
            n_workers=psutil.cpu_count(logical=False),
            worker_class=DefaultWorker,
            worker_args=None):
        # pylint: disable=super-init-not-called
        if worker_factory is None and max_episode_length is None:
            raise TypeError('Must construct a sampler from WorkerFactory or'
                            'parameters (at least max_episode_length)')
        if isinstance(worker_factory, WorkerFactory):
            self._factory = worker_factory
        else:
            self._factory = WorkerFactory(
                max_episode_length=max_episode_length,
                is_tf_worker=is_tf_worker,
                seed=seed,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args)

        self._agents = self._factory.prepare_worker_messages(agents)
        self._envs = self._factory.prepare_worker_messages(
            envs, preprocess=copy.deepcopy)
        self._workers = [
            self._factory(i) for i in range(self._factory.n_workers)
        ]
        for worker, agent, env in zip(self._workers, self._agents, self._envs):
            worker.update_agent(agent)
            worker.update_env(env)
        self.total_env_steps = 0

    @classmethod
    def from_worker_factory(cls, worker_factory, agents, envs):
        """Construct this sampler.

        Args:
            worker_factory (WorkerFactory): Pickleable factory for creating
                workers. Should be transmitted to other processes / nodes where
                work needs to be done, then workers should be constructed
                there.
            agents (Agent or List[Agent]): Agent(s) to use to sample episodes.
                If a list is passed in, it must have length exactly
                `worker_factory.n_workers`, and will be spread across the
                workers.
            envs (Environment or List[Environment]): Environment from which
                episodes are sampled. If a list is passed in, it must have
                length exactly `worker_factory.n_workers`, and will be spread
                across the workers.

        Returns:
            Sampler: An instance of `cls`.

        """
        return cls(agents, envs, worker_factory=worker_factory)

    def _update_workers(self, agent_update, env_update):
        """Apply updates to the workers.

        Args:
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        """
        agent_updates = self._factory.prepare_worker_messages(agent_update)
        env_updates = self._factory.prepare_worker_messages(
            env_update, preprocess=copy.deepcopy)
        for worker, agent_up, env_up in zip(self._workers, agent_updates,
                                            env_updates):
            worker.update_agent(agent_up)
            worker.update_env(env_up)

    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Collect at least a given number transitions (timesteps).

        Args:
            itr(int): The current iteration number. Using this argument is
                deprecated.
            num_samples (int): Minimum number of transitions / timesteps to
                sample.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: The batch of collected episodes.

        """
        self._update_workers(agent_update, env_update)
        batches = []
        completed_samples = 0
        while True:
            for worker in self._workers:
                batch = worker.rollout()
                completed_samples += len(batch.actions)
                batches.append(batch)
                if completed_samples >= num_samples:
                    samples = EpisodeBatch.concatenate(*batches)
                    self.total_env_steps += sum(samples.lengths)
                    return samples

    def obtain_exact_episodes(self,
                              n_eps_per_worker,
                              agent_update,
                              env_update=None):
        """Sample an exact number of episodes per worker.

        Args:
            n_eps_per_worker (int): Exact number of episodes to gather for
                each worker.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before samplin episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes. Always in worker
                order. In other words, first all episodes from worker 0,
                then all episodes from worker 1, etc.

        """
        self._update_workers(agent_update, env_update)
        batches = []
        for worker in self._workers:
            for _ in range(n_eps_per_worker):
                batch = worker.rollout()
                batches.append(batch)
        samples = EpisodeBatch.concatenate(*batches)
        self.total_env_steps += sum(samples.lengths)
        return samples

    def shutdown_worker(self):
        """Shutdown the workers."""
        for worker in self._workers:
            worker.shutdown()

    def __getstate__(self):
        """Get the pickle state.

        Returns:
            dict: The pickled state.

        """
        state = self.__dict__.copy()
        # Workers aren't picklable (but WorkerFactory is).
        state['_workers'] = None
        return state

    def __setstate__(self, state):
        """Unpickle the state.

        Args:
            state (dict): Unpickled state.

        """
        self.__dict__.update(state)
        self._workers = [
            self._factory(i) for i in range(self._factory.n_workers)
        ]
        for worker, agent, env in zip(self._workers, self._agents, self._envs):
            worker.update_agent(agent)
            worker.update_env(env)
Exemple #12
0
class RaySampler(Sampler):
    """Samples episodes in a data-parallel fashion using a Ray cluster.

    The sampler need to be created either from a worker factory or from
    parameters which can construct a worker factory. See the __init__ method
    of WorkerFactory for the detail of these parameters.

    Args:
        agents (list[Policy]): Agents to distribute across workers.
        envs (list[Environment]): Environments to distribute across workers.
        worker_factory (WorkerFactory): Used for worker behavior. Either this
            param or params after this are required to construct a sampler.
        max_episode_length(int): Params used to construct a worker factory.
            The maximum length episodes which will be sampled.
        is_tf_worker (bool): Whether it is workers for TFTrainer.
        seed(int): The seed to use to initialize random number generators.
        n_workers(int): The number of workers to use.
        worker_class(type): Class of the workers. Instances should implement
            the Worker interface.
        worker_args (dict or None): Additional arguments that should be passed
            to the worker.

    """
    def __init__(
            self,
            agents,
            envs,
            *,  # After this require passing by keyword.
            worker_factory=None,
            max_episode_length=None,
            is_tf_worker=False,
            seed=get_seed(),
            n_workers=psutil.cpu_count(logical=False),
            worker_class=DefaultWorker,
            worker_args=None):
        # pylint: disable=super-init-not-called
        if not ray.is_initialized():
            ray.init(log_to_driver=False, ignore_reinit_error=True)
        if worker_factory is None and max_episode_length is None:
            raise TypeError('Must construct a sampler from WorkerFactory or'
                            'parameters (at least max_episode_length)')
        if isinstance(worker_factory, WorkerFactory):
            self._worker_factory = worker_factory
        else:
            self._worker_factory = WorkerFactory(
                max_episode_length=max_episode_length,
                is_tf_worker=is_tf_worker,
                seed=seed,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args)
        self._sampler_worker = ray.remote(SamplerWorker)
        self._agents = agents
        self._envs = self._worker_factory.prepare_worker_messages(envs)
        self._all_workers = defaultdict(None)
        self._workers_started = False
        self.start_worker()
        self.total_env_steps = 0

    @classmethod
    def from_worker_factory(cls, worker_factory, agents, envs):
        """Construct this sampler.

        Args:
            worker_factory (WorkerFactory): Pickleable factory for creating
                workers. Should be transmitted to other processes / nodes where
                work needs to be done, then workers should be constructed
                there.
            agents (Policy or List[Policy]): Agent(s) to use to sample
                episodes. If a list is passed in, it must have length exactly
                `worker_factory.n_workers`, and will be spread across the
                workers.
            envs (Environment or List[Environment]): Environment from which
                episodes are sampled. If a list is passed in, it must have
                length exactly `worker_factory.n_workers`, and will be spread
                across the workers.

        Returns:
            Sampler: An instance of `cls`.

        """
        return cls(agents, envs, worker_factory=worker_factory)

    def start_worker(self):
        """Initialize a new ray worker."""
        if self._workers_started:
            return
        self._workers_started = True
        # We need to pickle the agent so that we can e.g. set up the TF.Session
        # in the worker *before* unpickling it.
        agent_pkls = self._worker_factory.prepare_worker_messages(
            self._agents, cloudpickle.dumps)
        for worker_id in range(self._worker_factory.n_workers):
            self._all_workers[worker_id] = self._sampler_worker.remote(
                worker_id, self._envs[worker_id], agent_pkls[worker_id],
                self._worker_factory)

    def _update_workers(self, agent_update, env_update):
        """Update all of the workers.

        Args:
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling_episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            list[ray._raylet.ObjectID]: Remote values of worker ids.

        """
        updating_workers = []
        param_ids = self._worker_factory.prepare_worker_messages(
            agent_update, ray.put)
        env_ids = self._worker_factory.prepare_worker_messages(
            env_update, ray.put)
        for worker_id in range(self._worker_factory.n_workers):
            worker = self._all_workers[worker_id]
            updating_workers.append(
                worker.update.remote(param_ids[worker_id], env_ids[worker_id]))
        return updating_workers

    def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
        """Sample the policy for new episodes.

        Args:
            itr (int): Iteration number.
            num_samples (int): Number of steps the the sampler should collect.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes.

        """
        active_workers = []
        completed_samples = 0
        batches = []

        # update the policy params of each worker before sampling
        # for the current iteration
        idle_worker_ids = []
        updating_workers = self._update_workers(agent_update, env_update)

        with click.progressbar(length=num_samples, label='Sampling') as pbar:
            while completed_samples < num_samples:
                # if there are workers still being updated, check
                # which ones are still updating and take the workers that
                # are done updating, and start collecting episodes on those
                # workers.
                if updating_workers:
                    updated, updating_workers = ray.wait(updating_workers,
                                                         num_returns=1,
                                                         timeout=0.1)
                    upd = [ray.get(up) for up in updated]
                    idle_worker_ids.extend(upd)

                # if there are idle workers, use them to collect episodes and
                # mark the newly busy workers as active
                while idle_worker_ids:
                    idle_worker_id = idle_worker_ids.pop()
                    worker = self._all_workers[idle_worker_id]
                    active_workers.append(worker.rollout.remote())

                # check which workers are done/not done collecting a sample
                # if any are done, send them to process the collected
                # episode if they are not, keep checking if they are done
                ready, not_ready = ray.wait(active_workers,
                                            num_returns=1,
                                            timeout=0.001)
                active_workers = not_ready
                for result in ready:
                    ready_worker_id, episode_batch = ray.get(result)
                    idle_worker_ids.append(ready_worker_id)
                    num_returned_samples = episode_batch.lengths.sum()
                    completed_samples += num_returned_samples
                    batches.append(episode_batch)
                    pbar.update(num_returned_samples)

        samples = EpisodeBatch.concatenate(*batches)
        self.total_env_steps += sum(samples.lengths)
        return samples

    def obtain_exact_episodes(self,
                              n_eps_per_worker,
                              agent_update,
                              env_update=None):
        """Sample an exact number of episodes per worker.

        Args:
            n_eps_per_worker (int): Exact number of episodes to gather for
                each worker.
            agent_update (object): Value which will be passed into the
                `agent_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.
            env_update (object): Value which will be passed into the
                `env_update_fn` before sampling episodes. If a list is passed
                in, it must have length exactly `factory.n_workers`, and will
                be spread across the workers.

        Returns:
            EpisodeBatch: Batch of gathered episodes. Always in worker
                order. In other words, first all episodes from worker 0, then
                all episodes from worker 1, etc.

        """
        active_workers = []
        episodes = defaultdict(list)

        # update the policy params of each worker before sampling
        # for the current iteration
        idle_worker_ids = []
        updating_workers = self._update_workers(agent_update, env_update)

        with click.progressbar(length=self._worker_factory.n_workers,
                               label='Sampling') as pbar:
            while any(
                    len(episodes[i]) < n_eps_per_worker
                    for i in range(self._worker_factory.n_workers)):
                # if there are workers still being updated, check
                # which ones are still updating and take the workers that
                # are done updating, and start collecting episodes on
                # those workers.
                if updating_workers:
                    updated, updating_workers = ray.wait(updating_workers,
                                                         num_returns=1,
                                                         timeout=0.1)
                    upd = [ray.get(up) for up in updated]
                    idle_worker_ids.extend(upd)

                # if there are idle workers, use them to collect episodes
                # mark the newly busy workers as active
                while idle_worker_ids:
                    idle_worker_id = idle_worker_ids.pop()
                    worker = self._all_workers[idle_worker_id]
                    active_workers.append(worker.rollout.remote())

                # check which workers are done/not done collecting a sample
                # if any are done, send them to process the collected episode
                # if they are not, keep checking if they are done
                ready, not_ready = ray.wait(active_workers,
                                            num_returns=1,
                                            timeout=0.001)
                active_workers = not_ready
                for result in ready:
                    ready_worker_id, episode_batch = ray.get(result)
                    episodes[ready_worker_id].append(episode_batch)

                    if len(episodes[ready_worker_id]) < n_eps_per_worker:
                        idle_worker_ids.append(ready_worker_id)

                    pbar.update(1)

        ordered_episodes = list(
            itertools.chain(
                *[episodes[i] for i in range(self._worker_factory.n_workers)]))

        samples = EpisodeBatch.concatenate(*ordered_episodes)
        self.total_env_steps += sum(samples.lengths)
        return samples

    def shutdown_worker(self):
        """Shuts down the worker."""
        for worker in self._all_workers.values():
            worker.shutdown.remote()
        ray.shutdown()

    def __getstate__(self):
        """Get the pickle state.

        Returns:
            dict: The pickled state.

        """
        return dict(factory=self._worker_factory,
                    agents=self._agents,
                    envs=self._envs)

    def __setstate__(self, state):
        """Unpickle the state.

        Args:
            state (dict): Unpickled state.

        """
        self.__init__(state['agents'],
                      state['envs'],
                      worker_factory=state['factory'])