示例#1
0
def setup_problems(hparams, using_autoencoder=False):
    """Register problems based on game name."""
    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game
    game_problems_kwargs = {}
    # Problems
    if using_autoencoder:
        game_problems_kwargs["autoencoder_hparams"] = (
            hparams.autoencoder_hparams_set)
        problem_name = (
            "gym_discrete_problem_with_agent_on_%s_with_autoencoder" %
            game_with_mode)
        world_model_problem = (
            "gym_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            "_autoencoded" % game_with_mode)
    else:
        problem_name = ("gym_discrete_problem_with_agent_on_%s" %
                        game_with_mode)
        world_model_problem = problem_name
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            % game_with_mode)
    if problem_name not in registry.list_problems():
        game_problems_kwargs[
            "resize_height_factor"] = hparams.resize_height_factor
        game_problems_kwargs[
            "resize_width_factor"] = hparams.resize_width_factor
        game_problems_kwargs["grayscale"] = hparams.grayscale
        tf.logging.info("Game Problem %s not found; dynamically registering",
                        problem_name)
        gym_problems_specs.create_problems_for_game(
            hparams.game, game_mode="Deterministic-v4", **game_problems_kwargs)
    return (problem_name, world_model_problem, simulated_problem_name,
            world_model_eval_problem_name)
示例#2
0
def main(_):
    tf.gfile.MakeDirs(FLAGS.data_dir)
    tf.gfile.MakeDirs(FLAGS.tmp_dir)

    # Create problem if not already defined
    problem_name = "gym_discrete_problem_with_agent_on_%s" % FLAGS.game
    if problem_name not in registry.list_problems():
        gym_problems_specs.create_problems_for_game(FLAGS.game)

    # Generate
    tf.logging.info("Running %s environment for %d steps for trajectories.",
                    FLAGS.game, FLAGS.num_env_steps)
    problem = registry.problem(problem_name)
    problem.settable_num_steps = FLAGS.num_env_steps
    problem.settable_eval_phase = FLAGS.eval
    problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)

    # Log stats
    if problem.statistics.number_of_dones:
        mean_reward = (problem.statistics.sum_of_rewards /
                       problem.statistics.number_of_dones)
        tf.logging.info("Mean reward: %.2f, Num dones: %d", mean_reward,
                        problem.statistics.number_of_dones)
def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
    """Run the main training loop."""
    if report_fn:
        assert report_metric is not None

    # Global state

    # Directories
    subdirectories = ["data", "tmp", "world_model", "ppo"]
    using_autoencoder = hparams.autoencoder_train_steps > 0
    if using_autoencoder:
        subdirectories.append("autoencoder")
    directories = setup_directories(output_dir, subdirectories)

    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game
    # Problems
    if using_autoencoder:
        problem_name = (
            "gym_discrete_problem_with_agent_on_%s_with_autoencoder" %
            game_with_mode)
        world_model_problem = (
            "gym_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            "_autoencoded" % game_with_mode)
    else:
        problem_name = ("gym_discrete_problem_with_agent_on_%s" %
                        game_with_mode)
        world_model_problem = problem_name
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        world_model_eval_problem_name = (
            "gym_simulated_discrete_problem_for_world_model_eval_with_agent_on_%s"
            % game_with_mode)
        if problem_name not in registry.list_problems():
            tf.logging.info(
                "Game Problem %s not found; dynamically registering",
                problem_name)
            gym_problems_specs.create_problems_for_game(
                hparams.game,
                resize_height_factor=hparams.resize_height_factor,
                resize_width_factor=hparams.resize_width_factor,
                game_mode="Deterministic-v4")

    # Autoencoder model dir
    autoencoder_model_dir = directories.get("autoencoder")

    # Timing log function
    log_relative_time = make_relative_timing_fn()

    # Per-epoch state
    epoch_metrics = []
    epoch_data_dirs = []

    ppo_model_dir = None
    data_dir = os.path.join(directories["data"], "initial")
    epoch_data_dirs.append(data_dir)
    # Collect data from the real environment with PPO or random policy.
    if hparams.gather_ppo_real_env_data:
        ppo_model_dir = directories["ppo"]
        tf.logging.info("Initial training of PPO in real environment.")
        ppo_event_dir = os.path.join(directories["world_model"],
                                     "ppo_summaries/initial")
        train_agent_real_env(problem_name,
                             ppo_model_dir,
                             ppo_event_dir,
                             directories["world_model"],
                             data_dir,
                             hparams,
                             epoch=-1,
                             is_final_epoch=False)

    tf.logging.info("Generating real environment data with %s policy",
                    "PPO" if hparams.gather_ppo_real_env_data else "random")
    mean_reward = generate_real_env_data(problem_name, ppo_model_dir, hparams,
                                         data_dir, directories["tmp"])
    tf.logging.info("Mean reward (random): {}".format(mean_reward))

    eval_metrics_event_dir = os.path.join(directories["world_model"],
                                          "eval_metrics_event_dir")
    eval_metrics_writer = tf.summary.FileWriter(eval_metrics_event_dir)
    model_reward_accuracy_summary = tf.Summary()
    model_reward_accuracy_summary.value.add(tag="model_reward_accuracy",
                                            simple_value=None)
    mean_reward_summary = tf.Summary()
    mean_reward_summary.value.add(tag="mean_reward", simple_value=None)

    for epoch in range(hparams.epochs):
        is_final_epoch = (epoch + 1) == hparams.epochs
        log = make_log_fn(epoch, log_relative_time)

        # Combine all previously collected environment data
        epoch_data_dir = os.path.join(directories["data"], str(epoch))
        tf.gfile.MakeDirs(epoch_data_dir)
        # Because the data is being combined in every iteration, we only need to
        # copy from the previous directory.
        combine_training_data(registry.problem(problem_name), epoch_data_dir,
                              epoch_data_dirs[-1:])
        epoch_data_dirs.append(epoch_data_dir)

        if using_autoencoder:
            # Train the Autoencoder on all prior environment frames
            log("Training Autoencoder")
            train_autoencoder(problem_name, epoch_data_dir,
                              autoencoder_model_dir, hparams, epoch)

            log("Autoencoding environment frames")
            encode_env_frames(problem_name, world_model_problem,
                              autoencoder_model_dir, epoch_data_dir)

        # Train world model
        log("Training world model")
        train_world_model(world_model_problem, epoch_data_dir,
                          directories["world_model"], hparams, epoch)

        # Evaluate world model
        model_reward_accuracy = 0.
        if hparams.eval_world_model:
            log("Evaluating world model")
            model_reward_accuracy = evaluate_world_model(
                world_model_eval_problem_name, world_model_problem, hparams,
                directories["world_model"], epoch_data_dir, directories["tmp"])
            log("World model reward accuracy: %.4f", model_reward_accuracy)

        # Train PPO
        log("Training PPO in simulated environment.")
        ppo_event_dir = os.path.join(directories["world_model"],
                                     "ppo_summaries", str(epoch))
        ppo_model_dir = directories["ppo"]
        if not hparams.ppo_continue_training:
            ppo_model_dir = ppo_event_dir
        train_agent(simulated_problem_name,
                    ppo_model_dir,
                    ppo_event_dir,
                    directories["world_model"],
                    epoch_data_dir,
                    hparams,
                    epoch=epoch,
                    is_final_epoch=is_final_epoch)

        # Train PPO on real env (short)
        log("Training PPO in real environment.")
        train_agent_real_env(problem_name,
                             ppo_model_dir,
                             ppo_event_dir,
                             directories["world_model"],
                             epoch_data_dir,
                             hparams,
                             epoch=epoch,
                             is_final_epoch=is_final_epoch)

        if hparams.stop_loop_early:
            return 0.0
        # Collect data from the real environment.
        log("Generating real environment data")
        eval_data_dir = os.path.join(epoch_data_dir, "eval")
        mean_reward = generate_real_env_data(
            problem_name,
            ppo_model_dir,
            hparams,
            eval_data_dir,
            directories["tmp"],
            autoencoder_path=autoencoder_model_dir,
            eval_phase=True)
        log("Mean eval reward: {}".format(mean_reward))

        if not is_final_epoch:
            generation_mean_reward = generate_real_env_data(
                problem_name,
                ppo_model_dir,
                hparams,
                epoch_data_dir,
                directories["tmp"],
                autoencoder_path=autoencoder_model_dir,
                eval_phase=False)
            log("Mean reward during generation: {}".format(
                generation_mean_reward))

        # Summarize metrics
        assert model_reward_accuracy is not None
        assert mean_reward is not None
        model_reward_accuracy_summary.value[
            0].simple_value = model_reward_accuracy
        mean_reward_summary.value[0].simple_value = mean_reward
        eval_metrics_writer.add_summary(model_reward_accuracy_summary, epoch)
        eval_metrics_writer.add_summary(mean_reward_summary, epoch)
        eval_metrics_writer.flush()

        # Report metrics
        eval_metrics = {
            "model_reward_accuracy": model_reward_accuracy,
            "mean_reward": mean_reward
        }
        epoch_metrics.append(eval_metrics)
        log("Eval metrics: %s", str(eval_metrics))
        if report_fn:
            report_fn(eval_metrics[report_metric], epoch)

    # Return the evaluation metrics from the final epoch
    return epoch_metrics[-1]
示例#4
0
def main(_):
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    output_dir = FLAGS.output_dir

    subdirectories = ["data", "tmp", "world_model", "ppo"]
    using_autoencoder = hparams.autoencoder_train_steps > 0
    if using_autoencoder:
        subdirectories.append("autoencoder")
    directories = setup_directories(output_dir, subdirectories)

    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game

    if using_autoencoder:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
    else:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        if simulated_problem_name not in registry.list_problems():
            tf.logging.info(
                "Game Problem %s not found; dynamically registering",
                simulated_problem_name)
            gym_problems_specs.create_problems_for_game(
                hparams.game, game_mode="Deterministic-v4")

    epoch = hparams.epochs - 1
    epoch_data_dir = os.path.join(directories["data"], str(epoch))
    ppo_model_dir = directories["ppo"]

    world_model_dir = directories["world_model"]

    gym_problem = registry.problem(simulated_problem_name)

    model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
    environment_spec = copy.copy(gym_problem.environment_spec)
    environment_spec.simulation_random_starts = hparams.simulation_random_starts

    batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
    batch_env_hparams.add_hparam("model_hparams", model_hparams)
    batch_env_hparams.add_hparam("environment_spec", environment_spec)
    batch_env_hparams.num_agents = 1

    with temporary_flags({
            "problem": simulated_problem_name,
            "model": hparams.generative_model,
            "hparams_set": hparams.generative_model_params,
            "output_dir": world_model_dir,
            "data_dir": epoch_data_dir,
    }):
        sess = tf.Session()
        env = DebugBatchEnv(batch_env_hparams, sess)
        sess.run(tf.global_variables_initializer())
        env.initialize()

        env_model_loader = tf.train.Saver(tf.global_variables("next_frame*"))
        trainer_lib.restore_checkpoint(world_model_dir,
                                       env_model_loader,
                                       sess,
                                       must_restore=True)

        model_saver = tf.train.Saver(
            tf.global_variables(".*network_parameters.*"))
        trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

        key_mapping = gym_problem.env.env.get_keys_to_action()
        # map special codes
        key_mapping[()] = 100
        key_mapping[(ord("r"), )] = 101
        key_mapping[(ord("p"), )] = 102

        play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
 def testGymAtariBoots(self):
     problem = gym_problems_specs.create_problems_for_game("pong")["base"]()
     self.assertEqual(210, problem.frame_height)