コード例 #1
0
def remove_all_dr_wrappers(env: SimEnv, verbose: bool = False):
    """
    Go through the environment chain and remove all wrappers of type `DomainRandWrapper` (and subclasses).

    :param env: env chain with domain randomization wrappers
    :param verbose: choose if status messages should be printed
    :return: env chain without domain randomization wrappers
    """
    while any(
            isinstance(subenv, DomainRandWrapper) for subenv in all_envs(env)):
        if verbose:
            print_cbt(
                'Found domain randomization wrapper, trying to remove it.',
                'y',
                bright=True)
        try:
            env = remove_env(env, DomainRandWrapper)
            if verbose:
                print_cbt('Removed a domain randomization wrapper.',
                          'g',
                          bright=True)
        except Exception:
            raise RuntimeError(
                'Could not remove the domain randomization wrapper!')
    return env
コード例 #2
0
def remove_all_dr_wrappers(env: Env, verbose: bool = False):
    """
    Go through the environment chain and remove all wrappers of type `DomainRandWrapper` (and subclasses).

    :param env: env chain with domain randomization wrappers
    :param verbose: choose if status messages should be printed
    :return: env chain without domain randomization wrappers
    """
    while any(isinstance(subenv, DomainRandWrapper) for subenv in all_envs(env)):
        if verbose:
            with completion_context(
                f"Found domain randomization wrapper of type {type(env).__name__}. Removing it now",
                color="y",
                bright=True,
            ):
                env = remove_env(env, DomainRandWrapper)
        else:
            env = remove_env(env, DomainRandWrapper)
    return env
コード例 #3
0
def test_combination():
    env = QCartPoleSwingUpSim(dt=1/50., max_steps=20)

    randomizer = create_default_randomizer(env)
    env_r = DomainRandWrapperBuffer(env, randomizer)
    env_r.fill_buffer(num_domains=3)

    dp_before = []
    dp_after = []
    for i in range(4):
        dp_before.append(env_r.domain_param)
        rollout(env_r, DummyPolicy(env_r.spec), eval=True, seed=0, render_mode=RenderMode())
        dp_after.append(env_r.domain_param)
        assert dp_after[i] != dp_before[i]
    assert dp_after[0] == dp_after[3]

    env_rn = ActNormWrapper(env)
    elb = {'x_dot': -213., 'theta_dot': -42.}
    eub = {'x_dot': 213., 'theta_dot': 42., 'x': 0.123}
    env_rn = ObsNormWrapper(env_rn, explicit_lb=elb, explicit_ub=eub)
    alb, aub = env_rn.act_space.bounds
    assert all(alb == -1)
    assert all(aub == 1)
    olb, oub = env_rn.obs_space.bounds
    assert all(olb == -1)
    assert all(oub == 1)

    ro_r = rollout(env_r, DummyPolicy(env_r.spec), eval=True, seed=0, render_mode=RenderMode())
    ro_rn = rollout(env_rn, DummyPolicy(env_rn.spec), eval=True, seed=0, render_mode=RenderMode())
    assert np.allclose(env_rn._process_obs(ro_r.observations), ro_rn.observations)

    env_rnp = ObsPartialWrapper(env_rn, idcs=['x_dot', r'cos_theta'])
    ro_rnp = rollout(env_rnp, DummyPolicy(env_rnp.spec), eval=True, seed=0, render_mode=RenderMode())

    env_rnpa = GaussianActNoiseWrapper(env_rnp,
                                       noise_mean=0.5*np.ones(env_rnp.act_space.shape),
                                       noise_std=0.1*np.ones(env_rnp.act_space.shape))
    ro_rnpa = rollout(env_rnpa, DummyPolicy(env_rnpa.spec), eval=True, seed=0, render_mode=RenderMode())
    assert np.allclose(ro_rnp.actions, ro_rnpa.actions)
    assert not np.allclose(ro_rnp.observations, ro_rnpa.observations)

    env_rnpd = ActDelayWrapper(env_rnp, delay=3)
    ro_rnpd = rollout(env_rnpd, DummyPolicy(env_rnpd.spec), eval=True, seed=0, render_mode=RenderMode())
    assert np.allclose(ro_rnp.actions, ro_rnpd.actions)
    assert not np.allclose(ro_rnp.observations, ro_rnpd.observations)

    assert isinstance(inner_env(env_rnpd), QCartPoleSwingUpSim)
    assert typed_env(env_rnpd, ObsPartialWrapper) is not None
    assert isinstance(env_rnpd, ActDelayWrapper)
    env_rnpdr = remove_env(env_rnpd, ActDelayWrapper)
    assert not isinstance(env_rnpdr, ActDelayWrapper)
コード例 #4
0
def sim_policy_fixed_env(env: SimEnv, policy: Policy,
                         domain_param: [dict, list]):
    """
    Simulate (with animation) a rollout in a environment with fixed domain parameters.

    :param env: environment stack as it was used during training
    :param policy: policy to simulate
    :param domain_param: domain parameter set or a list of sets that specify the environment
    """
    # Remove wrappers that make the rollouts stochastic
    env = remove_env(env, GaussianObsNoiseWrapper)
    env = remove_env(env, DomainRandWrapperBuffer)
    env = remove_env(env, DomainRandWrapperLive)

    # Initialize
    done, state, i = False, None, 0
    if isinstance(domain_param, dict):
        param = domain_param
    elif isinstance(domain_param, list):
        param = domain_param[i]
    else:
        raise pyrado.TypeErr(given=domain_param, expected_type=[dict, list])

    while not done:
        ro = rollout(env,
                     policy,
                     reset_kwargs=dict(domain_param=param, init_state=state),
                     render_mode=RenderMode(video=True),
                     eval=True)
        print_domain_params(env.domain_param)
        print_cbt(f'Return: {ro.undiscounted_return()}', 'g', bright=True)
        done, state, _ = after_rollout_query(env, policy, ro)

        if isinstance(domain_param, list):
            # Iterate over the list of domain parameter sets
            i = (i + 1) % len(domain_param)
            param = domain_param[i]
コード例 #5
0
def conditional_actnorm_wrapper(env: Env, ex_dirs: list, idx: int):
    """
    Wrap the environment with an action normalization wrapper if the simulated environment had one.

    :param env: environment to sample from
    :param ex_dirs: list of experiment directories that will be loaded
    :param idx: index of the current directory
    :return: modified environment
    """
    # Get the simulation environment
    env_sim, _, _ = load_experiment(ex_dirs[idx])

    if typed_env(env_sim, ActNormWrapper) is not None:
        env = ActNormWrapper(env)
        print_cbt(
            f'Added an action normalization wrapper to {idx + 1}-th evaluation policy.',
            'y')
    else:
        env = remove_env(env, ActNormWrapper)
        print_cbt(
            f'Removed an action normalization wrapper to {idx + 1}-th evaluation policy.',
            'y')
    return env
コード例 #6
0
def test_combination(env: SimEnv):
    pyrado.set_seed(0)
    env.max_steps = 20

    randomizer = create_default_randomizer(env)
    env_r = DomainRandWrapperBuffer(env, randomizer)
    env_r.fill_buffer(num_domains=3)

    dp_before = []
    dp_after = []
    for i in range(4):
        dp_before.append(env_r.domain_param)
        rollout(env_r,
                DummyPolicy(env_r.spec),
                eval=True,
                seed=0,
                render_mode=RenderMode())
        dp_after.append(env_r.domain_param)
        assert dp_after[i] != dp_before[i]
    assert dp_after[0] == dp_after[3]

    env_rn = ActNormWrapper(env)
    elb = {"x_dot": -213.0, "theta_dot": -42.0}
    eub = {"x_dot": 213.0, "theta_dot": 42.0, "x": 0.123}
    env_rn = ObsNormWrapper(env_rn, explicit_lb=elb, explicit_ub=eub)
    alb, aub = env_rn.act_space.bounds
    assert all(alb == -1)
    assert all(aub == 1)
    olb, oub = env_rn.obs_space.bounds
    assert all(olb == -1)
    assert all(oub == 1)

    ro_r = rollout(env_r,
                   DummyPolicy(env_r.spec),
                   eval=True,
                   seed=0,
                   render_mode=RenderMode())
    ro_rn = rollout(env_rn,
                    DummyPolicy(env_rn.spec),
                    eval=True,
                    seed=0,
                    render_mode=RenderMode())
    assert np.allclose(env_rn._process_obs(ro_r.observations),
                       ro_rn.observations)

    env_rnp = ObsPartialWrapper(
        env_rn, idcs=[env.obs_space.labels[2], env.obs_space.labels[3]])
    ro_rnp = rollout(env_rnp,
                     DummyPolicy(env_rnp.spec),
                     eval=True,
                     seed=0,
                     render_mode=RenderMode())

    env_rnpa = GaussianActNoiseWrapper(
        env_rnp,
        noise_mean=0.5 * np.ones(env_rnp.act_space.shape),
        noise_std=0.1 * np.ones(env_rnp.act_space.shape))
    ro_rnpa = rollout(env_rnpa,
                      DummyPolicy(env_rnpa.spec),
                      eval=True,
                      seed=0,
                      render_mode=RenderMode())
    assert not np.allclose(
        ro_rnp.observations,
        ro_rnpa.observations)  # the action noise changed to rollout

    env_rnpd = ActDelayWrapper(env_rnp, delay=3)
    ro_rnpd = rollout(env_rnpd,
                      DummyPolicy(env_rnpd.spec),
                      eval=True,
                      seed=0,
                      render_mode=RenderMode())
    assert np.allclose(ro_rnp.actions, ro_rnpd.actions)
    assert not np.allclose(ro_rnp.observations, ro_rnpd.observations)

    assert type(inner_env(env_rnpd)) == type(env)
    assert typed_env(env_rnpd, ObsPartialWrapper) is not None
    assert isinstance(env_rnpd, ActDelayWrapper)
    env_rnpdr = remove_env(env_rnpd, ActDelayWrapper)
    assert not isinstance(env_rnpdr, ActDelayWrapper)