def initialize_env_specs(hparams):
    """Initializes env_specs using T2TGymEnvs."""
    env = rl_utils.setup_env(hparams, hparams.batch_size,
                             hparams.eval_max_num_noops)
    env.start_new_epoch(0)
    hparams.add_hparam("env_fn", rl.make_real_env_fn(env))
    return hparams
Example #2
0
def evaluate_world_model(agent_type, loop_hparams, planner_hparams, model_dir,
                         policy_dir, random_starts_step_limit,
                         debug_video_path, log_every_steps):
    """Evaluates the world model."""
    if debug_video_path:
        debug_video_path = os.path.join(debug_video_path, "0.avi")

    storage_env = rl_utils.setup_env(loop_hparams,
                                     batch_size=1,
                                     max_num_noops=0)
    stacked_env = rl_utils.BatchStackWrapper(storage_env,
                                             loop_hparams.frame_stack_size)
    policy_hparams = trainer_lib.create_hparams(loop_hparams.base_algo_params)
    agent = make_agent_from_hparams(
        agent_type,
        storage_env,
        stacked_env,
        loop_hparams,
        policy_hparams,
        planner_hparams,
        model_dir,
        policy_dir,
        # TODO(koz4k): Loop over eval_sampling_temps?
        sampling_temp=loop_hparams.eval_sampling_temps[0],
    )
    collect_frames_for_random_starts(storage_env, stacked_env, agent,
                                     loop_hparams.frame_stack_size,
                                     random_starts_step_limit, log_every_steps)
    return rl_utils.evaluate_world_model(storage_env,
                                         loop_hparams,
                                         model_dir,
                                         debug_video_path,
                                         split=None)
Example #3
0
def setup_and_load_epoch(hparams, data_dir, which_epoch_data=None):
  """Load T2TGymEnv with data from one epoch.

  Args:
    hparams: hparams.
    data_dir: data directory.
    which_epoch_data: data from which epoch to load.

  Returns:
    env.
  """
  t2t_env = rl_utils.setup_env(
      hparams, batch_size=hparams.real_batch_size,
      max_num_noops=hparams.max_num_noops
  )
  # Load data.
  if which_epoch_data is not None:
    if which_epoch_data == "last":
      which_epoch_data = infer_last_epoch_num(data_dir)
    assert isinstance(which_epoch_data, int), \
      "{}".format(type(which_epoch_data))
    t2t_env.start_new_epoch(which_epoch_data, data_dir)
  else:
    t2t_env.start_new_epoch(-999)
  return t2t_env
def initialize_env_specs(hparams):
    """Initializes env_specs using T2TGymEnvs."""
    env = rl_utils.setup_env(hparams, hparams.batch_size,
                             hparams.eval_max_num_noops)
    env.start_new_epoch(0)

    # TODO(afrozm): Decouple env_fn from hparams and return both, is there
    # even a need to return hparams? Just return the env_fn?
    hparams.add_hparam("env_fn", rl.make_real_env_fn(env))
    return hparams
Example #5
0
def initialize_env_specs(hparams):
    """Initializes env_specs using T2TGymEnvs."""
    env = rl_utils.setup_env(hparams,
                             hparams.batch_size,
                             hparams.eval_max_num_noops,
                             hparams.rl_env_max_episode_steps,
                             env_name=hparams.rl_env_name)

    env.start_new_epoch(0)

    return rl.make_real_env_fn(env)
def initialize_env_specs(hparams, env_problem_name):
  """Initializes env_specs using the appropriate env."""
  if env_problem_name:
    env = registry.env_problem(env_problem_name, batch_size=hparams.batch_size)
  else:
    env = rl_utils.setup_env(hparams, hparams.batch_size,
                             hparams.eval_max_num_noops,
                             hparams.rl_env_max_episode_steps,
                             env_name=hparams.rl_env_name)
    env.start_new_epoch(0)

  return rl.make_real_env_fn(env)
Example #7
0
def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
  """Run the main training loop."""
  if report_fn:
    assert report_metric is not None

  # Directories
  subdirectories = [
      "data", "tmp", "world_model", ("world_model", "debug_videos"),
      "policy", "eval_metrics"
  ]
  directories = setup_directories(output_dir, subdirectories)

  epoch = -1
  data_dir = directories["data"]
  env = rl_utils.setup_env(
      hparams, batch_size=hparams.real_batch_size,
      max_num_noops=hparams.max_num_noops,
      rl_env_max_episode_steps=hparams.rl_env_max_episode_steps
  )
  env.start_new_epoch(epoch, data_dir)

  if hparams.wm_policy_param_sharing:
    policy_model_dir = directories["world_model"]
  else:
    policy_model_dir = directories["policy"]
  learner = rl_utils.LEARNERS[hparams.base_algo](
      hparams.frame_stack_size, policy_model_dir,
      policy_model_dir, hparams.epochs
  )

  # Timing log function
  log_relative_time = make_relative_timing_fn()

  # Per-epoch state
  epoch_metrics = []
  metrics = {}

  # Collect data from the real environment.
  tf.logging.info("Initial training of the policy in real environment.")
  train_agent_real_env(env, learner, hparams, epoch)
  metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
      env.current_epoch_rollouts(), clipped=True
  )
  tf.logging.info("Mean training reward (initial): {}".format(
      metrics["mean_reward/train/clipped"]
  ))
  env.generate_data(data_dir)

  eval_metrics_writer = tf.summary.FileWriter(
      directories["eval_metrics"]
  )

  world_model_steps_num = 0

  for epoch in range(hparams.epochs):
    log = make_log_fn(epoch, log_relative_time)

    # Train world model
    log("Training world model")
    world_model_steps_num = train_world_model(
        env, data_dir, directories["world_model"], hparams,
        world_model_steps_num, epoch
    )

    # Train agent
    log("Training policy in simulated environment.")
    train_agent(env, learner, directories["world_model"], hparams, epoch)

    env.start_new_epoch(epoch, data_dir)

    # Train agent on real env (short)
    log("Training policy in real environment.")
    train_agent_real_env(env, learner, hparams, epoch)

    if hparams.stop_loop_early:
      return 0.0

    env.generate_data(data_dir)

    metrics = load_metrics(directories["eval_metrics"], epoch)
    if metrics:
      # Skip eval if metrics have already been written for this epoch. Otherwise
      # we'd overwrite them with wrong data.
      log("Metrics found for this epoch, skipping evaluation.")
    else:
      metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward(
          env.current_epoch_rollouts(), clipped=True
      )
      log("Mean training reward: {}".format(
          metrics["mean_reward/train/clipped"]
      ))

      eval_metrics = rl_utils.evaluate_all_configs(hparams, policy_model_dir)
      log("Agent eval metrics:\n{}".format(pprint.pformat(eval_metrics)))
      metrics.update(eval_metrics)

      if hparams.eval_world_model:
        debug_video_path = os.path.join(
            directories["world_model", "debug_videos"],
            "{}.avi".format(env.current_epoch)
        )
        wm_metrics = evaluate_world_model(
            env, hparams, directories["world_model"], debug_video_path
        )
        log("World model eval metrics:\n{}".format(pprint.pformat(wm_metrics)))
        metrics.update(wm_metrics)

      rl_utils.summarize_metrics(eval_metrics_writer, metrics, epoch)

      # Report metrics
      if report_fn:
        if report_metric == "mean_reward":
          metric_name = rl_utils.get_metric_name(
              sampling_temp=hparams.eval_sampling_temps[0],
              max_num_noops=hparams.eval_max_num_noops,
              clipped=False
          )
          report_fn(eval_metrics[metric_name], epoch)
        else:
          report_fn(eval_metrics[report_metric], epoch)

    epoch_metrics.append(metrics)

  # Return the evaluation metrics from the final epoch
  return epoch_metrics[-1]