class ParallelRolloutSampler(SamplerBase, Serializable): """Class for sampling from multiple environments in parallel""" def __init__( self, env, policy, num_workers: int, *, min_rollouts: int = None, min_steps: int = None, show_progress_bar: bool = True, seed: int = NO_SEED_PASSED, ): """ Constructor :param env: environment to sample from :param policy: policy to act in the environment (can also be an exploration strategy) :param num_workers: number of parallel samplers :param min_rollouts: minimum number of complete rollouts to sample :param min_steps: minimum total number of steps to sample :param show_progress_bar: it `True`, display a progress bar using `tqdm` :param seed: seed value for the random number generators, pass `None` for no seeding; defaults to the last seed that was set with `pyrado.set_seed` """ Serializable._init(self, locals()) super().__init__(min_rollouts=min_rollouts, min_steps=min_steps) self.env = env self.policy = policy self.show_progress_bar = show_progress_bar # Set method to spawn if using cuda if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) # Create parallel pool. We use one thread per env because it's easier. self.pool = SamplerPool(num_workers) if seed is NO_SEED_PASSED: seed = pyrado.get_base_seed() self._seed = seed # Initialize with -1 such that we start with the 0-th sample. Incrementing after sampling may cause issues when # the sampling crashes and the sample count is not incremented. self._sample_count = -1 # Distribute environments. We use pickle to make sure a copy is created for n_envs=1 self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def reinit(self, env: Optional[Env] = None, policy: Optional[Policy] = None): """ Re-initialize the sampler. :param env: the environment which the policy operates :param policy: the policy used for sampling """ # Update env and policy if passed if env is not None: self.env = env if policy is not None: self.policy = policy # Always broadcast to workers self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def sample( self, init_states: Optional[List[np.ndarray]] = None, domain_params: Optional[List[dict]] = None, eval: bool = False, ) -> List[StepSequence]: """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. .. note:: This method is **not** thread-safe! See for example the usage of `self._sample_count`. :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's initial state space :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`. :return: list of sampled rollouts """ self._sample_count += 1 # Update policy's state self.pool.invoke_all(_ps_update_policy, self.policy.state_dict()) # Collect samples with tqdm( leave=False, file=sys.stdout, desc="Sampling", disable=(not self.show_progress_bar), unit="steps" if self.min_steps is not None else "rollouts", ) as pb: if self.min_steps is None: if init_states is None and domain_params is None: # Simply run min_rollouts times func = partial(_ps_run_one, eval=eval) arglist = list(range(self.min_rollouts)) elif init_states is not None and domain_params is None: # Run every initial state so often that we at least get min_rollouts trajectories func = partial(_ps_run_one_init_state, eval=eval) rep_factor = ceil(self.min_rollouts / len(init_states)) arglist = list(enumerate(rep_factor * init_states)) elif init_states is None and domain_params is not None: # Run every domain parameter set so often that we at least get min_rollouts trajectories func = partial(_ps_run_one_domain_param, eval=eval) rep_factor = ceil(self.min_rollouts / len(domain_params)) arglist = list(enumerate(rep_factor * domain_params)) elif init_states is not None and domain_params is not None: # Run every combination of initial state and domain parameter so often that we at least get # min_rollouts trajectories func = partial(_ps_run_one_reset_kwargs, eval=eval) allcombs = list(product(init_states, domain_params)) rep_factor = ceil(self.min_rollouts / len(allcombs)) arglist = list(enumerate(rep_factor * allcombs)) # Only minimum number of rollouts given, thus use run_map return self.pool.run_map( partial(func, seed=self._seed, sub_seed=self._sample_count), arglist, pb) else: # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None) if init_states is None: return self.pool.run_collect( self.min_steps, partial(_ps_sample_one, eval=eval, seed=self._seed, sub_seed=self._sample_count), collect_progressbar=pb, min_runs=self.min_rollouts, )[0] else: raise NotImplementedError
class ParallelRolloutSampler(SamplerBase, Serializable): """ Class for sampling from multiple environments in parallel """ def __init__(self, env, policy, num_workers: int, *, min_rollouts: int = None, min_steps: int = None, show_progress_bar: bool = True, seed: int = None): """ Constructor :param env: environment to sample from :param policy: policy to act in the environment (can also be an exploration strategy) :param num_workers: number of parallel samplers :param min_rollouts: minimum number of complete rollouts to sample :param min_steps: minimum total number of steps to sample :param show_progress_bar: it `True`, display a progress bar using `tqdm` :param seed: seed value for the random number generators, pass `None` for no seeding """ Serializable._init(self, locals()) super().__init__(min_rollouts=min_rollouts, min_steps=min_steps) self.env = env self.policy = policy self.show_progress_bar = show_progress_bar # Set method to spawn if using cuda if self.policy.device == 'cuda': mp.set_start_method('spawn', force=True) # Create parallel pool. We use one thread per env because it's easier. self.pool = SamplerPool(num_workers) # Set all rngs' seeds if seed is not None: self.set_seed(seed) # Distribute environments. We use pickle to make sure a copy is created for n_envs=1 self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def set_seed(self, seed): """ Set a deterministic seed on all workers. :param seed: seed value for the random number generators """ self.pool.set_seed(seed) def reinit(self, env=None, policy=None): """ Re-initialize the sampler. :param env: the environment which the policy operates :param policy: the policy used for sampling """ # Update env and policy if passed if env is not None: self.env = env if policy is not None: self.policy = policy # Always broadcast to workers self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def sample(self, init_states: List[np.ndarray] = None, domain_params: List[np.ndarray] = None, eval: bool = False) -> List[StepSequence]: """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's initial state space :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`. :return: list of sampled rollouts """ # Update policy's state self.pool.invoke_all(_ps_update_policy, self.policy.state_dict()) # Collect samples with tqdm(leave=False, file=sys.stdout, desc='Sampling', disable=(not self.show_progress_bar), unit='steps' if self.min_steps is not None else 'rollouts') as pb: if self.min_steps is None: if init_states is None and domain_params is None: # Simply run min_rollouts times func = partial(_ps_run_one, eval=eval) arglist = range(self.min_rollouts) elif init_states is not None and domain_params is None: # Run every initial state so often that we at least get min_rollouts trajectories func = partial(_ps_run_one_init_state, eval=eval) rep_factor = ceil(self.min_rollouts / len(init_states)) arglist = rep_factor * init_states elif init_states is not None and domain_params is not None: # Run every combination of initial state and domain parameter so often that we at least get # min_rollouts trajectories func = partial(_ps_run_one_reset_kwargs, eval=eval) allcombs = list(product(init_states, domain_params)) rep_factor = ceil(self.min_rollouts / len(allcombs)) arglist = rep_factor * allcombs else: raise NotImplementedError # Only minimum number of rollouts given, thus use run_map return self.pool.run_map(func, arglist, pb) else: # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None) if init_states is None: return self.pool.run_collect(self.min_steps, partial(_ps_sample_one, eval=eval), collect_progressbar=pb, min_runs=self.min_rollouts)[0] else: raise NotImplementedError
class ParallelSampler(SamplerBase, Serializable): """ Class for sampling from multiple environments in parallel """ def __init__(self, env, policy, num_envs: int, *, min_rollouts: int = None, min_steps: int = None, bernoulli_reset: bool = None, seed: int = None): """ Constructor :param env: environment to sample from :param policy: policy to act in the environment (can also be an exploration strategy) :param num_envs: number of parallel samplers :param min_rollouts: minimum number of complete rollouts to sample. :param min_steps: minimum total number of steps to sample. :param bernoulli_reset: probability for resetting after the current time step :param seed: Seed to use. Every subprocess is seeded with seed+thread_number """ Serializable._init(self, locals()) super().__init__(min_rollouts=min_rollouts, min_steps=min_steps) self.env = env self.policy = policy self.bernoulli_reset = bernoulli_reset # Set method to spawn if using cuda if self.policy.device == 'cuda': mp.set_start_method('spawn', force=True) # Create parallel pool. We use one thread per env because it's easier. self.pool = SamplerPool(num_envs) if seed is not None: self.pool.set_seed(seed) # Distribute environments. We use pickle to make sure a copy is created for n_envs=1 self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy), pickle.dumps(self.bernoulli_reset)) def reinit(self, env=None, policy=None, bernoulli_reset: bool = None): """ Re-initialize the sampler. """ # Update env and policy if passed if env is not None: self.env = env if policy is not None: self.policy = policy if bernoulli_reset is not None: self.bernoulli_reset = bernoulli_reset # Always broadcast to workers self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy), pickle.dumps(self.bernoulli_reset)) def sample(self) -> List[StepSequence]: """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. """ # Update policy's state self.pool.invoke_all(_ps_update_policy, self.policy.state_dict()) # Collect samples with tqdm(leave=False, file=sys.stdout, desc='Sampling', unit='steps' if self.min_steps is not None else 'rollouts') as pb: if self.min_steps is None: # Only minimum number of rollouts given, thus use run_map return self.pool.run_map(_ps_run_one, range(self.min_rollouts), pb) else: # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None) return self.pool.run_collect(self.min_steps, _ps_sample_one, collect_progressbar=pb, min_runs=self.min_rollouts)[0]