Ejemplo n.º 1
0
def main():

    args = get_args()

    nn.set_default_context(
        get_extension_context(args.extension, device_id=args.device_id))

    from atari_utils import make_atari_deepmind
    env = make_atari_deepmind(args.gym_env, valid=True)
    print('Observation:', env.observation_space)
    print('Action:', env.action_space)
    obs_sampler = ObsSampler(args.num_frames)
    val_replay_memory = ReplayMemory(env.observation_space.shape,
                                     env.action_space.shape,
                                     max_memory=args.num_frames)

    # for one file
    explorer = GreedyExplorer(env.action_space.n,
                              use_nnp=True,
                              nnp_file=args.nnp,
                              name='qnet')
    validator = Validator(env,
                          val_replay_memory,
                          explorer,
                          obs_sampler,
                          num_episodes=30,
                          clip_episode_step=True,
                          render=not args.no_render)

    mean_reward = validator.step()
    with open(os.path.join(args.log_path, 'mean_reward.txt'), 'a') as f:
        print("{} {}".format(args.gym_env, str(mean_reward)), file=f)
Ejemplo n.º 2
0
def main():

    args = get_args()

    nn.set_default_context(
        get_extension_context(args.extension, device_id=args.device_id))

    if args.nnp is None:
        local_nnp_dir = os.path.join("asset", args.gym_env)
        local_nnp_file = os.path.join(local_nnp_dir, "qnet.nnp")

        if not find_local_nnp(args.gym_env):
            logger.info("Downloading nnp data since you didn't specify...")
            nnp_uri = os.path.join(
                "https://nnabla.org/pretrained-models/nnp_models/examples/dqn",
                args.gym_env, "qnet.nnp")
            if not os.path.exists(local_nnp_dir):
                os.mkdir(local_nnp_dir)
            download(nnp_uri, output_file=local_nnp_file, open_file=False)
            logger.info("Download done!")

        args.nnp = local_nnp_file

    from atari_utils import make_atari_deepmind
    env = make_atari_deepmind(args.gym_env, valid=False)
    print('Observation:', env.observation_space)
    print('Action:', env.action_space)
    obs_sampler = ObsSampler(args.num_frames)
    val_replay_memory = ReplayMemory(env.observation_space.shape,
                                     env.action_space.shape,
                                     max_memory=args.num_frames)
    # just play greedily
    explorer = GreedyExplorer(env.action_space.n,
                              use_nnp=True,
                              nnp_file=args.nnp,
                              name='qnet')
    validator = Validator(env,
                          val_replay_memory,
                          explorer,
                          obs_sampler,
                          num_episodes=1,
                          render=not args.no_render)
    while True:
        validator.step()
Ejemplo n.º 3
0
def main():

    args = get_args()

    device = torch.device('cuda',
                          index=args.device_id) if torch.cuda.is_available(
                          ) else torch.device('cpu')
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device_id)

    if args.log_path:
        output_path = OutputPath(args.log_path)
    else:
        output_path = OutputPath()


#    monitor = Monitor(output_path.path)

    tbw = SummaryWriter(output_path.path)

    # Create an atari env.
    from atari_utils import make_atari_deepmind
    env = make_atari_deepmind(args.gym_env, valid=False)
    env_val = make_atari_deepmind(args.gym_env, valid=True)
    print('Observation:', env.observation_space)
    print('Action:', env.action_space)

    # 10000 * 4 frames
    val_replay_memory = ReplayMemory(env.observation_space.shape,
                                     env.action_space.shape,
                                     max_memory=args.num_frames)
    replay_memory = ReplayMemory(env.observation_space.shape,
                                 env.action_space.shape,
                                 max_memory=40000)

    learner = QLearner(env.action_space.n,
                       device,
                       sync_freq=1000,
                       save_freq=250000,
                       gamma=0.99,
                       learning_rate=1e-4,
                       save_path=output_path)

    explorer = LinearDecayEGreedyExplorer(env.action_space.n,
                                          device,
                                          network=learner.get_network(),
                                          eps_start=1.0,
                                          eps_end=0.01,
                                          eps_steps=1e6)

    sampler = Sampler(args.num_frames)
    obs_sampler = ObsSampler(args.num_frames)

    validator = Validator(env_val,
                          val_replay_memory,
                          explorer,
                          obs_sampler,
                          num_episodes=args.num_val_episodes,
                          num_eval_steps=args.num_eval_steps,
                          render=args.render_val,
                          tbw=tbw)

    trainer_with_validator = Trainer(env,
                                     replay_memory,
                                     learner,
                                     sampler,
                                     explorer,
                                     obs_sampler,
                                     inter_eval_steps=args.inter_eval_steps,
                                     num_episodes=args.num_episodes,
                                     train_start=10000,
                                     batch_size=32,
                                     render=args.render_train,
                                     validator=validator,
                                     tbw=tbw)

    for e in range(args.num_epochs):
        trainer_with_validator.step()
Ejemplo n.º 4
0
def main():

    args = get_args()

    nn.set_default_context(
        get_extension_context(args.extension, device_id=args.device_id))

    if args.log_path:
        output_path = OutputPath(args.log_path)
    else:
        output_path = OutputPath()
    monitor = Monitor(output_path.path)

    tbw = SummaryWriter(output_path.path)

    # Create an atari env.
    from atari_utils import make_atari_deepmind
    env = make_atari_deepmind(args.gym_env, valid=False)
    env_val = make_atari_deepmind(args.gym_env, valid=True)
    print('Observation:', env.observation_space)
    print('Action:', env.action_space)

    # 10000 * 4 frames
    val_replay_memory = ReplayMemory(env.observation_space.shape,
                                     env.action_space.shape,
                                     max_memory=args.num_frames)
    replay_memory = ReplayMemory(env.observation_space.shape,
                                 env.action_space.shape,
                                 max_memory=40000)

    learner = QLearner(q_cnn,
                       env.action_space.n,
                       sync_freq=1000,
                       save_freq=250000,
                       gamma=0.99,
                       learning_rate=1e-4,
                       name_q='q',
                       save_path=output_path)

    explorer = LinearDecayEGreedyExplorer(env.action_space.n,
                                          eps_start=1.0,
                                          eps_end=0.01,
                                          eps_steps=1e6,
                                          q_builder=q_cnn,
                                          name='q')

    sampler = Sampler(args.num_frames)
    obs_sampler = ObsSampler(args.num_frames)

    validator = Validator(env_val,
                          val_replay_memory,
                          explorer,
                          obs_sampler,
                          num_episodes=args.num_val_episodes,
                          num_eval_steps=args.num_eval_steps,
                          render=args.render_val,
                          monitor=monitor,
                          tbw=tbw)

    trainer_with_validator = Trainer(env,
                                     replay_memory,
                                     learner,
                                     sampler,
                                     explorer,
                                     obs_sampler,
                                     inter_eval_steps=args.inter_eval_steps,
                                     num_episodes=args.num_episodes,
                                     train_start=10000,
                                     batch_size=32,
                                     render=args.render_train,
                                     validator=validator,
                                     monitor=monitor,
                                     tbw=tbw)

    for e in range(args.num_epochs):
        trainer_with_validator.step()