def eval_randomized_domain( pool: SamplerPool, env: SimEnv, randomizer: DomainRandomizer, policy: Policy, init_states: List[np.ndarray]) -> List[StepSequence]: """ Evaluate a policy in a randomized domain. :param pool: parallel sampler :param env: environment to evaluate in :param randomizer: randomizer used to sample random domain instances, inherited from `DomainRandomizer` :param policy: policy to evaluate :param init_states: initial states of the environment which will be fixed if not set to `None` :return: list of rollouts """ # Randomize the environments env = remove_all_dr_wrappers(env) env = DomainRandWrapperLive(env, randomizer) pool.invoke_all(_ps_init, pickle.dumps(env), pickle.dumps(policy)) # Run with progress bar with tqdm(leave=False, file=sys.stdout, unit="rollouts", desc="Sampling") as pb: return pool.run_map( functools.partial(_ps_run_one_init_state, eval=True), init_states, pb)
def eval_domain_params(pool: SamplerPool, env: SimEnv, policy: Policy, params: list, init_state=None) -> list: """ Evaluate a policy on a multidimensional grid of domain parameters. :param pool: parallel sampler :param env: environment to evaluate in :param policy: policy to evaluate :param params: multidimensional grid of domain parameters :param init_state: initial state of the environment which will be fixed if not set to None :return: list of rollouts """ # Strip all domain randomization wrappers from the environment env = remove_all_dr_wrappers(env, verbose=True) pool.invoke_all(_setup_env_policy, env, policy) # Run with progress bar with tqdm(leave=False, file=sys.stdout, unit='rollouts', desc='Sampling') as pb: return pool.run_map( functools.partial(_run_rollout_dp, init_state=init_state), params, pb)
def eval_domain_params( pool: SamplerPool, env: SimEnv, policy: Policy, params: List[Dict], init_state: Optional[np.ndarray] = None) -> List[StepSequence]: """ Evaluate a policy on a multidimensional grid of domain parameters. :param pool: parallel sampler :param env: environment to evaluate in :param policy: policy to evaluate :param params: multidimensional grid of domain parameters :param init_state: initial state of the environment which will be fixed if not set to `None` :return: list of rollouts """ # Strip all domain randomization wrappers from the environment env = remove_all_dr_wrappers(env, verbose=True) if init_state is not None: env.init_space = SingularStateSpace(fixed_state=init_state) pool.invoke_all(_ps_init, pickle.dumps(env), pickle.dumps(policy)) # Run with progress bar with tqdm(leave=False, file=sys.stdout, unit="rollouts", desc="Sampling") as pb: return pool.run_map( functools.partial(_ps_run_one_domain_param, eval=True), params, pb)
def eval_nominal_domain(pool: SamplerPool, env: SimEnv, policy: Policy, init_states: list) -> list: """ Evaluate a policy using the nominal (set in the given environment) domain parameters. :param pool: parallel sampler :param env: environment to evaluate in :param policy: policy to evaluate :param init_states: initial states of the environment which will be fixed if not set to None :return: list of rollouts """ # Strip all domain randomization wrappers from the environment env = remove_all_dr_wrappers(env, verbose=True) pool.invoke_all(_setup_env_policy, env, policy) # Run with progress bar with tqdm(leave=False, file=sys.stdout, unit='rollouts', desc='Sampling') as pb: return pool.run_map(_run_rollout_nom, init_states, pb)
def eval_randomized_domain(pool: SamplerPool, env: SimEnv, randomizer: DomainRandomizer, policy: Policy, init_states: list) -> list: """ Evaluate a policy in a randomized domain. :param pool: parallel sampler :param env: environment to evaluate in :param randomizer: randomizer used to sample random domain instances, inherited from `DomainRandomizer` :param policy: policy to evaluate :param init_states: initial states of the environment which will be fixed if not set to None :return: list of rollouts """ # Randomize the environments env = DomainRandWrapperLive(env, randomizer) pool.invoke_all(_setup_env_policy, env, policy) # Run with progress bar with tqdm(leave=False, file=sys.stdout, unit='rollouts', desc='Sampling') as pb: return pool.run_map(_run_rollout_nom, init_states, pb)
def eval_nominal_domain(pool: SamplerPool, env: SimEnv, policy: Policy, init_states: List[np.ndarray]) -> List[StepSequence]: """ Evaluate a policy using the nominal (set in the given environment) domain parameters. :param pool: parallel sampler :param env: environment to evaluate in :param policy: policy to evaluate :param init_states: initial states of the environment which will be fixed if not set to `None` :return: list of rollouts """ # Strip all domain randomization wrappers from the environment env = remove_all_dr_wrappers(env) pool.invoke_all(_ps_init, pickle.dumps(env), pickle.dumps(policy)) # Run with progress bar with tqdm(leave=False, file=sys.stdout, unit="rollouts", desc="Sampling") as pb: return pool.run_map( functools.partial(_ps_run_one_init_state, eval=True), init_states, pb)
class ParallelRolloutSampler(SamplerBase, Serializable): """ Class for sampling from multiple environments in parallel """ def __init__(self, env, policy, num_workers: int, *, min_rollouts: int = None, min_steps: int = None, show_progress_bar: bool = True, seed: int = None): """ Constructor :param env: environment to sample from :param policy: policy to act in the environment (can also be an exploration strategy) :param num_workers: number of parallel samplers :param min_rollouts: minimum number of complete rollouts to sample :param min_steps: minimum total number of steps to sample :param show_progress_bar: it `True`, display a progress bar using `tqdm` :param seed: seed value for the random number generators, pass `None` for no seeding """ Serializable._init(self, locals()) super().__init__(min_rollouts=min_rollouts, min_steps=min_steps) self.env = env self.policy = policy self.show_progress_bar = show_progress_bar # Set method to spawn if using cuda if self.policy.device == 'cuda': mp.set_start_method('spawn', force=True) # Create parallel pool. We use one thread per env because it's easier. self.pool = SamplerPool(num_workers) # Set all rngs' seeds if seed is not None: self.set_seed(seed) # Distribute environments. We use pickle to make sure a copy is created for n_envs=1 self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def set_seed(self, seed): """ Set a deterministic seed on all workers. :param seed: seed value for the random number generators """ self.pool.set_seed(seed) def reinit(self, env=None, policy=None): """ Re-initialize the sampler. :param env: the environment which the policy operates :param policy: the policy used for sampling """ # Update env and policy if passed if env is not None: self.env = env if policy is not None: self.policy = policy # Always broadcast to workers self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def sample(self, init_states: List[np.ndarray] = None, domain_params: List[np.ndarray] = None, eval: bool = False) -> List[StepSequence]: """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's initial state space :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`. :return: list of sampled rollouts """ # Update policy's state self.pool.invoke_all(_ps_update_policy, self.policy.state_dict()) # Collect samples with tqdm(leave=False, file=sys.stdout, desc='Sampling', disable=(not self.show_progress_bar), unit='steps' if self.min_steps is not None else 'rollouts') as pb: if self.min_steps is None: if init_states is None and domain_params is None: # Simply run min_rollouts times func = partial(_ps_run_one, eval=eval) arglist = range(self.min_rollouts) elif init_states is not None and domain_params is None: # Run every initial state so often that we at least get min_rollouts trajectories func = partial(_ps_run_one_init_state, eval=eval) rep_factor = ceil(self.min_rollouts / len(init_states)) arglist = rep_factor * init_states elif init_states is not None and domain_params is not None: # Run every combination of initial state and domain parameter so often that we at least get # min_rollouts trajectories func = partial(_ps_run_one_reset_kwargs, eval=eval) allcombs = list(product(init_states, domain_params)) rep_factor = ceil(self.min_rollouts / len(allcombs)) arglist = rep_factor * allcombs else: raise NotImplementedError # Only minimum number of rollouts given, thus use run_map return self.pool.run_map(func, arglist, pb) else: # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None) if init_states is None: return self.pool.run_collect(self.min_steps, partial(_ps_sample_one, eval=eval), collect_progressbar=pb, min_runs=self.min_rollouts)[0] else: raise NotImplementedError
class ParameterExplorationSampler(Serializable): """Parallel sampler for parameter exploration""" def __init__( self, env: Union[SimEnv, EnvWrapper], policy: Policy, num_init_states_per_domain: int, num_domains: int, num_workers: int, seed: Optional[int] = None, ): """ Constructor :param env: environment to sample from :param policy: policy used for sampling :param num_init_states_per_domain: number of rollouts to cover the variance over initial states :param num_domains: number of rollouts due to the variance over domain parameters :param num_workers: number of parallel samplers :param seed: seed value for the random number generators, pass `None` for no seeding; defaults to the last seed that was set with `pyrado.set_seed` """ if not isinstance(num_init_states_per_domain, int): raise pyrado.TypeErr(given=num_init_states_per_domain, expected_type=int) if num_init_states_per_domain < 1: raise pyrado.ValueErr(given=num_init_states_per_domain, ge_constraint="1") if not isinstance(num_domains, int): raise pyrado.TypeErr(given=num_domains, expected_type=int) if num_domains < 1: raise pyrado.ValueErr(given=num_domains, ge_constraint="1") Serializable._init(self, locals()) # Check environment for domain randomization wrappers (stops after finding the outermost) self._dr_wrapper = typed_env(env, DomainRandWrapper) if self._dr_wrapper is not None: assert isinstance(inner_env(env), SimEnv) # Remove them all from the env chain since we sample the domain parameter later explicitly env = remove_all_dr_wrappers(env) self.env, self.policy = env, policy self.num_init_states_per_domain = num_init_states_per_domain self.num_domains = num_domains # Set method to spawn if using cuda if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) # Create parallel pool. We use one thread per environment because it's easier. self.pool = SamplerPool(num_workers) if seed is NO_SEED_PASSED: seed = pyrado.get_base_seed() self._seed = seed # Initialize with -1 such that we start with the 0-th sample. Incrementing after sampling may cause issues when # the sampling crashes and the sample count is not incremented. self._sample_count = -1 # Distribute environments. We use pickle to make sure a copy is created for n_envs = 1 self.pool.invoke_all(_pes_init, pickle.dumps(self.env), pickle.dumps(self.policy)) @property def num_rollouts_per_param(self) -> int: """Get the number of rollouts per policy parameter set.""" return self.num_init_states_per_domain * self.num_domains def _sample_domain_params(self) -> list: """Sample domain parameters from the cached domain randomization wrapper.""" if self._dr_wrapper is None: # There was no randomizer, thus do not set any domain parameters return [None] * self.num_domains elif isinstance(self._dr_wrapper, DomainRandWrapperBuffer ) and self._dr_wrapper.buffer is not None: # Use buffered domain parameter sets idcs = np.random.randint(0, len(self._dr_wrapper.buffer), size=self.num_domains) return [self._dr_wrapper.buffer[i] for i in idcs] else: # Sample new domain parameters (same as in DomainRandWrapperBuffer.fill_buffer) self._dr_wrapper.randomizer.randomize(self.num_domains) return self._dr_wrapper.randomizer.get_params(-1, fmt="list", dtype="numpy") def _sample_one_init_state(self, domain_param: dict) -> Union[np.ndarray, None]: """ Sample an init state for the given domain parameter set(s). For some environments, the initial state space depends on the domain parameters, so we need to set them before sampling it. We can just reset `self.env` here safely though, since it's not used for anything else. :param domain_param: domain parameters to set :return: initial state, `None` if no initial state space is defined """ self.env.reset(domain_param=domain_param) ispace = attr_env_get(self.env, "init_space") if ispace is not None: return ispace.sample_uniform() else: # No init space, no init state return None def reinit(self, env: Optional[Env] = None, policy: Optional[Policy] = None): """ Re-initialize the sampler. :param env: the environment which the policy operates :param policy: the policy used for sampling """ # Update env and policy if passed if env is not None: self.env = env if policy is not None: self.policy = policy # Always broadcast to workers self.pool.invoke_all(_pes_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def sample( self, param_sets: to.Tensor, init_states: Optional[List[np.ndarray]] = None ) -> ParameterSamplingResult: """ Sample rollouts for a given set of parameters. .. note:: This method is **not** thread-safe! See for example the usage of `self._sample_count`. :param param_sets: sets of policy parameters :param init_states: fixed initial states, pass `None` to randomly sample initial states :return: data structure containing the policy parameter sets and the associated rollout data """ if init_states is not None and not isinstance(init_states, list): pyrado.TypeErr(given=init_states, expected_type=list) self._sample_count += 1 # Sample domain parameter sets domain_params = self._sample_domain_params() if not isinstance(domain_params, list): raise pyrado.TypeErr(given=domain_params, expected_type=list) # Sample the initial states for every domain, but reset before. Hence they are associated to their domain. if init_states is None: init_states = [ self._sample_one_init_state(dp) for dp in domain_params for _ in range(self.num_init_states_per_domain) ] else: # This is an edge case, but here we need as many init states as num_domains * num_init_states_per_domain if not len(init_states) == len( domain_params * self.num_init_states_per_domain): raise pyrado.ShapeErr(given=init_states, expected_match=domain_params * self.num_init_states_per_domain) # Repeat the sets for the number of initial states per domain domain_params *= self.num_init_states_per_domain # Explode parameter list for rollouts per param all_params = [(p, *r) for p in param_sets for r in zip(domain_params, init_states)] # Sample rollouts in parallel with tqdm(leave=False, file=sys.stdout, desc="Sampling", unit="rollouts") as pb: all_ros = self.pool.run_map( partial(_pes_sample_one, seed=self._seed, sub_seed=self._sample_count), list(enumerate(all_params)), pb) # Group rollouts by parameters ros_iter = iter(all_ros) return ParameterSamplingResult([ ParameterSample(params=p, rollouts=list( itertools.islice(ros_iter, self.num_rollouts_per_param))) for p in param_sets ])
def eval_domain_params_with_segmentwise_reset( pool: SamplerPool, env_sim: SimEnv, policy: Policy, segments_real_all: List[List[StepSequence]], domain_params_ml_all: List[List[dict]], stop_on_done: bool, use_rec: bool, ) -> List[List[StepSequence]]: """ Evaluate a policy for a given set of domain parameters, synchronizing the segments' initial states with the given target domain segments :param pool: parallel sampler :param env_sim: environment to evaluate in :param policy: policy to evaluate :param segments_real_all: all segments from the target domain rollout :param domain_params_ml_all: all domain parameters to evaluate over :param stop_on_done: if `True`, the rollouts are stopped as soon as they hit the state or observation space boundaries. This behavior is save, but can lead to short trajectories which are eventually padded with zeroes. Chose `False` to ignore the boundaries (dangerous on the real system). :param use_rec: `True` if pre-recorded actions have been used to generate the rollouts :return: list of segments of rollouts """ # Sample rollouts with the most likely domain parameter sets associated to that observation segments_ml_all = [ ] # all top max likelihood segments for all target domain rollouts for idx_r, (segments_real, domain_params_ml) in tqdm( enumerate(zip(segments_real_all, domain_params_ml_all)), total=len(segments_real_all), desc="Sampling", file=sys.stdout, leave=False, ): segments_ml = [ ] # all top max likelihood segments for one target domain rollout cnt_step = 0 # Iterate over target domain segments for segment_real in segments_real: # Initialize workers pool.invoke_all(_ps_init, pickle.dumps(env_sim), pickle.dumps(policy)) # Run without progress bar segments_dp = pool.run_map( functools.partial( _ps_run_one_reset_kwargs_segment, init_state=segment_real.states[0, :], len_segment=segment_real.length, stop_on_done=stop_on_done, use_rec=use_rec, idx_r=idx_r, cnt_step=cnt_step, eval=True, ), domain_params_ml, ) for sdp in segments_dp: assert np.allclose(sdp.states[0, :], segment_real.states[0, :]) if use_rec: check_act_equal(segment_real, sdp, check_applied=hasattr( sdp, "actions_applied")) # Increase step counter for next segment, and append all domain parameter segments cnt_step += segment_real.length segments_ml.append(segments_dp) # Append all segments for the current target domain rollout segments_ml_all.append(segments_ml) return segments_ml_all
class ParameterExplorationSampler(Serializable): """ Parallel sampler for parameter exploration """ def __init__(self, env: Env, policy: Policy, num_workers: int, num_rollouts_per_param: int, seed: int = None): """ Constructor :param env: environment to sample from :param policy: policy used for sampling :param num_workers: number of parallel samplers :param num_rollouts_per_param: number of rollouts per policy parameter set (and init state if specified) :param seed: seed value for the random number generators, pass `None` for no seeding """ if not isinstance(num_rollouts_per_param, int): raise pyrado.TypeErr(given=num_rollouts_per_param, expected_type=int) if num_rollouts_per_param < 1: raise pyrado.ValueErr(given=num_rollouts_per_param, ge_constraint='1') Serializable._init(self, locals()) # Check environment for domain randomization wrappers (stops after finding the outermost) self._dr_wrapper = typed_env(env, DomainRandWrapper) if self._dr_wrapper is not None: assert isinstance(inner_env(env), SimEnv) # Remove them all from the env chain since we sample the domain parameter later explicitly env = remove_all_dr_wrappers(env) self.env, self.policy = env, policy self.num_rollouts_per_param = num_rollouts_per_param # Create parallel pool. We use one thread per environment because it's easier. self.pool = SamplerPool(num_workers) # Set all rngs' seeds if seed is not None: self.pool.set_seed(seed) # Distribute environments. We use pickle to make sure a copy is created for n_envs = 1 self.pool.invoke_all(_pes_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def _sample_domain_params(self) -> [list, dict]: """ Sample domain parameters from the cached domain randomization wrapper. """ if self._dr_wrapper is None: # No params return [None] * self.num_rollouts_per_param elif isinstance(self._dr_wrapper, DomainRandWrapperBuffer ) and self._dr_wrapper.buffer is not None: # Use buffered param sets idcs = np.random.randint(0, len(self._dr_wrapper.buffer), size=self.num_rollouts_per_param) return [self._dr_wrapper.buffer[i] for i in idcs] else: # Sample new ones (same as in DomainRandWrapperBuffer.fill_buffer) self._dr_wrapper.randomizer.randomize(self.num_rollouts_per_param) return self._dr_wrapper.randomizer.get_params(-1, format='list', dtype='numpy') def _sample_one_init_state(self, domain_param: dict) -> [np.ndarray, None]: """ Sample an init state for the given domain parameter set(s). For some environments, the initial state space depends on the domain parameters, so we need to set them before sampling it. We can just reset `self.env` here safely though, since it's not used for anything else. :param domain_param: domain parameters to set :return: initial state, `None` if no initial state space is defined """ self.env.reset(domain_param=domain_param) ispace = attr_env_get(self.env, 'init_space') if ispace is not None: return ispace.sample_uniform() else: # No init space, no init state return None def sample(self, param_sets: to.Tensor) -> ParameterSamplingResult: """ Sample rollouts for a given set of parameters. :param param_sets: sets of policy parameters :return: data structure containing the policy parameter sets and the associated rollout data """ # Sample domain params for each rollout domain_params = self._sample_domain_params() if isinstance(domain_params, dict): # There is only one domain parameter set (i.e. one init state) init_states = [self._sample_one_init_state(domain_params)] domain_params = [ domain_params ] # cast to list of dict to make iterable like the next case elif isinstance(domain_params, list): # There are more than one domain parameter set (i.e. multiple init states) init_states = [ self._sample_one_init_state(dp) for dp in domain_params ] else: raise pyrado.TypeErr(given=domain_params, expected_type=[list, dict]) # Explode parameter list for rollouts per param all_params = [(p, *r) for p in param_sets for r in zip(domain_params, init_states)] # Sample rollouts in parallel with tqdm(leave=False, file=sys.stdout, desc='Sampling', unit='rollouts') as pb: all_ros = self.pool.run_map(_pes_sample_one, all_params, pb) # Group rollouts by parameters ros_iter = iter(all_ros) return ParameterSamplingResult([ ParameterSample(params=p, rollouts=list( itertools.islice(ros_iter, self.num_rollouts_per_param))) for p in param_sets ])
class ParallelRolloutSampler(SamplerBase, Serializable): """Class for sampling from multiple environments in parallel""" def __init__( self, env, policy, num_workers: int, *, min_rollouts: int = None, min_steps: int = None, show_progress_bar: bool = True, seed: int = NO_SEED_PASSED, ): """ Constructor :param env: environment to sample from :param policy: policy to act in the environment (can also be an exploration strategy) :param num_workers: number of parallel samplers :param min_rollouts: minimum number of complete rollouts to sample :param min_steps: minimum total number of steps to sample :param show_progress_bar: it `True`, display a progress bar using `tqdm` :param seed: seed value for the random number generators, pass `None` for no seeding; defaults to the last seed that was set with `pyrado.set_seed` """ Serializable._init(self, locals()) super().__init__(min_rollouts=min_rollouts, min_steps=min_steps) self.env = env self.policy = policy self.show_progress_bar = show_progress_bar # Set method to spawn if using cuda if mp.get_start_method(allow_none=True) != "spawn": mp.set_start_method("spawn", force=True) # Create parallel pool. We use one thread per env because it's easier. self.pool = SamplerPool(num_workers) if seed is NO_SEED_PASSED: seed = pyrado.get_base_seed() self._seed = seed # Initialize with -1 such that we start with the 0-th sample. Incrementing after sampling may cause issues when # the sampling crashes and the sample count is not incremented. self._sample_count = -1 # Distribute environments. We use pickle to make sure a copy is created for n_envs=1 self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def reinit(self, env: Optional[Env] = None, policy: Optional[Policy] = None): """ Re-initialize the sampler. :param env: the environment which the policy operates :param policy: the policy used for sampling """ # Update env and policy if passed if env is not None: self.env = env if policy is not None: self.policy = policy # Always broadcast to workers self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy)) def sample( self, init_states: Optional[List[np.ndarray]] = None, domain_params: Optional[List[dict]] = None, eval: bool = False, ) -> List[StepSequence]: """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. .. note:: This method is **not** thread-safe! See for example the usage of `self._sample_count`. :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's initial state space :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`. :return: list of sampled rollouts """ self._sample_count += 1 # Update policy's state self.pool.invoke_all(_ps_update_policy, self.policy.state_dict()) # Collect samples with tqdm( leave=False, file=sys.stdout, desc="Sampling", disable=(not self.show_progress_bar), unit="steps" if self.min_steps is not None else "rollouts", ) as pb: if self.min_steps is None: if init_states is None and domain_params is None: # Simply run min_rollouts times func = partial(_ps_run_one, eval=eval) arglist = list(range(self.min_rollouts)) elif init_states is not None and domain_params is None: # Run every initial state so often that we at least get min_rollouts trajectories func = partial(_ps_run_one_init_state, eval=eval) rep_factor = ceil(self.min_rollouts / len(init_states)) arglist = list(enumerate(rep_factor * init_states)) elif init_states is None and domain_params is not None: # Run every domain parameter set so often that we at least get min_rollouts trajectories func = partial(_ps_run_one_domain_param, eval=eval) rep_factor = ceil(self.min_rollouts / len(domain_params)) arglist = list(enumerate(rep_factor * domain_params)) elif init_states is not None and domain_params is not None: # Run every combination of initial state and domain parameter so often that we at least get # min_rollouts trajectories func = partial(_ps_run_one_reset_kwargs, eval=eval) allcombs = list(product(init_states, domain_params)) rep_factor = ceil(self.min_rollouts / len(allcombs)) arglist = list(enumerate(rep_factor * allcombs)) # Only minimum number of rollouts given, thus use run_map return self.pool.run_map( partial(func, seed=self._seed, sub_seed=self._sample_count), arglist, pb) else: # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None) if init_states is None: return self.pool.run_collect( self.min_steps, partial(_ps_sample_one, eval=eval, seed=self._seed, sub_seed=self._sample_count), collect_progressbar=pb, min_runs=self.min_rollouts, )[0] else: raise NotImplementedError
class ParallelSampler(SamplerBase, Serializable): """ Class for sampling from multiple environments in parallel """ def __init__(self, env, policy, num_envs: int, *, min_rollouts: int = None, min_steps: int = None, bernoulli_reset: bool = None, seed: int = None): """ Constructor :param env: environment to sample from :param policy: policy to act in the environment (can also be an exploration strategy) :param num_envs: number of parallel samplers :param min_rollouts: minimum number of complete rollouts to sample. :param min_steps: minimum total number of steps to sample. :param bernoulli_reset: probability for resetting after the current time step :param seed: Seed to use. Every subprocess is seeded with seed+thread_number """ Serializable._init(self, locals()) super().__init__(min_rollouts=min_rollouts, min_steps=min_steps) self.env = env self.policy = policy self.bernoulli_reset = bernoulli_reset # Set method to spawn if using cuda if self.policy.device == 'cuda': mp.set_start_method('spawn', force=True) # Create parallel pool. We use one thread per env because it's easier. self.pool = SamplerPool(num_envs) if seed is not None: self.pool.set_seed(seed) # Distribute environments. We use pickle to make sure a copy is created for n_envs=1 self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy), pickle.dumps(self.bernoulli_reset)) def reinit(self, env=None, policy=None, bernoulli_reset: bool = None): """ Re-initialize the sampler. """ # Update env and policy if passed if env is not None: self.env = env if policy is not None: self.policy = policy if bernoulli_reset is not None: self.bernoulli_reset = bernoulli_reset # Always broadcast to workers self.pool.invoke_all(_ps_init, pickle.dumps(self.env), pickle.dumps(self.policy), pickle.dumps(self.bernoulli_reset)) def sample(self) -> List[StepSequence]: """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. """ # Update policy's state self.pool.invoke_all(_ps_update_policy, self.policy.state_dict()) # Collect samples with tqdm(leave=False, file=sys.stdout, desc='Sampling', unit='steps' if self.min_steps is not None else 'rollouts') as pb: if self.min_steps is None: # Only minimum number of rollouts given, thus use run_map return self.pool.run_map(_ps_run_one, range(self.min_rollouts), pb) else: # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None) return self.pool.run_collect(self.min_steps, _ps_sample_one, collect_progressbar=pb, min_runs=self.min_rollouts)[0]