コード例 #1
0
ファイル: hoof_no_uvfa_npg.py プロジェクト: supratikp/HOOF
def learn_hoof_no_lambgam(env,
                          env_type,
                          timesteps_per_batch,
                          total_timesteps,
                          kl_range,
                          gamma_range,
                          lam_range,
                          num_kl=25,
                          num_gamma_lam=20,
                          **network_kwargs):
    params = defaults.mujoco()

    if gamma_range is 'fixed' and lam_range is 'fixed':
        num_gamma_lam = 1
    if kl_range is 'fixed':
        num_kl = 1

    run_hoof_no_lamgam(
        network=params['network'],
        env=env,
        total_timesteps=total_timesteps,
        timesteps_per_batch=int(timesteps_per_batch / env.num_envs),
        kl_range=kl_range if kl_range is not 'fixed' else params['max_kl'],
        gamma_range=gamma_range
        if gamma_range is not 'fixed' else params['gamma'],
        lam_range=lam_range if lam_range is not 'fixed' else params['lam'],
        num_kl=num_kl,
        num_gamma_lam=num_gamma_lam,
        cg_iters=params['cg_iters'],
        seed=None,
        cg_damping=params['cg_damping'],
        vf_stepsize=params['vf_stepsize'],
        vf_iters=params['vf_iters'],
        normalize_observations=params['normalize_observations'])
コード例 #2
0
def learn_npg_variant(algo, env, env_type, timesteps_per_batch, total_timesteps, **network_kwargs):
    params = defaults.mujoco()
    params['network'] = mlp(num_hidden=64, num_layers=2)
    run_pg(algo, 
            network=params['network'],
            env=env,
            total_timesteps=total_timesteps,
            timesteps_per_batch=int(timesteps_per_batch/env.num_envs),
            max_kl=params['max_kl'],
            cg_iters=params['cg_iters'],
            gamma=params['gamma'],
            lam=params['lam'],
            seed=None,
            cg_damping=params['cg_damping'],
            vf_stepsize=params['vf_stepsize'],
            vf_iters=params['vf_iters'],
            normalize_observations=params['normalize_observations']
            )