コード例 #1
0
ファイル: run_gail.py プロジェクト: wwxFromTju/obs-tower2
def main():
    state_features = StateFeatures()
    env = BatchedStateEnv(create_batched_env(NUM_ENVS, augment=True),
                          state_features=state_features)
    model = ACModel()
    prior = ACModel()
    discriminator = DiscriminatorModel()
    if os.path.exists('save.pkl'):
        model.load_state_dict(torch.load('save.pkl'))
    if os.path.exists('save_disc.pkl'):
        discriminator.load_state_dict(torch.load('save_disc.pkl'))
    else:
        discriminator.cnn.load_state_dict(model.cnn.state_dict())
    if os.path.exists('save_prior.pkl'):
        prior.load_state_dict(torch.load('save_prior.pkl'))
    model.to(torch.device('cuda'))
    discriminator.to(torch.device('cuda'))
    prior.to(torch.device('cuda'))
    train, test = load_data(augment=True)
    recordings = train + test
    roller = LogRoller(env, model, HORIZON)
    ppo = Prierarchy(prior, model, gamma=GAE_GAMMA, lam=GAE_LAM, lr=LR, ent_reg=PRIOR_REG)
    gail = GAIL(discriminator, lr=LR)
    gail.outer_loop(ppo,
                    roller,
                    recordings,
                    state_features,
                    rew_scale=GAIL_REWARD_SCALE,
                    real_rew_scale=REWARD_SCALE,
                    disc_num_steps=HORIZON * NUM_ENVS // BATCH_SIZE,
                    disc_batch_size=BATCH_SIZE,
                    expert_batch=GAIL_NUM_ENVS,
                    expert_horizon=GAIL_HORIZON,
                    num_steps=ITERS,
                    batch_size=BATCH_SIZE)
コード例 #2
0
def main():
    args = arg_parser().parse_args()
    env = BatchedStateEnv(
        create_batched_env(NUM_ENVS,
                           augment=True,
                           start=args.worker_idx,
                           rand_floors=(args.min, args.max)))
    model = ACModel()
    prior = ACModel()
    if os.path.exists(args.path):
        model.load_state_dict(torch.load(args.path))
    if os.path.exists('save_prior.pkl'):
        prior.load_state_dict(torch.load('save_prior.pkl'))
    model.to(torch.device('cuda'))
    prior.to(torch.device('cuda'))
    roller = LogRoller(env, model, HORIZON)
    ppo = Prierarchy(prior,
                     model,
                     gamma=GAE_GAMMA,
                     lam=GAE_LAM,
                     lr=LR,
                     ent_reg=PRIOR_REG)
    ppo.outer_loop(roller,
                   num_steps=ITERS,
                   batch_size=BATCH_SIZE,
                   save_path=args.path)
コード例 #3
0
ファイル: run_ppo.py プロジェクト: unixpickle/obs-tower2
def main():
    env = BatchedStateEnv(create_batched_env(NUM_ENVS))
    model = ACModel()
    if os.path.exists('save.pkl'):
        model.load_state_dict(torch.load('save.pkl'))
    model.to(torch.device('cuda'))
    roller = Roller(env, model, HORIZON)
    ppo = PPO(model, gamma=GAE_GAMMA, lam=GAE_LAM, lr=LR, ent_reg=ENTROPY_REG)
    ppo.outer_loop(roller, num_steps=ITERS, batch_size=BATCH_SIZE)