コード例 #1
0
def main():
    args = parser.parse_args()

    env = gym.make(args.env_name, scenario=args.scenario)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_global_seed(args.seed)
    log_dir, log_files = create_log_files(args, env.num_resources())

    # Learn the Option Keyboard
    env.set_learning_options(np.array([1, 1]), True)
    Q_E = learn_options(env=env,
                        d=env.num_resources(),
                        eps1=args.eps1_ok,
                        eps2=args.eps2_ok,
                        alpha=args.alpha_ok,
                        gamma=args.gamma_ok,
                        max_ep_steps=args.max_steps_ok,
                        device=device,
                        training_steps=args.n_training_steps_ok,
                        batch_size=args.ok_batch_size,
                        pretrained_options=args.pretrained_options,
                        test_interval=args.test_interval_option,
                        n_test_runs=args.n_test_runs,
                        log_files=log_files,
                        log_dir=log_dir)
    env.set_learning_options(np.array([1, 1]), False)

    for i in range(env.num_resources()):
        if args.n_training_steps_ok == 0:
            checkpoint = torch.load(
                os.path.join(args.pretrained_options,
                             'value_fn_%d.pt' % (i + 1)))
        else:
            checkpoint = torch.load(
                os.path.join(log_dir, 'saved_models', 'best',
                             'value_fn_%d.pt' % (i + 1)))
        Q_E[i].q_net.load_state_dict(checkpoint['Q'])

    W = [x for x in product([-1, 0, 1], repeat=2) if sum(x) >= 0]
    W.remove((0, 0))
    W = np.array(W)

    # Learn the agent
    Q_w = keyboard_player(env=env,
                          W=W,
                          Q=Q_E,
                          alpha=args.alpha_agent,
                          eps=args.eps_agent,
                          gamma=args.gamma_agent,
                          training_steps=args.n_training_steps_agent,
                          batch_size=args.agent_batch_size,
                          pretrained_agent=args.pretrained_agent,
                          max_ep_steps=args.max_steps_agent,
                          device=device,
                          test_interval=args.test_interval_agent,
                          n_test_runs=args.n_test_runs,
                          log_file=log_files['agent'],
                          log_dir=log_dir)
コード例 #2
0
def main():
    args = parser.parse_args()
    env = gym.make(args.env_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_global_seed(args.seed)

    Q = MlpDiscrete(input_dim=env.observation_space.shape[0],
                    output_dim=env.action_space.n,
                    hidden=[64, 128])

    if not torch.cuda.is_available():
        checkpoint = torch.load(os.path.join(args.saved_models, 'agent.pt'),
                                map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(os.path.join(args.saved_models, 'agent.pt'))
    Q.load_state_dict(checkpoint['Q'])
    Q.to(device)

    if args.save_path:
        fp = open(args.save_path, 'wb')

    returns = []
    for _ in range(args.n_test_episodes):
        s = env.reset()
        s = torch.from_numpy(s).float().to(device)
        done = False
        ep_return = 0
        if args.visualize:
            env.render()

        while not done:
            q = Q(s)
            a = torch.argmax(q)
            s, r, done, _ = env.step(a)
            ep_return += r
            s = torch.from_numpy(s).float().to(device)
            if args.visualize:
                env.render()

        print('Episodic return:', ep_return)
        returns.append(ep_return)

    returns = np.array(returns)
    print('Mean: %f, Std. dev: %f' % (returns.mean(), returns.std()))
    if args.save_path:
        pickle.dump({'Seed': args.seed, 'Returns': returns}, fp)
コード例 #3
0
ファイル: main.py プロジェクト: shawnsarwar/option-keyboard
def main():
    args = parser.parse_args()

    env = gym.make(args.env_name, scenario=args.scenario)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    set_global_seed(args.seed)

    log_dir, log_files = create_log_files(args, 0)

    dqn(env=env,
        eps=args.eps,
        gamma=args.gamma,
        alpha=args.alpha,
        device=device,
        training_steps=args.n_training_steps,
        batch_size=args.batch_size,
        pretrained_agent=args.pretrained_agent,
        test_interval=args.test_interval,
        n_test_runs=args.n_test_runs,
        log_file=log_files['agent'],
        log_dir=log_dir)
コード例 #4
0
def main():
    args = parser.parse_args()
    env = gym.make(args.env_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_global_seed(args.seed)

    d = env.num_resources()

    hyperparams_file = open(
        os.path.join(
            args.saved_models.split('saved_models')[0], 'hyperparams'), 'rb')

    # Loading saved models and constant values
    returns = []
    if args.save_path:
        fp = open(args.save_path, 'a+b')

    W = [x for x in product([-1, 0, 1], repeat=2) if sum(x) >= 0]
    W.remove((0, 0))
    W = np.array(W)

    hyperparams = pickle.load(hyperparams_file)
    gamma = hyperparams.gamma_ok
    max_ep_steps = hyperparams.max_steps_agent

    value_fns = [
        ValueFunction(input_dim=env.observation_space.shape[0] + d,
                      action_dim=(env.action_space.n + 1),
                      n_options=d,
                      hidden=[64, 128],
                      batch_size=hyperparams.ok_batch_size,
                      gamma=gamma,
                      alpha=hyperparams.alpha_ok) for _ in range(d)
    ]

    Q_w = MlpDiscrete(input_dim=env.observation_space.shape[0],
                      output_dim=W.shape[0],
                      hidden=[64, 128])

    for i in range(env.num_resources()):
        if not torch.cuda.is_available():
            checkpoint = torch.load(os.path.join(args.saved_models,
                                                 'value_fn_%d.pt' % (i + 1)),
                                    map_location=torch.device('cpu'))
        else:
            checkpoint = torch.load(
                os.path.join(args.saved_models, 'value_fn_%d.pt' % (i + 1)))

        value_fns[i].q_net.load_state_dict(checkpoint['Q'])
        value_fns[i].q_net.to(device)

    if not torch.cuda.is_available():
        checkpoint = torch.load(os.path.join(args.saved_models, 'agent.pt'),
                                map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(os.path.join(args.saved_models, 'agent.pt'))

    Q_w.load_state_dict(checkpoint['Q'])
    Q_w.to(device)
    # ########

    for _ in range(args.n_test_episodes):
        s = env.reset()
        done = False
        s = torch.from_numpy(s).float().to(device)
        n_steps = 0
        ret = 0

        while not done:
            w = W[torch.argmax(Q_w(s))]
            (s_next, done, _, _, n_steps,
             info) = option_keyboard(env, s, w, value_fns, gamma, n_steps,
                                     max_ep_steps, device, args.visualize)

            ret += sum(info['rewards'])
            s = torch.from_numpy(s_next).float().to(device)

        print('Episode return:', ret)
        returns.append(ret)

    returns = np.array(returns)
    print('Mean: %f, Std. dev: %f' % (returns.mean(), returns.std()))
    pickle.dump({'Seed': args.seed, 'Returns': returns}, fp)
    fp.close()