def main(): state_features = StateFeatures() env = BatchedStateEnv(create_batched_env(NUM_ENVS, augment=True), state_features=state_features) model = ACModel() prior = ACModel() discriminator = DiscriminatorModel() if os.path.exists('save.pkl'): model.load_state_dict(torch.load('save.pkl')) if os.path.exists('save_disc.pkl'): discriminator.load_state_dict(torch.load('save_disc.pkl')) else: discriminator.cnn.load_state_dict(model.cnn.state_dict()) if os.path.exists('save_prior.pkl'): prior.load_state_dict(torch.load('save_prior.pkl')) model.to(torch.device('cuda')) discriminator.to(torch.device('cuda')) prior.to(torch.device('cuda')) train, test = load_data(augment=True) recordings = train + test roller = LogRoller(env, model, HORIZON) ppo = Prierarchy(prior, model, gamma=GAE_GAMMA, lam=GAE_LAM, lr=LR, ent_reg=PRIOR_REG) gail = GAIL(discriminator, lr=LR) gail.outer_loop(ppo, roller, recordings, state_features, rew_scale=GAIL_REWARD_SCALE, real_rew_scale=REWARD_SCALE, disc_num_steps=HORIZON * NUM_ENVS // BATCH_SIZE, disc_batch_size=BATCH_SIZE, expert_batch=GAIL_NUM_ENVS, expert_horizon=GAIL_HORIZON, num_steps=ITERS, batch_size=BATCH_SIZE)
def main(): args = arg_parser().parse_args() env = BatchedStateEnv( create_batched_env(NUM_ENVS, augment=True, start=args.worker_idx, rand_floors=(args.min, args.max))) model = ACModel() prior = ACModel() if os.path.exists(args.path): model.load_state_dict(torch.load(args.path)) if os.path.exists('save_prior.pkl'): prior.load_state_dict(torch.load('save_prior.pkl')) model.to(torch.device('cuda')) prior.to(torch.device('cuda')) roller = LogRoller(env, model, HORIZON) ppo = Prierarchy(prior, model, gamma=GAE_GAMMA, lam=GAE_LAM, lr=LR, ent_reg=PRIOR_REG) ppo.outer_loop(roller, num_steps=ITERS, batch_size=BATCH_SIZE, save_path=args.path)
def main(): env = BatchedStateEnv(create_batched_env(NUM_ENVS)) model = ACModel() if os.path.exists('save.pkl'): model.load_state_dict(torch.load('save.pkl')) model.to(torch.device('cuda')) roller = Roller(env, model, HORIZON) ppo = PPO(model, gamma=GAE_GAMMA, lam=GAE_LAM, lr=LR, ent_reg=ENTROPY_REG) ppo.outer_loop(roller, num_steps=ITERS, batch_size=BATCH_SIZE)