コード例 #1
0
    def _setup_worker(self, env_indices, tasks):
        """Setup workers.

        Args:
            env_indices (List[Int]): Indices of environments to be assigned
                to workers for sampling.
            tasks (List[dict]): List of tasks to assign.

        """
        if self._vec_env is not None:
            self._vec_env.close()

        vec_envs = []
        for env_ind in env_indices:
            for _ in range(self._envs_per_worker):
                vec_env = copy.deepcopy(self.env)
                vec_env.set_task(tasks[env_ind])
                vec_envs.append(vec_env)
        seed0 = deterministic.get_seed()
        if seed0 is not None:
            for (i, e) in enumerate(vec_envs):
                e.seed(seed0 + i)

        self._vec_env = VecEnvExecutor(
            envs=vec_envs, max_path_length=self.algo.max_path_length)
コード例 #2
0
ファイル: local_runner.py プロジェクト: seba-1511/metarl
    def __init__(self, snapshot_config, max_cpus=1):
        self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                        snapshot_config.snapshot_mode,
                                        snapshot_config.snapshot_gap)

        parallel_sampler.initialize(max_cpus)

        seed = get_seed()
        if seed is not None:
            parallel_sampler.set_seed(seed)

        self._has_setup = False
        self._plot = False

        self._setup_args = None
        self._train_args = None
        self._stats = ExperimentStats(total_itr=0,
                                      total_env_steps=0,
                                      total_epoch=0,
                                      last_path=None)

        self._algo = None
        self._env = None
        self._policy = None
        self._sampler = None
        self._plotter = None
        self._meta_eval = None

        self._start_time = None
        self._itr_start_time = None
        self.step_itr = None
        self.step_path = None
コード例 #3
0
    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_path_length=None,
                     worker_class=DefaultWorker,
                     sampler_args=None,
                     worker_args=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_path_length (int): Maximum path length to be sampled by the
                sampler. Paths longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the Sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.
            worker_args (dict or None): Additional arguments that should be
                passed to the sampler.

        Raises:
            ValueError: If `max_path_length` isn't passed and the algorithm
                doesn't contain a `max_path_length` field, or if the algorithm
                doesn't have a policy field.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        if not hasattr(self._algo, 'policy'):
            raise ValueError('If the runner is used to construct a sampler, '
                             'the algorithm must have a `policy` field.')
        if max_path_length is None:
            if hasattr(self._algo, 'max_path_length'):
                max_path_length = self._algo.max_path_length
            else:
                raise ValueError('If `sampler_cls` is specified in '
                                 'runner.setup, the algorithm must have '
                                 'a `max_path_length` field.')
        if seed is None:
            seed = get_seed()
        if sampler_args is None:
            sampler_args = {}
        if worker_args is None:
            worker_args = {}
        if issubclass(sampler_cls, BaseSampler):
            return sampler_cls(self._algo, self._env, **sampler_args)
        else:
            return sampler_cls.from_worker_factory(WorkerFactory(
                seed=seed,
                max_path_length=max_path_length,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args),
                                                   agents=self._algo.policy,
                                                   envs=self._env)
コード例 #4
0
    def setup(self,
              algo,
              env,
              sampler_cls=None,
              sampler_args=None,
              n_workers=psutil.cpu_count(logical=False),
              worker_class=DefaultWorker,
              worker_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (metarl.np.algos.RLAlgorithm): An algorithm instance.
            env (metarl.envs.MetaRLEnv): An environement instance.
            sampler_cls (metarl.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the sampler should use.
            worker_args (dict or None): Additional arguments that should be
                passed to the worker.

        Raises:
            ValueError: If sampler_cls is passed and the algorithm doesn't
                contain a `max_path_length` field.

        """
        self._algo = algo
        self._env = env
        self._n_workers = n_workers
        self._worker_class = worker_class
        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = getattr(algo, 'sampler_cls', None)
        if worker_args is None:
            worker_args = {}

        self._worker_args = worker_args
        if sampler_cls is None:
            self._sampler = None
        else:
            self._sampler = self.make_sampler(sampler_cls,
                                              sampler_args=sampler_args,
                                              n_workers=n_workers,
                                              worker_class=worker_class,
                                              worker_args=worker_args)

        self._has_setup = True

        self._setup_args = SetupArgs(sampler_cls=sampler_cls,
                                     sampler_args=sampler_args,
                                     seed=get_seed())
コード例 #5
0
    def start_worker(self):
        """Initialize the sampler."""
        n_envs = self._n_envs
        envs = [pickle.loads(pickle.dumps(self.env)) for _ in range(n_envs)]

        # Deterministically set environment seeds based on the global seed.
        seed0 = deterministic.get_seed()
        if seed0 is not None:
            for (i, e) in enumerate(envs):
                e.seed(seed0 + i)

        self._vec_env = VecEnvExecutor(
            envs=envs, max_path_length=self.algo.max_path_length)
コード例 #6
0
ファイル: local_runner.py プロジェクト: seba-1511/metarl
    def make_sampler(self,
                     sampler_cls,
                     *,
                     seed=None,
                     n_workers=psutil.cpu_count(logical=False),
                     max_path_length=None,
                     worker_class=DefaultWorker,
                     sampler_args=None,
                     worker_args=None,
                     env=None,
                     policy=None):
        """Construct a Sampler from a Sampler class.

        Args:
            sampler_cls (type): The type of sampler to construct.
            seed (int): Seed to use in sampler workers.
            max_path_length (int): Maximum path length to be sampled by the
                sampler. Paths longer than this will be truncated.
            n_workers (int): The number of workers the sampler should use.
            worker_class (type): Type of worker the Sampler should use.
            sampler_args (dict or None): Additional arguments that should be
                passed to the sampler.

        Returns:
            sampler_cls: An instance of the sampler class.

        """
        if env is None:
            env = self._env
        if policy is None:
            policy = self._algo.policy
        if max_path_length is None:
            max_path_length = self._algo.max_path_length
        if seed is None:
            seed = get_seed()
        if worker_args is None:
            worker_args = {}
        if sampler_args is None:
            sampler_args = {}
        if issubclass(sampler_cls, BaseSampler):
            return sampler_cls(self._algo, self._env, **sampler_args)
        else:
            return sampler_cls.from_worker_factory(WorkerFactory(
                seed=seed,
                max_path_length=max_path_length,
                n_workers=n_workers,
                worker_class=worker_class,
                worker_args=worker_args),
                                                   agents=policy,
                                                   envs=env)
コード例 #7
0
ファイル: local_runner.py プロジェクト: seba-1511/metarl
    def setup(self,
              algo,
              env,
              sampler_cls=None,
              sampler_args=None,
              n_workers=psutil.cpu_count(logical=False),
              worker_class=DefaultWorker,
              worker_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (metarl.np.algos.RLAlgorithm): An algorithm instance.
            env (metarl.envs.MetaRLEnv): An environement instance.
            sampler_cls (metarl.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self._algo = algo
        self._env = env
        self._policy = self._algo.policy
        self._n_workers = n_workers
        self._worker_class = worker_class
        if worker_args is None:
            worker_args = {}
        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = algo.sampler_cls
        self._sampler = self.make_sampler(sampler_cls,
                                          n_workers=n_workers,
                                          worker_class=worker_class,
                                          sampler_args=sampler_args,
                                          worker_args=worker_args)

        self._use_all_worker = sampler_args[
            'use_all_workers'] if 'use_all_workers' in sampler_args else False
        self._has_setup = True
        self._worker_args = worker_args

        self._setup_args = SetupArgs(sampler_cls=sampler_cls,
                                     sampler_args=sampler_args,
                                     seed=get_seed())
コード例 #8
0
    def evaluate(self, algo, test_rollouts_per_task=None):
        """Evaluate the Meta-RL algorithm on the test tasks.

        Args:
            algo (metarl.np.algos.MetaRLAlgorithm): The algorithm to evaluate.
            test_rollouts_per_task (int or None): Number of rollouts per task.

        """
        if test_rollouts_per_task is None:
            test_rollouts_per_task = self._n_test_rollouts
        adapted_trajectories = []
        logger.log('Sampling for adapation and meta-testing...')
        if self._test_sampler is None:
            self._test_sampler = LocalSampler.from_worker_factory(
                WorkerFactory(seed=get_seed(),
                              max_path_length=self._max_path_length,
                              n_workers=1,
                              worker_class=self._worker_class,
                              worker_args=self._worker_args),
                agents=algo.get_exploration_policy(),
                envs=self._test_task_sampler.sample(1))
        for env_up in self._test_task_sampler.sample(self._n_test_tasks):
            policy = algo.get_exploration_policy()
            traj = TrajectoryBatch.concatenate(*[
                self._test_sampler.obtain_samples(self._eval_itr, 1, policy,
                                                  env_up)
                for _ in range(self._n_exploration_traj)
            ])
            adapted_policy = algo.adapt_policy(policy, traj)
            adapted_traj = self._test_sampler.obtain_samples(
                self._eval_itr, test_rollouts_per_task * self._max_path_length,
                adapted_policy)
            adapted_trajectories.append(adapted_traj)
        logger.log('Finished meta-testing...')

        if self._test_task_names is not None:
            name_map = dict(enumerate(self._test_task_names))
        else:
            name_map = None

        with tabular.prefix(self._prefix + '/' if self._prefix else ''):
            log_multitask_performance(
                self._eval_itr,
                TrajectoryBatch.concatenate(*adapted_trajectories),
                getattr(algo, 'discount', 1.0),
                name_map=name_map)
        self._eval_itr += 1
コード例 #9
0
    def __init__(self, snapshot_config, max_cpus=1):
        self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                        snapshot_config.snapshot_mode,
                                        snapshot_config.snapshot_gap)

        parallel_sampler.initialize(max_cpus)

        seed = get_seed()
        if seed is not None:
            parallel_sampler.set_seed(seed)

        self._has_setup = False
        self._plot = False

        self._setup_args = None
        self._train_args = None
        self._stats = ExperimentStats(total_itr=0,
                                      total_env_steps=0,
                                      total_epoch=0,
                                      last_path=None)

        self._algo = None
        self._env = None
        self._sampler = None
        self._plotter = None

        self._start_time = None
        self._itr_start_time = None
        self.step_itr = None
        self.step_path = None

        # only used for off-policy algorithms
        self.enable_logging = True

        self._n_workers = None
        self._worker_class = None
        self._worker_args = None