def _setup_worker(self, env_indices, tasks): """Setup workers. Args: env_indices (List[Int]): Indices of environments to be assigned to workers for sampling. tasks (List[dict]): List of tasks to assign. """ if self._vec_env is not None: self._vec_env.close() vec_envs = [] for env_ind in env_indices: for _ in range(self._envs_per_worker): vec_env = copy.deepcopy(self.env) vec_env.set_task(tasks[env_ind]) vec_envs.append(vec_env) seed0 = deterministic.get_seed() if seed0 is not None: for (i, e) in enumerate(vec_envs): e.seed(seed0 + i) self._vec_env = VecEnvExecutor( envs=vec_envs, max_path_length=self.algo.max_path_length)
def __init__(self, snapshot_config, max_cpus=1): self._snapshotter = Snapshotter(snapshot_config.snapshot_dir, snapshot_config.snapshot_mode, snapshot_config.snapshot_gap) parallel_sampler.initialize(max_cpus) seed = get_seed() if seed is not None: parallel_sampler.set_seed(seed) self._has_setup = False self._plot = False self._setup_args = None self._train_args = None self._stats = ExperimentStats(total_itr=0, total_env_steps=0, total_epoch=0, last_path=None) self._algo = None self._env = None self._policy = None self._sampler = None self._plotter = None self._meta_eval = None self._start_time = None self._itr_start_time = None self.step_itr = None self.step_path = None
def make_sampler(self, sampler_cls, *, seed=None, n_workers=psutil.cpu_count(logical=False), max_path_length=None, worker_class=DefaultWorker, sampler_args=None, worker_args=None): """Construct a Sampler from a Sampler class. Args: sampler_cls (type): The type of sampler to construct. seed (int): Seed to use in sampler workers. max_path_length (int): Maximum path length to be sampled by the sampler. Paths longer than this will be truncated. n_workers (int): The number of workers the sampler should use. worker_class (type): Type of worker the Sampler should use. sampler_args (dict or None): Additional arguments that should be passed to the sampler. worker_args (dict or None): Additional arguments that should be passed to the sampler. Raises: ValueError: If `max_path_length` isn't passed and the algorithm doesn't contain a `max_path_length` field, or if the algorithm doesn't have a policy field. Returns: sampler_cls: An instance of the sampler class. """ if not hasattr(self._algo, 'policy'): raise ValueError('If the runner is used to construct a sampler, ' 'the algorithm must have a `policy` field.') if max_path_length is None: if hasattr(self._algo, 'max_path_length'): max_path_length = self._algo.max_path_length else: raise ValueError('If `sampler_cls` is specified in ' 'runner.setup, the algorithm must have ' 'a `max_path_length` field.') if seed is None: seed = get_seed() if sampler_args is None: sampler_args = {} if worker_args is None: worker_args = {} if issubclass(sampler_cls, BaseSampler): return sampler_cls(self._algo, self._env, **sampler_args) else: return sampler_cls.from_worker_factory(WorkerFactory( seed=seed, max_path_length=max_path_length, n_workers=n_workers, worker_class=worker_class, worker_args=worker_args), agents=self._algo.policy, envs=self._env)
def setup(self, algo, env, sampler_cls=None, sampler_args=None, n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None): """Set up runner for algorithm and environment. This method saves algo and env within runner and creates a sampler. Note: After setup() is called all variables in session should have been initialized. setup() respects existing values in session so policy weights can be loaded before setup(). Args: algo (metarl.np.algos.RLAlgorithm): An algorithm instance. env (metarl.envs.MetaRLEnv): An environement instance. sampler_cls (metarl.sampler.Sampler): A sampler class. sampler_args (dict): Arguments to be passed to sampler constructor. n_workers (int): The number of workers the sampler should use. worker_class (type): Type of worker the sampler should use. worker_args (dict or None): Additional arguments that should be passed to the worker. Raises: ValueError: If sampler_cls is passed and the algorithm doesn't contain a `max_path_length` field. """ self._algo = algo self._env = env self._n_workers = n_workers self._worker_class = worker_class if sampler_args is None: sampler_args = {} if sampler_cls is None: sampler_cls = getattr(algo, 'sampler_cls', None) if worker_args is None: worker_args = {} self._worker_args = worker_args if sampler_cls is None: self._sampler = None else: self._sampler = self.make_sampler(sampler_cls, sampler_args=sampler_args, n_workers=n_workers, worker_class=worker_class, worker_args=worker_args) self._has_setup = True self._setup_args = SetupArgs(sampler_cls=sampler_cls, sampler_args=sampler_args, seed=get_seed())
def start_worker(self): """Initialize the sampler.""" n_envs = self._n_envs envs = [pickle.loads(pickle.dumps(self.env)) for _ in range(n_envs)] # Deterministically set environment seeds based on the global seed. seed0 = deterministic.get_seed() if seed0 is not None: for (i, e) in enumerate(envs): e.seed(seed0 + i) self._vec_env = VecEnvExecutor( envs=envs, max_path_length=self.algo.max_path_length)
def make_sampler(self, sampler_cls, *, seed=None, n_workers=psutil.cpu_count(logical=False), max_path_length=None, worker_class=DefaultWorker, sampler_args=None, worker_args=None, env=None, policy=None): """Construct a Sampler from a Sampler class. Args: sampler_cls (type): The type of sampler to construct. seed (int): Seed to use in sampler workers. max_path_length (int): Maximum path length to be sampled by the sampler. Paths longer than this will be truncated. n_workers (int): The number of workers the sampler should use. worker_class (type): Type of worker the Sampler should use. sampler_args (dict or None): Additional arguments that should be passed to the sampler. Returns: sampler_cls: An instance of the sampler class. """ if env is None: env = self._env if policy is None: policy = self._algo.policy if max_path_length is None: max_path_length = self._algo.max_path_length if seed is None: seed = get_seed() if worker_args is None: worker_args = {} if sampler_args is None: sampler_args = {} if issubclass(sampler_cls, BaseSampler): return sampler_cls(self._algo, self._env, **sampler_args) else: return sampler_cls.from_worker_factory(WorkerFactory( seed=seed, max_path_length=max_path_length, n_workers=n_workers, worker_class=worker_class, worker_args=worker_args), agents=policy, envs=env)
def setup(self, algo, env, sampler_cls=None, sampler_args=None, n_workers=psutil.cpu_count(logical=False), worker_class=DefaultWorker, worker_args=None): """Set up runner for algorithm and environment. This method saves algo and env within runner and creates a sampler. Note: After setup() is called all variables in session should have been initialized. setup() respects existing values in session so policy weights can be loaded before setup(). Args: algo (metarl.np.algos.RLAlgorithm): An algorithm instance. env (metarl.envs.MetaRLEnv): An environement instance. sampler_cls (metarl.sampler.Sampler): A sampler class. sampler_args (dict): Arguments to be passed to sampler constructor. """ self._algo = algo self._env = env self._policy = self._algo.policy self._n_workers = n_workers self._worker_class = worker_class if worker_args is None: worker_args = {} if sampler_args is None: sampler_args = {} if sampler_cls is None: sampler_cls = algo.sampler_cls self._sampler = self.make_sampler(sampler_cls, n_workers=n_workers, worker_class=worker_class, sampler_args=sampler_args, worker_args=worker_args) self._use_all_worker = sampler_args[ 'use_all_workers'] if 'use_all_workers' in sampler_args else False self._has_setup = True self._worker_args = worker_args self._setup_args = SetupArgs(sampler_cls=sampler_cls, sampler_args=sampler_args, seed=get_seed())
def evaluate(self, algo, test_rollouts_per_task=None): """Evaluate the Meta-RL algorithm on the test tasks. Args: algo (metarl.np.algos.MetaRLAlgorithm): The algorithm to evaluate. test_rollouts_per_task (int or None): Number of rollouts per task. """ if test_rollouts_per_task is None: test_rollouts_per_task = self._n_test_rollouts adapted_trajectories = [] logger.log('Sampling for adapation and meta-testing...') if self._test_sampler is None: self._test_sampler = LocalSampler.from_worker_factory( WorkerFactory(seed=get_seed(), max_path_length=self._max_path_length, n_workers=1, worker_class=self._worker_class, worker_args=self._worker_args), agents=algo.get_exploration_policy(), envs=self._test_task_sampler.sample(1)) for env_up in self._test_task_sampler.sample(self._n_test_tasks): policy = algo.get_exploration_policy() traj = TrajectoryBatch.concatenate(*[ self._test_sampler.obtain_samples(self._eval_itr, 1, policy, env_up) for _ in range(self._n_exploration_traj) ]) adapted_policy = algo.adapt_policy(policy, traj) adapted_traj = self._test_sampler.obtain_samples( self._eval_itr, test_rollouts_per_task * self._max_path_length, adapted_policy) adapted_trajectories.append(adapted_traj) logger.log('Finished meta-testing...') if self._test_task_names is not None: name_map = dict(enumerate(self._test_task_names)) else: name_map = None with tabular.prefix(self._prefix + '/' if self._prefix else ''): log_multitask_performance( self._eval_itr, TrajectoryBatch.concatenate(*adapted_trajectories), getattr(algo, 'discount', 1.0), name_map=name_map) self._eval_itr += 1
def __init__(self, snapshot_config, max_cpus=1): self._snapshotter = Snapshotter(snapshot_config.snapshot_dir, snapshot_config.snapshot_mode, snapshot_config.snapshot_gap) parallel_sampler.initialize(max_cpus) seed = get_seed() if seed is not None: parallel_sampler.set_seed(seed) self._has_setup = False self._plot = False self._setup_args = None self._train_args = None self._stats = ExperimentStats(total_itr=0, total_env_steps=0, total_epoch=0, last_path=None) self._algo = None self._env = None self._sampler = None self._plotter = None self._start_time = None self._itr_start_time = None self.step_itr = None self.step_path = None # only used for off-policy algorithms self.enable_logging = True self._n_workers = None self._worker_class = None self._worker_args = None