Exemple #1
0
def initialize_env_specs(hparams):
    """Initializes env_specs using T2TGymEnvs."""
    env = T2TGymEnv.setup_env_from_hparams(hparams, hparams.batch_size,
                                           hparams.eval_max_num_noops)
    env.start_new_epoch(0)

    # TODO(afrozm): Decouple env_fn from hparams and return both, is there
    # even a need to return hparams? Just return the env_fn?
    hparams.add_hparam("env_fn", rl.make_real_env_fn(env))
    return hparams
Exemple #2
0
def evaluate_single_config(hparams, sampling_temp, max_num_noops,
                           agent_model_dir):
    """Evaluate the PPO agent in the real environment."""
    eval_hparams = trainer_lib.create_hparams(hparams.base_algo_params)
    env = T2TGymEnv.setup_env_from_hparams(hparams,
                                           batch_size=hparams.eval_batch_size,
                                           max_num_noops=max_num_noops)
    env.start_new_epoch(0)
    env_fn = rl.make_real_env_fn(env)
    learner = LEARNERS[hparams.base_algo](hparams.frame_stack_size,
                                          base_event_dir=None,
                                          agent_model_dir=agent_model_dir,
                                          total_num_epochs=1)
    learner.evaluate(env_fn, eval_hparams, sampling_temp)
    rollouts = env.current_epoch_rollouts()
    env.close()

    return tuple(
        compute_mean_reward(rollouts, clipped) for clipped in (True, False))
def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
  """Run the main training loop."""
  if report_fn:
    assert report_metric is not None

  # Directories
  subdirectories = [
      "data", "tmp", "world_model", ("world_model", "debug_videos"),
      "policy", "eval_metrics"
  ]
  directories = setup_directories(output_dir, subdirectories)

  epoch = -1
  data_dir = directories["data"]
  env = T2TGymEnv.setup_env_from_hparams(
      hparams, batch_size=hparams.real_batch_size,
      max_num_noops=hparams.max_num_noops
  )
  env.start_new_epoch(epoch, data_dir)

  if hparams.wm_policy_param_sharing:
    policy_model_dir = directories["world_model"]
  else:
    policy_model_dir = directories["policy"]
  learner = rl_utils.LEARNERS[hparams.base_algo](
      hparams.frame_stack_size, policy_model_dir,
      policy_model_dir, hparams.epochs
  )

  # Timing log function
  log_relative_time = make_relative_timing_fn()

  # Per-epoch state
  epoch_metrics = []
  metrics = {}

  # Collect data from the real environment.
  tf.logging.info("Initial training of the policy in real environment.")
  train_agent_real_env(env, learner, hparams, epoch)
  metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
      env.current_epoch_rollouts(), clipped=True
  )
  tf.logging.info("Mean training reward (initial): {}".format(
      metrics["mean_reward/train/clipped"]
  ))
  env.generate_data(data_dir)

  eval_metrics_writer = tf.summary.FileWriter(
      directories["eval_metrics"]
  )

  world_model_steps_num = 0

  for epoch in range(hparams.epochs):
    log = make_log_fn(epoch, log_relative_time)

    # Train world model
    log("Training world model")
    world_model_steps_num = train_world_model(
        env, data_dir, directories["world_model"], hparams,
        world_model_steps_num, epoch
    )

    # Train agent
    log("Training policy in simulated environment.")
    train_agent(env, learner, directories["world_model"], hparams, epoch)

    env.start_new_epoch(epoch, data_dir)

    # Train agent on real env (short)
    log("Training policy in real environment.")
    train_agent_real_env(env, learner, hparams, epoch)

    if hparams.stop_loop_early:
      return 0.0

    env.generate_data(data_dir)

    metrics = load_metrics(directories["eval_metrics"], epoch)
    if metrics:
      # Skip eval if metrics have already been written for this epoch. Otherwise
      # we'd overwrite them with wrong data.
      log("Metrics found for this epoch, skipping evaluation.")
    else:
      metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
          env.current_epoch_rollouts(), clipped=True
      )
      log("Mean training reward: {}".format(
          metrics["mean_reward/train/clipped"]
      ))

      eval_metrics = rl_utils.evaluate_all_configs(hparams, policy_model_dir)
      log("Agent eval metrics:\n{}".format(pprint.pformat(eval_metrics)))
      metrics.update(eval_metrics)

      if hparams.eval_world_model:
        debug_video_path = os.path.join(
            directories["world_model", "debug_videos"],
            "{}.avi".format(env.current_epoch)
        )
        wm_metrics = evaluate_world_model(
            env, hparams, directories["world_model"], debug_video_path
        )
        log("World model eval metrics:\n{}".format(pprint.pformat(wm_metrics)))
        metrics.update(wm_metrics)

      summarize_metrics(eval_metrics_writer, metrics, epoch)

      # Report metrics
      if report_fn:
        if report_metric == "mean_reward":
          metric_name = rl_utils.get_metric_name(
              sampling_temp=hparams.eval_sampling_temps[0],
              max_num_noops=hparams.eval_max_num_noops,
              clipped=False
          )
          report_fn(eval_metrics[metric_name], epoch)
        else:
          report_fn(eval_metrics[report_metric], epoch)

    epoch_metrics.append(metrics)

  # Return the evaluation metrics from the final epoch
  return epoch_metrics[-1]