Exemple #1
0
def play_with_cb(env_name, fps=60, zoom=3, cb=None):
    def _cb(*args):
        print('a:', args[2])
        print('r:', args[3]) if args[3] != 0 else None
    cb = _cb if cb is None else cb
    env = gym.make(env_name)
    play(env, fps=fps, zoom=zoom, callback=cb)
Exemple #2
0
def main(args):
    assert BATCH_SIZE <= TRAIN_START <= REPLAY_BUFFER_SIZE
    assert TARGET_UPDATE_EVERY % UPDATE_EVERY == 0
    assert 84 % SIDE_BOXES == 0
    assert STRATEGY in ['final', 'future']
    print(args)
    env = make_atari('{}NoFrameskip-v4'.format(args.env))
    set_seed(env, args.seed)
    env_train = wrap_deepmind(env,
                              frame_stack=True,
                              episode_life=True,
                              clip_rewards=True)
    if args.weights:
        model = load_or_create_model(env_train, args.model)
        print_weights(model)
    elif args.debug:
        env, model, target_model, batch = load_for_debug()
        fit_batch(env, model, target_model, batch)
    elif args.play:
        env = wrap_deepmind(env)
        play(env)
    else:
        env_eval = wrap_deepmind(env, frame_stack=True)
        model = load_or_create_model(env_train, args.model)
        if args.view or args.images or args.eval:
            evaluate(env_eval, model, args.view, args.images)
        else:
            max_steps = 100 if args.test else MAX_STEPS
            train(env_train, env_eval, model, max_steps, args.name)
            if args.test:
                filename = save_model(model,
                                      EVAL_STEPS,
                                      logdir='.',
                                      name='test')
                load_or_create_model(env_train, filename)
Exemple #3
0
def main():
    # Define reward callback
    timestep = 0

    def callback(obs_t, obs_tp1, action, rew, done, info):
        nonlocal timestep
        timestep += 1
        data.append(rew)

    # Initialize data structures for plotting
    fig, ax = plt.subplots(1)
    ax.set_title("Reward over time")
    horizon_timesteps = int(30 * args.timeout)
    data = deque(maxlen=horizon_timesteps)
    if not args.plot_rewards:
        callback = None

    # Initialize game timer
    t = Timer(args.timeout, lambda: pygame.quit())
    t.start()

    # Run main game loop
    try:
        env = gym.make(args.env_name)
        play.play(env, fps=30, zoom=args.zoom_level, callback=callback)
    except pygame.error:
        pass

    # Plot rewards over time before quitting
    xmin, xmax = max(0, timestep - horizon_timesteps), timestep
    ax.scatter(range(xmin, xmax), list(data), c='blue')
    ax.set_xlim(xmin, xmax)
    plt.show()
def main():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('--environment', type=str, default='SunblazeBreakout-v0')
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    episode = {'reward': 0, 'initial': True}
    env = gym.make(args.environment)
    env.seed(args.seed)

    def reporter(obs_t, obs_tp1, action, reward, done, info):
        if episode['initial']:
            episode['initial'] = False
            print('Environment parameters:')
            for key in sorted(env.unwrapped.parameters.keys()):
                print('  {}: {}'.format(key, env.unwrapped.parameters[key]))

        episode['reward'] += reward
        if reward != 0:
            print('Reward:', episode['reward'])

        if done:
            print('*** GAME OVER ***')
            episode['reward'] = 0
            episode['initial'] = True

    play(env, callback=reporter)
Exemple #5
0
def main(_):
  # gym.logger.set_level(gym.logger.DEBUG)
  hparams = registry.hparams(FLAGS.loop_hparams_set)
  hparams.parse(FLAGS.loop_hparams)
  # Not important for experiments past 2018
  if "wm_policy_param_sharing" not in hparams.values().keys():
    hparams.add_hparam("wm_policy_param_sharing", False)
  directories = player_utils.infer_paths(
      output_dir=FLAGS.output_dir,
      world_model=FLAGS.wm_dir,
      policy=FLAGS.policy_dir,
      data=FLAGS.episodes_data_dir)
  if FLAGS.game_from_filenames:
    hparams.set_hparam(
        "game", player_utils.infer_game_name_from_filenames(directories["data"])
    )
  action_meanings = gym.make(full_game_name(hparams.game)).\
      unwrapped.get_action_meanings()
  epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

  def make_real_env():
    env = player_utils.setup_and_load_epoch(
        hparams, data_dir=directories["data"],
        which_epoch_data=None)
    env = FlatBatchEnv(env)  # pylint: disable=redefined-variable-type
    return env

  def make_simulated_env(setable_initial_frames, which_epoch_data):
    env = player_utils.load_data_and_make_simulated_env(
        directories["data"], directories["world_model"],
        hparams, which_epoch_data=which_epoch_data,
        setable_initial_frames=setable_initial_frames)
    return env

  if FLAGS.sim_and_real:
    sim_env = make_simulated_env(
        which_epoch_data=None, setable_initial_frames=True)
    real_env = make_real_env()
    env = SimAndRealEnvPlayer(real_env, sim_env, action_meanings)
  else:
    if FLAGS.simulated_env:
      env = make_simulated_env(  # pylint: disable=redefined-variable-type
          which_epoch_data=epoch, setable_initial_frames=False)
    else:
      env = make_real_env()
    env = SingleEnvPlayer(env, action_meanings)  # pylint: disable=redefined-variable-type

  env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

  if FLAGS.dry_run:
    env.unwrapped.get_keys_to_action()
    for _ in range(5):
      env.reset()
      for i in range(50):
        env.step(i % 3)
      env.step(PlayerEnv.RETURN_DONE_ACTION)  # reset
    return

  play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
 def run(self):
     env = gym.make(game_name)
     p.play(env, zoom=2, fps=fps_rate, callback=self.callbackFunction)
     now = datetime.datetime.now()
     filename = filePath + head + "_Score_" + str(int(self.score)) + ".data"
     f = open(filename, "wb")
     pickle.dump(self.game_data, f)
     f.close()
    def test_human(self, fps=60):
        from gym.utils.play import play  # Import here since it requires pygame and it is incompatible with python 3.8
        """This doesn't work if the environment has been wrapped with image preprocessing"""
        def callback_label(obs_t, obs_tp1, action, reward, done, info):
            if self.label(obs_tp1, reward, done, info):
                print('Property violated!')
            if done:
                self.reset()

        play(self.env, callback=callback_label, zoom=4, fps=fps)
def generate_expert_data(game_ref: int, fps: int, num_demos: int,
                         demo_start: int):
    save_path = Path(f"../human_demos/{GAME_STRINGS_LEARN[game_ref]}")
    env = gym.make(GAME_STRINGS_PLAY[game_ref])
    play_buffer = PlayBuffer(
        save_path,
        state_dimension=env.observation_space.shape,
        action_space_size=env.action_space.n,
    )
    for demo in range(demo_start, num_demos + demo_start):
        play(env, callback=play_buffer.update_play, zoom=5, fps=fps)
        play_buffer.save_demos(demo_number=demo)
        play_buffer.clear()
Exemple #9
0
 def play_pong(self, wrap_fn):
     """
     Manual check of full set of preprocessing steps for Pong.
     Not run as port of normal unit tests; run me with
       ./preprocessing_test.py TestPreprocessing.play_pong_generic_wrap
       ./preprocessing_test.py TestPreprocessing.play_pong_special_wrap
     """
     from gym.utils import play as gym_play
     env = gym.make('PongNoFrameskip-v4')
     env = NumberFrames(env)
     env = wrap_fn(env, max_n_noops=0)
     env = ConcatFrameStack(env)
     gym_play.play(env, fps=15, zoom=4)
Exemple #10
0
    def play_game(self,
                  environment_class,
                  cheat_mode=DEFAULT_CHEAT_MODE,
                  debug=DEFAULT_DEBUG,
                  fps=30):
        """
        Interactively play an environment.
    
        Parameters
        ----------
        environment_class : type
            A subclass of schema_games.breakout.core.BreakoutEngine that represents
            a game. A convenient list is included in schema_games.breakout.games.
        cheat_mode : bool
            If True, player has an infinite amount of lives.
        debug : bool
            If True, print debugging messages and perform additional sanity checks.
        fps : int
            Frame rate per second at which to display the game.
        """
        print blue("-" * 80)
        print blue("Starting interactive game. "
                   "Press <ESC> at any moment to terminate.")
        print blue("-" * 80)

        env_args = {
            'return_state_as_image': True,
            'debugging': debug,
        }

        if cheat_mode:
            env_args['num_lives'] = np.PINF

        env = environment_class(**env_args)
        keys_to_action = defaultdict(lambda: env.NOOP, {
            (pygame.K_LEFT, ): env.LEFT,
            (pygame.K_RIGHT, ): env.RIGHT,
        })

        def callback(prev_obs, obs, action, rew, done, info):
            if self.recommender is not None:
                self.recommender.get_observation(obs)
            return None
            # print("reward is %.2f" % (rew))
            # return [rew, ]

        play(env,
             fps=fps,
             keys_to_action=keys_to_action,
             zoom=ZOOM_FACTOR,
             callback=callback)
Exemple #11
0
def test_play_loop_real_env():
    SEED = 42
    ENV = "CartPole-v1"

    # set of key events to inject into the play loop as callback
    callback_events = [
        Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
        Event(KEYUP, {"key": RELEVANT_KEY_1}),
        Event(KEYDOWN, {"key": RELEVANT_KEY_2}),
        Event(KEYUP, {"key": RELEVANT_KEY_2}),
        Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
        Event(KEYUP, {"key": RELEVANT_KEY_1}),
        Event(KEYDOWN, {"key": RELEVANT_KEY_1}),
        Event(KEYUP, {"key": RELEVANT_KEY_1}),
        Event(KEYDOWN, {"key": RELEVANT_KEY_2}),
        Event(KEYUP, {"key": RELEVANT_KEY_2}),
        Event(QUIT),
    ]
    keydown_events = [k for k in callback_events if k.type == KEYDOWN]

    def callback(obs_t, obs_tp1, action, rew, done, info):
        pygame_event = callback_events.pop(0)
        event.post(pygame_event)

        # after releasing a key, post new events until
        # we have one keydown
        while pygame_event.type == KEYUP:
            pygame_event = callback_events.pop(0)
            event.post(pygame_event)

        return obs_t, obs_tp1, action, rew, done, info

    env = gym.make(ENV)
    env.reset(seed=SEED)
    keys_to_action = dummy_keys_to_action()

    # first action is 0 because at the first iteration
    # we can not inject a callback event into play()
    env.step(0)
    for e in keydown_events:
        action = keys_to_action[(e.key, )]
        obs, _, _, _ = env.step(action)

    env_play = gym.make(ENV)
    status = PlayStatus(callback)
    play(env_play,
         callback=status.callback,
         keys_to_action=keys_to_action,
         seed=SEED)

    assert (status.last_observation == obs).all()
Exemple #12
0
def play_game(environment_class,
              cheat_mode=DEFAULT_CHEAT_MODE,
              debug=DEFAULT_DEBUG,
              fps=30):
    """
    Interactively play an environment.

    Parameters
    ----------
    environment_class : type
        A subclass of schema_games.breakout.core.BreakoutEngine that represents
        a game. A convenient list is included in schema_games.breakout.games.
    cheat_mode : bool
        If True, player has an infinite amount of lives.
    debug : bool
        If True, print debugging messages and perform additional sanity checks.
    fps : int
        Frame rate per second at which to display the game.
    """
    print(blue("-" * 80))
    print(
        blue("Starting interactive game. "
             "Press <ESC> at any moment to terminate."))
    print(blue("-" * 80))

    env_args = {
        'return_state_as_image': True,
        'debugging': debug,
    }

    if cheat_mode:
        env_args['num_lives'] = np.PINF

    env = environment_class(**env_args)
    keys_to_action = defaultdict(lambda: env.NOOP, {
        (pygame.K_LEFT, ): env.LEFT,
        (pygame.K_RIGHT, ): env.RIGHT,
    })

    play(env, fps=fps, keys_to_action=keys_to_action, zoom=ZOOM_FACTOR)
Exemple #13
0
def main(_):
    # gym.logger.set_level(gym.logger.DEBUG)
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    # Not important for experiments past 2018
    if "wm_policy_param_sharing" not in hparams.values().keys():
        hparams.add_hparam("wm_policy_param_sharing", False)
    directories = player_utils.infer_paths(output_dir=FLAGS.output_dir,
                                           world_model=FLAGS.wm_dir,
                                           policy=FLAGS.policy_dir,
                                           data=FLAGS.episodes_data_dir)
    epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

    if FLAGS.simulated_env:
        env = player_utils.load_data_and_make_simulated_env(
            directories["data"],
            directories["world_model"],
            hparams,
            which_epoch_data=epoch)
    else:
        env = player_utils.setup_and_load_epoch(hparams,
                                                data_dir=directories["data"],
                                                which_epoch_data=epoch)
        env = FlatBatchEnv(env)

    env = PlayerEnvWrapper(env)  # pylint: disable=redefined-variable-type

    env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

    if FLAGS.dry_run:
        for _ in range(5):
            env.reset()
            for i in range(50):
                env.step(i % 3)
            env.step(PlayerEnvWrapper.RESET_ACTION)  # reset
        return

    play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
Exemple #14
0
def play_all_atari_envs(n_episodes=10, v=4, deterministic=False, noframeskip=True, fps=60, zoom=3, callback=None):
    for game in ['alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis',
        'bank_heist', 'battle_zone', 'beam_rider', 'berzerk', 'bowling', 'boxing', 'breakout', 'carnival',
        'centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk',
        'elevator_action', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar',
        'hero', 'ice_hockey', 'jamesbond', 'journey_escape', 'kangaroo', 'krull', 'kung_fu_master',
        'montezuma_revenge', 'ms_pacman', 'name_this_game', 'phoenix', 'pitfall', 'pong', 'pooyan',
        'private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing',
        'solaris', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down',
        'venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge', 'zaxxon']:
        name = ''.join([g.capitalize() for g in game.split('_')])

        if deterministic:
            name += 'Deterministic'
        elif noframeskip:
            name += 'NoFrameskip'

        name += "-v{}".format(v)
        env = gym.make(name)

        for i in range(n_episodes):
            print(name, "episode {}/{}".format(i + 1, n_episodes))
            play(env, fps=fps, zoom=zoom, callback=callback)
def main(_):
  hparams = registry.hparams(FLAGS.loop_hparams_set)
  hparams.parse(FLAGS.loop_hparams)
  output_dir = FLAGS.output_dir

  subdirectories = ["data", "tmp", "world_model", "ppo"]
  using_autoencoder = hparams.autoencoder_train_steps > 0
  if using_autoencoder:
    subdirectories.append("autoencoder")
  directories = setup_directories(output_dir, subdirectories)

  if hparams.game in gym_env.ATARI_GAMES:
    game_with_mode = hparams.game + "_deterministic-v4"
  else:
    game_with_mode = hparams.game

  if using_autoencoder:
    simulated_problem_name = (
        "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded"
        % game_with_mode)
  else:
    simulated_problem_name = ("gym_simulated_discrete_problem_with_agent_on_%s"
                              % game_with_mode)
    if simulated_problem_name not in registry.list_problems():
      tf.logging.info("Game Problem %s not found; dynamically registering",
                      simulated_problem_name)
      gym_env.register_game(hparams.game, game_mode="Deterministic-v4")

  epoch = hparams.epochs-1
  epoch_data_dir = os.path.join(directories["data"], str(epoch))
  ppo_model_dir = directories["ppo"]

  world_model_dir = directories["world_model"]

  gym_problem = registry.problem(simulated_problem_name)

  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  environment_spec = copy.copy(gym_problem.environment_spec)
  environment_spec.simulation_random_starts = hparams.simulation_random_starts

  batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
  batch_env_hparams.add_hparam("model_hparams", model_hparams)
  batch_env_hparams.add_hparam("environment_spec", environment_spec)
  batch_env_hparams.num_agents = 1

  with temporary_flags({
      "problem": simulated_problem_name,
      "model": hparams.generative_model,
      "hparams_set": hparams.generative_model_params,
      "output_dir": world_model_dir,
      "data_dir": epoch_data_dir,
  }):
    sess = tf.Session()
    env = DebugBatchEnv(batch_env_hparams, sess)
    sess.run(tf.global_variables_initializer())
    env.initialize()

    env_model_loader = tf.train.Saver(
        tf.global_variables("next_frame*"))
    trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess,
                                   must_restore=True)

    model_saver = tf.train.Saver(
        tf.global_variables(".*network_parameters.*"))
    trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

    key_mapping = gym_problem.env.env.get_keys_to_action()
    # map special codes
    key_mapping[()] = 100
    key_mapping[(ord("r"),)] = 101
    key_mapping[(ord("p"),)] = 102

    play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
Exemple #16
0
    args = parser.parse_args()

    env = gym.make(args.env)
    shape2d = env.observation_space.shape[:2]

    if args.imgsource:
        if args.imgsource == "color":
            imgsource = RandomColorSource(shape2d)
        elif args.imgsource == "noise":
            imgsource = NoiseSource(shape2d)
        else:
            files = glob.glob(os.path.expanduser(args.resource_files))
            assert len(files), "Pattern {} does not match any files".format(
                args.resource_files
            )
            if args.imgsource == "images":
                imgsource = RandomImageSource(shape2d, files)
            else:
                imgsource = RandomVideoSource(shape2d, files)

        wrapped_env = ReplaceBackgroundEnv(
            env, BackgroundMattingWithColor((0, 0, 0)), imgsource
        )
    else:
        wrapped_env = env

    if args.dump_video:
        assert os.path.isdir(args.dump_video)
        wrapped_env = gym.wrappers.Monitor(wrapped_env, args.dump_video)
    play.play(wrapped_env, zoom=4)
Exemple #17
0
    def get_action(self, state):
        # Get action from already trained network
        return np.argmax(self.value_function(state)[0])


if __name__ == "__main__":

    env = gym.make("CartPole-v1")

    if USER_PLAY:
        env.reset()
        key_action_map = {
            (ord('a'), ): 0,
            (ord('d'), ): 1
        }  # Map 'a' and 'd' to going left and right
        play.play(env, keys_to_action=key_action_map,
                  fps=15)  # Lower fps cause its too hard otherwise
        sys.exit()

    # 0: Left, 1: Right
    action_size = env.action_space.n
    # 0: Cart position, 1: Cart Velocity, 2: Pole Angle, 3: Pole Angular Velocity
    state_size = env.observation_space.shape[0]

    agent = QAgent(action_size, state_size)
    n_episodes = 40

    for e in range(n_episodes):
        state = env.reset()
        state = np.reshape(state, [1, state_size])

        done = False
Exemple #18
0
        left,
        B,
    )): 6,
    sorted_tuple((
        left,
        A,
        B,
    )): 7,
    sorted_tuple((
        right,
        A,
    )): 8,
    sorted_tuple((
        right,
        B,
    )): 9,
    sorted_tuple((
        right,
        A,
        B,
    )): 10,
    (A, ): 11,
    (B, ): 12,
    sorted_tuple((A, B)): 13
}

# Create the environment and play the game
env = build_nes_environment()

play(env, keys_to_action=keys_to_action, callback=callback)
Exemple #19
0
def test_prepro():
    env = EnvWrapper(gym.make('Pong-v0'), pool=False, frameskip=1)
    play(env)
Exemple #20
0
 def play(self, fps=30, **kwargs):
     from gym.utils.play import play
     return play(env=self, fps=fps, **kwargs)
Exemple #21
0
def records_game_played_by_hand():  # saves the history of a game
    env = gym.make('Pong-v4')
    env.reset()
    play(env, zoom=3, fps=12, callback=saveFrameGame)
    env.close()
Exemple #22
0
import gym
from gym.utils.play import play

if __name__ == "__main__":

    play(gym.make('BreakoutNoFrameskip-v4'), zoom=3)
Exemple #23
0

def callback(env):
    render_window = RenderWindow(800, 800)

    def _callback(prev_obs, obs, action, rew, env_done, info):
        render_window.render(env.render(observer='global', mode="rgb_array").swapaxes(0,2))
        if rew != 0:
            print(f'Reward: {rew}')
        if env_done:
            print(f'Done!')

    return _callback


if __name__ == '__main__':
    wrapper = GymWrapperFactory()

    environment_name = 'GVGAI/bait_keys'
    # environment_name = 'Mini-Grid/minigrid-drunkdwarf'
    # environment_name = 'Mini-Grid/minigrid-spiders'
    # environment_name = 'GVGAI/clusters'
    # environment_name = 'GVGAI/labyrinth_partially_observable'
    level = 2

    wrapper.build_gym_from_yaml(environment_name, f'Single-Player/{environment_name}.yaml',
                                player_observer_type=gd.ObserverType.BLOCK_2D,
                                global_observer_type=gd.ObserverType.SPRITE_2D, level=level, tile_size=50)
    env = gym.make(f'GDY-{environment_name}-v0')
    play(env, callback=callback(env), fps=10, zoom=3)
Exemple #24
0

def dump_current_gameplay():
    if obs_ts:
        np.savez_compressed('{}/{}_{}'.format(
            args.record_dir, args.game_name,
            str(datetime.datetime.now()).split('.')[0].replace(' ', '_')),
                            obs_ts=obs_ts,
                            obs_tp1s=obs_tp1s,
                            actions=actions,
                            rewards=rewards)


if __name__ == '__main__':
    args = parser.parse_args()
    env = gym.make(args.game_name)

    Path(args.record_dir).mkdir(parents=True, exist_ok=True)

    obs_ts = []
    obs_tp1s = []
    actions = []
    rewards = []

    recording = False

    with suppress(KeyboardInterrupt):
        play(env, zoom=args.zoom, fps=args.fps, callback=callback)

    dump_current_gameplay()
def main():
    env = gym.make("AchtungDieKurveFullImageRandomOpponent-v1")
    play.play(env, fps=30)
Exemple #26
0
"Play the game Atari Brakeout using the keyboard"
import gym
from gym.utils.play import play

play(gym.make('BreakoutDeterministic-v4'), zoom=3)
Exemple #27
0
def on_frame(obs, obs_tp1, action, reward, *_):
    from time import sleep
    on_frame.counter += 1
    state = extractor.extract_state(obs)
    if on_frame.counter % 1 == 0:
        print(state)
        print(action)
    if reward != 0:
        print(state)
        print(reward)
        sleep(5)


on_frame.counter = 0
play(env, zoom=6, callback=on_frame)
#
# env = gym.make('FreewayDeterministic-v4')
# im = env.reset()
# # pil_im = Image.fromarray(im)
# # pil_im.show()
# extractor = StateExtractor(im)
#
# # pil_im.save('step0.png')
#
# rewards = []
# for i in range(200):
#     im, r, done, _ = env.step(1)
#     if done:
#         break
#     # if i % 3 == 0:
Exemple #28
0
def main(_):
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    output_dir = FLAGS.output_dir

    subdirectories = ["data", "tmp", "world_model", "ppo"]
    using_autoencoder = hparams.autoencoder_train_steps > 0
    if using_autoencoder:
        subdirectories.append("autoencoder")
    directories = setup_directories(output_dir, subdirectories)

    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game

    if using_autoencoder:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
    else:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        if simulated_problem_name not in registry.list_problems():
            tf.logging.info(
                "Game Problem %s not found; dynamically registering",
                simulated_problem_name)
            gym_problems_specs.create_problems_for_game(
                hparams.game, game_mode="Deterministic-v4")

    epoch = hparams.epochs - 1
    epoch_data_dir = os.path.join(directories["data"], str(epoch))
    ppo_model_dir = directories["ppo"]

    world_model_dir = directories["world_model"]

    gym_problem = registry.problem(simulated_problem_name)

    model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
    environment_spec = copy.copy(gym_problem.environment_spec)
    environment_spec.simulation_random_starts = hparams.simulation_random_starts

    batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
    batch_env_hparams.add_hparam("model_hparams", model_hparams)
    batch_env_hparams.add_hparam("environment_spec", environment_spec)
    batch_env_hparams.num_agents = 1

    with temporary_flags({
            "problem": simulated_problem_name,
            "model": hparams.generative_model,
            "hparams_set": hparams.generative_model_params,
            "output_dir": world_model_dir,
            "data_dir": epoch_data_dir,
    }):
        sess = tf.Session()
        env = DebugBatchEnv(batch_env_hparams, sess)
        sess.run(tf.global_variables_initializer())
        env.initialize()

        env_model_loader = tf.train.Saver(tf.global_variables("next_frame*"))
        trainer_lib.restore_checkpoint(world_model_dir,
                                       env_model_loader,
                                       sess,
                                       must_restore=True)

        model_saver = tf.train.Saver(
            tf.global_variables(".*network_parameters.*"))
        trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

        key_mapping = gym_problem.env.env.get_keys_to_action()
        # map special codes
        key_mapping[()] = 100
        key_mapping[(ord("r"), )] = 101
        key_mapping[(ord("p"), )] = 102

        play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
    def get_keys_to_action(self):
        keyword_to_key = {
            'UP':      ord('w'),
            'DOWN':    ord('s'),
            'LEFT':    ord('a'),
            'RIGHT':   ord('d'),
        }

        keys_to_action = {}

        for action_id, action_meaning in self.action_meaning.items():
            keys = []
            for keyword, key in keyword_to_key.items():
                if keyword in action_meaning:
                    keys.append(key)
            keys = tuple(sorted(keys))

            assert keys not in keys_to_action
            keys_to_action[keys] = action_id
        return keys_to_action


if __name__ == '__main__':
    import gym.utils.play as play

    # register()
    # env = RandomMazeEnv(127, 127)
    env = RandomMazeEnv(63, 63, start_end_distance=5)
    play.play(env, zoom=16.0)
Exemple #30
0
import gym
from gym.utils.play import play

if __name__ == '__main__':
    env_name = 'SpaceInvaders-v0'
    env = gym.make(env_name)
    play(env, zoom=4)
Exemple #31
0
import gym
from gym.utils.play import play
env = gym.make('CarRacing-v0')
play(env, zoom=3)