Beispiel #1
0
    # Set seed if desired
    pyrado.set_seed(args.seed, verbose=True)

    # Environment
    env_hparams = dict(dt=1 / 100.0, max_steps=600)
    env = QQubeSwingUpSim(**env_hparams)
    # env = ObsVelFiltWrapper(env, idcs_pos=["theta", "alpha"], idcs_vel=["theta_dot", "alpha_dot"])
    env = ActNormWrapper(env)

    # Policy
    if args.mode is None or args.mode.lower() == FNNPolicy.name:
        policy_hparam = dict(hidden_sizes=[64, 64], hidden_nonlin=to.tanh)
        policy = FNNPolicy(spec=env.spec, **policy_hparam)
    elif args.mode.lower() == GRUPolicy.name:
        policy_hparam = dict(hidden_size=32, num_recurrent_layers=1)
        policy = GRUPolicy(spec=env.spec, **policy_hparam)

    # Critic
    if isinstance(policy, FNNPolicy):
        vfcn_hparam = dict(hidden_sizes=[32, 32], hidden_nonlin=to.relu)  # FNN
        vfcn = FNNPolicy(spec=EnvSpec(env.obs_space, ValueFunctionSpace), **vfcn_hparam)
    elif isinstance(policy, GRUPolicy):
        vfcn_hparam = dict(hidden_size=32, num_recurrent_layers=1)  # LSTM & GRU
        vfcn = GRUPolicy(spec=EnvSpec(env.obs_space, ValueFunctionSpace), **vfcn_hparam)
    critic_hparam = dict(
        gamma=0.9844224855479998,
        lamda=0.9700148505302241,
        num_epoch=5,
        batch_size=500,
        standardize_adv=False,
        lr=7.058326426522811e-4,
Beispiel #2
0
 def gru_policy(env: Env):
     return GRUPolicy(env.spec, hidden_size=8, num_recurrent_layers=2)
Beispiel #3
0
 def gru_policy_cuda(env: Env):
     return GRUPolicy(env.spec, hidden_size=8, num_recurrent_layers=2, use_cuda=True)
Beispiel #4
0
    # Set seed if desired
    pyrado.set_seed(args.seed, verbose=True)

    # Environment
    env_hparams = dict(dt=1 / 250.0, max_steps=12 * 250, long=False)
    env = QCartPoleSwingUpSim(**env_hparams)
    env = ActNormWrapper(env)

    # Policy
    policy_hparam = dict(
        hidden_size=32,
        num_recurrent_layers=1,
        # init_param_kwargs=dict(t_max=50)
    )
    # policy = LSTMPolicy(spec=env.spec, **policy_hparam)
    policy = GRUPolicy(spec=env.spec, **policy_hparam)

    # Algorithm
    algo_hparam = dict(
        max_iter=5000,
        pop_size=50,
        num_init_states_per_domain=6,
        eta_mean=2.0,
        eta_std=None,
        expl_std_init=0.5,
        symm_sampling=False,
        transform_returns=True,
        num_workers=10,
    )
    algo = NES(ex_dir, env, policy, **algo_hparam)