Beispiel #1
0
def main(_):
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    # Calculate the list of problems to generate.
    problems = sorted(
        list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
    for exclude in FLAGS.exclude_problems.split(","):
        if exclude:
            problems = [p for p in problems if exclude not in p]
    if FLAGS.problem and FLAGS.problem[-1] == "*":
        problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
    elif FLAGS.problem and "," in FLAGS.problem:
        problems = [p for p in problems if p in FLAGS.problem.split(",")]
    elif FLAGS.problem:
        problems = [p for p in problems if p == FLAGS.problem]
    else:
        problems = []

    # Remove TIMIT if paths are not given.
    if getattr(FLAGS, "timit_paths", None):
        problems = [p for p in problems if "timit" not in p]
    # Remove parsing if paths are not given.
    if getattr(FLAGS, "parsing_path", None):
        problems = [p for p in problems if "parsing_english_ptb" not in p]

    if not problems:
        problems_str = "\n  * ".join(
            sorted(
                list(_SUPPORTED_PROBLEM_GENERATORS) +
                registry.list_problems()))
        error_msg = ("You must specify one of the supported problems to "
                     "generate data for:\n  * " + problems_str + "\n")
        error_msg += ("TIMIT and parsing need data_sets specified with "
                      "--timit_paths and --parsing_path.")
        raise ValueError(error_msg)

    if not FLAGS.data_dir:
        FLAGS.data_dir = tempfile.gettempdir()
        tf.logging.warning(
            "It is strongly recommended to specify --data_dir. "
            "Data will be written to default data_dir=%s.", FLAGS.data_dir)
    FLAGS.data_dir = os.path.expanduser(FLAGS.data_dir)
    tf.gfile.MakeDirs(FLAGS.data_dir)

    tf.logging.info(
        "Generating problems:\n%s" %
        registry.display_list_by_prefix(problems, starting_spaces=4))
    if FLAGS.only_list:
        return
    for problem in problems:
        set_random_seed()

        if problem in _SUPPORTED_PROBLEM_GENERATORS:
            generate_data_for_problem(problem)
        else:
            generate_data_for_registered_problem(problem)
Beispiel #2
0
def main(_):
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # Calculate the list of problems to generate.
  problems = sorted(
      list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
  for exclude in FLAGS.exclude_problems.split(","):
    if exclude:
      problems = [p for p in problems if exclude not in p]
  if FLAGS.problem and FLAGS.problem[-1] == "*":
    problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
  elif FLAGS.problem and "," in FLAGS.problem:
    problems = [p for p in problems if p in FLAGS.problem.split(",")]
  elif FLAGS.problem:
    problems = [p for p in problems if p == FLAGS.problem]
  else:
    problems = []

  # Remove TIMIT if paths are not given.
  if getattr(FLAGS, "timit_paths", None):
    problems = [p for p in problems if "timit" not in p]
  # Remove parsing if paths are not given.
  if getattr(FLAGS, "parsing_path", None):
    problems = [p for p in problems if "parsing_english_ptb" not in p]

  if not problems:
    problems_str = "\n  * ".join(
        sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
    error_msg = ("You must specify one of the supported problems to "
                 "generate data for:\n  * " + problems_str + "\n")
    error_msg += ("TIMIT and parsing need data_sets specified with "
                  "--timit_paths and --parsing_path.")
    raise ValueError(error_msg)

  if not FLAGS.data_dir:
    FLAGS.data_dir = tempfile.gettempdir()
    tf.logging.warning("It is strongly recommended to specify --data_dir. "
                       "Data will be written to default data_dir=%s.",
                       FLAGS.data_dir)
  FLAGS.data_dir = os.path.expanduser(FLAGS.data_dir)
  tf.gfile.MakeDirs(FLAGS.data_dir)

  tf.logging.info("Generating problems:\n%s"
                  % registry.display_list_by_prefix(problems,
                                                    starting_spaces=4))
  if FLAGS.only_list:
    return
  for problem in problems:
    set_random_seed()

    if problem in _SUPPORTED_PROBLEM_GENERATORS:
      generate_data_for_problem(problem)
    else:
      generate_data_for_registered_problem(problem)
def main(_):
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # 不设置t2t_usr_dir参数,则此步不做事情

  # Calculate the list of problems to generate.
  problems = sorted( # 这是将上面列举的任务和注册的任务统一在一起?
      list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
  for exclude in FLAGS.exclude_problems.split(","): # 可以通过参数指定排除一些任务
    if exclude:
      problems = [p for p in problems if exclude not in p]
  if FLAGS.problem and FLAGS.problem[-1] == "*": # 这意思是问题后面带个*号,即表示选择了包含该前缀的所有问题
    problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
  elif FLAGS.problem: # 如果不带*号,则仅仅将这一个问题选出来
    problems = [p for p in problems if p == FLAGS.problem]
  else: # 否则就是没有问题
    problems = []

  # Remove TIMIT if paths are not given. # 有两个特殊的任务需要特殊的参数,如果这些参数没有指定,则没办进行下去
  if not FLAGS.timit_paths: 
    problems = [p for p in problems if "timit" not in p]
  # Remove parsing if paths are not given.
  if not FLAGS.parsing_path:
    problems = [p for p in problems if "parsing_english_ptb" not in p]

  if not problems: # 没有问题则报错
    problems_str = "\n  * ".join(
        sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
    error_msg = ("You must specify one of the supported problems to "
                 "generate data for:\n  * " + problems_str + "\n")
    error_msg += ("TIMIT and parsing need data_sets specified with "
                  "--timit_paths and --parsing_path.")
    raise ValueError(error_msg)

  if not FLAGS.data_dir: # 如果没有指定数据路径,则警告,并告知默认的数据路径
    FLAGS.data_dir = tempfile.gettempdir()
    tf.logging.warning("It is strongly recommended to specify --data_dir. "
                       "Data will be written to default data_dir=%s.",
                       FLAGS.data_dir)
  FLAGS.data_dir = os.path.expanduser(FLAGS.data_dir) # 扩展数据路径
  tf.gfile.MakeDirs(FLAGS.data_dir) # 创建路径

  tf.logging.info("Generating problems:\n%s"
                  % registry.display_list_by_prefix(problems,
                                                    starting_spaces=4))
  if FLAGS.only_list: # 真是高级呢==,控制“仅仅展示一下所有问题”,还是“不仅展示,还要生成这些数据”
    return
  for problem in problems:
    set_random_seed() # 先设置随机种子

    if problem in _SUPPORTED_PROBLEM_GENERATORS: # 两种不同来源的任务还要分不同的生成方式
      generate_data_for_problem(problem)
    else:
      generate_data_for_registered_problem(problem)
def main(_):

  tf.gfile.MakeDirs(FLAGS.data_dir)
  tf.gfile.MakeDirs(FLAGS.tmp_dir)

  # Create problem if not already defined
  problem_name = "gym_discrete_problem_with_agent_on_%s" % FLAGS.game
  if problem_name not in registry.list_problems():
    gym_env.register_game(FLAGS.game)

  # Generate
  tf.logging.info("Running %s environment for %d steps for trajectories.",
                  FLAGS.game, FLAGS.num_env_steps)
  problem = registry.problem(problem_name)
  problem.settable_num_steps = FLAGS.num_env_steps
  problem.settable_eval_phase = FLAGS.eval
  problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)

  # Log stats
  if problem.statistics.number_of_dones:
    mean_reward = (problem.statistics.sum_of_rewards /
                   problem.statistics.number_of_dones)
    tf.logging.info("Mean reward: %.2f, Num dones: %d",
                    mean_reward,
                    problem.statistics.number_of_dones)
Beispiel #5
0
def add_problem_hparams(hparams, problems):
    """Add problem hparams for the problems."""
    hparams.problems = []
    hparams.problem_instances = []
    for problem_name in problems.split("-"):
        try:
            problem = registry.problem(problem_name)
        except LookupError:
            problem = None

        if problem is None:
            try:
                p_hparams = problem_hparams.problem_hparams(
                    problem_name, hparams)
            except LookupError:
                # The problem is not in the set of registered Problems nor in the old
                # set of problem_hparams.
                all_problem_names = sorted(
                    list(problem_hparams.PROBLEM_HPARAMS_MAP) +
                    registry.list_problems())
                error_lines = [
                    "%s not in the set of supported problems:" % problem_name
                ] + all_problem_names
                error_msg = "\n  * ".join(error_lines)
                raise LookupError(error_msg)
        else:
            p_hparams = problem.get_hparams(hparams)

        hparams.problem_instances.append(problem)
        hparams.problems.append(p_hparams)

    return hparams
Beispiel #6
0
def add_problem_hparams(hparams, problems):
  """Add problem hparams for the problems."""
  hparams.problems = []
  hparams.problem_instances = []
  for problem_name in problems.split("-"):
    try:
      problem = registry.problem(problem_name)
    except LookupError:
      all_problem_names = sorted(registry.list_problems())
      error_lines = ["%s not in the set of supported problems:" % problem_name
                    ] + all_problem_names
      error_msg = "\n  * ".join(error_lines)
      raise LookupError(error_msg)
    p_hparams = problem.get_hparams(hparams)

    hparams.problem_instances.append(problem)
    hparams.problems.append(p_hparams)
Beispiel #7
0
def add_problem_hparams(hparams, problems):
  """Add problem hparams for the problems."""
  hparams.problems = []
  hparams.problem_instances = []
  for problem_name in problems.split("-"):
    try:
      problem = registry.problem(problem_name)
    except LookupError:
      all_problem_names = sorted(registry.list_problems())
      error_lines = ["%s not in the set of supported problems:" % problem_name
                    ] + all_problem_names
      error_msg = "\n  * ".join(error_lines)
      raise LookupError(error_msg)
    p_hparams = problem.get_hparams(hparams)

    hparams.problem_instances.append(problem)
    hparams.problems.append(p_hparams)
def setup_problems(hparams, using_autoencoder=False):
    """Register problems based on game name."""
    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game
    game_problems_kwargs = {}
    # Problems
    if using_autoencoder:
        game_problems_kwargs["autoencoder_hparams"] = (
            hparams.autoencoder_hparams_set)
        problem_name = (
            "gym_discrete_problem_with_agent_on_%s_with_autoencoder" %
            game_with_mode)
        world_model_problem = (
            "gym_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            "_autoencoded" % game_with_mode)
    else:
        problem_name = ("gym_discrete_problem_with_agent_on_%s" %
                        game_with_mode)
        world_model_problem = problem_name
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            % game_with_mode)
    if problem_name not in registry.list_problems():
        game_problems_kwargs[
            "resize_height_factor"] = hparams.resize_height_factor
        game_problems_kwargs[
            "resize_width_factor"] = hparams.resize_width_factor
        game_problems_kwargs["grayscale"] = hparams.grayscale
        tf.logging.info("Game Problem %s not found; dynamically registering",
                        problem_name)
        gym_problems_specs.create_problems_for_game(
            hparams.game, game_mode="Deterministic-v4", **game_problems_kwargs)
    return (problem_name, world_model_problem, simulated_problem_name,
            world_model_eval_problem_name)
Beispiel #9
0
def add_problem_hparams(hparams, problems):
    """Add problem hparams for the problems."""
    hparams.problems = []
    hparams.problem_instances = []
    for problem_name in problems.split("-"):
        try:
            problem = registry.problem(
                problem_name)  # search for the registered problem
        except LookupError:
            all_problem_names = sorted(
                registry.list_problems())  # list all problems
            error_lines = [
                "%s not in the set of supported problems:" % problem_name
            ] + all_problem_names
            error_msg = "\n  * ".join(error_lines)
            raise LookupError(error_msg)
        p_hparams = problem.get_hparams(
            hparams)  # contains vocabulary, inputs/targets modality

        hparams.problem_instances.append(problem)
        hparams.problems.append(p_hparams)
def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
    """Run the main training loop."""
    if report_fn:
        assert report_metric is not None

    # Global state

    # Directories
    subdirectories = ["data", "tmp", "world_model", "ppo"]
    using_autoencoder = hparams.autoencoder_train_steps > 0
    if using_autoencoder:
        subdirectories.append("autoencoder")
    directories = setup_directories(output_dir, subdirectories)

    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game
    # Problems
    if using_autoencoder:
        problem_name = (
            "gym_discrete_problem_with_agent_on_%s_with_autoencoder" %
            game_with_mode)
        world_model_problem = (
            "gym_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            "_autoencoded" % game_with_mode)
    else:
        problem_name = ("gym_discrete_problem_with_agent_on_%s" %
                        game_with_mode)
        world_model_problem = problem_name
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            % game_with_mode)
        if problem_name not in registry.list_problems():
            tf.logging.info(
                "Game Problem %s not found; dynamically registering",
                problem_name)
            gym_problems_specs.create_problems_for_game(
                hparams.game,
                resize_height_factor=hparams.resize_height_factor,
                resize_width_factor=hparams.resize_width_factor,
                game_mode="Deterministic-v4")

    # Autoencoder model dir
    autoencoder_model_dir = directories.get("autoencoder")

    # Timing log function
    log_relative_time = make_relative_timing_fn()

    # Per-epoch state
    epoch_metrics = []
    epoch_data_dirs = []

    ppo_model_dir = None
    data_dir = os.path.join(directories["data"], "initial")
    epoch_data_dirs.append(data_dir)
    # Collect data from the real environment with PPO or random policy.
    if hparams.gather_ppo_real_env_data:
        ppo_model_dir = directories["ppo"]
        tf.logging.info("Initial training of PPO in real environment.")
        ppo_event_dir = os.path.join(directories["world_model"],
                                     "ppo_summaries/initial")
        train_agent_real_env(problem_name,
                             ppo_model_dir,
                             ppo_event_dir,
                             directories["world_model"],
                             data_dir,
                             hparams,
                             epoch=-1,
                             is_final_epoch=False)

    tf.logging.info("Generating real environment data with %s policy",
                    "PPO" if hparams.gather_ppo_real_env_data else "random")
    mean_reward = generate_real_env_data(problem_name, ppo_model_dir, hparams,
                                         data_dir, directories["tmp"])
    tf.logging.info("Mean reward (random): {}".format(mean_reward))

    eval_metrics_event_dir = os.path.join(directories["world_model"],
                                          "eval_metrics_event_dir")
    eval_metrics_writer = tf.summary.FileWriter(eval_metrics_event_dir)
    model_reward_accuracy_summary = tf.Summary()
    model_reward_accuracy_summary.value.add(tag="model_reward_accuracy",
                                            simple_value=None)
    mean_reward_summary = tf.Summary()
    mean_reward_summary.value.add(tag="mean_reward", simple_value=None)

    for epoch in range(hparams.epochs):
        is_final_epoch = (epoch + 1) == hparams.epochs
        log = make_log_fn(epoch, log_relative_time)

        # Combine all previously collected environment data
        epoch_data_dir = os.path.join(directories["data"], str(epoch))
        tf.gfile.MakeDirs(epoch_data_dir)
        # Because the data is being combined in every iteration, we only need to
        # copy from the previous directory.
        combine_training_data(registry.problem(problem_name), epoch_data_dir,
                              epoch_data_dirs[-1:])
        epoch_data_dirs.append(epoch_data_dir)

        if using_autoencoder:
            # Train the Autoencoder on all prior environment frames
            log("Training Autoencoder")
            train_autoencoder(problem_name, epoch_data_dir,
                              autoencoder_model_dir, hparams, epoch)

            log("Autoencoding environment frames")
            encode_env_frames(problem_name, world_model_problem,
                              autoencoder_model_dir, epoch_data_dir)

        # Train world model
        log("Training world model")
        train_world_model(world_model_problem, epoch_data_dir,
                          directories["world_model"], hparams, epoch)

        # Evaluate world model
        model_reward_accuracy = 0.
        if hparams.eval_world_model:
            log("Evaluating world model")
            model_reward_accuracy = evaluate_world_model(
                world_model_eval_problem_name, world_model_problem, hparams,
                directories["world_model"], epoch_data_dir, directories["tmp"])
            log("World model reward accuracy: %.4f", model_reward_accuracy)

        # Train PPO
        log("Training PPO in simulated environment.")
        ppo_event_dir = os.path.join(directories["world_model"],
                                     "ppo_summaries", str(epoch))
        ppo_model_dir = directories["ppo"]
        if not hparams.ppo_continue_training:
            ppo_model_dir = ppo_event_dir
        train_agent(simulated_problem_name,
                    ppo_model_dir,
                    ppo_event_dir,
                    directories["world_model"],
                    epoch_data_dir,
                    hparams,
                    epoch=epoch,
                    is_final_epoch=is_final_epoch)

        # Train PPO on real env (short)
        log("Training PPO in real environment.")
        train_agent_real_env(problem_name,
                             ppo_model_dir,
                             ppo_event_dir,
                             directories["world_model"],
                             epoch_data_dir,
                             hparams,
                             epoch=epoch,
                             is_final_epoch=is_final_epoch)

        if hparams.stop_loop_early:
            return 0.0
        # Collect data from the real environment.
        log("Generating real environment data")
        eval_data_dir = os.path.join(epoch_data_dir, "eval")
        mean_reward = generate_real_env_data(
            problem_name,
            ppo_model_dir,
            hparams,
            eval_data_dir,
            directories["tmp"],
            autoencoder_path=autoencoder_model_dir,
            eval_phase=True)
        log("Mean eval reward: {}".format(mean_reward))

        if not is_final_epoch:
            generation_mean_reward = generate_real_env_data(
                problem_name,
                ppo_model_dir,
                hparams,
                epoch_data_dir,
                directories["tmp"],
                autoencoder_path=autoencoder_model_dir,
                eval_phase=False)
            log("Mean reward during generation: {}".format(
                generation_mean_reward))

        # Summarize metrics
        assert model_reward_accuracy is not None
        assert mean_reward is not None
        model_reward_accuracy_summary.value[
            0].simple_value = model_reward_accuracy
        mean_reward_summary.value[0].simple_value = mean_reward
        eval_metrics_writer.add_summary(model_reward_accuracy_summary, epoch)
        eval_metrics_writer.add_summary(mean_reward_summary, epoch)
        eval_metrics_writer.flush()

        # Report metrics
        eval_metrics = {
            "model_reward_accuracy": model_reward_accuracy,
            "mean_reward": mean_reward
        }
        epoch_metrics.append(eval_metrics)
        log("Eval metrics: %s", str(eval_metrics))
        if report_fn:
            report_fn(eval_metrics[report_metric], epoch)

    # Return the evaluation metrics from the final epoch
    return epoch_metrics[-1]
Beispiel #11
0
def main(_):
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    output_dir = FLAGS.output_dir

    subdirectories = ["data", "tmp", "world_model", "ppo"]
    using_autoencoder = hparams.autoencoder_train_steps > 0
    if using_autoencoder:
        subdirectories.append("autoencoder")
    directories = setup_directories(output_dir, subdirectories)

    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game

    if using_autoencoder:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
    else:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        if simulated_problem_name not in registry.list_problems():
            tf.logging.info(
                "Game Problem %s not found; dynamically registering",
                simulated_problem_name)
            gym_problems_specs.create_problems_for_game(
                hparams.game, game_mode="Deterministic-v4")

    epoch = hparams.epochs - 1
    epoch_data_dir = os.path.join(directories["data"], str(epoch))
    ppo_model_dir = directories["ppo"]

    world_model_dir = directories["world_model"]

    gym_problem = registry.problem(simulated_problem_name)

    model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
    environment_spec = copy.copy(gym_problem.environment_spec)
    environment_spec.simulation_random_starts = hparams.simulation_random_starts

    batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
    batch_env_hparams.add_hparam("model_hparams", model_hparams)
    batch_env_hparams.add_hparam("environment_spec", environment_spec)
    batch_env_hparams.num_agents = 1

    with temporary_flags({
            "problem": simulated_problem_name,
            "model": hparams.generative_model,
            "hparams_set": hparams.generative_model_params,
            "output_dir": world_model_dir,
            "data_dir": epoch_data_dir,
    }):
        sess = tf.Session()
        env = DebugBatchEnv(batch_env_hparams, sess)
        sess.run(tf.global_variables_initializer())
        env.initialize()

        env_model_loader = tf.train.Saver(tf.global_variables("next_frame*"))
        trainer_lib.restore_checkpoint(world_model_dir,
                                       env_model_loader,
                                       sess,
                                       must_restore=True)

        model_saver = tf.train.Saver(
            tf.global_variables(".*network_parameters.*"))
        trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

        key_mapping = gym_problem.env.env.get_keys_to_action()
        # map special codes
        key_mapping[()] = 100
        key_mapping[(ord("r"), )] = 101
        key_mapping[(ord("p"), )] = 102

        play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
Beispiel #12
0
def available():
    return sorted(registry.list_problems())
Beispiel #13
0
def available():
  return sorted(registry.list_problems())
def main(_):
  hparams = registry.hparams(FLAGS.loop_hparams_set)
  hparams.parse(FLAGS.loop_hparams)
  output_dir = FLAGS.output_dir

  subdirectories = ["data", "tmp", "world_model", "ppo"]
  using_autoencoder = hparams.autoencoder_train_steps > 0
  if using_autoencoder:
    subdirectories.append("autoencoder")
  directories = setup_directories(output_dir, subdirectories)

  if hparams.game in gym_env.ATARI_GAMES:
    game_with_mode = hparams.game + "_deterministic-v4"
  else:
    game_with_mode = hparams.game

  if using_autoencoder:
    simulated_problem_name = (
        "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded"
        % game_with_mode)
  else:
    simulated_problem_name = ("gym_simulated_discrete_problem_with_agent_on_%s"
                              % game_with_mode)
    if simulated_problem_name not in registry.list_problems():
      tf.logging.info("Game Problem %s not found; dynamically registering",
                      simulated_problem_name)
      gym_env.register_game(hparams.game, game_mode="Deterministic-v4")

  epoch = hparams.epochs-1
  epoch_data_dir = os.path.join(directories["data"], str(epoch))
  ppo_model_dir = directories["ppo"]

  world_model_dir = directories["world_model"]

  gym_problem = registry.problem(simulated_problem_name)

  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  environment_spec = copy.copy(gym_problem.environment_spec)
  environment_spec.simulation_random_starts = hparams.simulation_random_starts

  batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
  batch_env_hparams.add_hparam("model_hparams", model_hparams)
  batch_env_hparams.add_hparam("environment_spec", environment_spec)
  batch_env_hparams.num_agents = 1

  with temporary_flags({
      "problem": simulated_problem_name,
      "model": hparams.generative_model,
      "hparams_set": hparams.generative_model_params,
      "output_dir": world_model_dir,
      "data_dir": epoch_data_dir,
  }):
    sess = tf.Session()
    env = DebugBatchEnv(batch_env_hparams, sess)
    sess.run(tf.global_variables_initializer())
    env.initialize()

    env_model_loader = tf.train.Saver(
        tf.global_variables("next_frame*"))
    trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess,
                                   must_restore=True)

    model_saver = tf.train.Saver(
        tf.global_variables(".*network_parameters.*"))
    trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

    key_mapping = gym_problem.env.env.get_keys_to_action()
    # map special codes
    key_mapping[()] = 100
    key_mapping[(ord("r"),)] = 101
    key_mapping[(ord("p"),)] = 102

    play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)