예제 #1
0
def eval_randomized_domain(
        pool: SamplerPool, env: SimEnv, randomizer: DomainRandomizer,
        policy: Policy, init_states: List[np.ndarray]) -> List[StepSequence]:
    """
    Evaluate a policy in a randomized domain.

    :param pool: parallel sampler
    :param env: environment to evaluate in
    :param randomizer: randomizer used to sample random domain instances, inherited from `DomainRandomizer`
    :param policy: policy to evaluate
    :param init_states: initial states of the environment which will be fixed if not set to `None`
    :return: list of rollouts
    """
    # Randomize the environments
    env = remove_all_dr_wrappers(env)
    env = DomainRandWrapperLive(env, randomizer)

    pool.invoke_all(_ps_init, pickle.dumps(env), pickle.dumps(policy))

    # Run with progress bar
    with tqdm(leave=False, file=sys.stdout, unit="rollouts",
              desc="Sampling") as pb:
        return pool.run_map(
            functools.partial(_ps_run_one_init_state, eval=True), init_states,
            pb)
예제 #2
0
def eval_domain_params(pool: SamplerPool,
                       env: SimEnv,
                       policy: Policy,
                       params: list,
                       init_state=None) -> list:
    """
    Evaluate a policy on a multidimensional grid of domain parameters.

    :param pool: parallel sampler
    :param env: environment to evaluate in
    :param policy: policy to evaluate
    :param params: multidimensional grid of domain parameters
    :param init_state: initial state of the environment which will be fixed if not set to None
    :return: list of rollouts
    """
    # Strip all domain randomization wrappers from the environment
    env = remove_all_dr_wrappers(env, verbose=True)

    pool.invoke_all(_setup_env_policy, env, policy)

    # Run with progress bar
    with tqdm(leave=False, file=sys.stdout, unit='rollouts',
              desc='Sampling') as pb:
        return pool.run_map(
            functools.partial(_run_rollout_dp, init_state=init_state), params,
            pb)
예제 #3
0
def eval_domain_params(
        pool: SamplerPool,
        env: SimEnv,
        policy: Policy,
        params: List[Dict],
        init_state: Optional[np.ndarray] = None) -> List[StepSequence]:
    """
    Evaluate a policy on a multidimensional grid of domain parameters.

    :param pool: parallel sampler
    :param env: environment to evaluate in
    :param policy: policy to evaluate
    :param params: multidimensional grid of domain parameters
    :param init_state: initial state of the environment which will be fixed if not set to `None`
    :return: list of rollouts
    """
    # Strip all domain randomization wrappers from the environment
    env = remove_all_dr_wrappers(env, verbose=True)
    if init_state is not None:
        env.init_space = SingularStateSpace(fixed_state=init_state)

    pool.invoke_all(_ps_init, pickle.dumps(env), pickle.dumps(policy))

    # Run with progress bar
    with tqdm(leave=False, file=sys.stdout, unit="rollouts",
              desc="Sampling") as pb:
        return pool.run_map(
            functools.partial(_ps_run_one_domain_param, eval=True), params, pb)
def eval_nominal_domain(pool: SamplerPool, env: SimEnv, policy: Policy, init_states: list) -> list:
    """
    Evaluate a policy using the nominal (set in the given environment) domain parameters.

    :param pool: parallel sampler
    :param env: environment to evaluate in
    :param policy: policy to evaluate
    :param init_states: initial states of the environment which will be fixed if not set to None
    :return: list of rollouts
    """
    # Strip all domain randomization wrappers from the environment
    env = remove_all_dr_wrappers(env, verbose=True)

    pool.invoke_all(_setup_env_policy, env, policy)

    # Run with progress bar
    with tqdm(leave=False, file=sys.stdout, unit='rollouts', desc='Sampling') as pb:
        return pool.run_map(_run_rollout_nom, init_states, pb)
예제 #5
0
def eval_randomized_domain(pool: SamplerPool, env: SimEnv,
                           randomizer: DomainRandomizer, policy: Policy,
                           init_states: list) -> list:
    """
    Evaluate a policy in a randomized domain.

    :param pool: parallel sampler
    :param env: environment to evaluate in
    :param randomizer: randomizer used to sample random domain instances, inherited from `DomainRandomizer`
    :param policy: policy to evaluate
    :param init_states: initial states of the environment which will be fixed if not set to None
    :return: list of rollouts
    """
    # Randomize the environments
    env = DomainRandWrapperLive(env, randomizer)

    pool.invoke_all(_setup_env_policy, env, policy)

    # Run with progress bar
    with tqdm(leave=False, file=sys.stdout, unit='rollouts',
              desc='Sampling') as pb:
        return pool.run_map(_run_rollout_nom, init_states, pb)
예제 #6
0
def eval_nominal_domain(pool: SamplerPool, env: SimEnv, policy: Policy,
                        init_states: List[np.ndarray]) -> List[StepSequence]:
    """
    Evaluate a policy using the nominal (set in the given environment) domain parameters.

    :param pool: parallel sampler
    :param env: environment to evaluate in
    :param policy: policy to evaluate
    :param init_states: initial states of the environment which will be fixed if not set to `None`
    :return: list of rollouts
    """
    # Strip all domain randomization wrappers from the environment
    env = remove_all_dr_wrappers(env)

    pool.invoke_all(_ps_init, pickle.dumps(env), pickle.dumps(policy))

    # Run with progress bar
    with tqdm(leave=False, file=sys.stdout, unit="rollouts",
              desc="Sampling") as pb:
        return pool.run_map(
            functools.partial(_ps_run_one_init_state, eval=True), init_states,
            pb)
예제 #7
0
class ParallelRolloutSampler(SamplerBase, Serializable):
    """ Class for sampling from multiple environments in parallel """
    def __init__(self,
                 env,
                 policy,
                 num_workers: int,
                 *,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 show_progress_bar: bool = True,
                 seed: int = None):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy to act in the environment (can also be an exploration strategy)
        :param num_workers: number of parallel samplers
        :param min_rollouts: minimum number of complete rollouts to sample
        :param min_steps: minimum total number of steps to sample
        :param show_progress_bar: it `True`, display a progress bar using `tqdm`
        :param seed: seed value for the random number generators, pass `None` for no seeding
        """
        Serializable._init(self, locals())
        super().__init__(min_rollouts=min_rollouts, min_steps=min_steps)

        self.env = env
        self.policy = policy
        self.show_progress_bar = show_progress_bar

        # Set method to spawn if using cuda
        if self.policy.device == 'cuda':
            mp.set_start_method('spawn', force=True)

        # Create parallel pool. We use one thread per env because it's easier.
        self.pool = SamplerPool(num_workers)

        # Set all rngs' seeds
        if seed is not None:
            self.set_seed(seed)

        # Distribute environments. We use pickle to make sure a copy is created for n_envs=1
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def set_seed(self, seed):
        """
        Set a deterministic seed on all workers.

        :param seed: seed value for the random number generators
        """
        self.pool.set_seed(seed)

    def reinit(self, env=None, policy=None):
        """
        Re-initialize the sampler.

        :param env: the environment which the policy operates
        :param policy: the policy used for sampling
        """
        # Update env and policy if passed
        if env is not None:
            self.env = env
        if policy is not None:
            self.policy = policy

        # Always broadcast to workers
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def sample(self,
               init_states: List[np.ndarray] = None,
               domain_params: List[np.ndarray] = None,
               eval: bool = False) -> List[StepSequence]:
        """
        Do the sampling according to the previously given environment, policy, and number of steps/rollouts.

        :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's
                            initial state space
        :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them
        :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`.
        :return: list of sampled rollouts
        """
        # Update policy's state
        self.pool.invoke_all(_ps_update_policy, self.policy.state_dict())

        # Collect samples
        with tqdm(leave=False,
                  file=sys.stdout,
                  desc='Sampling',
                  disable=(not self.show_progress_bar),
                  unit='steps'
                  if self.min_steps is not None else 'rollouts') as pb:

            if self.min_steps is None:
                if init_states is None and domain_params is None:
                    # Simply run min_rollouts times
                    func = partial(_ps_run_one, eval=eval)
                    arglist = range(self.min_rollouts)
                elif init_states is not None and domain_params is None:
                    # Run every initial state so often that we at least get min_rollouts trajectories
                    func = partial(_ps_run_one_init_state, eval=eval)
                    rep_factor = ceil(self.min_rollouts / len(init_states))
                    arglist = rep_factor * init_states
                elif init_states is not None and domain_params is not None:
                    # Run every combination of initial state and domain parameter so often that we at least get
                    # min_rollouts trajectories
                    func = partial(_ps_run_one_reset_kwargs, eval=eval)
                    allcombs = list(product(init_states, domain_params))
                    rep_factor = ceil(self.min_rollouts / len(allcombs))
                    arglist = rep_factor * allcombs
                else:
                    raise NotImplementedError

                # Only minimum number of rollouts given, thus use run_map
                return self.pool.run_map(func, arglist, pb)

            else:
                # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None)
                if init_states is None:
                    return self.pool.run_collect(self.min_steps,
                                                 partial(_ps_sample_one,
                                                         eval=eval),
                                                 collect_progressbar=pb,
                                                 min_runs=self.min_rollouts)[0]
                else:
                    raise NotImplementedError
class ParameterExplorationSampler(Serializable):
    """Parallel sampler for parameter exploration"""
    def __init__(
        self,
        env: Union[SimEnv, EnvWrapper],
        policy: Policy,
        num_init_states_per_domain: int,
        num_domains: int,
        num_workers: int,
        seed: Optional[int] = None,
    ):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy used for sampling
        :param num_init_states_per_domain: number of rollouts to cover the variance over initial states
        :param num_domains: number of rollouts due to the variance over domain parameters
        :param num_workers: number of parallel samplers
        :param seed: seed value for the random number generators, pass `None` for no seeding; defaults to the last seed
                     that was set with `pyrado.set_seed`
        """
        if not isinstance(num_init_states_per_domain, int):
            raise pyrado.TypeErr(given=num_init_states_per_domain,
                                 expected_type=int)
        if num_init_states_per_domain < 1:
            raise pyrado.ValueErr(given=num_init_states_per_domain,
                                  ge_constraint="1")
        if not isinstance(num_domains, int):
            raise pyrado.TypeErr(given=num_domains, expected_type=int)
        if num_domains < 1:
            raise pyrado.ValueErr(given=num_domains, ge_constraint="1")

        Serializable._init(self, locals())

        # Check environment for domain randomization wrappers (stops after finding the outermost)
        self._dr_wrapper = typed_env(env, DomainRandWrapper)
        if self._dr_wrapper is not None:
            assert isinstance(inner_env(env), SimEnv)
            # Remove them all from the env chain since we sample the domain parameter later explicitly
            env = remove_all_dr_wrappers(env)

        self.env, self.policy = env, policy
        self.num_init_states_per_domain = num_init_states_per_domain
        self.num_domains = num_domains

        # Set method to spawn if using cuda
        if mp.get_start_method(allow_none=True) != "spawn":
            mp.set_start_method("spawn", force=True)

        # Create parallel pool. We use one thread per environment because it's easier.
        self.pool = SamplerPool(num_workers)

        if seed is NO_SEED_PASSED:
            seed = pyrado.get_base_seed()
        self._seed = seed
        # Initialize with -1 such that we start with the 0-th sample. Incrementing after sampling may cause issues when
        # the sampling crashes and the sample count is not incremented.
        self._sample_count = -1

        # Distribute environments. We use pickle to make sure a copy is created for n_envs = 1
        self.pool.invoke_all(_pes_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    @property
    def num_rollouts_per_param(self) -> int:
        """Get the number of rollouts per policy parameter set."""
        return self.num_init_states_per_domain * self.num_domains

    def _sample_domain_params(self) -> list:
        """Sample domain parameters from the cached domain randomization wrapper."""
        if self._dr_wrapper is None:
            # There was no randomizer, thus do not set any domain parameters
            return [None] * self.num_domains

        elif isinstance(self._dr_wrapper, DomainRandWrapperBuffer
                        ) and self._dr_wrapper.buffer is not None:
            # Use buffered domain parameter sets
            idcs = np.random.randint(0,
                                     len(self._dr_wrapper.buffer),
                                     size=self.num_domains)
            return [self._dr_wrapper.buffer[i] for i in idcs]

        else:
            # Sample new domain parameters (same as in DomainRandWrapperBuffer.fill_buffer)
            self._dr_wrapper.randomizer.randomize(self.num_domains)
            return self._dr_wrapper.randomizer.get_params(-1,
                                                          fmt="list",
                                                          dtype="numpy")

    def _sample_one_init_state(self,
                               domain_param: dict) -> Union[np.ndarray, None]:
        """
        Sample an init state for the given domain parameter set(s).
        For some environments, the initial state space depends on the domain parameters, so we need to set them before
        sampling it. We can just reset `self.env` here safely though, since it's not used for anything else.

        :param domain_param: domain parameters to set
        :return: initial state, `None` if no initial state space is defined
        """
        self.env.reset(domain_param=domain_param)
        ispace = attr_env_get(self.env, "init_space")
        if ispace is not None:
            return ispace.sample_uniform()
        else:
            # No init space, no init state
            return None

    def reinit(self,
               env: Optional[Env] = None,
               policy: Optional[Policy] = None):
        """
        Re-initialize the sampler.

        :param env: the environment which the policy operates
        :param policy: the policy used for sampling
        """
        # Update env and policy if passed
        if env is not None:
            self.env = env
        if policy is not None:
            self.policy = policy

        # Always broadcast to workers
        self.pool.invoke_all(_pes_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def sample(
        self,
        param_sets: to.Tensor,
        init_states: Optional[List[np.ndarray]] = None
    ) -> ParameterSamplingResult:
        """
        Sample rollouts for a given set of parameters.

        .. note::
            This method is **not** thread-safe! See for example the usage of `self._sample_count`.

        :param param_sets: sets of policy parameters
        :param init_states: fixed initial states, pass `None` to randomly sample initial states
        :return: data structure containing the policy parameter sets and the associated rollout data
        """
        if init_states is not None and not isinstance(init_states, list):
            pyrado.TypeErr(given=init_states, expected_type=list)

        self._sample_count += 1

        # Sample domain parameter sets
        domain_params = self._sample_domain_params()
        if not isinstance(domain_params, list):
            raise pyrado.TypeErr(given=domain_params, expected_type=list)

        # Sample the initial states for every domain, but reset before. Hence they are associated to their domain.
        if init_states is None:
            init_states = [
                self._sample_one_init_state(dp) for dp in domain_params
                for _ in range(self.num_init_states_per_domain)
            ]
        else:
            # This is an edge case, but here we need as many init states as num_domains * num_init_states_per_domain
            if not len(init_states) == len(
                    domain_params * self.num_init_states_per_domain):
                raise pyrado.ShapeErr(given=init_states,
                                      expected_match=domain_params *
                                      self.num_init_states_per_domain)

        # Repeat the sets for the number of initial states per domain
        domain_params *= self.num_init_states_per_domain

        # Explode parameter list for rollouts per param
        all_params = [(p, *r) for p in param_sets
                      for r in zip(domain_params, init_states)]

        # Sample rollouts in parallel
        with tqdm(leave=False,
                  file=sys.stdout,
                  desc="Sampling",
                  unit="rollouts") as pb:
            all_ros = self.pool.run_map(
                partial(_pes_sample_one,
                        seed=self._seed,
                        sub_seed=self._sample_count),
                list(enumerate(all_params)), pb)

        # Group rollouts by parameters
        ros_iter = iter(all_ros)
        return ParameterSamplingResult([
            ParameterSample(params=p,
                            rollouts=list(
                                itertools.islice(ros_iter,
                                                 self.num_rollouts_per_param)))
            for p in param_sets
        ])
예제 #9
0
def eval_domain_params_with_segmentwise_reset(
    pool: SamplerPool,
    env_sim: SimEnv,
    policy: Policy,
    segments_real_all: List[List[StepSequence]],
    domain_params_ml_all: List[List[dict]],
    stop_on_done: bool,
    use_rec: bool,
) -> List[List[StepSequence]]:
    """
    Evaluate a policy for a given set of domain parameters, synchronizing the segments' initial states with the given
    target domain segments

    :param pool: parallel sampler
    :param env_sim: environment to evaluate in
    :param policy: policy to evaluate
    :param segments_real_all: all segments from the target domain rollout
    :param domain_params_ml_all: all domain parameters to evaluate over
    :param stop_on_done: if `True`, the rollouts are stopped as soon as they hit the state or observation space
                             boundaries. This behavior is save, but can lead to short trajectories which are eventually
                             padded with zeroes. Chose `False` to ignore the boundaries (dangerous on the real system).
    :param use_rec: `True` if pre-recorded actions have been used to generate the rollouts
    :return: list of segments of rollouts
    """
    # Sample rollouts with the most likely domain parameter sets associated to that observation
    segments_ml_all = [
    ]  # all top max likelihood segments for all target domain rollouts
    for idx_r, (segments_real, domain_params_ml) in tqdm(
            enumerate(zip(segments_real_all, domain_params_ml_all)),
            total=len(segments_real_all),
            desc="Sampling",
            file=sys.stdout,
            leave=False,
    ):
        segments_ml = [
        ]  # all top max likelihood segments for one target domain rollout
        cnt_step = 0

        # Iterate over target domain segments
        for segment_real in segments_real:
            # Initialize workers
            pool.invoke_all(_ps_init, pickle.dumps(env_sim),
                            pickle.dumps(policy))

            # Run without progress bar
            segments_dp = pool.run_map(
                functools.partial(
                    _ps_run_one_reset_kwargs_segment,
                    init_state=segment_real.states[0, :],
                    len_segment=segment_real.length,
                    stop_on_done=stop_on_done,
                    use_rec=use_rec,
                    idx_r=idx_r,
                    cnt_step=cnt_step,
                    eval=True,
                ),
                domain_params_ml,
            )
            for sdp in segments_dp:
                assert np.allclose(sdp.states[0, :], segment_real.states[0, :])
                if use_rec:
                    check_act_equal(segment_real,
                                    sdp,
                                    check_applied=hasattr(
                                        sdp, "actions_applied"))

            # Increase step counter for next segment, and append all domain parameter segments
            cnt_step += segment_real.length
            segments_ml.append(segments_dp)

        # Append all segments for the current target domain rollout
        segments_ml_all.append(segments_ml)

    return segments_ml_all
예제 #10
0
class ParameterExplorationSampler(Serializable):
    """ Parallel sampler for parameter exploration """
    def __init__(self,
                 env: Env,
                 policy: Policy,
                 num_workers: int,
                 num_rollouts_per_param: int,
                 seed: int = None):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy used for sampling
        :param num_workers: number of parallel samplers
        :param num_rollouts_per_param: number of rollouts per policy parameter set (and init state if specified)
        :param seed: seed value for the random number generators, pass `None` for no seeding
        """
        if not isinstance(num_rollouts_per_param, int):
            raise pyrado.TypeErr(given=num_rollouts_per_param,
                                 expected_type=int)
        if num_rollouts_per_param < 1:
            raise pyrado.ValueErr(given=num_rollouts_per_param,
                                  ge_constraint='1')

        Serializable._init(self, locals())

        # Check environment for domain randomization wrappers (stops after finding the outermost)
        self._dr_wrapper = typed_env(env, DomainRandWrapper)
        if self._dr_wrapper is not None:
            assert isinstance(inner_env(env), SimEnv)
            # Remove them all from the env chain since we sample the domain parameter later explicitly
            env = remove_all_dr_wrappers(env)

        self.env, self.policy = env, policy
        self.num_rollouts_per_param = num_rollouts_per_param

        # Create parallel pool. We use one thread per environment because it's easier.
        self.pool = SamplerPool(num_workers)

        # Set all rngs' seeds
        if seed is not None:
            self.pool.set_seed(seed)

        # Distribute environments. We use pickle to make sure a copy is created for n_envs = 1
        self.pool.invoke_all(_pes_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def _sample_domain_params(self) -> [list, dict]:
        """ Sample domain parameters from the cached domain randomization wrapper. """
        if self._dr_wrapper is None:
            # No params
            return [None] * self.num_rollouts_per_param

        elif isinstance(self._dr_wrapper, DomainRandWrapperBuffer
                        ) and self._dr_wrapper.buffer is not None:
            # Use buffered param sets
            idcs = np.random.randint(0,
                                     len(self._dr_wrapper.buffer),
                                     size=self.num_rollouts_per_param)
            return [self._dr_wrapper.buffer[i] for i in idcs]

        else:
            # Sample new ones (same as in DomainRandWrapperBuffer.fill_buffer)
            self._dr_wrapper.randomizer.randomize(self.num_rollouts_per_param)
            return self._dr_wrapper.randomizer.get_params(-1,
                                                          format='list',
                                                          dtype='numpy')

    def _sample_one_init_state(self, domain_param: dict) -> [np.ndarray, None]:
        """
        Sample an init state for the given domain parameter set(s).
        For some environments, the initial state space depends on the domain parameters, so we need to set them before
        sampling it. We can just reset `self.env` here safely though, since it's not used for anything else.

        :param domain_param: domain parameters to set
        :return: initial state, `None` if no initial state space is defined
        """
        self.env.reset(domain_param=domain_param)
        ispace = attr_env_get(self.env, 'init_space')
        if ispace is not None:
            return ispace.sample_uniform()
        else:
            # No init space, no init state
            return None

    def sample(self, param_sets: to.Tensor) -> ParameterSamplingResult:
        """
        Sample rollouts for a given set of parameters.

        :param param_sets: sets of policy parameters
        :return: data structure containing the policy parameter sets and the associated rollout data
        """
        # Sample domain params for each rollout
        domain_params = self._sample_domain_params()

        if isinstance(domain_params, dict):
            # There is only one domain parameter set (i.e. one init state)
            init_states = [self._sample_one_init_state(domain_params)]
            domain_params = [
                domain_params
            ]  # cast to list of dict to make iterable like the next case
        elif isinstance(domain_params, list):
            # There are more than one domain parameter set (i.e. multiple init states)
            init_states = [
                self._sample_one_init_state(dp) for dp in domain_params
            ]
        else:
            raise pyrado.TypeErr(given=domain_params,
                                 expected_type=[list, dict])

        # Explode parameter list for rollouts per param
        all_params = [(p, *r) for p in param_sets
                      for r in zip(domain_params, init_states)]

        # Sample rollouts in parallel
        with tqdm(leave=False,
                  file=sys.stdout,
                  desc='Sampling',
                  unit='rollouts') as pb:
            all_ros = self.pool.run_map(_pes_sample_one, all_params, pb)

        # Group rollouts by parameters
        ros_iter = iter(all_ros)
        return ParameterSamplingResult([
            ParameterSample(params=p,
                            rollouts=list(
                                itertools.islice(ros_iter,
                                                 self.num_rollouts_per_param)))
            for p in param_sets
        ])
예제 #11
0
class ParallelRolloutSampler(SamplerBase, Serializable):
    """Class for sampling from multiple environments in parallel"""
    def __init__(
        self,
        env,
        policy,
        num_workers: int,
        *,
        min_rollouts: int = None,
        min_steps: int = None,
        show_progress_bar: bool = True,
        seed: int = NO_SEED_PASSED,
    ):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy to act in the environment (can also be an exploration strategy)
        :param num_workers: number of parallel samplers
        :param min_rollouts: minimum number of complete rollouts to sample
        :param min_steps: minimum total number of steps to sample
        :param show_progress_bar: it `True`, display a progress bar using `tqdm`
        :param seed: seed value for the random number generators, pass `None` for no seeding; defaults to the last seed
                     that was set with `pyrado.set_seed`
        """
        Serializable._init(self, locals())
        super().__init__(min_rollouts=min_rollouts, min_steps=min_steps)

        self.env = env
        self.policy = policy
        self.show_progress_bar = show_progress_bar

        # Set method to spawn if using cuda
        if mp.get_start_method(allow_none=True) != "spawn":
            mp.set_start_method("spawn", force=True)

        # Create parallel pool. We use one thread per env because it's easier.
        self.pool = SamplerPool(num_workers)

        if seed is NO_SEED_PASSED:
            seed = pyrado.get_base_seed()
        self._seed = seed
        # Initialize with -1 such that we start with the 0-th sample. Incrementing after sampling may cause issues when
        # the sampling crashes and the sample count is not incremented.
        self._sample_count = -1

        # Distribute environments. We use pickle to make sure a copy is created for n_envs=1
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def reinit(self,
               env: Optional[Env] = None,
               policy: Optional[Policy] = None):
        """
        Re-initialize the sampler.

        :param env: the environment which the policy operates
        :param policy: the policy used for sampling
        """
        # Update env and policy if passed
        if env is not None:
            self.env = env
        if policy is not None:
            self.policy = policy

        # Always broadcast to workers
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def sample(
        self,
        init_states: Optional[List[np.ndarray]] = None,
        domain_params: Optional[List[dict]] = None,
        eval: bool = False,
    ) -> List[StepSequence]:
        """
        Do the sampling according to the previously given environment, policy, and number of steps/rollouts.

        .. note::
            This method is **not** thread-safe! See for example the usage of `self._sample_count`.

        :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's
                            initial state space
        :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them
        :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`.
        :return: list of sampled rollouts
        """
        self._sample_count += 1

        # Update policy's state
        self.pool.invoke_all(_ps_update_policy, self.policy.state_dict())

        # Collect samples
        with tqdm(
                leave=False,
                file=sys.stdout,
                desc="Sampling",
                disable=(not self.show_progress_bar),
                unit="steps" if self.min_steps is not None else "rollouts",
        ) as pb:

            if self.min_steps is None:
                if init_states is None and domain_params is None:
                    # Simply run min_rollouts times
                    func = partial(_ps_run_one, eval=eval)
                    arglist = list(range(self.min_rollouts))
                elif init_states is not None and domain_params is None:
                    # Run every initial state so often that we at least get min_rollouts trajectories
                    func = partial(_ps_run_one_init_state, eval=eval)
                    rep_factor = ceil(self.min_rollouts / len(init_states))
                    arglist = list(enumerate(rep_factor * init_states))
                elif init_states is None and domain_params is not None:
                    # Run every domain parameter set so often that we at least get min_rollouts trajectories
                    func = partial(_ps_run_one_domain_param, eval=eval)
                    rep_factor = ceil(self.min_rollouts / len(domain_params))
                    arglist = list(enumerate(rep_factor * domain_params))
                elif init_states is not None and domain_params is not None:
                    # Run every combination of initial state and domain parameter so often that we at least get
                    # min_rollouts trajectories
                    func = partial(_ps_run_one_reset_kwargs, eval=eval)
                    allcombs = list(product(init_states, domain_params))
                    rep_factor = ceil(self.min_rollouts / len(allcombs))
                    arglist = list(enumerate(rep_factor * allcombs))

                # Only minimum number of rollouts given, thus use run_map
                return self.pool.run_map(
                    partial(func, seed=self._seed,
                            sub_seed=self._sample_count), arglist, pb)

            else:
                # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None)
                if init_states is None:
                    return self.pool.run_collect(
                        self.min_steps,
                        partial(_ps_sample_one,
                                eval=eval,
                                seed=self._seed,
                                sub_seed=self._sample_count),
                        collect_progressbar=pb,
                        min_runs=self.min_rollouts,
                    )[0]
                else:
                    raise NotImplementedError
예제 #12
0
class ParallelSampler(SamplerBase, Serializable):
    """ Class for sampling from multiple environments in parallel """
    def __init__(self,
                 env,
                 policy,
                 num_envs: int,
                 *,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 bernoulli_reset: bool = None,
                 seed: int = None):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy to act in the environment (can also be an exploration strategy)
        :param num_envs: number of parallel samplers
        :param min_rollouts: minimum number of complete rollouts to sample.
        :param min_steps: minimum total number of steps to sample.
        :param bernoulli_reset: probability for resetting after the current time step
        :param seed: Seed to use. Every subprocess is seeded with seed+thread_number
        """
        Serializable._init(self, locals())
        super().__init__(min_rollouts=min_rollouts, min_steps=min_steps)

        self.env = env
        self.policy = policy
        self.bernoulli_reset = bernoulli_reset

        # Set method to spawn if using cuda
        if self.policy.device == 'cuda':
            mp.set_start_method('spawn', force=True)

        # Create parallel pool. We use one thread per env because it's easier.
        self.pool = SamplerPool(num_envs)

        if seed is not None:
            self.pool.set_seed(seed)

        # Distribute environments. We use pickle to make sure a copy is created for n_envs=1
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy),
                             pickle.dumps(self.bernoulli_reset))

    def reinit(self, env=None, policy=None, bernoulli_reset: bool = None):
        """ Re-initialize the sampler. """
        # Update env and policy if passed
        if env is not None:
            self.env = env
        if policy is not None:
            self.policy = policy
        if bernoulli_reset is not None:
            self.bernoulli_reset = bernoulli_reset

        # Always broadcast to workers
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy),
                             pickle.dumps(self.bernoulli_reset))

    def sample(self) -> List[StepSequence]:
        """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. """
        # Update policy's state
        self.pool.invoke_all(_ps_update_policy, self.policy.state_dict())

        # Collect samples
        with tqdm(leave=False,
                  file=sys.stdout,
                  desc='Sampling',
                  unit='steps'
                  if self.min_steps is not None else 'rollouts') as pb:

            if self.min_steps is None:
                # Only minimum number of rollouts given, thus use run_map
                return self.pool.run_map(_ps_run_one, range(self.min_rollouts),
                                         pb)
            else:
                # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None)
                return self.pool.run_collect(self.min_steps,
                                             _ps_sample_one,
                                             collect_progressbar=pb,
                                             min_runs=self.min_rollouts)[0]