Esempio n. 1
0
    acmodel = ACModel(preprocess_obss.obs_space, envs[0].action_space,
                      not args.no_instr, not args.no_mem)
    status = {"num_frames": 0, "update": 0}
    logger.info("Model successfully created\n")
logger.info("{}\n".format(acmodel))

if torch.cuda.is_available():
    acmodel.cuda()
logger.info("CUDA available: {}\n".format(torch.cuda.is_available()))

# Define actor-critic algo

if args.algo == "a2c":
    algo = torch_rl.A2CAlgo(envs, acmodel, args.frames_per_proc, args.discount,
                            args.lr, args.gae_lambda, args.entropy_coef,
                            args.value_loss_coef, args.max_grad_norm,
                            args.recurrence, args.optim_alpha, args.optim_eps,
                            preprocess_obss)
elif args.algo == "ppo":
    algo = torch_rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount,
                            args.lr, args.gae_lambda, args.entropy_coef,
                            args.value_loss_coef, args.max_grad_norm,
                            args.recurrence, args.optim_eps, args.clip_eps,
                            args.epochs, args.batch_size, preprocess_obss)
else:
    raise ValueError("Incorrect algorithm name: {}".format(args.algo))

# Train model

num_frames = status["num_frames"]
total_start_time = time.time()
Esempio n. 2
0
obss_preprocessor = utils.ObssPreprocessor(model_name, envs[0].observation_space)

# Define actor-critic model

acmodel = utils.load_model(model_name, raise_not_found=False)
if acmodel is None:
    acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space,
                      args.instr_model, not args.no_mem, args.arch)
if torch.cuda.is_available():
    acmodel.cuda()

# Define actor-critic algo

if args.algo == "a2c":
    algo = torch_rl.A2CAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau,
                            args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
                            args.optim_alpha, args.optim_eps, obss_preprocessor, utils.reshape_reward)
elif args.algo == "ppo":
    algo = torch_rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau,
                            args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
                            args.optim_eps, args.clip_eps, args.epochs, args.batch_size, obss_preprocessor,
                            utils.reshape_reward)
else:
    raise ValueError("Incorrect algorithm name: {}".format(args.algo))

# Define logger and Tensorboard writer

logger = utils.get_logger(model_name)
if args.tb:
    from tensorboardX import SummaryWriter