Ejemplo n.º 1
0
def create_experiment(
    run_config,
    hparams,
    model_name,
    problem_name,
    data_dir,
    train_steps,
    eval_steps,
    min_eval_frequency=2000,
    eval_throttle_seconds=600,
    schedule="train_and_evaluate",
    export=False,
    decode_hparams=None,
    use_tfdbg=False,
    use_dbgprofile=False,
    eval_early_stopping_steps=None,
    eval_early_stopping_metric=None,
    eval_early_stopping_metric_delta=None,
    eval_early_stopping_metric_minimize=True,
    eval_timeout_mins=240,
    use_tpu=False,
    use_tpu_estimator=False,
    use_xla=False,
    additional_train_hooks=None,
    additional_eval_hooks=None,
    warm_start_from=None,
    decode_from_file=None,
    decode_to_file=None,
    decode_reference=None,
    std_server_protocol=None):
  """Create Experiment."""
  # HParams
  hparams.add_hparam("model_dir", run_config.model_dir)
  hparams.add_hparam("data_dir", data_dir)
  hparams.add_hparam("train_steps", train_steps)
  hparams.add_hparam("eval_steps", eval_steps)
  hparams.add_hparam("schedule", schedule)
  hparams.add_hparam("warm_start_from", warm_start_from)
  hparams.add_hparam("std_server_protocol", std_server_protocol)
  hparams.add_hparam("eval_freq_in_steps", min_eval_frequency)
  hparams.add_hparam("eval_timeout_mins", eval_timeout_mins)
  if decode_hparams is not None:
    decode_hparams.add_hparam("decode_from_file", decode_from_file)
    decode_hparams.add_hparam("decode_to_file", decode_to_file)
    decode_hparams.add_hparam("decode_reference", decode_reference)
  trainer_lib.add_problem_hparams(hparams, problem_name)

  # Estimator
  estimator = trainer_lib.create_estimator(
      model_name,
      hparams,
      run_config,
      schedule=schedule,
      decode_hparams=decode_hparams,
      use_tpu=use_tpu,
      use_tpu_estimator=use_tpu_estimator,
      use_xla=use_xla)

  # Input fns from Problem
  problem = hparams.problem
  train_input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.TRAIN, hparams,
      dataset_kwargs={"max_records": FLAGS.train_data_size})
  eval_input_fn = problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams)

  # Export
  exporter = None
  if export:
    def compare_fn(best_eval_result, current_eval_result):
      metric = eval_early_stopping_metric or "loss"
      return current_eval_result[metric] < best_eval_result[metric]

    exporter = tf.estimator.BestExporter(
        name="best",
        serving_input_receiver_fn=lambda: problem.serving_input_fn(hparams),
        compare_fn=compare_fn,
        assets_extra=problem.export_assets)

  # Hooks
  validation_monitor_kwargs = dict(
      input_fn=eval_input_fn,
      eval_steps=eval_steps,
      every_n_steps=min_eval_frequency,
      early_stopping_rounds=eval_early_stopping_steps,
      early_stopping_metric=eval_early_stopping_metric,
      early_stopping_metric_minimize=eval_early_stopping_metric_minimize)
  dbgprofile_kwargs = {"output_dir": run_config.model_dir}
  early_stopping_kwargs = dict(
      events_dir=os.path.join(run_config.model_dir, "eval_continuous"),
      tag=eval_early_stopping_metric,
      num_plateau_steps=eval_early_stopping_steps,
      plateau_decrease=eval_early_stopping_metric_minimize,
      plateau_delta=eval_early_stopping_metric_delta,
      every_n_steps=min_eval_frequency)

  # Eval on TPU Pods is not supported yet
  if use_tpu and run_config.tpu_config.num_shards > 8 and "eval" in schedule:
    raise ValueError("Eval is not currently supported on a TPU Pod")

  # In-process eval (and possible early stopping)
  if schedule == "continuous_train_and_eval" and min_eval_frequency:
    tf.logging.warn("ValidationMonitor only works with "
                    "--schedule=train_and_evaluate")
  use_validation_monitor = (
      schedule == "train_and_evaluate" and min_eval_frequency)
  # Distributed early stopping
  local_schedules = ["train_and_evaluate", "continuous_train_and_eval"]
  use_early_stopping = (
      schedule not in local_schedules and eval_early_stopping_steps)
  train_hooks, eval_hooks = trainer_lib.create_hooks(
      use_tfdbg=use_tfdbg,
      use_dbgprofile=use_dbgprofile,
      dbgprofile_kwargs=dbgprofile_kwargs,
      use_validation_monitor=use_validation_monitor,
      validation_monitor_kwargs=validation_monitor_kwargs,
      use_early_stopping=use_early_stopping,
      early_stopping_kwargs=early_stopping_kwargs)

  hook_context = trainer_lib.HookContext(
      estimator=estimator, problem=problem, hparams=hparams)

  train_hooks += t2t_model.T2TModel.get_train_hooks(model_name, hook_context)
  eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name, hook_context)
  if additional_train_hooks:
    train_hooks += additional_train_hooks
  if additional_eval_hooks:
    eval_hooks += additional_eval_hooks

  train_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
      train_hooks, estimator)
  eval_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
      eval_hooks, estimator)

  train_spec = tf.estimator.TrainSpec(
      train_input_fn, max_steps=train_steps, hooks=train_hooks)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=eval_steps,
      hooks=eval_hooks,
      start_delay_secs=0 if hparams.schedule == "evaluate" else 120,
      throttle_secs=eval_throttle_seconds,
      exporters=exporter)

  return trainer_lib.T2TExperiment(estimator, hparams, train_spec, eval_spec,
                                   use_validation_monitor, decode_hparams)
Ejemplo n.º 2
0
def run_std_server():
    exp = trainer_lib.T2TExperiment(*([None] * 5))
    exp.run_std_server()
Ejemplo n.º 3
0
def create_experiment(run_config,
                      hparams,
                      model_name,
                      params,
                      problem_instance,
                      data_dir,
                      train_steps,
                      eval_steps,
                      min_eval_frequency=2000,
                      eval_throttle_seconds=600,
                      schedule="train_and_evaluate",
                      export=False,
                      decode_hparams=None,
                      use_tfdbg=False,
                      use_dbgprofile=False,
                      use_validation_monitor=False,
                      eval_early_stopping_steps=None,
                      eval_early_stopping_metric=None,
                      eval_early_stopping_metric_delta=None,
                      eval_early_stopping_metric_minimize=True,
                      autotune=False,
                      use_tpu=False):
    """Create Experiment."""
    # HParams
    hparams.add_hparam('model_dir', params.model_dir)
    hparams.add_hparam("data_dir", data_dir)
    hparams.add_hparam("train_steps", train_steps)
    hparams.add_hparam("eval_steps", eval_steps)
    hparams.add_hparam("schedule", schedule)
    add_problem_hparams(hparams, problem_instance)

    # Estimator
    estimator = trainer_lib.create_estimator(model_name,
                                             hparams,
                                             run_config,
                                             schedule=schedule,
                                             decode_hparams=decode_hparams,
                                             use_tpu=use_tpu)

    # Input fns from Problem
    problem = hparams.problem
    train_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.TRAIN, hparams)
    eval_input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL,
                                                    hparams)

    # Export
    if export:
        tf.logging.warn("Exporting from the trainer is deprecated. "
                        "See serving/export.py.")

    # Hooks
    validation_monitor_kwargs = dict(
        input_fn=eval_input_fn,
        eval_steps=eval_steps,
        every_n_steps=min_eval_frequency,
        early_stopping_rounds=eval_early_stopping_steps,
        early_stopping_metric=eval_early_stopping_metric,
        early_stopping_metric_minimize=eval_early_stopping_metric_minimize)
    dbgprofile_kwargs = {"output_dir": run_config.model_dir}
    early_stopping_kwargs = dict(
        events_dir=os.path.join(run_config.model_dir, "eval_continuous"),
        tag=eval_early_stopping_metric,
        num_plateau_steps=eval_early_stopping_steps,
        plateau_decrease=eval_early_stopping_metric_minimize,
        plateau_delta=eval_early_stopping_metric_delta,
        every_n_steps=min_eval_frequency)

    # In-process eval (and possible early stopping)
    if schedule == "continuous_train_and_eval" and min_eval_frequency:
        tf.logging.warn("ValidationMonitor only works with "
                        "--schedule=train_and_evaluate")
    use_validation_monitor = (schedule == "train_and_evaluate"
                              and min_eval_frequency)
    # Distributed early stopping
    local_schedules = ["train_and_evaluate", "continuous_train_and_eval"]
    use_early_stopping = (schedule not in local_schedules
                          and eval_early_stopping_steps)
    train_hooks, eval_hooks = trainer_lib.create_hooks(
        use_tfdbg=use_tfdbg,
        use_dbgprofile=use_dbgprofile,
        dbgprofile_kwargs=dbgprofile_kwargs,
        use_validation_monitor=use_validation_monitor,
        validation_monitor_kwargs=validation_monitor_kwargs,
        use_early_stopping=use_early_stopping,
        early_stopping_kwargs=early_stopping_kwargs)
    train_hooks += t2t_model.T2TModel.get_train_hooks(model_name)
    eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name)

    train_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
        train_hooks, estimator)
    eval_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks(
        eval_hooks, estimator)

    train_spec = tf.estimator.TrainSpec(train_input_fn,
                                        max_steps=train_steps,
                                        hooks=train_hooks)
    eval_spec = tf.estimator.EvalSpec(
        eval_input_fn,
        steps=eval_steps,
        hooks=eval_hooks,
        start_delay_secs=0 if hparams.schedule == "evaluate" else 120,
        throttle_secs=eval_throttle_seconds)

    if autotune:
        hooks_kwargs = {
            "train_monitors": train_hooks,
            "eval_hooks": eval_hooks
        }
        return tf.contrib.learn.Experiment(
            estimator=estimator,
            train_input_fn=train_input_fn,
            eval_input_fn=eval_input_fn,
            train_steps=train_steps,
            eval_steps=eval_steps,
            min_eval_frequency=min_eval_frequency,
            train_steps_per_iteration=min(min_eval_frequency, train_steps),
            eval_delay_secs=0 if schedule == "evaluate" else 120,
            **hooks_kwargs if not use_tpu else {})
    return trainer_lib.T2TExperiment(estimator, hparams, train_spec, eval_spec,
                                     use_validation_monitor, decode_hparams)