acmodel = ACModel(preprocess_obss.obs_space, envs[0].action_space, not args.no_instr, not args.no_mem) status = {"num_frames": 0, "update": 0} logger.info("Model successfully created\n") logger.info("{}\n".format(acmodel)) if torch.cuda.is_available(): acmodel.cuda() logger.info("CUDA available: {}\n".format(torch.cuda.is_available())) # Define actor-critic algo if args.algo == "a2c": algo = torch_rl.A2CAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, args.optim_alpha, args.optim_eps, preprocess_obss) elif args.algo == "ppo": algo = torch_rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, args.optim_eps, args.clip_eps, args.epochs, args.batch_size, preprocess_obss) else: raise ValueError("Incorrect algorithm name: {}".format(args.algo)) # Train model num_frames = status["num_frames"] total_start_time = time.time()
obss_preprocessor = utils.ObssPreprocessor(model_name, envs[0].observation_space) # Define actor-critic model acmodel = utils.load_model(model_name, raise_not_found=False) if acmodel is None: acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space, args.instr_model, not args.no_mem, args.arch) if torch.cuda.is_available(): acmodel.cuda() # Define actor-critic algo if args.algo == "a2c": algo = torch_rl.A2CAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau, args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, args.optim_alpha, args.optim_eps, obss_preprocessor, utils.reshape_reward) elif args.algo == "ppo": algo = torch_rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau, args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, args.optim_eps, args.clip_eps, args.epochs, args.batch_size, obss_preprocessor, utils.reshape_reward) else: raise ValueError("Incorrect algorithm name: {}".format(args.algo)) # Define logger and Tensorboard writer logger = utils.get_logger(model_name) if args.tb: from tensorboardX import SummaryWriter