示例#1
0
# Load model
try:
    policy_net = utils.load_model(model_dir)
    target_net = DQNModel(env.action_space, env=args.env)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    print("Model successfully loaded\n")
except OSError:
    policy_net = DQNModel(env.action_space, env=args.env)
    target_net = DQNModel(env.action_space, env=args.env)
    target_net.load_state_dict(policy_net.state_dict())
    print("Model successfully created\n")

if torch.cuda.is_available():
    policy_net.cuda()
    target_net.cuda()
    target_net.eval()
print("CUDA available: {}\n".format(torch.cuda.is_available()))

# Init Algorithm
algo = torch_rl.DQNAlgo_new(env,
                            policy_net,
                            target_net,
                            args.frames,
                            args.discount,
                            args.lr,
                            args.optim_eps,
                            args.batch_size,
                            preprocess_obss,
                            record_qvals=args.debug)
示例#2
0
try:
    base_model = utils.load_model(model_dir)
    logger.info("Model successfully loaded\n")
except OSError:
    if args.algo == "dqn":
        base_model = DQNModel(obs_space, envs[0].action_space, args.mem,
                              args.text)
    else:
        base_model = ACModel(obs_space, envs[0].action_space, args.mem,
                             args.text)
    logger.info("Model successfully created\n")
logger.info("{}\n".format(base_model))

if torch.cuda.is_available():
    base_model.cuda()
logger.info("CUDA available: {}\n".format(torch.cuda.is_available()))

# Train model

num_frames = status["num_frames"]
total_start_time = time.time()
update = status["update"]
best_val = 0

if args.algo == "a2c":
    algo = torch_rl.A2CAlgo(envs, base_model, args.frames_per_proc,
                            args.discount, args.lr, args.gae_lambda,
                            args.entropy_coef, args.value_loss_coef,
                            args.max_grad_norm, args.recurrence,
                            args.optim_alpha, args.optim_eps, preprocess_obss)