示例#1
0
def create_estimator(working_dir,
                     model_fn,
                     keep_checkpoint_max=20,
                     iterations_per_loop=320,
                     warmstart=None):
    """Create a TF estimator. Used when not using TPU.

  Args:
    working_dir: working directory for loading the model.
    model_fn: an estimator model function.
    keep_checkpoint_max: the maximum number of checkpoints to save in checkpoint
      directory.
    iterations_per_loop: number of steps to run on TPU before outfeeding
      metrics to the CPU. If the number of iterations in the loop would exceed
      the number of train steps, the loop will exit before reaching
      --iterations_per_loop. The larger this value is, the higher
      the utilization on the TPU. For CPU-only training, this flag is equal to
      `num_epochs * num_minibatches`.
    warmstart: if not None, warm start the estimator from an existing
      checkpoint.
  """
    run_config = tf_estimator.RunConfig(
        save_checkpoints_steps=iterations_per_loop,
        save_summary_steps=iterations_per_loop,
        keep_checkpoint_max=keep_checkpoint_max)

    if warmstart is not None:
        return tf_estimator.Estimator(model_fn,
                                      model_dir=working_dir,
                                      config=run_config,
                                      warm_start_from=warmstart)
    else:
        return tf_estimator.Estimator(model_fn,
                                      model_dir=working_dir,
                                      config=run_config)
def initiate_task_helper(model_params,
                         ckpt_directory=None,
                         pruning_params=None):
  """Get all predictions for eval.

  Args:
    model_params:
    ckpt_directory: model checkpoint directory containing event file
    pruning_params:

  Returns:
    pd.DataFrame containing metrics from event file
  """

  if model_params["task"] != "imagenet_training":
    classifier = tf_estimator.Estimator(
        model_fn=model_fn_w_pruning, params=model_params)

    if model_params["task"] in ["imagenet_predictions"]:
      predictions = classifier.predict(
          input_fn=data_input.input_fn, checkpoint_path=ckpt_directory)
      return predictions

    if model_params["task"] in [
        "robustness_imagenet_a", "robustness_imagenet_c", "robustness_pie",
        "imagenet_eval", "ckpt_prediction"
    ]:

      eval_steps = model_params["num_eval_images"] // model_params["batch_size"]
      tf.logging.info("start computing eval metrics...")
      classifier = tf_estimator.Estimator(
          model_fn=model_fn_w_pruning, params=model_params)
      evaluation_metrics = classifier.evaluate(
          input_fn=data_input.input_fn,
          steps=eval_steps,
          checkpoint_path=ckpt_directory)
      tf.logging.info("finished per class accuracy eval.")
      return evaluation_metrics

  else:
    model_params["pruning_dict"] = pruning_params
    run_config = tf_estimator.RunConfig(
        save_summary_steps=300,
        save_checkpoints_steps=1000,
        log_step_count_steps=100)
    classifier = tf_estimator.Estimator(
        model_fn=model_fn_w_pruning, config=run_config, params=model_params)
    tf.logging.info("start training...")
    classifier.train(
        input_fn=data_input.input_fn, max_steps=model_params["num_train_steps"])
    tf.logging.info("finished training.")
def train_and_eval():
    """Train and evaluate a model."""
    save_summary_steps = FLAGS.save_summaries_steps
    save_checkpoints_steps = FLAGS.save_checkpoints_steps
    log_step_count = FLAGS.log_step_count

    config = tf_estimator.RunConfig(
        save_summary_steps=save_summary_steps,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=log_step_count,
        keep_checkpoint_max=None)

    params = {'dummy': 0}
    estimator = tf_estimator.Estimator(model_fn=model_fn,
                                       model_dir=FLAGS.checkpoint_dir,
                                       config=config,
                                       params=params)

    train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn,
                                        max_steps=FLAGS.train_steps)

    eval_spec = tf_estimator.EvalSpec(input_fn=eval_input_fn,
                                      start_delay_secs=60,
                                      steps=FLAGS.eval_examples,
                                      throttle_secs=60)

    tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#4
0
def run_mnist():
  """Run MNIST training and eval loop."""
  mnist_classifier = tf_estimator.Estimator(
      model_fn=model_fn,
      model_dir=FLAGS.model_dir)

  # Set up training and evaluation input functions.
  def train_input_fn():
    """Prepare data for training."""

    # When choosing shuffle buffer sizes, larger sizes result in better
    # randomness, while smaller sizes use less memory. MNIST is a small
    # enough dataset that we can easily shuffle the full epoch.
    ds = dataset.train(FLAGS.data_dir)
    ds_batched = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size)

    # Iterate through the dataset a set number (`epochs_between_evals`) of times
    # during each training session.
    ds = ds_batched.repeat(FLAGS.epochs_between_evals)
    return ds

  def eval_input_fn():
    return dataset.test(FLAGS.data_dir).batch(
        FLAGS.batch_size).make_one_shot_iterator().get_next()

  # Train and evaluate model.
  for _ in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
    mnist_classifier.train(input_fn=train_input_fn, hooks=None)
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print("\nEvaluation results:\n\t%s\n" % eval_results)
示例#5
0
def experiment_function(run_config, hparams):
    """An experiment function satisfying the tf.estimator API.

  Args:
    run_config: A learn_running.EstimatorConfig object.
    hparams: Unused set of hyperparams.

  Returns:
    experiment: A tf.contrib.learn.Experiment object.
  """
    del hparams

    train_input_fn = partial(input_function, is_train=True)
    eval_input_fn = partial(input_function, is_train=False)

    estimator = tf_estimator.Estimator(model_fn=model_function,
                                       config=run_config,
                                       model_dir=run_config.model_dir)

    experiment = tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        eval_steps=FLAGS.num_eval_steps,
    )

    return experiment
示例#6
0
    def test_latest_module_exporter_with_eval_spec(self):
        model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        estimator = tf_estimator.Estimator(_get_model_fn(register_module=True),
                                           model_dir=model_dir)
        exporter = hub.LatestModuleExporter("tf_hub",
                                            _serving_input_fn,
                                            exports_to_keep=2)
        estimator.train(_input_fn, max_steps=1)
        export_base_dir = os.path.join(model_dir, "export", "tf_hub")

        exporter.export(estimator, export_base_dir)
        timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir)
        self.assertEquals(1, len(timestamp_dirs))
        oldest_timestamp = timestamp_dirs[0]

        expected_module_dir = os.path.join(export_base_dir, timestamp_dirs[0],
                                           _EXPORT_MODULE_NAME)
        self.assertTrue(tf.compat.v1.gfile.IsDirectory(expected_module_dir))

        exporter.export(estimator, export_base_dir)
        timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir)
        self.assertEquals(2, len(timestamp_dirs))

        # Triggering yet another export should clean the oldest export.
        exporter.export(estimator, export_base_dir)
        timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir)
        self.assertEquals(2, len(timestamp_dirs))
        self.assertFalse(oldest_timestamp in timestamp_dirs)
示例#7
0
    def testLatestModuleExporterDirectly(self):
        model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        export_base_dir = os.path.join(
            tempfile.mkdtemp(dir=self.get_temp_dir()), "export")

        estimator = tf_estimator.Estimator(_get_model_fn(register_module=True),
                                           model_dir=model_dir)
        estimator.train(input_fn=_input_fn, steps=1)

        exporter = hub.LatestModuleExporter("exporter_name", _serving_input_fn)
        export_dir = exporter.export(estimator=estimator,
                                     export_path=export_base_dir,
                                     eval_result=None,
                                     is_the_final_export=None)

        # Check that a timestamped directory is created in the expected location.
        timestamp_dirs = tf.compat.v1.gfile.ListDirectory(export_base_dir)
        self.assertEquals(1, len(timestamp_dirs))
        self.assertEquals(
            tf.compat.as_bytes(os.path.join(export_base_dir,
                                            timestamp_dirs[0])),
            tf.compat.as_bytes(export_dir))

        # Check the timestamped directory containts the exported modules inside.
        expected_module_dir = os.path.join(
            tf.compat.as_bytes(export_dir),
            tf.compat.as_bytes(_EXPORT_MODULE_NAME))
        self.assertTrue(tf.compat.v1.gfile.IsDirectory(expected_module_dir))
def create_estimator(experiment_dir, hparams, decode_length=20):
    """Creates an estimator with given hyper parameters."""
    if FLAGS.worker_gpu > 1:
        strategy = tf.distribute.MirroredStrategy()
    else:
        strategy = None
    config = tf_estimator.RunConfig(save_checkpoints_steps=1000,
                                    save_summary_steps=300,
                                    train_distribute=strategy)
    model_fn = seq2act_estimator.create_model_fn(
        hparams,
        seq2act_estimator.compute_additional_loss\
        if hparams.use_additional_loss else None,
        seq2act_estimator.compute_additional_metric\
        if hparams.use_additional_loss else None,
        compute_seq_accuracy=True,
        decode_length=decode_length)
    if FLAGS.reference_checkpoint:
        latest_checkpoint = tf.train.latest_checkpoint(
            FLAGS.reference_checkpoint)
        ws = tf_estimator.WarmStartSettings(
            ckpt_to_initialize_from=latest_checkpoint,
            vars_to_warm_start=[
                "embed_tokens/task_embed_w", "encode_decode/.*",
                "output_layer/.*"
            ])
    else:
        ws = None
    estimator = tf_estimator.Estimator(model_fn=model_fn,
                                       model_dir=experiment_dir,
                                       config=config,
                                       warm_start_from=ws)
    return estimator
示例#9
0
def run():
    """Runs train_and_evaluate."""
    hparams_filename = os.path.join(FLAGS.model_dir, 'hparams.json')
    if FLAGS.is_chief:
        gfile.MakeDirs(FLAGS.model_dir)
        hparams = core.read_hparams(FLAGS.hparams, get_hparams())
        core.write_hparams(hparams, hparams_filename)

    # Always load HParams from model_dir.
    hparams = core.wait_for_hparams(hparams_filename, get_hparams())

    grammar = grammar_utils.load_grammar(grammar_path=hparams.grammar_path)

    estimator = tf_estimator.Estimator(
        model_fn=functools.partial(model_fn, grammar=grammar),
        params=hparams,
        config=tf_estimator.RunConfig(
            save_checkpoints_secs=hparams.save_checkpoints_secs,
            keep_checkpoint_max=hparams.keep_checkpoint_max))

    train_spec = tf_estimator.TrainSpec(input_fn=functools.partial(
        input_ops.input_fn,
        input_pattern=hparams.train_pattern,
        grammar=grammar),
                                        max_steps=hparams.train_steps)

    # NOTE(leeley): The SavedModel will be stored under the
    # tf.saved_model.tag_constants.SERVING tag.
    latest_exporter = tf_estimator.LatestExporter(
        name='latest_exported_model',
        serving_input_receiver_fn=functools.partial(
            input_ops.serving_input_receiver_fn,
            params=hparams,
            num_production_rules=grammar.num_production_rules),
        exports_to_keep=hparams.exports_to_keep)

    eval_hooks = []
    if hparams.num_expressions_per_condition > 0:
        eval_hooks.append(
            metrics.GenerationWithLeadingPowersHook(
                generation_leading_powers_abs_sums=core.hparams_list_value(
                    hparams.generation_leading_powers_abs_sums),
                num_expressions_per_condition=hparams.
                num_expressions_per_condition,
                max_length=hparams.max_length,
                grammar=grammar))

    eval_spec = tf_estimator.EvalSpec(
        input_fn=functools.partial(input_ops.input_fn,
                                   input_pattern=hparams.tune_pattern,
                                   grammar=grammar),
        steps=hparams.eval_steps,
        exporters=latest_exporter,
        start_delay_secs=hparams.start_delay_secs,
        throttle_secs=hparams.throttle_secs,
        hooks=eval_hooks)

    tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def print_variable_names():
    """Print variable names in a model."""
    params = {'dummy': 0}
    estimator = tf_estimator.Estimator(model_fn=model_fn,
                                       model_dir=FLAGS.checkpoint_dir,
                                       params=params)
    names = estimator.get_variable_names()
    for name in names:
        print(name)
def run_local_training(losses_fn,
                       input_fn,
                       trainer_params_overrides,
                       model_params,
                       vars_to_restore_fn=None):
    """Run a simple single-mechine traing loop.

  Args:
    losses_fn: A callable that receives two arguments, `features` and `params`,
      both are dictionaries, and returns a dictionary whose values are the
      losses. Their sum is the total loss to be minimized.
    input_fn: A callable that complies with tf.Estimtor's definition of
      input_fn.
    trainer_params_overrides: A dictionary or a ParameterContainer with
      overrides for the default values in TRAINER_PARAMS above.
    model_params: A ParameterContainer that will be passed to the model (i. e.
      to losses_fn and input_fn).
    vars_to_restore_fn: A callable that receives no arguments. When called,
      expected to provide a dictionary that maps the checkpoint name of each
      variable to the respective variable object. This dictionary will be used
      as `var_list` in a Saver object used for initializing from the checkpoint
      at trainer_params.init_ckpt. If None, the default saver will be used.
  """
    trainer_params = ParameterContainer.from_defaults_and_overrides(
        TRAINER_PARAMS, trainer_params_overrides, is_strict=True)

    run_config_params = {
        'model_dir': trainer_params.model_dir,
        'save_summary_steps': 50,
        'keep_checkpoint_every_n_hours':
        trainer_params.keep_checkpoint_every_n_hours,
        'log_step_count_steps': 50,
    }
    logging.info(
        'Estimators run config parameters:\n%s',
        json.dumps(run_config_params, indent=2, sort_keys=True, default=str))
    run_config = tf_estimator.RunConfig(**run_config_params)

    def estimator_spec_fn(features, labels, mode, params):
        del labels  # unused
        return _build_estimator_spec(losses_fn(features, mode, params),
                                     trainer_params=trainer_params,
                                     mode=mode,
                                     use_tpu=False)

    init_hook = InitFromCheckpointHook(trainer_params.model_dir,
                                       trainer_params.init_ckpt,
                                       vars_to_restore_fn)

    estimator = tf_estimator.Estimator(model_fn=estimator_spec_fn,
                                       config=run_config,
                                       params=model_params.as_dict())

    estimator.train(input_fn=input_fn,
                    max_steps=trainer_params.max_steps,
                    hooks=[init_hook])
    def testEndToEnd(self):

        params = imagenet_params

        params['output_dir'] = '/tmp/'
        params['batch_size'] = 2
        params['num_train_steps'] = 1
        params['eval_steps'] = 1
        params['threshold'] = 80.
        params['data_format'] = 'channels_last'
        mean_stats = [0.485, 0.456, 0.406]
        std_stats = [0.229, 0.224, 0.225]
        update_params = {
            'mean_rgb':
            mean_stats,
            'stddev_rgb':
            std_stats,
            'lr_schedule': [  # (multiplier, epoch to start) tuples
                (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
            ],
            'momentum':
            0.9,
            'data_format':
            'channels_last'
        }
        params.update(update_params)

        dataset_ = data_input.DataIterator(mode=FLAGS.mode,
                                           data_directory='',
                                           saliency_method='ig_smooth_2',
                                           transformation='modified_image',
                                           threshold=params['threshold'],
                                           keep_information=False,
                                           use_squared_value=True,
                                           mean_stats=mean_stats,
                                           std_stats=std_stats,
                                           test_small_sample=True,
                                           num_cores=FLAGS.num_cores)

        images, labels = dataset_.input_fn(params)
        self.assertEqual(images.shape.as_list(), [2, 224, 224, 3])
        self.assertEqual(labels.shape.as_list(), [
            2,
        ])

        run_config = tf_estimator.RunConfig(
            model_dir=FLAGS.dest_dir,
            save_checkpoints_steps=FLAGS.steps_per_checkpoint)

        classifier = tf_estimator.Estimator(model_fn=resnet_model_fn,
                                            model_dir=FLAGS.dest_dir,
                                            params=params,
                                            config=run_config)
        classifier.train(input_fn=dataset_.input_fn, max_steps=1)
        tf.logging.info('finished training.')
示例#13
0
def main(unused_argv):
    tf.compat.v1.logging.set_verbosity(0)

    # Load training and test data.
    train_data, train_labels, test_data, test_labels = load_adult()

    # Instantiate the tf.Estimator.
    adult_classifier = tf_compat_v1_estimator.Estimator(
        model_fn=nn_model_fn, model_dir=FLAGS.model_dir)

    # Create tf.Estimator input functions for the training and test data.
    eval_input_fn = tf_compat_v1_estimator.inputs.numpy_input_fn(
        x={'x': test_data}, y=test_labels, num_epochs=1, shuffle=False)

    # Training loop.
    steps_per_epoch = num_examples // sampling_batch
    test_accuracy_list = []
    for epoch in range(1, FLAGS.epochs + 1):
        for _ in range(steps_per_epoch):
            whether = np.random.random_sample(num_examples) > (
                1 - sampling_batch / num_examples)
            subsampling = [i for i in np.arange(num_examples) if whether[i]]
            global microbatches
            microbatches = len(subsampling)

            train_input_fn = tf_compat_v1_estimator.inputs.numpy_input_fn(
                x={'x': train_data[subsampling]},
                y=train_labels[subsampling],
                batch_size=len(subsampling),
                num_epochs=1,
                shuffle=True)
            # Train the model for one step.
            adult_classifier.train(input_fn=train_input_fn, steps=1)

        # Evaluate the model and print results
        eval_results = adult_classifier.evaluate(input_fn=eval_input_fn)
        test_accuracy = eval_results['accuracy']
        test_accuracy_list.append(test_accuracy)
        print('Test accuracy after %d epochs is: %.3f' %
              (epoch, test_accuracy))

        # Compute the privacy budget expended so far.
        if FLAGS.dpsgd:
            eps = compute_eps_poisson(epoch, FLAGS.noise_multiplier,
                                      num_examples, sampling_batch, 1e-5)
            mu = compute_mu_poisson(epoch, FLAGS.noise_multiplier,
                                    num_examples, sampling_batch)
            print('For delta=1e-5, the current epsilon is: %.2f' % eps)
            print('For delta=1e-5, the current mu is: %.2f' % mu)

            if mu > FLAGS.max_mu:
                break
        else:
            print('Trained with vanilla non-private SGD optimizer')
示例#14
0
def export_model(working_dir, model_path, model_fn, serving_input):
    """Take the latest checkpoint & export it to path.

  Args:
    working_dir: The directory where tf.estimator keeps its checkpoints.
    model_path: The path to export the model to.
    model_fn: model_fn of model.
    serving_input: function for processing input.
  """
    estimator = tf_estimator.Estimator(model_fn, model_dir=working_dir)
    estimator.export_saved_model(model_path,
                                 serving_input_receiver_fn=serving_input)
示例#15
0
def main(argv):
    del argv  # Unused.

    if FLAGS.output_dir is None:
        raise ValueError("`output_dir` must be defined")

    if FLAGS.delete_existing and tf.gfile.Exists(FLAGS.output_dir):
        tf.logging.warn("Deleting old log directory at {}".format(
            FLAGS.output_dir))
        tf.gfile.DeleteRecursively(FLAGS.output_dir)
    tf.gfile.MakeDirs(FLAGS.output_dir)

    print("Logging to {}".format(FLAGS.output_dir))

    # Load the training or test split of the Celeb-A filenames.
    if FLAGS.celeba_dir is None:
        raise ValueError("`celeba_dir` must be defined")
    celeba_dataset_path = \
        os.path.join(FLAGS.celeba_dir, "Img/img_align_celeba/")
    celeba_partition_path = \
        os.path.join(FLAGS.celeba_dir, "Eval/list_eval_partition.txt")
    with open(celeba_partition_path, "r") as fid:
        partition = fid.readlines()
    filenames, splits = zip(*[x.split() for x in partition])
    filenames = np.array(
        [os.path.join(celeba_dataset_path, f) for f in filenames])
    splits = np.array([int(x) for x in splits])

    with tf.Graph().as_default():
        train_input_fn = prep_dataset_fn(filenames, splits, is_training=True)
        eval_input_fn = prep_dataset_fn(filenames, splits, is_training=False)
        estimator = tf_estimator.Estimator(
            model_fn,
            config=tf_estimator.RunConfig(
                model_dir=FLAGS.output_dir,
                save_checkpoints_steps=FLAGS.viz_steps,
            ),
        )

        train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn,
                                            max_steps=FLAGS.max_steps)
        # Sad ugly hack here. Setting steps=None should go through all of the
        # validation set, but doesn't seem to, so I'm doing it manually.
        eval_spec = tf_estimator.EvalSpec(input_fn=eval_input_fn,
                                          steps=len(filenames[splits == 1]) //
                                          FLAGS.batch_size,
                                          start_delay_secs=0,
                                          throttle_secs=0)
        for _ in range(FLAGS.max_steps // FLAGS.viz_steps):
            tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)
示例#16
0
def main(_):
    cpu = os.cpu_count()
    tf_config = _tf_config(flags)  #1
    # 分布式需要 TF_CONFIG 环境变量
    os.environ['TF_CONFIG'] = json.dumps(tf_config)  #2
    session_config = ConfigProto(device_count={'CPU': cpu},
                                 inter_op_parallelism_threads=cpu // 2,
                                 intra_op_parallelism_threads=cpu // 2,
                                 device_filters=flags.device_filters,
                                 allow_soft_placement=True)
    strategy = experimental.ParameterServerStrategy()
    run_config = estimator.RunConfig(
        **{
            'save_summary_steps': 100,
            'save_checkpoints_steps': 1000,
            'keep_checkpoint_max': 10,
            'log_step_count_steps': 100,
            'train_distribute': strategy,
            'eval_distribute': strategy,
        }).replace(session_config=session_config)

    model = estimator.Estimator(
        model_fn=model_fn,
        model_dir='/home/axing/din/checkpoints/din',  #实际应用中是分布式文件系统
        config=run_config,
        params={
            'tf_config': tf_config,
            'decay_rate': 0.9,
            'decay_steps': 10000,
            'learning_rate': 0.1
        })

    train_spec = estimator.TrainSpec(
        input_fn=lambda: input_fn(mode='train',
                                  num_workers=flags.num_workers,
                                  worker_index=flags.worker_index,
                                  pattern='/home/axing/din/dataset/*'),  #3
        max_steps=1000  #4
    )

    # 这里就假设验证集和训练集地址一样了,实际应用中是肯定不一样的。
    eval_spec = estimator.EvalSpec(
        input_fn=lambda: input_fn(mode='eval',
                                  pattern='/home/axing/din/dataset/*'),
        steps=100,  # 每次验证 100 个 batch size 的数据
        throttle_secs=60  # 每隔至少 60 秒验证一次
    )
    estimator.train_and_evaluate(model, train_spec, eval_spec)
    def get_estimator(self):
        """Obtain estimator for the working directory.

    Returns:
      an (TPU/non-TPU) estimator.
    """
        if self._use_tpu:
            return tf_utils.get_tpu_estimator(self._checkpoint_dir,
                                              self._model_fn)

        run_config = tf_estimator.RunConfig(
            save_summary_steps=self._iterations_per_loop,
            save_checkpoints_steps=self._iterations_per_loop,
            keep_checkpoint_max=self._keep_checkpoint_max)
        return tf_estimator.Estimator(self._model_fn,
                                      model_dir=self._checkpoint_dir,
                                      config=run_config)
示例#18
0
    def test_latest_module_exporter_with_no_modules(self):
        model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
        export_base_dir = os.path.join(
            tempfile.mkdtemp(dir=self.get_temp_dir()), "export")
        self.assertFalse(tf.compat.v1.gfile.Exists(export_base_dir))

        estimator = tf_estimator.Estimator(
            _get_model_fn(register_module=False), model_dir=model_dir)
        estimator.train(input_fn=_input_fn, steps=1)

        exporter = hub.LatestModuleExporter("exporter_name", _serving_input_fn)
        export_dir = exporter.export(estimator=estimator,
                                     export_path=export_base_dir,
                                     eval_result=None,
                                     is_the_final_export=None)

        # Check the result.
        self.assertIsNone(export_dir)

        # Check that a no directory has been created in the expected location.
        self.assertFalse(tf.compat.v1.gfile.Exists(export_base_dir))
示例#19
0
def _make_estimator(hparams, label_vocab, output_dir):
    """Create a tf.estimator.Estimator.

  Args:
    hparams: tf.contrib.training.HParams.
    label_vocab: list of string.
    output_dir: str. Path to save checkpoints.

  Returns:
    tf.estimator.Estimator.
  """
    model_fn = protein_model.make_model_fn(label_vocab=label_vocab,
                                           hparams=hparams)
    run_config = tf_estimator.RunConfig(model_dir=output_dir)

    estimator = tf_estimator.Estimator(
        model_fn=model_fn,
        params=hparams,
        config=run_config,
    )

    return estimator
def predict_all_test():
    """Print error statistics for the test dataset."""
    params = {'dummy': 0}
    estimator = tf_estimator.Estimator(model_fn=model_fn,
                                       model_dir=FLAGS.checkpoint_dir,
                                       params=params)
    evals = estimator.predict(input_fn=eval_input_fn,
                              yield_single_examples=False)

    # Print error statistics.
    all_errors = [x['error'] for x in evals]
    errors = np.array(all_errors)
    print('Evaluated %d examples' % np.size(errors))
    print('Mean error: %f degrees', np.mean(errors))
    print('Median error: %f degrees', np.median(errors))
    print('Std: %f degrees', np.std(errors))
    sorted_errors = np.sort(errors)
    n = np.size(sorted_errors)
    print('\nPercentiles:')
    for perc in range(1, 101):
        index = np.int32(np.float32(n * perc) / 100.0) - 1
        print('%3d%%: %f' % (perc, sorted_errors[index]))
示例#21
0
def main(_):
    inference_fn = network.inference
    hparams = contrib_training.HParams(learning_rate=FLAGS.learning_rate)
    model_fn = estimator.create_model_fn(inference_fn, hparams)
    config = tf_estimator.RunConfig(FLAGS.model_dir)
    tf_estimator = tf_estimator.Estimator(model_fn=model_fn, config=config)

    train_dataset_fn = dataset.create_dataset_fn(FLAGS.train_pattern,
                                                 height=FLAGS.image_size,
                                                 width=FLAGS.image_size,
                                                 batch_size=FLAGS.batch_size)

    eval_dataset_fn = dataset.create_dataset_fn(FLAGS.test_pattern,
                                                height=FLAGS.image_size,
                                                width=FLAGS.image_size,
                                                batch_size=FLAGS.batch_size)

    train_spec, eval_spec = estimator.create_train_and_eval_specs(
        train_dataset_fn, eval_dataset_fn)

    tf.logging.set_verbosity(tf.logging.INFO)
    tf_estimator.train_and_evaluate(tf_estimator, train_spec, eval_spec)
示例#22
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    train_mode = FLAGS.train_mode
    ############################################################################
    # Load the dual_encoder_config_file file.
    ############################################################################
    if tf.gfile.Exists(FLAGS.dual_encoder_config_file):
        exp_config = utils.load_config_from_file(
            FLAGS.dual_encoder_config_file,
            experiment_config_pb2.DualEncoderConfig())
    else:
        raise ValueError("dual_encoder_config: {} not found!".format(
            FLAGS.dual_encoder_config_file))
    tf.logging.info(">>>> final dual_encoder_config:\n {}".format(exp_config))
    tf.gfile.MakeDirs(FLAGS.output_dir)

    ############################################################################
    # Save/copy the configuration file.
    ############################################################################
    configs_dir = os.path.join(FLAGS.output_dir, "configs")
    tf.gfile.MakeDirs(configs_dir)
    tf.gfile.MakeDirs(FLAGS.output_dir)
    with tf.gfile.Open(os.path.join(configs_dir, "dual_encoder_config.pbtxt"),
                       "w") as fout:
        print(exp_config, file=fout)

    # Write bert_config.json and doc_bert_config.json.
    tf.gfile.Copy(exp_config.encoder_config.bert_config_file,
                  os.path.join(configs_dir, "bert_config.json"),
                  overwrite=True)
    tf.gfile.Copy(exp_config.encoder_config.doc_bert_config_file,
                  os.path.join(configs_dir, "doc_bert_config.json"),
                  overwrite=True)

    # Write vocab file(s).
    tf.gfile.Copy(exp_config.encoder_config.vocab_file,
                  os.path.join(configs_dir, "vocab.txt"),
                  overwrite=True)

    # Save other important parameters as a json file.
    hparams = {
        "dual_encoder_config_file": FLAGS.dual_encoder_config_file,
        "output_dir": FLAGS.output_dir,
        "schedule": FLAGS.schedule,
        "debugging": FLAGS.debugging,
        "learning_rate": FLAGS.learning_rate,
        "num_warmup_steps": FLAGS.num_warmup_steps,
        "num_train_steps": FLAGS.num_train_steps,
        "num_tpu_cores": FLAGS.num_tpu_cores
    }
    with tf.gfile.Open(os.path.join(configs_dir, "hparams.json"), "w") as fout:
        json.dump(hparams, fout)
        tf.logging.info(">>>> saved hparams.json:\n {}".format(hparams))

    ############################################################################
    # Run the train/eval/predict/export process based on the schedule.
    ############################################################################
    max_seq_length_actual, max_predictions_per_seq_actual = \
          utils.get_actual_max_seq_len(exp_config.encoder_config.model_name,
                                       exp_config.encoder_config.max_doc_length_by_sentence,
                                       exp_config.encoder_config.max_sent_length_by_word,
                                       exp_config.encoder_config.max_predictions_per_seq)

    # Prepare input for train and eval.
    input_files = []
    for input_pattern in exp_config.train_eval_config.input_file_for_train.split(
            ","):
        input_files.extend(tf.gfile.Glob(input_pattern))
    input_file_num = 0
    tf.logging.info("*** Input Files ***")
    for input_file in input_files:
        tf.logging.info("  %s" % input_file)
        input_file_num += 1
        if input_file_num > 10:
            break
    tf.logging.info("train input_files[0:10]: %s " %
                    "\n".join(input_files[0:10]))
    eval_files = []
    if exp_config.train_eval_config.eval_with_eval_data:
        eval_files = []
        for input_pattern in exp_config.train_eval_config.input_file_for_eval.split(
                ","):
            eval_files.extend(tf.gfile.Glob(input_pattern))
    else:
        eval_files = input_files

    input_fn_builder = input_fns.input_fn_builder
    # Prepare the input functions.
    # Drop_remainder = True during training to maintain fixed batch size.
    train_input_fn = input_fn_builder(
        input_files=input_files,
        is_training=True,
        drop_remainder=True,
        max_seq_length=max_seq_length_actual,
        max_predictions_per_seq=max_predictions_per_seq_actual,
        num_cpu_threads=4,
        batch_size=exp_config.train_eval_config.train_batch_size,
    )
    eval_drop_remainder = True if FLAGS.use_tpu else False
    eval_input_fn = input_fn_builder(
        input_files=eval_files,
        max_seq_length=max_seq_length_actual,
        max_predictions_per_seq=max_predictions_per_seq_actual,
        is_training=False,
        drop_remainder=eval_drop_remainder,
        batch_size=exp_config.train_eval_config.eval_batch_size)
    predict_input_fn = input_fn_builder(
        input_files=eval_files,
        max_seq_length=max_seq_length_actual,
        max_predictions_per_seq=max_predictions_per_seq_actual,
        is_training=False,
        drop_remainder=eval_drop_remainder,
        batch_size=exp_config.train_eval_config.predict_batch_size,
        is_prediction=True)

    # Build and run the model.
    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf_estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=exp_config.train_eval_config.
        save_checkpoints_steps,
        tpu_config=tf_estimator.tpu.TPUConfig(
            iterations_per_loop=exp_config.train_eval_config.
            iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    model_fn = smith_modeling.model_fn_builder(
        dual_encoder_config=exp_config,
        train_mode=FLAGS.train_mode,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=FLAGS.num_train_steps,
        num_warmup_steps=FLAGS.num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        debugging=FLAGS.debugging)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU. The batch size for eval and predict is the same.
    estimator = tf_estimator.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=exp_config.train_eval_config.train_batch_size,
        eval_batch_size=exp_config.train_eval_config.eval_batch_size,
        predict_batch_size=exp_config.train_eval_config.predict_batch_size)

    if FLAGS.schedule == "train":
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Batch size = %d",
                        exp_config.train_eval_config.train_batch_size)
        estimator.train(input_fn=train_input_fn,
                        max_steps=FLAGS.num_train_steps)
    elif FLAGS.schedule == "continuous_eval":
        tf.logging.info("***** Running continuous evaluation *****")
        tf.logging.info("  Batch size = %d",
                        exp_config.train_eval_config.eval_batch_size)
        # checkpoints_iterator blocks until a new checkpoint appears.
        for ckpt in tf.train.checkpoints_iterator(estimator.model_dir):
            try:
                # Estimator automatically loads and evaluates the latest checkpoint.
                result = estimator.evaluate(
                    input_fn=eval_input_fn,
                    steps=exp_config.train_eval_config.max_eval_steps)
                tf.logging.info("***** Eval results for %s *****", ckpt)
                for key, value in result.items():
                    tf.logging.info("  %s = %s", key, str(value))

            except tf.errors.NotFoundError:
                # Checkpoint might get garbage collected before the eval can run.
                tf.logging.error("Checkpoint path '%s' no longer exists.",
                                 ckpt)
    elif FLAGS.schedule == "predict":
        # Load the model checkpoint and run the prediction process
        # to get the predicted scores and labels. The batch size is the same with
        # the eval batch size. For more options, refer to
        # https://www.tensorflow.org/api_docs/python/tf/compat/v1/estimator/tpu/TPUEstimator#predict
        tf.logging.info("***** Running prediction with ckpt {} *****".format(
            exp_config.encoder_config.predict_checkpoint))
        tf.logging.info("  Batch size = %d",
                        exp_config.train_eval_config.eval_batch_size)
        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "prediction_results.json")
        # Output the prediction results in json format.
        pred_res_list = []
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            written_line_index = 0
            tf.logging.info("***** Predict results *****")
            for result in estimator.predict(input_fn=predict_input_fn,
                                            checkpoint_path=exp_config.
                                            encoder_config.predict_checkpoint,
                                            yield_single_examples=True):
                if (exp_config.encoder_config.model_name ==
                        constants.MODEL_NAME_SMITH_DUAL_ENCODER):
                    pred_item_dict = utils.get_pred_res_list_item_smith_de(
                        result)
                else:
                    raise ValueError("Unsupported model name: %s" %
                                     exp_config.encoder_config.model_name)
                pred_res_list.append(pred_item_dict)
                written_line_index += 1
                if written_line_index % 500 == 0:
                    tf.logging.info(
                        "Current written_line_index: {} *****".format(
                            written_line_index))
            tf.logging.info("***** Finished prediction for %d examples *****",
                            written_line_index)
            tf.logging.info("***** Output prediction results into %s *****",
                            output_predict_file)
            json.dump(pred_res_list, writer)

    elif FLAGS.schedule == "export":
        run_config = tf_estimator.RunConfig(
            model_dir=FLAGS.output_dir,
            save_checkpoints_steps=exp_config.train_eval_config.
            save_checkpoints_steps)
        estimator = tf_estimator.Estimator(model_fn=model_fn,
                                           config=run_config)
        export_dir_base = os.path.join(FLAGS.output_dir, "export/")
        tf.logging.info(
            "***** Export the prediction checkpoint to the folder {} *****".
            format(export_dir_base))
        tf.gfile.MakeDirs(export_dir_base)
        estimator.export_saved_model(
            export_dir_base=export_dir_base,
            assets_extra={"vocab.txt": exp_config.encoder_config.vocab_file},
            serving_input_receiver_fn=input_fns.make_serving_input_example_fn(
                max_seq_length=max_seq_length_actual,
                max_predictions_per_seq=max_predictions_per_seq_actual),
            checkpoint_path=exp_config.encoder_config.predict_checkpoint)
    else:
        raise ValueError("Unsupported schedule : %s" % FLAGS.schedule)
示例#23
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    # Retrieve entries from FewRel input file. We do this before model building
    # so we can determine the number of classes and examples per class to use
    # in model building.
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    processor = fewrel.FewRelProcessor(tokenizer,
                                       FLAGS.max_seq_length,
                                       add_entity_markers=True)
    (predict_examples, fewshot_num_classes_eval,
     fewshot_num_examples_per_class) = processor.process_file(FLAGS.input)

    # Build model.
    model_fn = model_fn_builder(
        bert_config=bert_config,
        use_one_hot_embeddings=False,
        fewshot_num_examples_per_class=fewshot_num_examples_per_class,
        tokenizer=tokenizer,
        class_examples_combiner=FLAGS.fewshot_examples_combiner)
    estimator = tf_estimator.Estimator(
        model_fn=model_fn, params={"batch_size": FLAGS.predict_batch_size})

    # Convert examples into tensorflow examples, and store to a file.
    temp_dir = tempfile.mkdtemp()
    predict_file = os.path.join(temp_dir, "predict.tf_record")
    file_based_convert_examples_to_features(predict_examples,
                                            FLAGS.max_seq_length, tokenizer,
                                            predict_file)

    input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        seq_length=FLAGS.max_seq_length,
        fewshot_num_classes=fewshot_num_classes_eval,
        fewshot_num_examples_per_class=fewshot_num_examples_per_class,
        drop_remainder=False)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d", len(predict_examples))
    tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    # Perform predictions.
    predictions = []
    for item in estimator.predict(input_fn=input_fn,
                                  checkpoint_path=FLAGS.checkpoint):
        tf.logging.info("%s\t%s", item["guid"], item["predictions"])
        predictions.append(int(item["predictions"]))

    # Dump predictions to output file.
    output_predictions_file = os.path.join(FLAGS.output)
    with tf.gfile.GFile(output_predictions_file, "w") as writer:
        json.dump(predictions, writer)
        writer.write("\n")
示例#24
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.model == "seq2seq":
        assert FLAGS.rnn_cell == "lstm"
        assert FLAGS.att_type != "hyper"
    if FLAGS.model == "hypernet" and FLAGS.rank != FLAGS.decoder_dim:
        print("WARNING: recommended rank value: decoder_dim.")
    if FLAGS.att_neighbor:
        assert FLAGS.neighbor_dim == FLAGS.encoder_dim or FLAGS.att_type == "my"

    if FLAGS.use_copy or FLAGS.att_neighbor:
        assert FLAGS.att_type == "my"
    # These numbers are the target vocabulary sizes of the datasets.
    # It allows for using different vocabularies for source and targets,
    # following the implementation in Open-NMT.
    # I will later put these into command line arguments.
    if FLAGS.use_bpe:
        if FLAGS.dataset == "nyt":
            output_size = 10013
        elif FLAGS.dataset == "giga":
            output_size = 24654
        elif FLAGS.dataset == "cnnd":
            output_size = 10232
    else:
        if FLAGS.dataset == "nyt":
            output_size = 68885
        elif FLAGS.dataset == "giga":
            output_size = 107389
        elif FLAGS.dataset == "cnnd":
            output_size = 21000

    vocab = data.Vocab(FLAGS.vocab_path, FLAGS.vocab_size, FLAGS.dataset)
    hps = tf.contrib.training.HParams(
        sample_neighbor=FLAGS.sample_neighbor,
        use_cluster=FLAGS.use_cluster,
        binary_neighbor=FLAGS.binary_neighbor,
        att_neighbor=FLAGS.att_neighbor,
        encode_neighbor=FLAGS.encode_neighbor,
        sum_neighbor=FLAGS.sum_neighbor,
        dataset=FLAGS.dataset,
        rnn_cell=FLAGS.rnn_cell,
        output_size=output_size + vocab.offset,
        train_path=FLAGS.train_path,
        dev_path=FLAGS.dev_path,
        tie_embedding=FLAGS.tie_embedding,
        use_bpe=FLAGS.use_bpe,
        use_copy=FLAGS.use_copy,
        reuse_attention=FLAGS.reuse_attention,
        use_bridge=FLAGS.use_bridge,
        use_residual=FLAGS.use_residual,
        att_type=FLAGS.att_type,
        random_neighbor=FLAGS.random_neighbor,
        num_neighbors=FLAGS.num_neighbors,
        model=FLAGS.model,
        trainer=FLAGS.trainer,
        learning_rate=FLAGS.learning_rate,
        lr_schedule=FLAGS.lr_schedule,
        total_steps=FLAGS.total_steps,
        emb_dim=FLAGS.emb_dim,
        binary_dim=FLAGS.binary_dim,
        neighbor_dim=FLAGS.neighbor_dim,
        drop=FLAGS.drop,
        emb_drop=FLAGS.emb_drop,
        out_drop=FLAGS.out_drop,
        encoder_drop=FLAGS.encoder_drop,
        decoder_drop=FLAGS.decoder_drop,
        weight_decay=FLAGS.weight_decay,
        encoder_dim=FLAGS.encoder_dim,
        num_encoder_layers=FLAGS.num_encoder_layers,
        decoder_dim=FLAGS.decoder_dim,
        num_decoder_layers=FLAGS.num_decoder_layers,
        num_mlp_layers=FLAGS.num_mlp_layers,
        rank=FLAGS.rank,
        sigma_norm=FLAGS.sigma_norm,
        batch_size=FLAGS.batch_size,
        sampling_probability=FLAGS.sampling_probability,
        beam_width=FLAGS.beam_width,
        max_enc_steps=FLAGS.max_enc_steps,
        max_dec_steps=FLAGS.max_dec_steps,
        vocab_size=FLAGS.vocab_size,
        max_grad_norm=FLAGS.max_grad_norm,
        length_norm=FLAGS.length_norm,
        cp=FLAGS.coverage_penalty,
        predict_mode=FLAGS.predict_mode)

    run_config = tf_estimator.RunConfig(model_dir=FLAGS.model_dir)

    vocab = data.Vocab(FLAGS.vocab_path, FLAGS.vocab_size, FLAGS.dataset)
    eval_input_fn = partial(data.input_function,
                            is_train=False,
                            vocab=vocab,
                            hps=hps)

    estimator = tf_estimator.Estimator(model_fn=partial(
        model_function.model_function, vocab=vocab, hps=hps),
                                       config=run_config,
                                       model_dir=run_config.model_dir)
    results = estimator.predict(input_fn=eval_input_fn)

    with tf.gfile.Open("%s/prediction" % FLAGS.model_dir, "w") as fout:
        for result in results:
            outputs, _ = result["outputs"], result["lengths"]
            prediction = data.id2text(outputs, vocab, use_bpe=FLAGS.use_bpe)
            fout.write(prediction + "\n")
示例#25
0
def main(argv):
    del argv  # Unused.

    if FLAGS.squared_value:
        is_squared = 'squared'
    else:
        is_squared = 'not_squared'
    if FLAGS.keep_information:
        info_keep = 'keep'
    else:
        info_keep = 'remove'

    if FLAGS.dataset_name == 'food_101':
        params = food_101_params
    elif FLAGS.dataset_name == 'imagenet':
        params = imagenet_params
    elif FLAGS.dataset_name == 'birdsnap':
        params = birdsnap_params
    else:
        raise ValueError('Dataset type is not known %s' % (FLAGS.dataset))

    if FLAGS.test_small_sample:
        model_dir = '/tmp/lalala/'
    else:
        model_dir = os.path.join(FLAGS.output_dir, FLAGS.dataset_name,
                                 FLAGS.transformation, str(FLAGS.threshold),
                                 str(params['base_learning_rate']),
                                 str(params['weight_decay']), is_squared,
                                 info_keep)

        if FLAGS.transformation in ['modified_image', 'raw_saliency_map']:
            model_dir = os.path.join(model_dir, FLAGS.saliency_method)

    if FLAGS.mode == 'eval':
        split = 'validation'
    else:
        split = 'training'

    mean_stats = [0.485, 0.456, 0.406]
    std_stats = [0.229, 0.224, 0.225]
    update_params = {
        'mean_rgb':
        mean_stats,
        'stddev_rgb':
        std_stats,
        'lr_schedule': [  # (multiplier, epoch to start) tuples
            (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
        ],
        'momentum':
        0.9,
        'data_format':
        'channels_last'
    }
    params.update(update_params)
    sal_method = saliency_dict[FLAGS.saliency_method]
    if FLAGS.test_small_sample:
        update_params = {
            'train_batch_size': 2,
            'eval_batch_size': 2,
            'num_train_steps': 10,
            'num_images': 2
        }
        params.update(update_params)

    data_directory = os.path.join(FLAGS.base_dir, FLAGS.dataset_name,
                                  '2018-12-10', 'resnet_50', sal_method,
                                  split + '*')

    dataset_ = data_input.DataIterator(
        mode=FLAGS.mode,
        data_directory=data_directory,
        saliency_method=FLAGS.saliency_method,
        transformation=FLAGS.transformation,
        threshold=FLAGS.threshold,
        keep_information=FLAGS.keep_information,
        use_squared_value=FLAGS.squared_value,
        mean_stats=mean_stats,
        std_stats=std_stats,
        test_small_sample=FLAGS.test_small_sample,
        num_cores=FLAGS.num_cores)

    params['output_dir'] = model_dir
    if FLAGS.mode == 'train':
        params['batch_size'] = params['train_batch_size']
    else:
        params['batch_size'] = params['eval_batch_size']

    num_train_steps = params['num_train_steps']
    eval_steps = params['num_eval_images'] // params['batch_size']

    run_config = tf_estimator.RunConfig(
        model_dir=model_dir, save_checkpoints_steps=FLAGS.steps_per_checkpoint)

    classifier = tf_estimator.Estimator(model_fn=resnet_model_fn,
                                        model_dir=model_dir,
                                        params=params,
                                        config=run_config)

    if FLAGS.mode == 'eval':
        # Run evaluation when there's a new checkpoint
        for ckpt in tf2.training.checkpoints_iterator(model_dir):
            tf.logging.info('Starting to evaluate.')
            try:
                classifier.evaluate(input_fn=dataset_.input_fn,
                                    steps=eval_steps,
                                    checkpoint_path=ckpt)
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= num_train_steps:
                    print('Evaluation finished after training step %d' %
                          current_step)
                    break

            except tf.errors.NotFoundError:
                tf.logging.info(
                    'Checkpoint was not found, skipping checkpoint.')

    else:
        if FLAGS.mode == 'train':
            tf.logging.info('start training...')
            classifier.train(input_fn=dataset_.input_fn,
                             max_steps=num_train_steps)
            tf.logging.info('finished training.')
示例#26
0
def create_estimator(model_name,
                     hparams,
                     run_config,
                     schedule="train_and_evaluate",
                     decode_hparams=None,
                     use_tpu=False,
                     use_tpu_estimator=False,
                     use_xla=False,
                     export_saved_model_api_version=1,
                     use_guarantee_const_getter=False):
    """Create a T2T Estimator."""
    model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        model_name, hparams, decode_hparams=decode_hparams, use_tpu=use_tpu)

    del use_xla
    if use_tpu or use_tpu_estimator:
        from tensorflow.contrib.tpu.python.tpu import tpu_estimator  # pylint: disable=g-import-not-at-top
        problem = hparams.problem
        batch_size = (problem.tpu_batch_size_per_shard(hparams) *
                      run_config.tpu_config.num_shards)
        mlperf_log.transformer_print(key=mlperf_log.INPUT_BATCH_SIZE,
                                     value=batch_size)
        if getattr(hparams, "mtf_mode", False):
            batch_size = problem.tpu_batch_size_per_shard(hparams)
        predict_batch_size = batch_size
        if decode_hparams and decode_hparams.batch_size:
            predict_batch_size = decode_hparams.batch_size
        if decode_hparams and run_config.tpu_config:
            decode_hparams.add_hparam(
                "iterations_per_loop",
                run_config.tpu_config.iterations_per_loop)
        if export_saved_model_api_version == 1:
            api_version_enum_name = tpu_estimator.ExportSavedModelApiVersion.V1
            estimator_model_fn = model_fn
        elif export_saved_model_api_version == 2:
            api_version_enum_name = tpu_estimator.ExportSavedModelApiVersion.V2

            def maybe_use_guarantee_const_getter_model_fn(
                    features, labels, mode, params):
                """Wrapper model_fn with guarantee_const getter."""
                if not use_guarantee_const_getter:
                    return model_fn(features, labels, mode, params)

                # It marks all weights as constant, which may improves TPU inference
                # performance because it prevents the weights being transferred to the
                # TPU. It will increase HBM "program" usage and reduce HBM "arguments"
                # usage during TPU model serving.
                def guarantee_const_getter(getter, name, *args, **kwargs):
                    with tf.control_dependencies(None):
                        return tf.guarantee_const(
                            getter(name, *args, **kwargs),
                            name=name + "/GuaranteeConst")

                @contextlib.contextmanager
                def guarantee_const_scope():
                    var_scope = tf.get_variable_scope()
                    prev_custom_getter = var_scope.custom_getter
                    prev_caching_device = var_scope.caching_device
                    var_scope.set_custom_getter(guarantee_const_getter)
                    var_scope.set_caching_device(lambda op: op.device)
                    yield
                    var_scope.set_custom_getter(prev_custom_getter)
                    var_scope.set_caching_device(prev_caching_device)

                with guarantee_const_scope():
                    return model_fn(features, labels, mode, params)

            def tpu_model_fn(features, labels, mode, params):
                """Wrapper model_fn with tpu.rewrite / TPUPartitionedCall."""
                if mode == tf_estimator.ModeKeys.PREDICT and params["use_tpu"]:
                    batch_config = tpu_estimator.BatchConfig(
                        num_batch_threads=2,
                        max_batch_size=predict_batch_size,
                        batch_timeout_micros=60 * 1000,
                        allowed_batch_sizes=[predict_batch_size])
                    return tpu_estimator.model_fn_inference_on_tpu(
                        maybe_use_guarantee_const_getter_model_fn,
                        features=features,
                        labels=labels,
                        config=None,
                        params=params,
                        batch_config=batch_config)
                else:
                    return model_fn(features, labels, mode, params)

            estimator_model_fn = tpu_model_fn
        else:
            raise ValueError(
                "Flag export_saved_model_api_version must be 1 or 2.")
        estimator = contrib.tpu().TPUEstimator(
            model_fn=estimator_model_fn,
            model_dir=run_config.model_dir,
            config=run_config,
            use_tpu=use_tpu,
            train_batch_size=batch_size,
            eval_batch_size=batch_size if "eval" in schedule else None,
            predict_batch_size=predict_batch_size,
            export_saved_model_api_version=api_version_enum_name)
    else:
        estimator = tf_estimator.Estimator(
            model_fn=model_fn,
            model_dir=run_config.model_dir,
            config=run_config,
        )
    return estimator
示例#27
0
def run_model():
    """Run experiment with tf.estimator.

  """
    params = {
        'kb_index': FLAGS.kb_index,
        'cm_width': FLAGS.cm_width,
        'cm_depth': FLAGS.cm_depth,
        'entity_emb_size': FLAGS.entity_emb_size,
        'relation_emb_size': FLAGS.relation_emb_size,
        'vocab_emb_size': FLAGS.vocab_emb_size,
        'max_set': FLAGS.max_set,
        'learning_rate': FLAGS.learning_rate,
        'gradient_clip': FLAGS.gradient_clip,
        'intermediate_top_k': FLAGS.intermediate_top_k,
        'use_cm_sketch': FLAGS.use_cm_sketch,
        'train_entity_emb': FLAGS.train_entity_emb,
        'train_relation_emb': FLAGS.train_relation_emb,
        'bert_handle': FLAGS.bert_handle,
        'train_bert': FLAGS.train_bert,
    }

    data_loader = DataLoader(params, FLAGS.name, get_root_dir(FLAGS.name),
                             FLAGS.kb_file, FLAGS.vocab_file)

    estimator_config = tf_estimator.RunConfig(
        save_checkpoints_steps=FLAGS.checkpoint_step)

    warm_start_settings = tf_estimator.WarmStartSettings(  # pylint: disable=g-long-ternary
        ckpt_to_initialize_from=FLAGS.load_model_dir,
        vars_to_warm_start=[
            'embeddings_mat/entity_embeddings_mat',
            'embeddings_mat/relation_embeddings_mat'
        ],
    ) if FLAGS.load_model_dir is not None else None

    estimator = tf_estimator.Estimator(
        model_fn=build_model_fn(FLAGS.name, data_loader, FLAGS.eval_name,
                                FLAGS.eval_metric_at_k),
        model_dir=FLAGS.checkpoint_dir + FLAGS.model_name,
        config=estimator_config,
        params=params,
        warm_start_from=warm_start_settings)

    if FLAGS.mode == 'train':
        train_input_fn = data_loader.build_input_fn(
            name=FLAGS.name,
            batch_size=FLAGS.batch_size,
            mode='train',
            epochs=FLAGS.epochs,
            n_take=-1,
            shuffle=True)

    eval_input_fn = data_loader.build_input_fn(name=FLAGS.name,
                                               batch_size=FLAGS.batch_size,
                                               mode='eval',
                                               epochs=1,
                                               n_take=FLAGS.num_eval,
                                               shuffle=False)

    # Define mode-specific operations
    if FLAGS.mode == 'train':
        train_spec = tf_estimator.TrainSpec(input_fn=train_input_fn)
        # Busy waiting for evaluation until new checkpoint comes out
        test_spec = tf_estimator.EvalSpec(input_fn=eval_input_fn,
                                          steps=FLAGS.num_online_eval,
                                          start_delay_secs=0,
                                          throttle_secs=FLAGS.eval_time)
        tf_estimator.train_and_evaluate(estimator, train_spec, test_spec)

    elif FLAGS.mode == 'eval':
        tf_evaluation = estimator.evaluate(eval_input_fn)
        print(tf_evaluation)

    elif FLAGS.mode == 'pred':
        tf_predictions = estimator.predict(eval_input_fn)

        if FLAGS.name.startswith('query2box'):
            task = FLAGS.name.split('_')[-1]
            metrics = Query2BoxMetrics(task, FLAGS.root_dir, data_loader)
        else:
            raise NotImplementedError()

        for tf_prediction in tqdm(tf_predictions):
            metrics.eval(tf_prediction)
        metrics.print_metrics()

    else:
        raise ValueError('mode not recognized: %s' % FLAGS.mode)
def run_experiment(model_fn,
                   train_input_fn,
                   eval_input_fn,
                   exporters=None,
                   params=None,
                   params_fname=None):
  """Run an experiment using estimators.

  This is a light wrapper around typical estimator usage to avoid boilerplate
  code. Please use the following components separately for more complex
  usages.

  Args:
    model_fn: A model function to be passed to the estimator. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#args_1
    train_input_fn: An input function to be passed to the estimator that
      corresponds to the training data. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#train
    eval_input_fn: An input function to be passed to the estimator that
      corresponds to the held-out eval data. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#evaluate
    exporters: (Optional) An tf.estimator.Exporter or a list of them.
    params: (Optional) A dictionary of parameters that will be accessible by the
      model_fn and input_fns. The 'batch_size' and 'use_tpu' values will be set
      automatically.
    params_fname: (Optional) If specified, `params` will be written to here
      under `FLAGS.model_dir` in JSON format.
  """
  params = params if params is not None else {}
  params.setdefault("use_tpu", FLAGS.use_tpu)

  if FLAGS.model_dir and params_fname:
    tf.io.gfile.makedirs(FLAGS.model_dir)
    params_path = os.path.join(FLAGS.model_dir, params_fname)
    with tf.io.gfile.GFile(params_path, "w") as params_file:
      json.dump(params, params_file, indent=2, sort_keys=True)

  if params["use_tpu"]:
    if FLAGS.tpu_name:
      tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
          FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
      tpu_cluster_resolver = None
    run_config = tf_estimator.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf_estimator.tpu.TPUConfig(
            iterations_per_loop=FLAGS.save_checkpoints_steps))
    if "batch_size" in params:
      # Let the TPUEstimator fill in the batch size.
      params.pop("batch_size")
    estimator = tf_estimator.tpu.TPUEstimator(
        use_tpu=True,
        model_fn=model_fn,
        params=params,
        config=run_config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.eval_batch_size)
  else:
    run_config = tf_estimator.RunConfig(
        model_dir=FLAGS.model_dir,
        tf_random_seed=FLAGS.tf_random_seed,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max)
    params["batch_size"] = FLAGS.batch_size
    estimator = tf_estimator.Estimator(
        config=run_config,
        model_fn=model_fn,
        params=params,
        model_dir=FLAGS.model_dir)

  train_spec = tf_estimator.TrainSpec(
      input_fn=train_input_fn,
      max_steps=FLAGS.num_train_steps)
  eval_spec = tf_estimator.EvalSpec(
      name="default",
      input_fn=eval_input_fn,
      exporters=exporters,
      start_delay_secs=FLAGS.eval_start_delay_secs,
      throttle_secs=FLAGS.eval_throttle_secs,
      steps=FLAGS.num_eval_steps)

  tf.logging.set_verbosity(tf.logging.INFO)
  tf_estimator.train_and_evaluate(
      estimator=estimator,
      train_spec=train_spec,
      eval_spec=eval_spec)
def main(_):

    # Modify the paths to save results when tuning hyperparameters.
    if FLAGS.node_encoder == "lstm":
        FLAGS.result_path = os.path.join(FLAGS.result_path,
                                         str(FLAGS.node_lstm_size))
    if FLAGS.node_encoder == "transformer":
        FLAGS.result_path = os.path.join(
            FLAGS.result_path, "max_steps_" + str(FLAGS.max_steps_no_increase))
        FLAGS.result_path = os.path.join(
            FLAGS.result_path,
            "hidden_unit_" + str(FLAGS.transformer_hidden_unit))
    if FLAGS.cross_vertical:
        FLAGS.result_path = os.path.join(
            FLAGS.result_path,
            "CKP-{0}/{1}/".format(FLAGS.checkpoint_vertical,
                                  FLAGS.checkpoint_websites))
        FLAGS.checkpoint_path = os.path.join(
            FLAGS.checkpoint_path,
            "{0}/{1}-results/".format(FLAGS.checkpoint_vertical,
                                      FLAGS.checkpoint_websites))

    tf.gfile.MakeDirs(
        os.path.join(
            FLAGS.result_path,
            "{0}/{1}-results/".format(FLAGS.vertical, FLAGS.source_website)))
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.use_uniform_embedding:
        vocab_vertical = "all"
    else:
        vocab_vertical = FLAGS.vertical

    # Hyper-parameters.
    params = {
        "add_goldmine":
        FLAGS.add_goldmine,
        "add_leaf_types":
        FLAGS.add_leaf_types,
        "batch_size":
        FLAGS.batch_size,
        "buffer":
        1000,  # Buffer for shuffling. No need to care about.
        "chars":
        os.path.join(FLAGS.domtree_data_path,
                     "%s.vocab.chars.txt" % vocab_vertical),
        "circle_features":
        FLAGS.circle_features,
        "dim_word_embedding":
        FLAGS.dim_word_embedding,
        "dim_chars":
        FLAGS.dim_chars,
        "dim_label_embedding":
        FLAGS.dim_label_embedding,
        "dim_goldmine":
        30,
        "dim_leaf_type":
        20,
        "dim_positions":
        30,
        "dim_xpath_units":
        FLAGS.dim_xpath_units,
        "dropout":
        0.3,
        "epochs":
        FLAGS.epochs,
        "extract_node_emb":
        FLAGS.extract_node_emb,
        "filters":
        50,  # The dimension of char-level word representations.
        "friend_encoder":
        FLAGS.friend_encoder,
        "use_friend_semantic":
        FLAGS.use_friend_semantic,
        "goldmine_features":
        os.path.join(FLAGS.domtree_data_path, "vocab.goldmine_features.txt"),
        "glove":
        os.path.join(
            FLAGS.domtree_data_path,
            "%s.%d.emb.npz" % (vocab_vertical, FLAGS.dim_word_embedding)),
        "friend_hidden_size":
        FLAGS.friend_hidden_size,
        "kernel_size":
        3,  # CNN window size to embed char sequences.
        "last_hidden_layer_size":
        FLAGS.last_hidden_layer_size,
        "leaf_types":
        os.path.join(FLAGS.domtree_data_path,
                     "%s.vocab.leaf_types.txt" % vocab_vertical),
        "lstm_size":
        100,
        "max_steps_no_increase":
        FLAGS.max_steps_no_increase,
        "node_encoder":
        FLAGS.node_encoder,
        "node_filters":
        100,
        "node_kernel_size":
        5,
        "node_lstm_size":
        FLAGS.node_lstm_size,
        "num_oov_buckets":
        1,
        "objective":
        FLAGS.objective,
        "positions":
        os.path.join(FLAGS.domtree_data_path, "vocab.positions.txt"),
        "running_mode":
        FLAGS.run,
        "semantic_encoder":
        FLAGS.semantic_encoder,
        "source_website":
        FLAGS.source_website,
        "tags":
        os.path.join(FLAGS.domtree_data_path,
                     "%s.vocab.tags.txt" % (FLAGS.vertical)),
        "tags-all":
        os.path.join(FLAGS.domtree_data_path, "all.vocab.tags.txt"),
        "target_website":
        FLAGS.target_website,
        "transformer_hidden_unit":
        FLAGS.transformer_hidden_unit,
        "transformer_head":
        FLAGS.transformer_head,
        "transformer_hidden_layer":
        FLAGS.transformer_hidden_layer,
        "use_crf":
        FLAGS.use_crf,
        "use_friends_cnn":
        FLAGS.use_friends_cnn,
        "use_friends_discrete_feature":
        FLAGS.use_friends_discrete_feature,
        "use_prev_text_lstm":
        FLAGS.use_prev_text_lstm,
        "use_xpath_lstm":
        FLAGS.use_xpath_lstm,
        "use_uniform_label":
        FLAGS.use_uniform_label,
        "use_position_embedding":
        FLAGS.use_position_embedding,
        "words":
        os.path.join(FLAGS.domtree_data_path,
                     "%s.vocab.words.txt" % vocab_vertical),
        "xpath_lstm_size":
        100,
        "xpath_units":
        os.path.join(FLAGS.domtree_data_path,
                     "%s.vocab.xpath_units.txt" % vocab_vertical),
    }
    with tf.gfile.Open(
            os.path.join(
                FLAGS.result_path,
                "{0}/{1}-results/params.json".format(FLAGS.vertical,
                                                     FLAGS.source_website)),
            "w") as f:
        json.dump(params, f, indent=4, sort_keys=True)
    # Build estimator, train and evaluate.
    train_input_function = functools.partial(
        model_util.joint_input_fn,
        get_data_path(vertical=FLAGS.vertical,
                      website=FLAGS.source_website,
                      dev=False,
                      goldmine=False),
        get_data_path(vertical=FLAGS.vertical,
                      website=FLAGS.source_website,
                      dev=False,
                      goldmine=True),
        FLAGS.vertical,
        params,
        shuffle_and_repeat=True,
        mode="train")

    cfg = tf_estimator.RunConfig(save_checkpoints_steps=300,
                                 save_summary_steps=300,
                                 tf_random_seed=42)
    # Set up the checkpoint to load.
    if FLAGS.checkpoint_path:
        # The best model was always saved in "cpkt-601".
        checkpoint_file = FLAGS.checkpoint_path + "/model/model.ckpt-601"
        # Do not load parameters whose names contain the "label_dense".
        # These parameters are ought to be learned from scratch.
        ws = tf_estimator.WarmStartSettings(
            ckpt_to_initialize_from=checkpoint_file,
            vars_to_warm_start="^((?!label_dense).)*$")
        estimator = tf_estimator.Estimator(models.joint_extraction_model_fn,
                                           os.path.join(
                                               FLAGS.result_path,
                                               "{0}/{1}-results/model".format(
                                                   FLAGS.vertical,
                                                   FLAGS.source_website)),
                                           cfg,
                                           params,
                                           warm_start_from=ws)
    else:
        estimator = tf_estimator.Estimator(
            models.joint_extraction_model_fn,
            os.path.join(
                FLAGS.result_path,
                "{0}/{1}-results/model".format(FLAGS.vertical,
                                               FLAGS.source_website)), cfg,
            params)

    tf.gfile.MakeDirs(estimator.eval_dir())

    hook = early_stopping.stop_if_no_increase_hook(
        estimator,
        metric_name="f1",
        max_steps_without_increase=FLAGS.max_steps_no_increase,
        min_steps=300,
        run_every_steps=100,
        run_every_secs=None)
    train_spec = tf_estimator.TrainSpec(input_fn=train_input_function,
                                        hooks=[hook])

    if FLAGS.run == "train":
        eval_input_function = functools.partial(
            model_util.joint_input_fn,
            get_data_path(vertical=FLAGS.vertical,
                          website=FLAGS.source_website,
                          dev=True,
                          goldmine=False),
            get_data_path(vertical=FLAGS.vertical,
                          website=FLAGS.source_website,
                          dev=True,
                          goldmine=True),
            FLAGS.vertical,
            mode="all")
        eval_spec = tf_estimator.EvalSpec(input_fn=eval_input_function,
                                          steps=300,
                                          throttle_secs=1)
        tf_estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    target_websites = FLAGS.target_website.split("_")
    if FLAGS.source_website not in target_websites:
        target_websites = [FLAGS.source_website] + target_websites
    for target_website in target_websites:
        write_predictions(estimator=estimator,
                          vertical=FLAGS.vertical,
                          source_website=FLAGS.source_website,
                          target_website=target_website)
        model_util.page_hits_level_metric(result_path=FLAGS.result_path,
                                          vertical=FLAGS.vertical,
                                          source_website=FLAGS.source_website,
                                          target_website=target_website)
        model_util.site_level_voting(result_path=FLAGS.result_path,
                                     vertical=FLAGS.vertical,
                                     source_website=FLAGS.source_website,
                                     target_website=target_website)
        model_util.page_level_constraint(
            domtree_data_path=FLAGS.domtree_data_path,
            result_path=FLAGS.result_path,
            vertical=FLAGS.vertical,
            source_website=FLAGS.source_website,
            target_website=target_website)
示例#30
0
def model_to_estimator(model,
                       model_dir=None,
                       config=None,
                       custom_objects=None,
                       weights_feature_name=None,
                       warm_start_from=None,
                       serving_default="regress"):
    """Keras ranking model to Estimator.

  This function is based on the custom model_fn in TF2.0 migration guide.
  https://www.tensorflow.org/guide/migrate#custom_model_fn_with_tf_20_symbols

  Args:
    model: (tf.keras.Model) A ranking keras model, which  can be created using
      `tfr.keras.model.create_keras_model`. Masking is handled inside this
      function.
    model_dir: (str) Directory to save `Estimator` model graph and checkpoints.
    config: (tf.estimator.RunConfig) Specified config for distributed training
      and checkpointing.
    custom_objects: (dict) mapping names (strings) to custom objects (classes
      and functions) to be considered during deserialization.
    weights_feature_name: (str) A string specifying the name of the per-example
      (of shape [batch_size, list_size]) or per-list (of shape [batch_size, 1])
      weights feature in `features` dict.
    warm_start_from: (`tf.estimator.WarmStartSettings`) settings to warm-start
      the `tf.estimator.Estimator`.
    serving_default: (str) Specifies "regress" or "predict" as the
      serving_default signature.

  Returns:
    (tf.estimator.Estimator) A ranking estimator.

  Raises:
    ValueError: if weights_feature_name is not in features.
  """
    def _clone_fn(obj):
        """Clone keras object."""
        fn_args = function_utils.fn_args(obj.__class__.from_config)

        if "custom_objects" in fn_args:
            return obj.__class__.from_config(obj.get_config(),
                                             custom_objects=custom_objects)

        return obj.__class__.from_config(obj.get_config())

    def _model_fn(features, labels, mode, params, config):
        """Defines an `Estimator` `model_fn`."""
        del [config, params]

        # In Estimator, all sub-graphs need to be constructed inside the model_fn.
        # Hence, ranker, losses, metrics and optimizer are cloned inside this
        # function.
        ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn)
        training = (mode == tf_estimator.ModeKeys.TRAIN)

        weights = None
        if weights_feature_name and mode != tf_estimator.ModeKeys.PREDICT:
            if weights_feature_name not in features:
                raise ValueError(
                    "weights_feature '{0}' can not be found in 'features'.".
                    format(weights_feature_name))
            else:
                weights = utils.reshape_to_2d(
                    features.pop(weights_feature_name))

        logits = ranker(features, training=training)

        if serving_default not in ["regress", "predict"]:
            raise ValueError(
                "serving_default should be 'regress' or 'predict', "
                "but got {}".format(serving_default))

        if serving_default == "regress":
            default_export_output = tf_estimator.export.RegressionOutput(
                logits)
        else:
            default_export_output = tf_estimator.export.PredictOutput(logits)
        export_outputs = {
            _DEFAULT_SERVING_KEY: default_export_output,
            _REGRESS_SERVING_KEY: tf_estimator.export.RegressionOutput(logits),
            _PREDICT_SERVING_KEY: tf_estimator.export.PredictOutput(logits)
        }

        if mode == tf_estimator.ModeKeys.PREDICT:
            return tf_estimator.EstimatorSpec(mode=mode,
                                              predictions=logits,
                                              export_outputs=export_outputs)

        loss = _clone_fn(model.loss)
        total_loss = loss(labels, logits, sample_weight=weights)

        keras_metrics = []
        for metric in model.metrics:
            keras_metrics.append(_clone_fn(metric))
        # Adding default metrics here as model.metrics does not contain custom
        # metrics.
        keras_metrics += metrics.default_keras_metrics()
        eval_metric_ops = {}
        for keras_metric in keras_metrics:
            keras_metric.update_state(labels, logits, sample_weight=weights)
            eval_metric_ops[keras_metric.name] = keras_metric

        train_op = None
        if training:
            optimizer = _clone_fn(model.optimizer)
            optimizer.iterations = tf.compat.v1.train.get_or_create_global_step(
            )
            # Get both the unconditional updates (the None part)
            # and the input-conditional updates (the features part).
            # These updates are for layers like BatchNormalization, which have
            # separate update and minimize ops.
            update_ops = ranker.get_updates_for(None) + ranker.get_updates_for(
                features)
            minimize_op = optimizer.get_updates(
                loss=total_loss, params=ranker.trainable_variables)[0]
            train_op = tf.group(minimize_op, *update_ops)

        return tf_estimator.EstimatorSpec(mode=mode,
                                          predictions=logits,
                                          loss=total_loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops,
                                          export_outputs=export_outputs)

    return tf_estimator.Estimator(model_fn=_model_fn,
                                  config=config,
                                  model_dir=model_dir,
                                  warm_start_from=warm_start_from)