class ParallelRolloutSampler(SamplerBase, Serializable):
    """Class for sampling from multiple environments in parallel"""
    def __init__(
        self,
        env,
        policy,
        num_workers: int,
        *,
        min_rollouts: int = None,
        min_steps: int = None,
        show_progress_bar: bool = True,
        seed: int = NO_SEED_PASSED,
    ):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy to act in the environment (can also be an exploration strategy)
        :param num_workers: number of parallel samplers
        :param min_rollouts: minimum number of complete rollouts to sample
        :param min_steps: minimum total number of steps to sample
        :param show_progress_bar: it `True`, display a progress bar using `tqdm`
        :param seed: seed value for the random number generators, pass `None` for no seeding; defaults to the last seed
                     that was set with `pyrado.set_seed`
        """
        Serializable._init(self, locals())
        super().__init__(min_rollouts=min_rollouts, min_steps=min_steps)

        self.env = env
        self.policy = policy
        self.show_progress_bar = show_progress_bar

        # Set method to spawn if using cuda
        if mp.get_start_method(allow_none=True) != "spawn":
            mp.set_start_method("spawn", force=True)

        # Create parallel pool. We use one thread per env because it's easier.
        self.pool = SamplerPool(num_workers)

        if seed is NO_SEED_PASSED:
            seed = pyrado.get_base_seed()
        self._seed = seed
        # Initialize with -1 such that we start with the 0-th sample. Incrementing after sampling may cause issues when
        # the sampling crashes and the sample count is not incremented.
        self._sample_count = -1

        # Distribute environments. We use pickle to make sure a copy is created for n_envs=1
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def reinit(self,
               env: Optional[Env] = None,
               policy: Optional[Policy] = None):
        """
        Re-initialize the sampler.

        :param env: the environment which the policy operates
        :param policy: the policy used for sampling
        """
        # Update env and policy if passed
        if env is not None:
            self.env = env
        if policy is not None:
            self.policy = policy

        # Always broadcast to workers
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def sample(
        self,
        init_states: Optional[List[np.ndarray]] = None,
        domain_params: Optional[List[dict]] = None,
        eval: bool = False,
    ) -> List[StepSequence]:
        """
        Do the sampling according to the previously given environment, policy, and number of steps/rollouts.

        .. note::
            This method is **not** thread-safe! See for example the usage of `self._sample_count`.

        :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's
                            initial state space
        :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them
        :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`.
        :return: list of sampled rollouts
        """
        self._sample_count += 1

        # Update policy's state
        self.pool.invoke_all(_ps_update_policy, self.policy.state_dict())

        # Collect samples
        with tqdm(
                leave=False,
                file=sys.stdout,
                desc="Sampling",
                disable=(not self.show_progress_bar),
                unit="steps" if self.min_steps is not None else "rollouts",
        ) as pb:

            if self.min_steps is None:
                if init_states is None and domain_params is None:
                    # Simply run min_rollouts times
                    func = partial(_ps_run_one, eval=eval)
                    arglist = list(range(self.min_rollouts))
                elif init_states is not None and domain_params is None:
                    # Run every initial state so often that we at least get min_rollouts trajectories
                    func = partial(_ps_run_one_init_state, eval=eval)
                    rep_factor = ceil(self.min_rollouts / len(init_states))
                    arglist = list(enumerate(rep_factor * init_states))
                elif init_states is None and domain_params is not None:
                    # Run every domain parameter set so often that we at least get min_rollouts trajectories
                    func = partial(_ps_run_one_domain_param, eval=eval)
                    rep_factor = ceil(self.min_rollouts / len(domain_params))
                    arglist = list(enumerate(rep_factor * domain_params))
                elif init_states is not None and domain_params is not None:
                    # Run every combination of initial state and domain parameter so often that we at least get
                    # min_rollouts trajectories
                    func = partial(_ps_run_one_reset_kwargs, eval=eval)
                    allcombs = list(product(init_states, domain_params))
                    rep_factor = ceil(self.min_rollouts / len(allcombs))
                    arglist = list(enumerate(rep_factor * allcombs))

                # Only minimum number of rollouts given, thus use run_map
                return self.pool.run_map(
                    partial(func, seed=self._seed,
                            sub_seed=self._sample_count), arglist, pb)

            else:
                # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None)
                if init_states is None:
                    return self.pool.run_collect(
                        self.min_steps,
                        partial(_ps_sample_one,
                                eval=eval,
                                seed=self._seed,
                                sub_seed=self._sample_count),
                        collect_progressbar=pb,
                        min_runs=self.min_rollouts,
                    )[0]
                else:
                    raise NotImplementedError
Exemplo n.º 2
0
class ParallelRolloutSampler(SamplerBase, Serializable):
    """ Class for sampling from multiple environments in parallel """
    def __init__(self,
                 env,
                 policy,
                 num_workers: int,
                 *,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 show_progress_bar: bool = True,
                 seed: int = None):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy to act in the environment (can also be an exploration strategy)
        :param num_workers: number of parallel samplers
        :param min_rollouts: minimum number of complete rollouts to sample
        :param min_steps: minimum total number of steps to sample
        :param show_progress_bar: it `True`, display a progress bar using `tqdm`
        :param seed: seed value for the random number generators, pass `None` for no seeding
        """
        Serializable._init(self, locals())
        super().__init__(min_rollouts=min_rollouts, min_steps=min_steps)

        self.env = env
        self.policy = policy
        self.show_progress_bar = show_progress_bar

        # Set method to spawn if using cuda
        if self.policy.device == 'cuda':
            mp.set_start_method('spawn', force=True)

        # Create parallel pool. We use one thread per env because it's easier.
        self.pool = SamplerPool(num_workers)

        # Set all rngs' seeds
        if seed is not None:
            self.set_seed(seed)

        # Distribute environments. We use pickle to make sure a copy is created for n_envs=1
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def set_seed(self, seed):
        """
        Set a deterministic seed on all workers.

        :param seed: seed value for the random number generators
        """
        self.pool.set_seed(seed)

    def reinit(self, env=None, policy=None):
        """
        Re-initialize the sampler.

        :param env: the environment which the policy operates
        :param policy: the policy used for sampling
        """
        # Update env and policy if passed
        if env is not None:
            self.env = env
        if policy is not None:
            self.policy = policy

        # Always broadcast to workers
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))

    def sample(self,
               init_states: List[np.ndarray] = None,
               domain_params: List[np.ndarray] = None,
               eval: bool = False) -> List[StepSequence]:
        """
        Do the sampling according to the previously given environment, policy, and number of steps/rollouts.

        :param init_states: initial states forw `run_map()`, pass `None` (default) to sample from the environment's
                            initial state space
        :param domain_params: domain parameters for `run_map()`, pass `None` (default) to not explicitly set them
        :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to `rollout()`.
        :return: list of sampled rollouts
        """
        # Update policy's state
        self.pool.invoke_all(_ps_update_policy, self.policy.state_dict())

        # Collect samples
        with tqdm(leave=False,
                  file=sys.stdout,
                  desc='Sampling',
                  disable=(not self.show_progress_bar),
                  unit='steps'
                  if self.min_steps is not None else 'rollouts') as pb:

            if self.min_steps is None:
                if init_states is None and domain_params is None:
                    # Simply run min_rollouts times
                    func = partial(_ps_run_one, eval=eval)
                    arglist = range(self.min_rollouts)
                elif init_states is not None and domain_params is None:
                    # Run every initial state so often that we at least get min_rollouts trajectories
                    func = partial(_ps_run_one_init_state, eval=eval)
                    rep_factor = ceil(self.min_rollouts / len(init_states))
                    arglist = rep_factor * init_states
                elif init_states is not None and domain_params is not None:
                    # Run every combination of initial state and domain parameter so often that we at least get
                    # min_rollouts trajectories
                    func = partial(_ps_run_one_reset_kwargs, eval=eval)
                    allcombs = list(product(init_states, domain_params))
                    rep_factor = ceil(self.min_rollouts / len(allcombs))
                    arglist = rep_factor * allcombs
                else:
                    raise NotImplementedError

                # Only minimum number of rollouts given, thus use run_map
                return self.pool.run_map(func, arglist, pb)

            else:
                # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None)
                if init_states is None:
                    return self.pool.run_collect(self.min_steps,
                                                 partial(_ps_sample_one,
                                                         eval=eval),
                                                 collect_progressbar=pb,
                                                 min_runs=self.min_rollouts)[0]
                else:
                    raise NotImplementedError
Exemplo n.º 3
0
class ParallelSampler(SamplerBase, Serializable):
    """ Class for sampling from multiple environments in parallel """
    def __init__(self,
                 env,
                 policy,
                 num_envs: int,
                 *,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 bernoulli_reset: bool = None,
                 seed: int = None):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy to act in the environment (can also be an exploration strategy)
        :param num_envs: number of parallel samplers
        :param min_rollouts: minimum number of complete rollouts to sample.
        :param min_steps: minimum total number of steps to sample.
        :param bernoulli_reset: probability for resetting after the current time step
        :param seed: Seed to use. Every subprocess is seeded with seed+thread_number
        """
        Serializable._init(self, locals())
        super().__init__(min_rollouts=min_rollouts, min_steps=min_steps)

        self.env = env
        self.policy = policy
        self.bernoulli_reset = bernoulli_reset

        # Set method to spawn if using cuda
        if self.policy.device == 'cuda':
            mp.set_start_method('spawn', force=True)

        # Create parallel pool. We use one thread per env because it's easier.
        self.pool = SamplerPool(num_envs)

        if seed is not None:
            self.pool.set_seed(seed)

        # Distribute environments. We use pickle to make sure a copy is created for n_envs=1
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy),
                             pickle.dumps(self.bernoulli_reset))

    def reinit(self, env=None, policy=None, bernoulli_reset: bool = None):
        """ Re-initialize the sampler. """
        # Update env and policy if passed
        if env is not None:
            self.env = env
        if policy is not None:
            self.policy = policy
        if bernoulli_reset is not None:
            self.bernoulli_reset = bernoulli_reset

        # Always broadcast to workers
        self.pool.invoke_all(_ps_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy),
                             pickle.dumps(self.bernoulli_reset))

    def sample(self) -> List[StepSequence]:
        """ Do the sampling according to the previously given environment, policy, and number of steps/rollouts. """
        # Update policy's state
        self.pool.invoke_all(_ps_update_policy, self.policy.state_dict())

        # Collect samples
        with tqdm(leave=False,
                  file=sys.stdout,
                  desc='Sampling',
                  unit='steps'
                  if self.min_steps is not None else 'rollouts') as pb:

            if self.min_steps is None:
                # Only minimum number of rollouts given, thus use run_map
                return self.pool.run_map(_ps_run_one, range(self.min_rollouts),
                                         pb)
            else:
                # Minimum number of steps given, thus use run_collect (automatically handles min_runs=None)
                return self.pool.run_collect(self.min_steps,
                                             _ps_sample_one,
                                             collect_progressbar=pb,
                                             min_runs=self.min_rollouts)[0]