示例#1
0
def wrap_like_other_env(env_targ: Env, env_src: [SimEnv, EnvWrapper]) -> Env:
    """
    Wrap a given real environment like it's simulated counterpart (except the domain randomization of course).

    :param env_targ: target environment e.g. environment representing the physical device
    :param env_src: source environment e.g. simulation environment used for training
    :return: target environment
    """
    if env_src.dt > env_targ.dt:
        ds_factor = int(env_src.dt / env_targ.dt)
        env_targ = DownsamplingWrapper(env_targ, ds_factor)
        print_cbt(
            f'Wrapped the env with an DownsamplingWrapper of factor {ds_factor}.',
            'c')

    if typed_env(env_src, ActNormWrapper) is not None:
        env_targ = ActNormWrapper(env_targ)
        print_cbt('Wrapped the env with an ActNormWrapper.', 'c')

    if typed_env(env_src, ObsNormWrapper) is not None:
        env_targ = ObsNormWrapper(env_targ)
        print_cbt('Wrapped the env with an ObsNormWrapper.', 'c')
    elif typed_env(env_src, ObsRunningNormWrapper) is not None:
        env_targ = ObsRunningNormWrapper(env_targ)
        print_cbt('Wrapped the env with an ObsRunningNormWrapper.', 'c')

    if typed_env(env_src, ObsPartialWrapper) is not None:
        env_targ = ObsPartialWrapper(env_targ,
                                     mask=typed_env(
                                         env_src, ObsPartialWrapper).keep_mask,
                                     keep_selected=True)
        print_cbt('Wrapped the env with an ObsPartialWrapper.', 'c')

    return env_targ
示例#2
0
def plot_actions(ro: StepSequence, env: Env):
    """
    Plot all action trajectories of the given rollout.

    :param ro: input rollout
    :param env: environment (used for getting the clipped action values)
    """
    if hasattr(ro, 'actions'):
        if not isinstance(ro.actions, np.ndarray):
            raise pyrado.TypeErr(given=ro.actions, expected_type=np.ndarray)

        dim_act = ro.actions.shape[1]
        # Use recorded time stamps if possible
        t = ro.env_infos.get('t', np.arange(0, ro.length)) if hasattr(
            ro, 'env_infos') else np.arange(0, ro.length)

        fig, axs = plt.subplots(dim_act, figsize=(8, 12))
        fig.suptitle('Actions over Time')
        colors = plt.get_cmap('tab20')(np.linspace(0, 1, dim_act))

        act_norm_wrapper = typed_env(env, ActNormWrapper)
        if act_norm_wrapper is not None:
            lb, ub = inner_env(env).act_space.bounds
            act_denorm = lb + (ro.actions[:] + 1.) * (ub - lb) / 2
            act_clipped = np.array(
                [inner_env(env).limit_act(a) for a in act_denorm])
        else:
            act_denorm = ro.actions
            act_clipped = np.array([env.limit_act(a) for a in ro.actions[:]])

        if dim_act == 1:
            axs.plot(t, act_denorm, label=_get_act_label(ro, 0) + ' (to env)')
            axs.plot(t,
                     act_clipped,
                     label=_get_act_label(ro, 0) + ' (clipped)',
                     c='k',
                     ls='--')
            axs.legend(bbox_to_anchor=(0, 1.0, 1, -0.1),
                       loc='lower left',
                       mode='expand',
                       ncol=2)
        else:
            for i in range(dim_act):
                axs[i].plot(t,
                            act_denorm[:, i],
                            label=_get_act_label(ro, i) + ' (to env)',
                            c=colors[i])
                axs[i].plot(t,
                            act_clipped[:, i],
                            label=_get_act_label(ro, i) + ' (clipped)',
                            c='k',
                            ls='--')
                axs[i].legend(bbox_to_anchor=(0, 1.0, 1, -0.1),
                              loc='lower left',
                              mode='expand',
                              ncol=2)

        plt.subplots_adjust(hspace=1.2)
        plt.show()
示例#3
0
    def __init__(
        self,
        env: Union[SimEnv, EnvWrapper],
        policy: Policy,
        dp_mapping: Mapping[int, str],
        embedding: Embedding,
        num_segments: int = None,
        len_segments: int = None,
        stop_on_done: bool = True,
        rollouts_real: Optional[List[StepSequence]] = None,
        use_rec_act: bool = True,
    ):
        """
        Constructor

        :param env: environment which the policy operates, which must not be a randomized environment since we want to
                    randomize it manually via the domain parameters coming from the sbi package
        :param policy: policy used for sampling the rollout
        :param dp_mapping: mapping from subsequent integers (starting at 0) to domain parameter names (e.g. mass)
        :param embedding: embedding used for pre-processing the data before (later) passing it to the posterior
        :param num_segments: number of segments in which the rollouts are split into. For every segment, the initial
                             state of the simulation is reset, and thus for every set the features of the trajectories
                             are computed separately. Either specify `num_segments` or `len_segments`.
        :param len_segments: length of the segments in which the rollouts are split into. For every segment, the initial
                            state of the simulation is reset, and thus for every set the features of the trajectories
                            are computed separately. Either specify `num_segments` or `len_segments`.
        :param stop_on_done: if `True`, the rollouts are stopped as soon as they hit the state or observation space
                             boundaries. This behavior is save, but can lead to short trajectories which are eventually
                             padded with zeroes. Chose `False` to ignore the boundaries (dangerous on the real system).
        :param rollouts_real: list of rollouts recorded from the target domain, which are used to sync the simulations'
                              initial states
        :param use_rec_act: if `True` the recorded actions form the target domain are used to generate the rollout
                            during simulation (feed-forward). If `False` there policy is used to generate (potentially)
                            state-dependent actions (feed-back).
        """
        if typed_env(env, DomainRandWrapper):
            raise pyrado.TypeErr(
                msg="The environment passed to sbi as simulator must not be wrapped with a subclass of"
                "DomainRandWrapper since sbi has be able to set the domain parameters explicitly!"
            )
        if rollouts_real is not None:
            if not isinstance(rollouts_real, list):
                raise pyrado.TypeErr(given=rollouts_real, expected_type=list)
            if not isinstance(rollouts_real[0], StepSequence):  # only check 1st element
                raise pyrado.TypeErr(given=rollouts_real[0], expected_type=StepSequence)

        Serializable._init(self, locals())

        super().__init__(env, policy, embedding, num_segments, len_segments, stop_on_done)

        self.dp_names = dp_mapping.values()
        self.rollouts_real = rollouts_real
        self.use_rec_act = use_rec_act
        if self.rollouts_real is not None:
            self._set_action_field(self.rollouts_real)
示例#4
0
 def step(self, act: np.ndarray) -> tuple:
     obs, reward, done, info = self.wrapped_env.step(act)
     saw = typed_env(self.wrapped_env, StateAugmentationWrapper)
     nonobserved = to.from_numpy(obs[saw.offset :])
     adversarial = self.get_arpl_grad(self.state, nonobserved)
     if self.decide_apply():
         self.state += adversarial.view(-1).numpy()
     if saw:
         obs[: saw.offset] = inner_env(self).observe(self.state)
     else:
         obs = inner_env(self).observe(self.state)
     return obs, reward, done, info
def test_combination():
    env = QCartPoleSwingUpSim(dt=1/50., max_steps=20)

    randomizer = create_default_randomizer(env)
    env_r = DomainRandWrapperBuffer(env, randomizer)
    env_r.fill_buffer(num_domains=3)

    dp_before = []
    dp_after = []
    for i in range(4):
        dp_before.append(env_r.domain_param)
        rollout(env_r, DummyPolicy(env_r.spec), eval=True, seed=0, render_mode=RenderMode())
        dp_after.append(env_r.domain_param)
        assert dp_after[i] != dp_before[i]
    assert dp_after[0] == dp_after[3]

    env_rn = ActNormWrapper(env)
    elb = {'x_dot': -213., 'theta_dot': -42.}
    eub = {'x_dot': 213., 'theta_dot': 42., 'x': 0.123}
    env_rn = ObsNormWrapper(env_rn, explicit_lb=elb, explicit_ub=eub)
    alb, aub = env_rn.act_space.bounds
    assert all(alb == -1)
    assert all(aub == 1)
    olb, oub = env_rn.obs_space.bounds
    assert all(olb == -1)
    assert all(oub == 1)

    ro_r = rollout(env_r, DummyPolicy(env_r.spec), eval=True, seed=0, render_mode=RenderMode())
    ro_rn = rollout(env_rn, DummyPolicy(env_rn.spec), eval=True, seed=0, render_mode=RenderMode())
    assert np.allclose(env_rn._process_obs(ro_r.observations), ro_rn.observations)

    env_rnp = ObsPartialWrapper(env_rn, idcs=['x_dot', r'cos_theta'])
    ro_rnp = rollout(env_rnp, DummyPolicy(env_rnp.spec), eval=True, seed=0, render_mode=RenderMode())

    env_rnpa = GaussianActNoiseWrapper(env_rnp,
                                       noise_mean=0.5*np.ones(env_rnp.act_space.shape),
                                       noise_std=0.1*np.ones(env_rnp.act_space.shape))
    ro_rnpa = rollout(env_rnpa, DummyPolicy(env_rnpa.spec), eval=True, seed=0, render_mode=RenderMode())
    assert np.allclose(ro_rnp.actions, ro_rnpa.actions)
    assert not np.allclose(ro_rnp.observations, ro_rnpa.observations)

    env_rnpd = ActDelayWrapper(env_rnp, delay=3)
    ro_rnpd = rollout(env_rnpd, DummyPolicy(env_rnpd.spec), eval=True, seed=0, render_mode=RenderMode())
    assert np.allclose(ro_rnp.actions, ro_rnpd.actions)
    assert not np.allclose(ro_rnp.observations, ro_rnpd.observations)

    assert isinstance(inner_env(env_rnpd), QCartPoleSwingUpSim)
    assert typed_env(env_rnpd, ObsPartialWrapper) is not None
    assert isinstance(env_rnpd, ActDelayWrapper)
    env_rnpdr = remove_env(env_rnpd, ActDelayWrapper)
    assert not isinstance(env_rnpdr, ActDelayWrapper)
示例#6
0
    def __init__(self,
                 env: EnvWrapper,
                 subrtn: Algorithm,
                 skip_iter: int,
                 epsilon: float,
                 gamma: float = 1.0):
        """
        Constructor

        :param env: same environment as the subroutine runs in. Only used for checking and saving the randomizer.
        :param subrtn: algorithm which performs the policy / value-function optimization
        :param skip_iter: number of iterations for which all rollouts will be used (see prefix 'full')
        :param epsilon: quantile of (worst) rollouts that will be kept
        :param gamma: discount factor to compute the discounted return, default is 1 (no discount)
        """
        if not isinstance(subrtn, Algorithm):
            raise pyrado.TypeErr(given=subrtn, expected_type=Algorithm)
        if not typed_env(env, DomainRandWrapper):  # there is a DR wrapper
            raise pyrado.TypeErr(given=env, expected_type=DomainRandWrapper)
        if not hasattr(subrtn, "sampler"):
            raise AttributeError(
                "The subroutine must have a sampler attribute!")

        # Call Algorithm's constructor with the subroutine's properties
        super().__init__(subrtn.save_dir, subrtn.max_iter, subrtn.policy,
                         subrtn.logger)

        self._subrtn = subrtn
        self._subrtn.save_name = "subrtn"
        self.epsilon = epsilon
        self.gamma = gamma
        self.skip_iter = skip_iter

        # Override the subroutine's sampler
        self._subrtn.sampler = CVaRSampler(
            self._subrtn.sampler,
            epsilon=1.0,  # keep all rollouts until curr_iter = skip_iter
            gamma=self.gamma,
            min_rollouts=self._subrtn.sampler.min_rollouts,
            min_steps=self._subrtn.sampler.min_steps,
        )

        # Save initial environment and randomizer
        joblib.dump(env, osp.join(self.save_dir, "env.pkl"))
        joblib.dump(env.randomizer, osp.join(self.save_dir, "randomizer.pkl"))
示例#7
0
    def __init__(self, wrapped_env, policy, eps, phi, width=0.25):
        """
        Constructor

        :param wrapped_env: environemnt to be wrapped
        :param policy: policy to be updated
        :param eps: magnitude of perturbation
        :param phi: probability of perturbation
        :param width: width of distribution to sample from
        """
        Serializable._init(self, locals())
        AdversarialWrapper.__init__(self, wrapped_env, policy, eps, phi)
        self.width = width
        self.saw = typed_env(self.wrapped_env, StateAugmentationWrapper)
        self.nominal = self.saw.nominal
        self.nominalT = to.from_numpy(self.nominal)
        self.adv = None
        self.re_adv()
示例#8
0
    def __init__(self,
                 env: Env,
                 policy: Policy,
                 num_workers: int,
                 num_rollouts_per_param: int,
                 seed: int = None):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy used for sampling
        :param num_workers: number of parallel samplers
        :param num_rollouts_per_param: number of rollouts per policy parameter set (and init state if specified)
        :param seed: seed value for the random number generators, pass `None` for no seeding
        """
        if not isinstance(num_rollouts_per_param, int):
            raise pyrado.TypeErr(given=num_rollouts_per_param,
                                 expected_type=int)
        if num_rollouts_per_param < 1:
            raise pyrado.ValueErr(given=num_rollouts_per_param,
                                  ge_constraint='1')

        Serializable._init(self, locals())

        # Check environment for domain randomization wrappers (stops after finding the outermost)
        self._dr_wrapper = typed_env(env, DomainRandWrapper)
        if self._dr_wrapper is not None:
            assert isinstance(inner_env(env), SimEnv)
            # Remove them all from the env chain since we sample the domain parameter later explicitly
            env = remove_all_dr_wrappers(env)

        self.env, self.policy = env, policy
        self.num_rollouts_per_param = num_rollouts_per_param

        # Create parallel pool. We use one thread per environment because it's easier.
        self.pool = SamplerPool(num_workers)

        # Set all rngs' seeds
        if seed is not None:
            self.pool.set_seed(seed)

        # Distribute environments. We use pickle to make sure a copy is created for n_envs = 1
        self.pool.invoke_all(_pes_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))
示例#9
0
    def __init__(self, env: EnvWrapper, subrtn: Algorithm):
        """
        Constructor

        :param env: same environment as the subroutine runs in. Only used for checking and saving the randomizer.
        :param subrtn: algorithm which performs the policy / value-function optimization
        """
        if not isinstance(subrtn, Algorithm):
            raise pyrado.TypeErr(given=subrtn, expected_type=Algorithm)
        if not typed_env(env, DomainRandWrapper):  # there is a DR wrapper
            raise pyrado.TypeErr(given=env, expected_type=DomainRandWrapper)

        # Call Algorithm's constructor with the subroutine's properties
        super().__init__(subrtn.save_dir, subrtn.max_iter, subrtn.policy, subrtn.logger)

        self._subrtn = subrtn
        self._subrtn.save_name = 'subrtn'

        # Save initial randomizer
        joblib.dump(env.randomizer, osp.join(self.save_dir, 'randomizer.pkl'))
示例#10
0
    def __init__(self,
                 subroutine: Algorithm,
                 skip_iter: int,
                 epsilon: float,
                 gamma: float = 1.):
        """
        Constructor

        :param subroutine: algorithm which performs the policy / value-function optimization
        :param skip_iter: number of iterations for which all rollouts will be used (see prefix full)
        :param epsilon: quantile of rollouts that will be kept
        :param gamma: discount factor to compute the discounted return, default is 1 (no discount)
        """
        if not isinstance(subroutine, Algorithm):
            raise pyrado.TypeErr(given=subroutine, expected_type=Algorithm)
        if not typed_env(subroutine.sampler.env, DomainRandWrapperLive):  # there is a domain randomization wrapper
            raise pyrado.TypeErr(given=subroutine.sampler.env, expected_type=DomainRandWrapperLive)

        # Call Algorithm's constructor with the subroutine's properties
        super().__init__(subroutine.save_dir, subroutine.max_iter, subroutine.policy, subroutine.logger)

        # Store inputs
        self._subroutine = subroutine
        self.epsilon = epsilon
        self.gamma = gamma
        self.skip_iter = skip_iter

        # Override the subroutine's sampler
        self._subroutine.sampler = CVaRSampler(
                self._subroutine.sampler,
                epsilon=1.,  # keep all rollouts until curr_iter = skip_iter
                gamma=self.gamma,
                min_rollouts=self._subroutine.sampler.min_rollouts,
                min_steps=self._subroutine.sampler.min_steps,
        )

        # Save initial randomizer
        joblib.dump(subroutine.sampler.randomizer, osp.join(self.save_dir, 'randomizer.pkl'))
示例#11
0
def conditional_actnorm_wrapper(env: Env, ex_dirs: list, idx: int):
    """
    Wrap the environment with an action normalization wrapper if the simulated environment had one.

    :param env: environment to sample from
    :param ex_dirs: list of experiment directories that will be loaded
    :param idx: index of the current directory
    :return: modified environment
    """
    # Get the simulation environment
    env_sim, _, _ = load_experiment(ex_dirs[idx])

    if typed_env(env_sim, ActNormWrapper) is not None:
        env = ActNormWrapper(env)
        print_cbt(
            f'Added an action normalization wrapper to {idx + 1}-th evaluation policy.',
            'y')
    else:
        env = remove_env(env, ActNormWrapper)
        print_cbt(
            f'Removed an action normalization wrapper to {idx + 1}-th evaluation policy.',
            'y')
    return env
示例#12
0
    def __init__(self,
                 save_dir: str,
                 env: DomainRandWrapperBuffer,
                 subroutine_cand: Algorithm,
                 subroutine_refs: Algorithm,
                 max_iter: int,
                 alpha: float,
                 beta: float,
                 nG: int,
                 nJ: int,
                 ntau: int,
                 nc_init: int,
                 nr_init: int,
                 sequence_cand: callable,
                 sequence_refs: callable,
                 warmstart_cand: bool = False,
                 warmstart_refs: bool = True,
                 cand_policy_param_init: to.Tensor = None,
                 cand_critic_param_init: to.Tensor = None,
                 num_bs_reps: int = 1000,
                 studentized_ci: bool = False,
                 base_seed: int = None,
                 logger=None):
        """
        Constructor

        :param save_dir: directory to save the snapshots i.e. the results in
        :param env: the environment which the policy operates
        :param subroutine_cand: the algorithm that is called at every iteration of SPOTA to yield a candidate policy
        :param subroutine_refs: the algorithm that is called at every iteration of SPOTA to yield reference policies
        :param max_iter: maximum number of iterations that SPOTA algorithm runs.
                         Each of these iterations includes multiple iterations of the subroutine.
        :param alpha: confidence level for the upper confidence bound (UCBOG)
        :param beta: optimality gap threshold for training
        :param nG: number of reference solutions
        :param nJ: number of samples for Monte-Carlo approximation of the optimality gap
        :param ntau: number of rollouts per domain parameter set
        :param nc_init: initial number of domains for training the candidate solution
        :param nr_init: initial number of domains for training the reference solutions
        :param sequence_cand: mathematical sequence for the number of domains for training the candidate solution
        :param sequence_refs: mathematical sequence for the number of domains for training the reference solutions
        :param warmstart_cand: flag if the next candidate solution should be initialized with the previous one
        :param warmstart_refs: flag if the reference solutions should be initialized with the current candidate
        :param cand_policy_param_init: initial policy parameter values for the candidate, set None to be random
        :param cand_critic_param_init: initial critic parameter values for the candidate, set None to be random
        :param num_bs_reps: number of replications for the statistical bootstrap
        :param studentized_ci: flag if a student T distribution should be applied for the confidence interval
        :param base_seed: seed added to all other seeds in order to make the experiments distinct but repeatable
        """
        if not typed_env(env, DomainRandWrapperBuffer
                         ):  # there is a domain randomization wrapper
            raise pyrado.TypeErr(
                msg=
                'There must be a DomainRandWrapperBuffer in the environment chain.'
            )
        if not isinstance(subroutine_cand, Algorithm):
            raise pyrado.TypeErr(given=subroutine_cand,
                                 expected_type=Algorithm)
        if not isinstance(subroutine_refs, Algorithm):
            raise pyrado.TypeErr(given=subroutine_refs,
                                 expected_type=Algorithm)

        # Call Algorithm's constructor without specifying the policy
        super().__init__(save_dir, max_iter, None, logger)

        # Get the randomized environment (recommended to make it the most outer one in the chain)
        self._env_dr = typed_env(env, DomainRandWrapperBuffer)

        # Candidate and reference solutions, and optimality gap
        self.Gn_diffs = None
        self.ucbog = pyrado.inf  # upper confidence bound on the optimality gap
        self._subrtn_cand = subroutine_cand
        self._subrtn_refs = subroutine_refs
        assert id(self._subrtn_cand) != id(self._subrtn_refs)
        assert id(self._subrtn_cand.policy) != id(self._subrtn_refs.policy)
        assert id(self._subrtn_cand.expl_strat) != id(
            self._subrtn_refs.expl_strat)

        # Store the hyper-parameters
        self.alpha = alpha
        self.beta = beta
        self.warmstart_cand = warmstart_cand
        self.warmstart_refs = warmstart_refs
        self.cand_policy_param_init = cand_policy_param_init.detach(
        ) if cand_policy_param_init is not None else None
        self.cand_critic_param_init = cand_critic_param_init.detach(
        ) if cand_critic_param_init is not None else None
        self.nG = nG
        self.nJ = nJ
        self.ntau = ntau
        self.nc_init = nc_init
        self.nr_init = nr_init
        self.seq_cand = sequence_cand
        self.seq_ref = sequence_refs
        self.num_bs_reps = num_bs_reps
        self.studentized_ci = studentized_ci
        self.base_seed = np.random.randint(
            low=10000) if base_seed is None else base_seed

        # Save initial environment and randomizer
        joblib.dump(env, osp.join(self.save_dir, 'init_env.pkl'))
        joblib.dump(env.randomizer, osp.join(self.save_dir, 'randomizer.pkl'))
示例#13
0
def plot_actions(ro: StepSequence, env: Env):
    """
    Plot all action trajectories of the given rollout.

    :param ro: input rollout
    :param env: environment (used for getting the clipped action values)
    """
    if hasattr(ro, "actions"):
        if not isinstance(ro.actions, np.ndarray):
            raise pyrado.TypeErr(given=ro.actions, expected_type=np.ndarray)

        dim_act = ro.actions.shape[1]
        # Use recorded time stamps if possible
        t = getattr(ro, "time", np.arange(0, ro.length + 1))[:-1]

        num_rows, num_cols = num_rows_cols_from_length(dim_act,
                                                       transposed=True)
        fig, axs = plt.subplots(num_rows,
                                num_cols,
                                figsize=(10, 8),
                                tight_layout=True)
        fig.canvas.manager.set_window_title("Actions over Time")
        axs = np.atleast_2d(axs)
        axs = correct_atleast_2d(axs)
        colors = plt.get_cmap("tab20")(np.linspace(0, 1, dim_act))

        act_norm_wrapper = typed_env(env, ActNormWrapper)
        if act_norm_wrapper is not None:
            lb, ub = inner_env(env).act_space.bounds
            act_denorm = lb + (ro.actions + 1.0) * (ub - lb) / 2
            act_clipped = np.array(
                [inner_env(env).limit_act(a) for a in act_denorm])
        else:
            act_denorm = ro.actions
            act_clipped = np.array([env.limit_act(a) for a in ro.actions])

        if dim_act == 1:
            axs[0, 0].plot(t, act_denorm, label="to env")
            axs[0, 0].plot(t, act_clipped, label="clipped", c="k", ls="--")
            axs[0, 0].legend(ncol=2)
            axs[0, 0].set_ylabel(_get_act_label(ro, 0))
        else:
            for idx_a in range(dim_act):
                axs[idx_a // num_cols,
                    idx_a % num_cols].plot(t,
                                           act_denorm[:, idx_a],
                                           label="to env",
                                           c=colors[idx_a])
                axs[idx_a // num_cols,
                    idx_a % num_cols].plot(t,
                                           act_clipped[:, idx_a],
                                           label="clipped",
                                           c="k",
                                           ls="--")
                axs[idx_a // num_cols, idx_a % num_cols].legend(ncol=2)
                axs[idx_a // num_cols,
                    idx_a % num_cols].set_ylabel(_get_act_label(ro, idx_a))

        # Put legends to the right of the plot
        if dim_act < 8:  # otherwise it gets too cluttered
            for a in fig.get_axes():
                a.legend(ncol=2)

        plt.subplots_adjust(hspace=0.2)
示例#14
0
def wrap_like_other_env(
        env_targ: Union[SimEnv, RealEnv],
        env_src: [SimEnv, EnvWrapper],
        use_downsampling: bool = False) -> Union[SimEnv, RealEnv]:
    """
    Wrap a given real environment like it's simulated counterpart (except the domain randomization of course).

    :param env_targ: target environment e.g. environment representing the physical device
    :param env_src: source environment e.g. simulation environment used for training
    :param use_downsampling: apply a wrapper that downsamples the actions if the sampling frequencies don't match
    :return: target environment
    """
    if use_downsampling and env_src.dt > env_targ.dt:
        if typed_env(env_targ, DownsamplingWrapper) is None:
            ds_factor = int(env_src.dt / env_targ.dt)
            env_targ = DownsamplingWrapper(env_targ, ds_factor)
            print_cbt(
                f"Wrapped the target environment with a DownsamplingWrapper of factor {ds_factor}.",
                "y")
        else:
            print_cbt(
                "The target environment was already wrapped with a DownsamplingWrapper.",
                "y")

    if typed_env(env_src, ActNormWrapper) is not None:
        if typed_env(env_targ, ActNormWrapper) is None:
            env_targ = ActNormWrapper(env_targ)
            print_cbt("Wrapped the target environment with an ActNormWrapper.",
                      "y")
        else:
            print_cbt(
                "The target environment was already wrapped with an ActNormWrapper.",
                "y")

    if typed_env(env_src, ObsNormWrapper) is not None:
        if typed_env(env_targ, ObsNormWrapper) is None:
            env_targ = ObsNormWrapper(env_targ)
            print_cbt("Wrapped the target environment with an ObsNormWrapper.",
                      "y")
        else:
            print_cbt(
                "The target environment was already wrapped with an ObsNormWrapper.",
                "y")

    if typed_env(env_src, ObsRunningNormWrapper) is not None:
        if typed_env(env_targ, ObsRunningNormWrapper) is None:
            env_targ = ObsRunningNormWrapper(env_targ)
            print_cbt(
                "Wrapped the target environment with an ObsRunningNormWrapper.",
                "y")
        else:
            print_cbt(
                "The target environment was already wrapped with an ObsRunningNormWrapper.",
                "y")

    if typed_env(env_src, ObsPartialWrapper) is not None:
        if typed_env(env_targ, ObsPartialWrapper) is None:
            env_targ = ObsPartialWrapper(env_targ,
                                         mask=typed_env(
                                             env_src,
                                             ObsPartialWrapper).keep_mask,
                                         keep_selected=True)
            print_cbt(
                "Wrapped the target environment with an ObsPartialWrapper.",
                "y")
        else:
            print_cbt(
                "The target environment was already wrapped with an ObsPartialWrapper.",
                "y")

    return env_targ
示例#15
0
from pyrado.plotting.gaussian_process import render_singletask_gp
from pyrado.utils.argparser import get_argparser


if __name__ == "__main__":
    # Parse command line arguments
    parser = get_argparser()
    parser.add_argument("--render3D", action="store_true", default=False, help="render the GP in 3D")
    args = parser.parse_args()
    plt.rc("text", usetex=args.use_tex)

    # Get the experiment's directory to load from
    ex_dir = ask_for_experiment(hparam_list=args.show_hparams) if args.dir is None else args.dir

    env_sim = joblib.load(osp.join(ex_dir, "env_sim.pkl"))
    if not typed_env(env_sim, MetaDomainRandWrapper):
        raise pyrado.TypeErr(given_name=env_sim, expected_type=MetaDomainRandWrapper)
    labels_sel_dims = [env_sim.dp_mapping[args.idcs[i]][0] for i in range(len(args.idcs))]

    env_real = joblib.load(osp.join(ex_dir, "env_real.pkl"))
    if isinstance(inner_env(env_real), SimEnv):
        # Use actual ground truth domain param if sim-2-sim setting
        domain_params = env_real.domain_param
    else:
        # Use nominal domain param if sim-2-real setting
        domain_params = inner_env(env_sim).get_nominal_domain_param()
    for dp_name, dp_val in domain_params.items():
        if dp_name in labels_sel_dims[0]:
            gt_val_x = dp_val
        try:
            if dp_name == labels_sel_dims[1]:
示例#16
0
def evaluate_policy(args, ex_dir):
    """Helper function to evaluate the policy from an experiment in the associated environment."""
    env, policy, _ = load_experiment(ex_dir, args)

    # Create multi-dim evaluation grid
    param_spec = dict()
    param_spec_dim = None

    if isinstance(inner_env(env), BallOnPlateSim):
        param_spec["ball_radius"] = np.linspace(0.02, 0.08, num=2, endpoint=True)
        param_spec["ball_rolling_friction_coefficient"] = np.linspace(0.0295, 0.9, num=2, endpoint=True)

    elif isinstance(inner_env(env), QQubeSwingUpSim):
        eval_num = 200
        # Use nominal values for all other parameters.
        for param, nominal_value in env.get_nominal_domain_param().items():
            param_spec[param] = nominal_value
        # param_spec["gravity_const"] = np.linspace(5.0, 15.0, num=eval_num, endpoint=True)
        param_spec["damping_pend_pole"] = np.linspace(0.0, 0.0001, num=eval_num, endpoint=True)
        param_spec["damping_rot_pole"] = np.linspace(0.0, 0.0006, num=eval_num, endpoint=True)
        param_spec_dim = 2

    elif isinstance(inner_env(env), QBallBalancerSim):
        # param_spec["gravity_const"] = np.linspace(7.91, 11.91, num=11, endpoint=True)
        # param_spec["ball_mass"] = np.linspace(0.003, 0.3, num=11, endpoint=True)
        # param_spec["ball_radius"] = np.linspace(0.01, 0.1, num=11, endpoint=True)
        param_spec["plate_length"] = np.linspace(0.275, 0.275, num=11, endpoint=True)
        param_spec["arm_radius"] = np.linspace(0.0254, 0.0254, num=11, endpoint=True)
        # param_spec["load_inertia"] = np.linspace(5.2822e-5*0.5, 5.2822e-5*1.5, num=11, endpoint=True)
        # param_spec["motor_inertia"] = np.linspace(4.6063e-7*0.5, 4.6063e-7*1.5, num=11, endpoint=True)
        # param_spec["gear_ratio"] = np.linspace(60, 80, num=11, endpoint=True)
        # param_spec["gear_efficiency"] = np.linspace(0.6, 1.0, num=11, endpoint=True)
        # param_spec["motor_efficiency"] = np.linspace(0.49, 0.89, num=11, endpoint=True)
        # param_spec["motor_back_emf"] = np.linspace(0.006, 0.066, num=11, endpoint=True)
        # param_spec["motor_resistance"] = np.linspace(2.6*0.5, 2.6*1.5, num=11, endpoint=True)
        # param_spec["combined_damping"] = np.linspace(0.0, 0.05, num=11, endpoint=True)
        # param_spec["friction_coeff"] = np.linspace(0, 0.015, num=11, endpoint=True)
        # param_spec["voltage_thold_x_pos"] = np.linspace(0.0, 1.0, num=11, endpoint=True)
        # param_spec["voltage_thold_x_neg"] = np.linspace(-1., 0.0, num=11, endpoint=True)
        # param_spec["voltage_thold_y_pos"] = np.linspace(0.0, 1.0, num=11, endpoint=True)
        # param_spec["voltage_thold_y_neg"] = np.linspace(-1.0, 0, num=11, endpoint=True)
        # param_spec["offset_th_x"] = np.linspace(-5/180*np.pi, 5/180*np.pi, num=11, endpoint=True)
        # param_spec["offset_th_y"] = np.linspace(-5/180*np.pi, 5/180*np.pi, num=11, endpoint=True)

    else:
        raise NotImplementedError

    # Always add an action delay wrapper (with 0 delay by default)
    if typed_env(env, ActDelayWrapper) is None:
        env = ActDelayWrapper(env)
    # param_spec['act_delay'] = np.linspace(0, 30, num=11, endpoint=True, dtype=int)

    add_info = "-".join(param_spec.keys())

    # Create multidimensional results grid and ensure right number of rollouts
    param_list = param_grid(param_spec)
    param_list *= args.num_rollouts_per_config

    # Fix initial state (set to None if it should not be fixed)
    init_state = np.array([0.0, 0.0, 0.0, 0.0])

    # Create sampler
    pool = SamplerPool(args.num_workers)
    if args.seed is not None:
        pool.set_seed(args.seed)
        print_cbt(f"Set the random number generators' seed to {args.seed}.", "w")
    else:
        print_cbt("No seed was set", "y")

    # Sample rollouts
    ros = eval_domain_params(pool, env, policy, param_list, init_state)

    # Compute metrics
    lod = []
    for ro in ros:
        d = dict(**ro.rollout_info["domain_param"], ret=ro.undiscounted_return(), len=ro.length)
        # Simply remove the observation noise from the domain parameters
        try:
            d.pop("obs_noise_mean")
            d.pop("obs_noise_std")
        except KeyError:
            pass
        lod.append(d)

    df = pd.DataFrame(lod)
    metrics = dict(
        avg_len=df["len"].mean(),
        avg_ret=df["ret"].mean(),
        median_ret=df["ret"].median(),
        min_ret=df["ret"].min(),
        max_ret=df["ret"].max(),
        std_ret=df["ret"].std(),
    )
    pprint(metrics, indent=4)

    # Create subfolder and save
    timestamp = datetime.datetime.now()
    add_info = timestamp.strftime(pyrado.timestamp_format) + "--" + add_info
    save_dir = osp.join(ex_dir, "eval_domain_grid", add_info)
    os.makedirs(save_dir, exist_ok=True)

    save_dicts_to_yaml(
        {"ex_dir": str(ex_dir)},
        {"varied_params": list(param_spec.keys())},
        {"num_rpp": args.num_rollouts_per_config, "seed": args.seed},
        {"metrics": dict_arraylike_to_float(metrics)},
        save_dir=save_dir,
        file_name="summary",
    )
    pyrado.save(df, f"df_sp_grid_{len(param_spec) if param_spec_dim is None else param_spec_dim}d.pkl", save_dir)
    def __init__(
        self,
        env: Union[SimEnv, EnvWrapper],
        policy: Policy,
        num_init_states_per_domain: int,
        num_domains: int,
        num_workers: int,
        seed: Optional[int] = None,
    ):
        """
        Constructor

        :param env: environment to sample from
        :param policy: policy used for sampling
        :param num_init_states_per_domain: number of rollouts to cover the variance over initial states
        :param num_domains: number of rollouts due to the variance over domain parameters
        :param num_workers: number of parallel samplers
        :param seed: seed value for the random number generators, pass `None` for no seeding; defaults to the last seed
                     that was set with `pyrado.set_seed`
        """
        if not isinstance(num_init_states_per_domain, int):
            raise pyrado.TypeErr(given=num_init_states_per_domain,
                                 expected_type=int)
        if num_init_states_per_domain < 1:
            raise pyrado.ValueErr(given=num_init_states_per_domain,
                                  ge_constraint="1")
        if not isinstance(num_domains, int):
            raise pyrado.TypeErr(given=num_domains, expected_type=int)
        if num_domains < 1:
            raise pyrado.ValueErr(given=num_domains, ge_constraint="1")

        Serializable._init(self, locals())

        # Check environment for domain randomization wrappers (stops after finding the outermost)
        self._dr_wrapper = typed_env(env, DomainRandWrapper)
        if self._dr_wrapper is not None:
            assert isinstance(inner_env(env), SimEnv)
            # Remove them all from the env chain since we sample the domain parameter later explicitly
            env = remove_all_dr_wrappers(env)

        self.env, self.policy = env, policy
        self.num_init_states_per_domain = num_init_states_per_domain
        self.num_domains = num_domains

        # Set method to spawn if using cuda
        if mp.get_start_method(allow_none=True) != "spawn":
            mp.set_start_method("spawn", force=True)

        # Create parallel pool. We use one thread per environment because it's easier.
        self.pool = SamplerPool(num_workers)

        if seed is NO_SEED_PASSED:
            seed = pyrado.get_base_seed()
        self._seed = seed
        # Initialize with -1 such that we start with the 0-th sample. Incrementing after sampling may cause issues when
        # the sampling crashes and the sample count is not incremented.
        self._sample_count = -1

        # Distribute environments. We use pickle to make sure a copy is created for n_envs = 1
        self.pool.invoke_all(_pes_init, pickle.dumps(self.env),
                             pickle.dumps(self.policy))
示例#18
0
        prefixes = [
            osp.join(pyrado.EXP_DIR, 'FILL_IN', 'FILL_IN'),
        ]
        exp_names = [
            '',
        ]
        exp_labels = [
            '',
        ]

    else:
        raise pyrado.ValueErr(given=args.env_name, eq_constraint=f'{QBallBalancerSim.name}, {QCartPoleStabSim.name},'
                                                                 f'{QCartPoleSwingUpSim.name}, or {QQubeSim.name}')

        # Always add an action delay wrapper (with 0 delay by default)
    if typed_env(env, ActDelayWrapper) is None:
        env = ActDelayWrapper(env)
        # param_spec['act_delay'] = np.linspace(0, 60, num=21, endpoint=True, dtype=int)

    if not len(param_spec.keys()) == 1:
        raise pyrado.ValueErr(msg='Do not vary more than one domain parameter for this script! (Check action delay.)')
    varied_param_key = ''.join(param_spec.keys())  # to get a str

    if not (len(prefixes) == len(exp_names) and len(prefixes) == len(exp_labels)):
        raise pyrado.ShapeErr(msg=f'The lengths of prefixes, exp_names, and exp_labels must be equal, '
                                  f'but they are {len(prefixes)}, {len(exp_names)}, and {len(exp_labels)}!')

    # Load the policies
    ex_dirs = [osp.join(p, e) for p, e in zip(prefixes, exp_names)]
    policies = []
    for ex_dir in ex_dirs:
示例#19
0
def after_rollout_query(
    env: Env, policy: Policy, rollout: StepSequence
) -> Tuple[bool, Optional[np.ndarray], Optional[dict]]:
    """
    Ask the user what to do after a rollout has been animated.

    :param env: environment used for the rollout
    :param policy: policy used for the rollout
    :param rollout: collected data from the rollout
    :return: done flag, initial state, and domain parameters
    """
    # Fist entry contains hotkey, second the info text
    options = [
        ["C", "continue simulation (with domain randomization)"],
        ["N", "set domain parameters to nominal values, and continue"],
        ["F", "fix the initial state"],
        ["I", "print information about environment (including randomizer), and policy"],
        ["S", "set a domain parameter explicitly"],
        ["P", "plot all observations, actions, and rewards"],
        ["PS [indices]", "plot all states, or selected ones by passing separated integers"],
        ["PO [indices]", "plot all observations, or selected ones by passing separated integers"],
        ["PA", "plot actions"],
        ["PR", "plot rewards"],
        ["PF", "plot features (for linear policy)"],
        ["PPOT", "plot potentials, stimuli, and actions (for potential-based policies)"],
        ["PDT", "plot time deltas (profiling of a real system)"],
        ["E", "exit"],
    ]

    # Ask for user input
    ans = input(tabulate(options, tablefmt="simple") + "\n").lower()

    if ans == "c" or ans == "":
        # We don't have to do anything here since the env will be reset at the beginning of the next rollout
        return False, None, None

    elif ans == "f":
        try:
            if isinstance(inner_env(env), RealEnv):
                raise pyrado.TypeErr(given=inner_env(env), expected_type=SimEnv)
            elif isinstance(inner_env(env), SimEnv):
                # Get the user input
                usr_inp = input(
                    f"Enter the {env.obs_space.flat_dim}-dim initial state "
                    f"(format: each dim separated by a whitespace):\n"
                )
                state = list(map(float, usr_inp.split()))
                if isinstance(state, list):
                    state = np.array(state)
                    if state.shape != env.obs_space.shape:
                        raise pyrado.ShapeErr(given=state, expected_match=env.obs_space)
                else:
                    raise pyrado.TypeErr(given=state, expected_type=list)
                return False, state, {}
        except (pyrado.TypeErr, pyrado.ShapeErr):
            return after_rollout_query(env, policy, rollout)

    elif ans == "n":
        # Get nominal domain parameters
        if isinstance(inner_env(env), SimEnv):
            dp_nom = inner_env(env).get_nominal_domain_param()
            if typed_env(env, ActDelayWrapper) is not None:
                # There is an ActDelayWrapper in the env chain
                dp_nom["act_delay"] = 0
        else:
            dp_nom = None
        return False, None, dp_nom

    elif ans == "i":
        # Print the information and return to the query
        print(env)
        if hasattr(env, "randomizer"):
            print(env.randomizer)
        print(policy)
        return after_rollout_query(env, policy, rollout)

    elif ans == "p":
        plot_observations_actions_rewards(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pa":
        plot_actions(rollout, env)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "po":
        plot_observations(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif "po" in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_observations(rollout, idcs_sel=idcs)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "ps":
        plot_states(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif "ps" in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_states(rollout, idcs_sel=idcs)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pf":
        plot_features(rollout, policy)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pp":
        draw_policy_params(policy, env.spec, annotate=False)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pr":
        plot_rewards(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pdt":
        draw_dts(rollout.dts_policy, rollout.dts_step, rollout.dts_remainder)
        plt.show()
        return (after_rollout_query(env, policy, rollout),)

    elif ans == "ppot":
        plot_potentials(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "s":
        if isinstance(env, SimEnv):
            dp = env.get_nominal_domain_param()
            for k, v in dp.items():
                dp[k] = [v]  # cast float to list of one element to make it iterable for tabulate
            print("These are the nominal domain parameters:")
            print(tabulate(dp, headers="keys", tablefmt="simple"))

        # Get the user input
        strs = input("Enter one new domain parameter\n(format: key whitespace value):\n")
        try:
            param = dict(str.split() for str in strs.splitlines())
            # Cast the values of the param dict from str to float
            for k, v in param.items():
                param[k] = float(v)
            return False, None, param
        except (ValueError, KeyError):
            print_cbt(f"Could not parse {strs} into a dict.", "r")
            after_rollout_query(env, policy, rollout)

    elif ans == "e":
        env.close()
        return True, None, {}  # breaks the outer while loop

    else:
        return after_rollout_query(env, policy, rollout)  # recursion
示例#20
0
def load_experiment(
        ex_dir: str,
        args: Any = None) -> ([SimEnv, EnvWrapper], Policy, Optional[dict]):
    """
    Load the (training) environment and the policy.
    This helper function first tries to read the hyper-parameters yaml-file in the experiment's directory to infer
    why entities should be loaded. If no file was found, we fall back to some heuristic and hope for the best.

    :param ex_dir: experiment's parent directory
    :param args: arguments from the argument parser
    :return: environment, policy, and optional output (e.g. valuefcn)
    """
    hparams_file_name = 'hyperparams.yaml'
    env, policy, kwout = None, None, dict()

    try:
        hparams = load_dict_from_yaml(osp.join(ex_dir, hparams_file_name))
        kwout['hparams'] = hparams

        # Check which algorithm has been used for training, i.e. what can be loaded, by crawing the hyper-parameters
        # First check meta algorithms so they don't get masked by their subroutines
        if SPOTA.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'init_env.pkl'))
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(100)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'init_env.pkl')} and filled it with 100 random instances.",
                'g')
            # Policy
            if args.iter == -1:
                policy = to.load(osp.join(ex_dir, 'final_policy_cand.pt'))
                print_cbt(f"Loaded {osp.join(ex_dir, 'final_policy_cand.pt')}",
                          'g')
            else:
                policy = to.load(
                    osp.join(ex_dir, f'iter_{args.iter}_policy_cand.pt'))
                print_cbt(
                    f"Loaded {osp.join(ex_dir, f'iter_{args.iter}_policy_cand.pt')}",
                    'g')
            # Value function (optional)
            if any([
                    a.name in hparams.get('subroutine_name', '')
                    for a in [PPO, PPO2, A2C]
            ]):
                try:
                    kwout['valuefcn'] = to.load(
                        osp.join(ex_dir, 'final_valuefcn.pt'))
                    print_cbt(
                        f"Loaded {osp.join(ex_dir, 'final_valuefcn.pt')}", 'g')
                except FileNotFoundError:
                    kwout['valuefcn'] = to.load(osp.join(
                        ex_dir, 'valuefcn.pt'))
                    print_cbt(f"Loaded {osp.join(ex_dir, 'valuefcn.pt')}", 'g')

        elif BayRn.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env_sim.pkl'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
            if hasattr(env, 'randomizer'):
                last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
                env.adapt_randomizer(last_cand.numpy())
                print_cbt(f'Loaded the domain randomizer\n{env.randomizer}',
                          'w')
            # Policy
            if args.iter == -1:
                policy = to.load(osp.join(ex_dir, 'policy.pt'))
                print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
            else:
                policy = to.load(osp.join(ex_dir, f'iter_{args.iter}.pt'))
                print_cbt(f"Loaded {osp.join(ex_dir, f'iter_{args.iter}.pt')}",
                          'g')
            # Value function (optional)
            if any([
                    a.name in hparams.get('subroutine_name', '')
                    for a in [PPO, PPO2, A2C]
            ]):
                try:
                    kwout['valuefcn'] = to.load(
                        osp.join(ex_dir, 'final_valuefcn.pt'))
                    print_cbt(
                        f"Loaded {osp.join(ex_dir, 'final_valuefcn.pt')}", 'g')
                except FileNotFoundError:
                    kwout['valuefcn'] = to.load(osp.join(
                        ex_dir, 'valuefcn.pt'))
                    print_cbt(f"Loaded {osp.join(ex_dir, 'valuefcn.pt')}", 'g')

        elif EPOpt.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')

        elif any(
            [a.name in hparams.get('algo_name', '')
             for a in [PPO, PPO2, A2C]]):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
            # Value function
            kwout['valuefcn'] = to.load(osp.join(ex_dir, 'valuefcn.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'valuefcn.pt')}", 'g')

        elif SAC.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
            # Target value functions
            kwout['target1'] = to.load(osp.join(ex_dir, 'target1.pt'))
            kwout['target2'] = to.load(osp.join(ex_dir, 'target2.pt'))
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'target1.pt')} and {osp.join(ex_dir, 'target2.pt')}",
                'g')

        elif any([
                a.name in hparams.get('algo_name', '')
                for a in [HC, PEPG, NES, REPS, PoWER, CEM]
        ]):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')

        else:
            raise KeyError(
                'No matching algorithm name found during loading the experiment.'
                'Check for the algo_name field in the yaml-file.')

    except (FileNotFoundError, KeyError):
        print_cbt(
            f'Did not find {hparams_file_name} in {ex_dir} or could not crawl the loaded hyper-parameters.',
            'y',
            bright=True)

        try:
            # Results of a standard algorithm
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
        except FileNotFoundError:
            try:
                # Results of SPOTA
                env = joblib.load(osp.join(ex_dir, 'init_env.pkl'))
                typed_env(env, DomainRandWrapperBuffer).fill_buffer(100)
                print_cbt(
                    f"Loaded {osp.join(ex_dir, 'init_env.pkl')} and filled it with 100 random instances.",
                    'g')
            except FileNotFoundError:
                # Results of BayRn
                env = joblib.load(osp.join(ex_dir, 'env_sim.pkl'))

            try:
                # Results of SPOTA
                if args.iter == -1:
                    policy = to.load(osp.join(ex_dir, 'final_policy_cand.pt'))
                    print_cbt(f'Loaded final_policy_cand.pt', 'g')
                else:
                    policy = to.load(
                        osp.join(ex_dir, f'iter_{args.iter}_policy_cand.pt'))
                    print_cbt(f'Loaded iter_{args.iter}_policy_cand.pt', 'g')
            except FileNotFoundError:
                # Results of BayRn
                if args.iter == -1:
                    policy = to.load(osp.join(ex_dir, 'final_policy.pt'))
                    print_cbt(f'Loaded final_policy.pt', 'g')
                else:
                    policy = to.load(
                        osp.join(ex_dir, f'iter_{args.iter}_policy.pt'))
                    print_cbt(f'Loaded iter_{args.iter}_policy.pt', 'g')

    # Check if the return types are correct
    if not isinstance(env, (SimEnv, EnvWrapper)):
        raise pyrado.TypeErr(given=env, expected_type=[SimEnv, EnvWrapper])
    if not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)

    return env, policy, kwout
示例#21
0
    def __init__(self,
                 save_dir: str,
                 env: Env,
                 policy: TwoHeadedPolicy,
                 qfcn_1: Policy,
                 qfcn_2: Policy,
                 memory_size: int,
                 gamma: float,
                 max_iter: int,
                 num_batch_updates: Optional[int] = None,
                 tau: float = 0.995,
                 ent_coeff_init: float = 0.2,
                 learn_ent_coeff: bool = True,
                 target_update_intvl: int = 1,
                 num_init_memory_steps: int = None,
                 standardize_rew: bool = True,
                 rew_scale: Union[int, float] = 1.,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 batch_size: int = 256,
                 num_workers: int = 4,
                 max_grad_norm: float = 5.,
                 lr: float = 3e-4,
                 lr_scheduler=None,
                 lr_scheduler_hparam: Optional[dict] = None,
                 logger: StepLogger = None):
        r"""
        Constructor

        :param save_dir: directory to save the snapshots i.e. the results in
        :param env: the environment which the policy operates
        :param policy: policy to be updated
        :param qfcn_1: state-action value function $Q(s,a)$, the associated target Q-functions is created from a
                        re-initialized copies of this one
        :param qfcn_2: state-action value function $Q(s,a)$, the associated target Q-functions is created from a
                        re-initialized copies of this one
        :param memory_size: number of transitions in the replay memory buffer, e.g. 1000000
        :param gamma: temporal discount factor for the state values
        :param max_iter: number of iterations (policy updates)
        :param num_batch_updates: number of (batched) gradient updates per algorithm step
        :param tau: interpolation factor in averaging for target networks, update used for the soft update a.k.a. polyak
                    update, between 0 and 1
        :param ent_coeff_init: initial weighting factor of the entropy term in the loss function
        :param learn_ent_coeff: adapt the weighting factor of the entropy term
        :param target_update_intvl: number of iterations that pass before updating the target network
        :param num_init_memory_steps: number of samples used to initially fill the replay buffer with, pass `None` to
                                      fill the buffer completely
        :param standardize_rew:  if `True`, the rewards are standardized to be $~ N(0,1)$
        :param rew_scale: scaling factor for the rewards, defaults no scaling
        :param min_rollouts: minimum number of rollouts sampled per policy update batch
        :param min_steps: minimum number of state transitions sampled per policy update batch
        :param batch_size: number of samples per policy update batch
        :param num_workers: number of environments for parallel sampling
        :param max_grad_norm: maximum L2 norm of the gradients for clipping, set to `None` to disable gradient clipping
        :param lr: (initial) learning rate for the optimizer which can be by modified by the scheduler.
                   By default, the learning rate is constant.
        :param lr_scheduler: learning rate scheduler type for the policy and the Q-functions that does one step
                             per `update()` call
        :param lr_scheduler_hparam: hyper-parameters for the learning rate scheduler
        :param logger: logger for every step of the algorithm, if `None` the default logger will be created
        """
        if typed_env(env, ActNormWrapper) is None:
            raise pyrado.TypeErr(msg='SAC required an environment wrapped by an ActNormWrapper!')
        if not isinstance(qfcn_1, Policy):
            raise pyrado.TypeErr(given=qfcn_1, expected_type=Policy)
        if not isinstance(qfcn_2, Policy):
            raise pyrado.TypeErr(given=qfcn_2, expected_type=Policy)

        # Call ValueBased's constructor
        super().__init__(save_dir, env, policy, memory_size, gamma, max_iter, num_batch_updates, target_update_intvl,
                         num_init_memory_steps, min_rollouts, min_steps, batch_size, num_workers, max_grad_norm, logger)

        self.qfcn_1 = qfcn_1
        self.qfcn_2 = qfcn_2
        self.qfcn_targ_1 = deepcopy(self.qfcn_1).eval()  # will not be trained using an optimizer
        self.qfcn_targ_2 = deepcopy(self.qfcn_2).eval()  # will not be trained using an optimizer
        self.tau = tau
        self.learn_ent_coeff = learn_ent_coeff
        self.standardize_rew = standardize_rew
        self.rew_scale = rew_scale

        # Create sampler for exploration during training
        self._expl_strat = SACExplStrat(self._policy)
        self.sampler_trn = ParallelRolloutSampler(
            self._env, self._expl_strat,
            num_workers=num_workers if min_steps != 1 else 1,
            min_steps=min_steps,  # in [2] this would be 1
            min_rollouts=min_rollouts,  # in [2] this would be None
        )

        # Q-function optimizers
        self._optim_policy = to.optim.Adam([{'params': self._policy.parameters()}], lr=lr, eps=1e-5)
        self._optim_qfcns = to.optim.Adam([{'params': self.qfcn_1.parameters()},
                                           {'params': self.qfcn_2.parameters()}], lr=lr, eps=1e-5)

        # Automatic entropy tuning
        log_ent_coeff_init = to.log(to.tensor(ent_coeff_init, device=policy.device, dtype=to.get_default_dtype()))
        if learn_ent_coeff:
            self._log_ent_coeff = nn.Parameter(log_ent_coeff_init, requires_grad=True)
            self._ent_coeff_optim = to.optim.Adam([{'params': self._log_ent_coeff}], lr=lr, eps=1e-5)
            self.target_entropy = -to.prod(to.tensor(env.act_space.shape))
        else:
            self._log_ent_coeff = log_ent_coeff_init

        # Learning rate scheduler
        self._lr_scheduler_policy = lr_scheduler
        self._lr_scheduler_hparam = lr_scheduler_hparam
        if lr_scheduler is not None:
            self._lr_scheduler_policy = lr_scheduler(self._optim_policy, **lr_scheduler_hparam)
            self._lr_scheduler_qfcns = lr_scheduler(self._optim_qfcns, **lr_scheduler_hparam)
示例#22
0
def after_rollout_query(env: Env, policy: Policy, rollout: StepSequence) -> tuple:
    """
    Ask the user what to do after a rollout has been animated.

    :param env: environment used for the rollout
    :param policy: policy used for the rollout
    :param rollout: collected data from the rollout
    :return: done flag, initial state, and domain parameters
    """
    # Fist entry contains hotkey, second the info text
    options = [
        ['C', 'continue simulation (with domain randomization)'],
        ['N', 'set domain parameters to nominal values, and continue'],
        ['F', 'fix the initial state'],
        ['I', 'print information about environment (including randomizer), and policy'],
        ['S', 'set a domain parameter explicitly'],
        ['P', 'plot all observations, actions, and rewards'],
        ['PA', 'plot actions'],
        ['PR', 'plot rewards'],
        ['PO', 'plot all observations'],
        ['PO idcs', 'plot selected observations (integers separated by whitespaces)'],
        ['PF', 'plot features (for linear policy)'],
        ['PP', 'plot policy parameters (not suggested for many parameters)'],
        ['PDT', 'plot time deltas (profiling of a real system)'],
        ['PPOT', 'plot potentials, stimuli, and actions'],
        ['E', 'exit']
    ]

    # Ask for user input
    ans = input(tabulate(options, tablefmt='simple') + '\n').lower()

    if ans == 'c' or ans == '':
        # We don't have to do anything here since the env will be reset at the beginning of the next rollout
        return False, None, None

    elif ans == 'f':
        try:
            if isinstance(inner_env(env), RealEnv):
                raise pyrado.TypeErr(given=inner_env(env), expected_type=SimEnv)
            elif isinstance(inner_env(env), SimEnv):
                # Get the user input
                str = input(f'Enter the {env.obs_space.flat_dim}-dim initial state'
                            f'(format: each dim separated by a whitespace):\n')
                state = list(map(float, str.split()))
                if isinstance(state, list):
                    state = np.array(state)
                    if state.shape != env.obs_space.shape:
                        raise pyrado.ShapeErr(given=state, expected_match=env.obs_space)
                else:
                    raise pyrado.TypeErr(given=state, expected_type=list)
                return False, state, {}
        except (pyrado.TypeErr, pyrado.ShapeErr):
            return after_rollout_query(env, policy, rollout)

    elif ans == 'n':
        # Get nominal domain parameters
        if isinstance(inner_env(env), SimEnv):
            dp_nom = inner_env(env).get_nominal_domain_param()
            if typed_env(env, ActDelayWrapper) is not None:
                # There is an ActDelayWrapper in the env chain
                dp_nom['act_delay'] = 0
        else:
            dp_nom = None
        return False, None, dp_nom

    elif ans == 'i':
        # Print the information and return to the query
        print(env)
        if hasattr(env, 'randomizer'):
            print(env.randomizer)
        print(policy)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'p':
        draw_observations_actions_rewards(rollout)
        plt.plot()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pa':
        draw_actions(rollout, env)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'po':
        draw_observations(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif 'po' in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        draw_observations(rollout, idcs_sel=idcs)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pf':
        draw_features(rollout, policy)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pp':
        draw_policy_params(policy, env.spec, annotate=False)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pr':
        draw_rewards(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pdt':
        draw_dts(rollout.dts_policy, rollout.dts_step, rollout.dts_remainder)
        plt.show()
        return after_rollout_query(env, policy, rollout),

    elif ans == 'ppot':
        draw_potentials(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 's':
        if isinstance(env, SimEnv):
            dp = env.get_nominal_domain_param()
            for k, v in dp.items():
                dp[k] = [v]  # cast float to list of one element to make it iterable for tabulate
            print('These are the nominal domain parameters:')
            print(tabulate(dp, headers="keys", tablefmt='simple'))

        # Get the user input
        strs = input('Enter one new domain parameter\n(format: key whitespace value):\n')
        try:
            param = dict(str.split() for str in strs.splitlines())
            # Cast the values of the param dict from str to float
            for k, v in param.items():
                param[k] = float(v)
            return False, None, param
        except (ValueError, KeyError):
            print_cbt(f'Could not parse {strs} into a dict.', 'r')
            after_rollout_query(env, policy, rollout)

    elif ans == 'e':
        env.close()
        return True, None, {}  # breaks the outer while loop

    else:
        return after_rollout_query(env, policy, rollout)  # recursion
示例#23
0
def load_experiment(
    ex_dir: str,
    args: Any = None
) -> Tuple[Optional[Union[SimEnv, EnvWrapper]], Optional[Policy],
           Optional[dict]]:
    """
    Load the (training) environment and the policy.
    This helper function first tries to read the hyper-parameters yaml-file in the experiment's directory to infer
    why entities should be loaded. If no file was found, we fall back to some heuristic and hope for the best.

    :param ex_dir: experiment's parent directory
    :param args: arguments from the argument parser, pass `None` to fall back to the values from the default argparser
    :return: environment, policy, and optional output (e.g. valuefcn)
    """
    env, policy, extra = None, None, dict()

    if args is None:
        # Fall back to default arguments. By passing [], we ignore the command line arguments
        args = get_argparser().parse_args([])

    # Hyper-parameters
    extra["hparams"] = load_hyperparameters(ex_dir)

    # Algorithm specific
    algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name="algo")

    if algo.name == "spota":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            if not isinstance(env, DomainRandWrapperBuffer):
                raise pyrado.TypeErr(given=env,
                                     expected_type=DomainRandWrapperBuffer)
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(10)
            print_cbt(
                f"Loaded the domain randomizer\n{env.randomizer}\nand filled it with 10 random instances.",
                "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "r")
        # Policy
        policy = pyrado.load(algo.subroutine_cand.policy,
                             f"{args.policy_name}.pt",
                             ex_dir,
                             verbose=True)
        # Extra (value function)
        if isinstance(algo.subroutine_cand, ActorCritic):
            extra["vfcn"] = pyrado.load(algo.subroutine_cand.critic.vfcn,
                                        f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        verbose=True)

    elif algo.name == "bayrn":
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if hasattr(env, "randomizer"):
            last_cand = to.load(osp.join(ex_dir, "candidates.pt"))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f"Loaded the domain randomizer\n{env.randomizer}", "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "r")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        obj=algo.subroutine.critic.vfcn,
                                        verbose=True)

    elif algo.name == "simopt":
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            last_cand = to.load(osp.join(ex_dir, "candidates.pt"))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f"Loaded the domain randomizer\n{env.randomizer}", "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "r")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.subroutine_policy.policy,
                             verbose=True)
        # Extra (domain parameter distribution policy)
        extra["ddp_policy"] = pyrado.load("ddp_policy.pt",
                                          ex_dir,
                                          obj=algo.subroutine_distr.policy,
                                          verbose=True)

    elif algo.name in ["epopt", "udr"]:
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            if not isinstance(env, DomainRandWrapperLive):
                raise pyrado.TypeErr(given=env,
                                     expected_type=DomainRandWrapperLive)
            print_cbt(f"Loaded the domain randomizer\n{env.randomizer}", "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "y")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        obj=algo.subroutine.critic.vfcn,
                                        verbose=True)

    elif algo.name in ["bayessim", "npdr"]:
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            if not isinstance(env, DomainRandWrapperBuffer):
                raise pyrado.TypeErr(given=env,
                                     expected_type=DomainRandWrapperBuffer)
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(10)
            print_cbt(
                f"Loaded the domain randomizer\n{env.randomizer}\nand filled it with 10 random instances.",
                "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "y")
            env = remove_all_dr_wrappers(env, verbose=True)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (prior, posterior, data)
        extra["prior"] = pyrado.load("prior.pt", ex_dir, verbose=True)
        # By default load the latest posterior (latest iteration and the last round)
        try:
            extra["posterior"] = algo.load_posterior(ex_dir,
                                                     args.iter,
                                                     args.round,
                                                     obj=None,
                                                     verbose=True)
            # Load the complete data or the data of the given iteration
            prefix = "" if args.iter == -1 else f"iter_{args.iter}"
            extra["data_real"] = pyrado.load(f"data_real.pt",
                                             ex_dir,
                                             prefix=prefix,
                                             verbose=True)
        except FileNotFoundError:
            pass

    elif algo.name in ["a2c", "ppo", "ppo2"]:
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (value function)
        extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                    ex_dir,
                                    obj=algo.critic.vfcn,
                                    verbose=True)

    elif algo.name in ["hc", "pepg", "power", "cem", "reps", "nes"]:
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)

    elif algo.name in ["dql", "sac"]:
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Target value functions
        if algo.name == "dql":
            extra["qfcn_target"] = pyrado.load("qfcn_target.pt",
                                               ex_dir,
                                               obj=algo.qfcn_targ,
                                               verbose=True)
        elif algo.name == "sac":
            extra["qfcn_target1"] = pyrado.load("qfcn_target1.pt",
                                                ex_dir,
                                                obj=algo.qfcn_targ_1,
                                                verbose=True)
            extra["qfcn_target2"] = pyrado.load("qfcn_target2.pt",
                                                ex_dir,
                                                obj=algo.qfcn_targ_2,
                                                verbose=True)
        else:
            raise NotImplementedError

    elif algo.name == "svpg":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (particles)
        for idx, p in enumerate(algo.particles):
            extra[f"particle{idx}"] = pyrado.load(f"particle_{idx}.pt",
                                                  ex_dir,
                                                  obj=algo.particles[idx],
                                                  verbose=True)

    elif algo.name == "tspred":
        # Dataset
        extra["dataset"] = to.load(osp.join(ex_dir, "dataset.pt"))
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)

    elif algo.name == "sprl":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env.pkl')}.", "g")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt", ex_dir, obj=algo.policy)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", "g")
        # Extra (value function)
        if isinstance(algo._subroutine, ActorCritic):
            extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        obj=algo._subroutine.critic.vfcn,
                                        verbose=True)

    elif algo.name == "pddr":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Teachers
        extra["teacher_policies"] = algo.teacher_policies
        extra["teacher_envs"] = algo.teacher_envs
        extra["teacher_expl_strats"] = algo.teacher_expl_strats
        extra["teacher_critics"] = algo.teacher_critics
        extra["teacher_ex_dirs"] = algo.teacher_ex_dirs

    else:
        raise pyrado.TypeErr(
            msg=
            "No matching algorithm name found during loading the experiment!")

    # Check if the return types are correct. They can be None, too.
    if env is not None and not isinstance(env, (SimEnv, EnvWrapper)):
        raise pyrado.TypeErr(given=env, expected_type=[SimEnv, EnvWrapper])
    if policy is not None and not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)
    if extra is not None and not isinstance(extra, dict):
        raise pyrado.TypeErr(given=extra, expected_type=dict)

    return env, policy, extra
示例#24
0
    def __init__(self,
                 save_dir: str,
                 env: Env,
                 policy: TwoHeadedPolicy,
                 q_fcn_1: Policy,
                 q_fcn_2: Policy,
                 memory_size: int,
                 gamma: float,
                 max_iter: int,
                 num_batch_updates: int,
                 tau: float = 0.995,
                 alpha_init: float = 0.2,
                 learn_alpha: bool = True,
                 target_update_intvl: int = 1,
                 standardize_rew: bool = True,
                 batch_size: int = 500,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 num_sampler_envs: int = 4,
                 max_grad_norm: float = 5.,
                 lr: float = 3e-4,
                 lr_scheduler=None,
                 lr_scheduler_hparam: [dict, None] = None,
                 logger: StepLogger = None):
        """
        Constructor

        :param save_dir: directory to save the snapshots i.e. the results in
        :param env: the environment which the policy operates
        :param policy: policy to be updated
        :param q_fcn_1: state-action value function $Q(s,a)$, the associated target Q-functions is created from a
                        re-initialized copies of this one
        :param q_fcn_2: state-action value function $Q(s,a)$, the associated target Q-functions is created from a
                        re-initialized copies of this one
        :param memory_size: number of transitions in the replay memory buffer, e.g. 1000000
        :param gamma: temporal discount factor for the state values
        :param max_iter: number of iterations (policy updates)
        :param num_batch_updates: number of batch updates per algorithm steps
        :param tau: interpolation factor in averaging for target networks, update used for the soft update a.k.a. polyak
                    update, between 0 and 1
        :param alpha_init: initial weighting factor of the entropy term in the loss function
        :param learn_alpha: adapt the weighting factor of the entropy term
        :param target_update_intvl: number of iterations that pass before updating the target network
        :param standardize_rew: bool to flag if the rewards should be standardized
        :param batch_size: number of samples per policy update batch
        :param min_rollouts: minimum number of rollouts sampled per policy update batch
        :param min_steps: minimum number of state transitions sampled per policy update batch
        :param num_sampler_envs: number of environments for parallel sampling
        :param max_grad_norm: maximum L2 norm of the gradients for clipping, set to `None` to disable gradient clipping
        :param lr: (initial) learning rate for the optimizer which can be by modified by the scheduler.
                   By default, the learning rate is constant.
        :param lr_scheduler: learning rate scheduler type for the policy and the Q-functions that does one step
                             per `update()` call
        :param lr_scheduler_hparam: hyper-parameters for the learning rate scheduler
        :param logger: logger for every step of the algorithm, if `None` the default logger will be created
        """
        if not isinstance(env, Env):
            raise pyrado.TypeErr(given=env, expected_type=Env)
        if typed_env(env, ActNormWrapper) is None:
            raise pyrado.TypeErr(
                msg='SAC required an environment wrapped by an ActNormWrapper!'
            )
        if not isinstance(q_fcn_1, Policy):
            raise pyrado.TypeErr(given=q_fcn_1, expected_type=Policy)
        if not isinstance(q_fcn_2, Policy):
            raise pyrado.TypeErr(given=q_fcn_2, expected_type=Policy)

        if logger is None:
            # Create logger that only logs every 100 steps of the algorithm
            logger = StepLogger(print_interval=100)
            logger.printers.append(ConsolePrinter())
            logger.printers.append(
                CSVPrinter(osp.join(save_dir, 'progress.csv')))

        # Call Algorithm's constructor
        super().__init__(save_dir, max_iter, policy, logger)

        # Store the inputs
        self._env = env
        self.q_fcn_1 = q_fcn_1
        self.q_fcn_2 = q_fcn_2
        self.q_targ_1 = deepcopy(self.q_fcn_1)
        self.q_targ_2 = deepcopy(self.q_fcn_2)
        self.q_targ_1.eval()
        self.q_targ_2.eval()
        self.gamma = gamma
        self.tau = tau
        self.learn_alpha = learn_alpha
        self.target_update_intvl = target_update_intvl
        self.standardize_rew = standardize_rew
        self.num_batch_updates = num_batch_updates
        self.batch_size = batch_size
        self.max_grad_norm = max_grad_norm

        # Initialize
        self._memory = ReplayMemory(memory_size)
        if policy.is_recurrent:
            init_expl_policy = RecurrentDummyPolicy(env.spec,
                                                    policy.hidden_size)
        else:
            init_expl_policy = DummyPolicy(env.spec)
        self.sampler_init = ParallelSampler(
            env,
            init_expl_policy,  # samples uniformly random from the action space
            num_envs=num_sampler_envs,
            min_steps=memory_size,
        )
        self._expl_strat = SACExplStrat(
            self._policy,
            std_init=1.)  # std_init will be overwritten by 2nd policy head
        self.sampler = ParallelSampler(
            env,
            self._expl_strat,
            num_envs=1,
            min_steps=min_steps,  # in [2] this would be 1
            min_rollouts=min_rollouts  # in [2] this would be None
        )
        self.sampler_eval = ParallelSampler(env,
                                            self._policy,
                                            num_envs=num_sampler_envs,
                                            min_steps=100 * env.max_steps,
                                            min_rollouts=None)
        self._optim_policy = to.optim.Adam([{
            'params': self._policy.parameters()
        }],
                                           lr=lr)
        self._optim_q_fcn_1 = to.optim.Adam(
            [{
                'params': self.q_fcn_1.parameters()
            }], lr=lr)
        self._optim_q_fcn_2 = to.optim.Adam(
            [{
                'params': self.q_fcn_2.parameters()
            }], lr=lr)
        log_alpha_init = to.log(
            to.tensor(alpha_init, dtype=to.get_default_dtype()))
        if learn_alpha:
            # Automatic entropy tuning
            self._log_alpha = nn.Parameter(log_alpha_init, requires_grad=True)
            self._alpha_optim = to.optim.Adam([{
                'params': self._log_alpha
            }],
                                              lr=lr)
            self.target_entropy = -to.prod(to.tensor(env.act_space.shape))
        else:
            self._log_alpha = log_alpha_init

        self._lr_scheduler_policy = lr_scheduler
        self._lr_scheduler_hparam = lr_scheduler_hparam
        if lr_scheduler is not None:
            self._lr_scheduler_policy = lr_scheduler(self._optim_policy,
                                                     **lr_scheduler_hparam)
            self._lr_scheduler_q_fcn_1 = lr_scheduler(self._optim_q_fcn_1,
                                                      **lr_scheduler_hparam)
            self._lr_scheduler_q_fcn_2 = lr_scheduler(self._optim_q_fcn_2,
                                                      **lr_scheduler_hparam)
示例#25
0
def test_combination(env: SimEnv):
    pyrado.set_seed(0)
    env.max_steps = 20

    randomizer = create_default_randomizer(env)
    env_r = DomainRandWrapperBuffer(env, randomizer)
    env_r.fill_buffer(num_domains=3)

    dp_before = []
    dp_after = []
    for i in range(4):
        dp_before.append(env_r.domain_param)
        rollout(env_r,
                DummyPolicy(env_r.spec),
                eval=True,
                seed=0,
                render_mode=RenderMode())
        dp_after.append(env_r.domain_param)
        assert dp_after[i] != dp_before[i]
    assert dp_after[0] == dp_after[3]

    env_rn = ActNormWrapper(env)
    elb = {"x_dot": -213.0, "theta_dot": -42.0}
    eub = {"x_dot": 213.0, "theta_dot": 42.0, "x": 0.123}
    env_rn = ObsNormWrapper(env_rn, explicit_lb=elb, explicit_ub=eub)
    alb, aub = env_rn.act_space.bounds
    assert all(alb == -1)
    assert all(aub == 1)
    olb, oub = env_rn.obs_space.bounds
    assert all(olb == -1)
    assert all(oub == 1)

    ro_r = rollout(env_r,
                   DummyPolicy(env_r.spec),
                   eval=True,
                   seed=0,
                   render_mode=RenderMode())
    ro_rn = rollout(env_rn,
                    DummyPolicy(env_rn.spec),
                    eval=True,
                    seed=0,
                    render_mode=RenderMode())
    assert np.allclose(env_rn._process_obs(ro_r.observations),
                       ro_rn.observations)

    env_rnp = ObsPartialWrapper(
        env_rn, idcs=[env.obs_space.labels[2], env.obs_space.labels[3]])
    ro_rnp = rollout(env_rnp,
                     DummyPolicy(env_rnp.spec),
                     eval=True,
                     seed=0,
                     render_mode=RenderMode())

    env_rnpa = GaussianActNoiseWrapper(
        env_rnp,
        noise_mean=0.5 * np.ones(env_rnp.act_space.shape),
        noise_std=0.1 * np.ones(env_rnp.act_space.shape))
    ro_rnpa = rollout(env_rnpa,
                      DummyPolicy(env_rnpa.spec),
                      eval=True,
                      seed=0,
                      render_mode=RenderMode())
    assert not np.allclose(
        ro_rnp.observations,
        ro_rnpa.observations)  # the action noise changed to rollout

    env_rnpd = ActDelayWrapper(env_rnp, delay=3)
    ro_rnpd = rollout(env_rnpd,
                      DummyPolicy(env_rnpd.spec),
                      eval=True,
                      seed=0,
                      render_mode=RenderMode())
    assert np.allclose(ro_rnp.actions, ro_rnpd.actions)
    assert not np.allclose(ro_rnp.observations, ro_rnpd.observations)

    assert type(inner_env(env_rnpd)) == type(env)
    assert typed_env(env_rnpd, ObsPartialWrapper) is not None
    assert isinstance(env_rnpd, ActDelayWrapper)
    env_rnpdr = remove_env(env_rnpd, ActDelayWrapper)
    assert not isinstance(env_rnpdr, ActDelayWrapper)
示例#26
0
    def __init__(
        self,
        env: DomainRandWrapper,
        subroutine: Algorithm,
        kl_constraints_ub: float,
        max_iter: int,
        performance_lower_bound: float,
        std_lower_bound: float = 0.2,
        kl_threshold: float = 0.1,
        optimize_mean: bool = True,
        optimize_cov: bool = True,
        max_subrtn_retries: int = 1,
    ):
        """
        Constructor

        :param env: environment wrapped in a DomainRandWrapper
        :param subroutine: algorithm which performs the policy/value-function optimization, which
                           must expose its sampler
        :param kl_constraints_ub: upper bound for the KL-divergence
        :param max_iter: Maximal iterations for the SPRL algorithm (not for the subroutine)
        :param performance_lower_bound: lower bound for the performance SPRL tries to stay above
                                        during distribution updates
        :param std_lower_bound: clipping value for the standard deviation,necessary when using
                                         very small target variances
        :param kl_threshold: threshold for the KL-divergence until which std_lower_bound is enforced
        :param optimize_mean: whether the mean should be changed or considered fixed
        :param optimize_cov: whether the (co-)variance should be changed or considered fixed
        :param max_subrtn_retries: how often a failed (median performance < 30 % of performance_lower_bound)
                                   training attempt of the subroutine should be reattempted
        """
        if not isinstance(subroutine, Algorithm):
            raise pyrado.TypeErr(given=subroutine, expected_type=Algorithm)
        if not hasattr(subroutine, "sampler"):
            raise AttributeError(
                "The subroutine must have a sampler attribute!")
        if not typed_env(env, DomainRandWrapper):
            raise pyrado.TypeErr(given=env, expected_type=DomainRandWrapper)

        # Call Algorithm's constructor with the subroutine's properties
        super().__init__(subroutine.save_dir, max_iter, subroutine.policy,
                         subroutine.logger)

        # Wrap the sampler of the subroutine with an rollout saving wrapper
        self._subroutine = subroutine
        self._subroutine.sampler = RolloutSavingWrapper(subroutine.sampler)
        self._subroutine.save_name = self._subroutine.name

        self._env = env

        # Properties for the variance bound and kl constraint
        self._kl_constraints_ub = kl_constraints_ub
        self._std_lower_bound = std_lower_bound
        self._kl_threshold = kl_threshold

        # Properties of the performance constraint
        self._performance_lower_bound = performance_lower_bound
        self._performance_lower_bound_reached = False

        self._optimize_mean = optimize_mean
        self._optimize_cov = optimize_cov

        self._max_subrtn_retries = max_subrtn_retries

        self._spl_parameters = [
            param for param in env.randomizer.domain_params
            if isinstance(param, SelfPacedDomainParam)
        ]
示例#27
0
    def __init__(self,
                 save_dir: str,
                 env_sim: MetaDomainRandWrapper,
                 env_real: [RealEnv, EnvWrapper],
                 subrtn: Algorithm,
                 ddp_space: BoxSpace,
                 max_iter: int,
                 acq_fc: str,
                 acq_restarts: int,
                 acq_samples: int,
                 acq_param: dict = None,
                 num_init_cand: int = 5,
                 mc_estimator: bool = True,
                 num_eval_rollouts_real: int = 5,
                 num_eval_rollouts_sim: int = 50,
                 thold_succ: float = pyrado.inf,
                 thold_succ_subrtn: float = -pyrado.inf,
                 warmstart: bool = True,
                 policy_param_init: Optional[to.Tensor] = None,
                 valuefcn_param_init: Optional[to.Tensor] = None,
                 subrtn_snapshot_mode: str = 'best',
                 logger: Optional[StepLogger] = None):
        """
        Constructor

        .. note::
            If you want to continue an experiment, use the `load_dir` argument for the `train` call. If you want to
            initialize every of the policies with a pre-trained policy parameters use `policy_param_init`.

        :param save_dir: directory to save the snapshots i.e. the results in
        :param env_sim: randomized simulation environment a.k.a. source domain
        :param env_real: real-world environment a.k.a. target domain
        :param subrtn: algorithm which performs the policy / value-function optimization
        :param ddp_space: space holding the boundaries for the domain distribution parameters
        :param max_iter: maximum number of iterations
        :param acq_fc: Acquisition Function
                       'UCB': Upper Confidence Bound (default $\beta = 0.1$)
                       'EI': Expected Improvement
                       'PI': Probability of Improvement
        :param acq_restarts: number of restarts for optimizing the acquisition function
        :param acq_samples: number of initial samples for optimizing the acquisition function
        :param acq_param: hyper-parameter for the acquisition function, e.g. $\beta$ for UCB
        :param num_init_cand: number of initial policies to train, ignored if `init_dir` is provided
        :param mc_estimator: estimate the return with a sample average (`True`) or a lower confidence
                                     bound (`False`) obtained from bootstrapping
        :param num_eval_rollouts_real: number of rollouts in the target domain to estimate the return
        :param num_eval_rollouts_sim: number of rollouts in simulation to estimate the return after training
        :param thold_succ: success threshold on the real system's return for BayRn, stop the algorithm if exceeded
        :param thold_succ_subrtn: success threshold on the simulated system's return for the subroutine, repeat the
                                      subroutine until the threshold is exceeded or the for a given number of iterations
        :param warmstart: initialize the policy parameters with the one of the previous iteration. This option has no
                          effect for initial policies and can be overruled by passing init policy params explicitly.
        :param policy_param_init: initial policy parameter values for the subroutine, set `None` to be random
        :param valuefcn_param_init: initial value function parameter values for the subroutine, set `None` to be random
        :param subrtn_snapshot_mode: snapshot mode for saving during training of the subroutine
        :param logger: logger for every step of the algorithm, if `None` the default logger will be created
        """
        if typed_env(env_sim, MetaDomainRandWrapper) is None:
            raise pyrado.TypeErr(given=env_sim, expected_type=MetaDomainRandWrapper)
        if not isinstance(subrtn, Algorithm):
            raise pyrado.TypeErr(given=subrtn, expected_type=Algorithm)
        if not isinstance(ddp_space, BoxSpace):
            raise pyrado.TypeErr(given=ddp_space, expected_type=BoxSpace)
        if num_init_cand < 1:
            raise pyrado.ValueErr(given=num_init_cand, ge_constraint='1')

        # Call InterruptableAlgorithm's constructor without specifying the policy
        super().__init__(num_checkpoints=2, init_checkpoint=-2, save_dir=save_dir, max_iter=max_iter,
                         policy=subrtn.policy, logger=logger)

        self._env_sim = env_sim
        self._env_real = env_real
        self._subrtn = subrtn
        self._subrtn.save_name = 'subrtn'
        self.ddp_space = ddp_space
        self.ddp_projector = UnitCubeProjector(to.from_numpy(self.ddp_space.bound_lo),
                                               to.from_numpy(self.ddp_space.bound_up))
        self.cands = None  # called x in the context of GPs
        self.cands_values = None  # called y in the context of GPs
        self.argmax_cand = to.Tensor()
        self.acq_fcn_type = acq_fc.upper()
        self.acq_restarts = acq_restarts
        self.acq_samples = acq_samples
        self.acq_param = acq_param
        self.num_init_cand = num_init_cand
        self.mc_estimator = mc_estimator
        self.policy_param_init = policy_param_init
        self.valuefcn_param_init = valuefcn_param_init.detach() if valuefcn_param_init is not None else None
        self.warmstart = warmstart
        self.num_eval_rollouts_real = num_eval_rollouts_real
        self.num_eval_rollouts_sim = num_eval_rollouts_sim
        self.subrtn_snapshot_mode = subrtn_snapshot_mode
        self.thold_succ = to.tensor([thold_succ])
        self.thold_succ_subrtn = to.tensor([thold_succ_subrtn])
        self.max_subrtn_rep = 3  # number of tries to exceed thold_succ_subrtn during training in simulation
        self.curr_cand_value = -pyrado.inf  # for the stopping criterion

        if self.policy_param_init is not None:
            if to.is_tensor(self.policy_param_init):
                self.policy_param_init.detach()
            else:
                self.policy_param_init = to.tensor(self.policy_param_init)

        # Save initial environments and the domain distribution parameter space
        self.save_snapshot(meta_info=None)
        pyrado.save(self.ddp_space, 'ddp_space', 'pkl', self.save_dir)
示例#28
0
def load_experiment(
        ex_dir: str,
        args: Any = None) -> (Union[SimEnv, EnvWrapper], Policy, dict):
    """
    Load the (training) environment and the policy.
    This helper function first tries to read the hyper-parameters yaml-file in the experiment's directory to infer
    why entities should be loaded. If no file was found, we fall back to some heuristic and hope for the best.

    :param ex_dir: experiment's parent directory
    :param args: arguments from the argument parser, pass `None` to fall back to the values from the default argparser
    :return: environment, policy, and optional output (e.g. valuefcn)
    """
    env, policy, extra = None, None, dict()

    if args is None:
        # Fall back to default arguments. By passing [], we ignore the command line arguments
        args = get_argparser().parse_args([])

    # Hyper-parameters
    hparams_file_name = 'hyperparams.yaml'
    try:
        hparams = load_dict_from_yaml(osp.join(ex_dir, hparams_file_name))
        extra['hparams'] = hparams
    except (pyrado.PathErr, FileNotFoundError, KeyError):
        print_cbt(
            f'Did not find {hparams_file_name} in {ex_dir} or could not crawl the loaded hyper-parameters.',
            'y',
            bright=True)

    # Algorithm specific
    algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name='algo')
    if isinstance(algo, BayRn):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f'Loaded the domain randomizer\n{env.randomizer}', 'w')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, SPOTA):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            if not isinstance(env.randomizer, DomainRandWrapperBuffer):
                raise pyrado.TypeErr(given=env.randomizer,
                                     expected_type=DomainRandWrapperBuffer)
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(100)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'env.pkl')} and filled it with 100 random instances.",
                'g')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.subroutine_cand.policy,
                             f'{args.policy_name}', 'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine_cand, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine_cand.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, SimOpt):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f'Loaded the domain randomizer\n{env.randomizer}', 'w')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.subroutine_policy.policy,
                             f'{args.policy_name}', 'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (domain parameter distribution policy)
        extra['ddp_policy'] = pyrado.load(algo.subroutine_distr.policy,
                                          'ddp_policy', 'pt', ex_dir, None)

    elif isinstance(algo, (EPOpt, UDR)):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        if hasattr(env, 'randomizer'):
            if not isinstance(env.randomizer, DomainRandWrapperLive):
                raise pyrado.TypeErr(given=env.randomizer,
                                     expected_type=DomainRandWrapperLive)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'env.pkl')} with DomainRandWrapperLive randomizer.",
                'g')
        else:
            print_cbt('Loaded environment has no randomizer.', 'y')
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, ActorCritic):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        extra['vfcn'] = pyrado.load(algo.critic.vfcn, f'{args.vfcn_name}',
                                    'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}", 'g')

    elif isinstance(algo, ParameterExploring):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')

    elif isinstance(algo, ValueBased):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Target value functions
        if isinstance(algo, DQL):
            extra['qfcn_target'] = pyrado.load(algo.qfcn_targ, 'qfcn_target',
                                               'pt', ex_dir, None)
            print_cbt(f"Loaded {osp.join(ex_dir, 'qfcn_target.pt')}", 'g')
        elif isinstance(algo, SAC):
            extra['qfcn_target1'] = pyrado.load(algo.qfcn_targ_1,
                                                'qfcn_target1', 'pt', ex_dir,
                                                None)
            extra['qfcn_target2'] = pyrado.load(algo.qfcn_targ_2,
                                                'qfcn_target2', 'pt', ex_dir,
                                                None)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'qfcn_target1.pt')} and {osp.join(ex_dir, 'qfcn_target2.pt')}",
                'g')
        else:
            raise NotImplementedError

    elif isinstance(algo, SVPG):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (particles)
        for idx, p in enumerate(algo.particles):
            extra[f'particle{idx}'] = pyrado.load(algo.particles[idx],
                                                  f'particle_{idx}', 'pt',
                                                  ex_dir, None)

    elif isinstance(algo, TSPred):
        # Dataset
        extra['dataset'] = to.load(osp.join(ex_dir, 'dataset.pt'))
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)

    else:
        raise pyrado.TypeErr(
            msg=
            'No matching algorithm name found during loading the experiment!')

    # Check if the return types are correct. They can be None, too.
    if env is not None and not isinstance(env, (SimEnv, EnvWrapper)):
        raise pyrado.TypeErr(given=env, expected_type=[SimEnv, EnvWrapper])
    if policy is not None and not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)
    if extra is not None and not isinstance(extra, dict):
        raise pyrado.TypeErr(given=extra, expected_type=dict)

    return env, policy, extra