示例#1
0
def main(_):
  # gym.logger.set_level(gym.logger.DEBUG)
  hparams = registry.hparams(FLAGS.loop_hparams_set)
  hparams.parse(FLAGS.loop_hparams)
  # Not important for experiments past 2018
  if "wm_policy_param_sharing" not in hparams.values().keys():
    hparams.add_hparam("wm_policy_param_sharing", False)
  directories = player_utils.infer_paths(
      output_dir=FLAGS.output_dir,
      world_model=FLAGS.wm_dir,
      policy=FLAGS.policy_dir,
      data=FLAGS.episodes_data_dir)
  if FLAGS.game_from_filenames:
    hparams.set_hparam(
        "game", player_utils.infer_game_name_from_filenames(directories["data"])
    )
  action_meanings = gym.make(full_game_name(hparams.game)).\
      unwrapped.get_action_meanings()
  epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

  def make_real_env():
    env = player_utils.setup_and_load_epoch(
        hparams, data_dir=directories["data"],
        which_epoch_data=None)
    env = FlatBatchEnv(env)  # pylint: disable=redefined-variable-type
    return env

  def make_simulated_env(setable_initial_frames, which_epoch_data):
    env = player_utils.load_data_and_make_simulated_env(
        directories["data"], directories["world_model"],
        hparams, which_epoch_data=which_epoch_data,
        setable_initial_frames=setable_initial_frames)
    return env

  if FLAGS.sim_and_real:
    sim_env = make_simulated_env(
        which_epoch_data=None, setable_initial_frames=True)
    real_env = make_real_env()
    env = SimAndRealEnvPlayer(real_env, sim_env, action_meanings)
  else:
    if FLAGS.simulated_env:
      env = make_simulated_env(  # pylint: disable=redefined-variable-type
          which_epoch_data=epoch, setable_initial_frames=False)
    else:
      env = make_real_env()
    env = SingleEnvPlayer(env, action_meanings)  # pylint: disable=redefined-variable-type

  env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

  if FLAGS.dry_run:
    env.unwrapped.get_keys_to_action()
    for _ in range(5):
      env.reset()
      for i in range(50):
        env.step(i % 3)
      env.step(PlayerEnv.RETURN_DONE_ACTION)  # reset
    return

  play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
示例#2
0
def main(_):
    # gym.logger.set_level(gym.logger.DEBUG)
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    # Not important for experiments past 2018
    if "wm_policy_param_sharing" not in hparams.values().keys():
        hparams.add_hparam("wm_policy_param_sharing", False)
    directories = player_utils.infer_paths(output_dir=FLAGS.output_dir,
                                           world_model=FLAGS.wm_dir,
                                           policy=FLAGS.policy_dir,
                                           data=FLAGS.episodes_data_dir)
    epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

    if FLAGS.simulated_env:
        env = player_utils.load_data_and_make_simulated_env(
            directories["data"],
            directories["world_model"],
            hparams,
            which_epoch_data=epoch)
    else:
        env = player_utils.setup_and_load_epoch(hparams,
                                                data_dir=directories["data"],
                                                which_epoch_data=epoch)
        env = FlatBatchEnv(env)

    env = PlayerEnvWrapper(env)  # pylint: disable=redefined-variable-type

    env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

    if FLAGS.dry_run:
        for _ in range(5):
            env.reset()
            for i in range(50):
                env.step(i % 3)
            env.step(PlayerEnvWrapper.RESET_ACTION)  # reset
        return

    play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)