Exemple #1
0
def create_hooks(use_tfdbg=False,
                 use_dbgprofile=False,
                 dbgprofile_kwargs=None,
                 use_validation_monitor=False,
                 validation_monitor_kwargs=None,
                 use_early_stopping=False,
                 early_stopping_kwargs=None):
  """Create train and eval hooks for Experiment."""
  train_hooks = []
  eval_hooks = []

  if use_tfdbg:
    hook = debug.LocalCLIDebugHook()
    train_hooks.append(hook)
    eval_hooks.append(hook)

  if use_dbgprofile:
    # Recorded traces can be visualized with chrome://tracing/
    # The memory/tensor lifetime is also profiled
    tf.logging.info("Using ProfilerHook")
    defaults = dict(save_steps=10, show_dataflow=True, show_memory=True)
    defaults.update(dbgprofile_kwargs)
    train_hooks.append(tf.train.ProfilerHook(**defaults))

  if use_validation_monitor:
    tf.logging.info("Using ValidationMonitor")
    train_hooks.append(
        contrib.learn().monitors.ValidationMonitor(
            hooks=eval_hooks, **validation_monitor_kwargs))

  if use_early_stopping:
    tf.logging.info("Using EarlyStoppingHook")
    hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs)
    # Adding to both training and eval so that eval aborts as well
    train_hooks.append(hook)
    eval_hooks.append(hook)

  return train_hooks, eval_hooks
Exemple #2
0
def create_experiment(run_config,
                      hparams,
                      model_name,
                      problem_name,
                      data_dir,
                      train_steps,
                      eval_steps,
                      min_eval_frequency=2000,
                      eval_throttle_seconds=600,
                      schedule="train_and_evaluate",
                      export=False,
                      decode_hparams=None,
                      use_tfdbg=False,
                      use_dbgprofile=False,
                      eval_early_stopping_steps=None,
                      eval_early_stopping_metric=None,
                      eval_early_stopping_metric_delta=None,
                      eval_early_stopping_metric_minimize=True,
                      eval_timeout_mins=240,
                      eval_use_test_set=False,
                      use_tpu=False,
                      use_tpu_estimator=False,
                      use_xla=False,
                      export_saved_model_api_version=1,
                      use_guarantee_const_getter=False,
                      additional_train_hooks=None,
                      additional_eval_hooks=None,
                      warm_start_from=None,
                      decode_from_file="",
                      decode_to_file="",
                      decode_reference="",
                      std_server_protocol=None):
    """Create Experiment."""
    # HParams
    hparams.add_hparam("model_dir", run_config.model_dir)
    hparams.add_hparam("data_dir", data_dir)
    hparams.add_hparam("train_steps", train_steps)
    hparams.add_hparam("eval_steps", eval_steps)
    hparams.add_hparam("schedule", schedule)
    hparams.add_hparam("warm_start_from", warm_start_from)
    hparams.add_hparam("std_server_protocol", std_server_protocol)
    hparams.add_hparam("eval_freq_in_steps", min_eval_frequency)
    hparams.add_hparam("eval_timeout_mins", eval_timeout_mins)
    if decode_hparams is not None:
        decode_hparams.add_hparam("decode_from_file", decode_from_file)
        if decode_to_file and not decode_hparams.decode_to_file:
            decode_hparams.decode_to_file = decode_to_file
        if decode_reference and not decode_hparams.decode_reference:
            decode_hparams.decode_reference = decode_reference
    add_problem_hparams(hparams, problem_name)

    # Estimator
    estimator = create_estimator(
        model_name,
        hparams,
        run_config,
        schedule=schedule,
        decode_hparams=decode_hparams,
        use_tpu=use_tpu,
        use_tpu_estimator=use_tpu_estimator,
        use_xla=use_xla,
        export_saved_model_api_version=export_saved_model_api_version,
        use_guarantee_const_getter=use_guarantee_const_getter)

    # Input fns from Problem
    problem = hparams.problem
    train_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.TRAIN, hparams)

    dataset_split = "test" if eval_use_test_set else None
    dataset_kwargs = {"dataset_split": dataset_split}
    eval_input_fn = problem.make_estimator_input_fn(
        tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)

    # Export
    exporter = None
    if export:

        def compare_fn(best_eval_result, current_eval_result):
            metric = eval_early_stopping_metric or "loss"
            return current_eval_result[metric] < best_eval_result[metric]

        def serving_input_receiver_fn(hparams, decode_hparams, use_tpu):
            return problem.serving_input_fn(hparams, decode_hparams, use_tpu)

        exporter = tf.estimator.BestExporter(
            name="best",
            serving_input_receiver_fn=serving_input_receiver_fn,
            compare_fn=compare_fn,
            assets_extra=problem.export_assets)

    # Hooks
    validation_monitor_kwargs = dict(
        input_fn=eval_input_fn,
        eval_steps=eval_steps,
        every_n_steps=min_eval_frequency,
        early_stopping_rounds=eval_early_stopping_steps,
        early_stopping_metric=eval_early_stopping_metric,
        early_stopping_metric_minimize=eval_early_stopping_metric_minimize)
    dbgprofile_kwargs = {"output_dir": run_config.model_dir}
    early_stopping_kwargs = dict(
        events_dir=os.path.join(run_config.model_dir, "eval_continuous"),
        tag=eval_early_stopping_metric,
        num_plateau_steps=eval_early_stopping_steps,
        plateau_decrease=eval_early_stopping_metric_minimize,
        plateau_delta=eval_early_stopping_metric_delta,
        every_n_steps=min_eval_frequency)

    # Eval on TPU Pods is not supported yet
    if use_tpu and run_config.tpu_config.num_shards > 8 and "eval" in schedule:
        raise ValueError("Eval is not currently supported on a TPU Pod")

    # In-process eval (and possible early stopping)
    if schedule == "continuous_train_and_eval" and min_eval_frequency:
        tf.logging.warn("ValidationMonitor only works with "
                        "--schedule=train_and_evaluate")
    use_validation_monitor = (schedule == "train_and_evaluate"
                              and min_eval_frequency)
    # Distributed early stopping
    local_schedules = ["train_and_evaluate", "continuous_train_and_eval"]
    use_early_stopping = (schedule not in local_schedules
                          and eval_early_stopping_steps)
    train_hooks, eval_hooks = create_hooks(
        use_tfdbg=use_tfdbg,
        use_dbgprofile=use_dbgprofile,
        dbgprofile_kwargs=dbgprofile_kwargs,
        use_validation_monitor=use_validation_monitor,
        validation_monitor_kwargs=validation_monitor_kwargs,
        use_early_stopping=use_early_stopping,
        early_stopping_kwargs=early_stopping_kwargs)

    hook_context = HookContext(estimator=estimator,
                               problem=problem,
                               hparams=hparams)

    train_hooks += t2t_model.T2TModel.get_train_hooks(model_name, hook_context)
    eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name, hook_context)
    if additional_train_hooks:
        train_hooks += additional_train_hooks
    if additional_eval_hooks:
        eval_hooks += additional_eval_hooks

    train_hooks = contrib.learn().monitors.replace_monitors_with_hooks(
        train_hooks, estimator)
    eval_hooks = contrib.learn().monitors.replace_monitors_with_hooks(
        eval_hooks, estimator)

    train_spec = tf.estimator.TrainSpec(train_input_fn,
                                        max_steps=train_steps,
                                        hooks=train_hooks)
    eval_spec = tf.estimator.EvalSpec(
        eval_input_fn,
        steps=eval_steps,
        hooks=eval_hooks,
        start_delay_secs=0 if hparams.schedule == "evaluate" else 120,
        throttle_secs=eval_throttle_seconds,
        exporters=exporter)

    return T2TExperiment(estimator, hparams, train_spec, eval_spec,
                         use_validation_monitor, decode_hparams)
Exemple #3
0
def create_run_config(model_name,
                      master="",
                      model_dir=None,
                      iterations_per_loop=1000,
                      num_shards=8,
                      log_device_placement=False,
                      save_checkpoints_steps=1000,
                      save_checkpoints_secs=None,
                      keep_checkpoint_max=20,
                      keep_checkpoint_every_n_hours=10000,
                      num_gpus=1,
                      gpu_order="",
                      num_async_replicas=1,
                      enable_graph_rewriter=False,
                      gpu_mem_fraction=0.95,
                      no_data_parallelism=False,
                      optionally_use_dist_strat=False,
                      daisy_chain_variables=True,
                      schedule="continuous_train_and_eval",
                      worker_job="/job:localhost",
                      worker_id=0,
                      ps_replicas=0,
                      ps_job="/job:ps",
                      ps_gpu=0,
                      random_seed=None,
                      sync=False,
                      tpu_infeed_sleep_secs=None,
                      use_tpu=False,
                      use_tpu_estimator=False,
                      xla_jit_level=tf.OptimizerOptions.OFF,
                      inter_op_parallelism_threads=0,
                      log_step_count_steps=100,
                      intra_op_parallelism_threads=0,
                      tpu_config_extra_kwargs=None,
                      cloud_tpu_name="",
                      cloud_tpu_zone=None):
    """Create RunConfig, TPUConfig, and Parallelism object."""
    session_config = create_session_config(
        log_device_placement=log_device_placement,
        enable_graph_rewriter=enable_graph_rewriter,
        gpu_mem_fraction=gpu_mem_fraction,
        use_tpu=use_tpu,
        xla_jit_level=xla_jit_level,
        inter_op_parallelism_threads=inter_op_parallelism_threads,
        intra_op_parallelism_threads=intra_op_parallelism_threads)
    run_config_args = {
        "master": master,
        "evaluation_master": master,
        "model_dir": model_dir,
        "session_config": session_config,
        "save_summary_steps": 100,
        "save_checkpoints_steps": save_checkpoints_steps,
        "save_checkpoints_secs": save_checkpoints_secs,
        "keep_checkpoint_max": keep_checkpoint_max,
        "keep_checkpoint_every_n_hours": keep_checkpoint_every_n_hours,
        "tf_random_seed": random_seed,
        "log_step_count_steps": log_step_count_steps,
    }
    if save_checkpoints_secs:
        del run_config_args["save_checkpoints_steps"]
    run_config_cls = contrib.learn().RunConfig

    if use_tpu or use_tpu_estimator:
        # If using TPUEstimator, use TPU RunConfig, add TPUConfig, and add
        # additional args.
        tpu_config_kwargs = {
            "iterations_per_loop": iterations_per_loop,
            "num_shards": num_shards,
            "per_host_input_for_training": True,
            "initial_infeed_sleep_secs": tpu_infeed_sleep_secs,
        }
        if tpu_config_extra_kwargs is not None:
            tpu_config_kwargs.update(tpu_config_extra_kwargs)
        run_config_cls = contrib.tpu().RunConfig
        tpu_config = contrib.tpu().TPUConfig(**tpu_config_kwargs)
        run_config_args["tpu_config"] = tpu_config
        if not master and "KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS" in os.environ:
            # If running on TPU but no master is set and the KUBE env var is present
            # then we're running on ML Engine. Set the master.
            run_config_args["master"] = os.environ[
                "KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS"]
            run_config_args["evaluation_master"] = run_config_args["master"]
        elif not master and cloud_tpu_name:
            # Update run_config to use cluster instead of master/evaluation_master
            # as we need the cluster spec to use Cloud Pods
            tpu_cluster_resolver = contrib.cluster_resolver(
            ).TPUClusterResolver(tpu=cloud_tpu_name, zone=cloud_tpu_zone)
            run_config_args["cluster"] = tpu_cluster_resolver
            del run_config_args["master"]
            del run_config_args["evaluation_master"]
    elif is_cloud_async_distributed():
        run_config_cls = tf.estimator.RunConfig
        del run_config_args["master"]
        del run_config_args["evaluation_master"]

    # tf.estimator RunConfig construction got totally broken in TF2.
    # we now have to specify master in a global environment variable
    if contrib.is_tf2:
        del run_config_args["evaluation_master"]
        del run_config_args["master"]

    config = run_config_cls(**run_config_args)

    # If not using TPU, add device info for data_parallelism
    config.use_tpu = use_tpu
    if not use_tpu:
        config.t2t_device_info = {
            "num_async_replicas": num_async_replicas,
        }
        use_distribution_strategy = (
            optionally_use_dist_strat
            and t2t_model.T2TModel.has_symmetric_shards(model_name)
            and not no_data_parallelism and ps_replicas == 0 and ps_gpu == 0
            and num_async_replicas == 1)

        if use_distribution_strategy:
            tf.logging.info(
                "Configuring MirroredStrategy DistributionStrategy to replicate the "
                "model.")
            distribution = contrib.distribute().MirroredStrategy()
            config = config.replace(train_distribute=distribution)
            config.data_parallelism = None
        else:
            tf.logging.info(
                "Configuring DataParallelism to replicate the model.")
            config.data_parallelism = devices.data_parallelism(
                daisy_chain_variables=daisy_chain_variables,
                ps_replicas=ps_replicas,
                ps_job=ps_job,
                ps_gpu=ps_gpu,
                schedule=schedule,
                sync=sync,
                worker_gpu=num_gpus,
                worker_replicas=num_async_replicas,
                worker_id=worker_id,
                gpu_order=gpu_order,
                worker_job=worker_job,
                no_data_parallelism=no_data_parallelism)

    return config