Exemplo n.º 1
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        logging.info("*** Model: Params ***")
        for name in sorted(params.keys()):
            logging.info("  %s = %s", name, params[name])
        logging.info("*** Model: Features ***")
        for name in sorted(features.keys()):
            logging.info("  name = %s, shape = %s", name, features[name].shape)

        model = modeling.ReadItTwiceBertModel(
            config=model_config, use_one_hot_embeddings=use_one_hot_embeddings)

        span_prediction_layer = modeling.SpanPredictionHead(
            intermediate_size=model_config.intermediate_size,
            dropout_rate=model_config.hidden_dropout_prob)

        # [batch_size, main_seq_length]
        token_ids = features["token_ids"]
        main_seq_length = tf.shape(token_ids)[1]
        block_ids = features["block_ids"]
        block_pos = features["block_pos"]

        annotation_begins = features.get("entity_annotation_begins")
        annotation_ends = features.get("entity_annotation_ends")
        annotation_labels = features.get("entity_annotation_labels")

        # Do not attend padding tokens
        # [batch_size, main_seq_length, main_seq_length]
        att_mask = tf.tile(
            tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1),
            [1, main_seq_length, 1])
        att_mask = tf.cast(att_mask, dtype=tf.int32)

        main_output = model(
            token_ids=token_ids,
            training=(mode == tf.estimator.ModeKeys.TRAIN),
            block_ids=block_ids,
            block_pos=block_pos,
            att_mask=att_mask,
            annotation_begins=annotation_begins,
            annotation_ends=annotation_ends,
            annotation_labels=annotation_labels,
            enable_side_inputs=enable_side_inputs,
            num_replicas_concat=num_replicas_concat,
            cross_block_attention_mode=cross_block_attention_mode)

        span_logits = span_prediction_layer(
            hidden_states=main_output.final_hidden_states,
            token_ids=token_ids,
            padding_token_id=padding_token_id,
            ignore_prefix_length=features["prefix_length"],
            training=(mode == tf.estimator.ModeKeys.TRAIN))

        is_summary_loss_enabled = (mode == tf.estimator.ModeKeys.TRAIN
                                   and summary_loss_weight is not None
                                   and summary_loss_weight > 0)
        if is_summary_loss_enabled:
            logging.info("Using summary prediction loss with weight %.3f",
                         summary_loss_weight)
            summary_token_ids = features["summary_token_ids"]
            summary_labels = tf.roll(summary_token_ids, shift=-1, axis=1)
            decoder = modeling.ReadItTwiceDecoderModel(
                config=model_config,
                num_layers_override=summary_num_layers,
                num_cross_attention_heads=summary_num_cross_attention_heads,
                enable_default_side_input=summary_enable_default_side_input,
                use_one_hot_embeddings=use_one_hot_embeddings)
            summary_token_logits = decoder(
                token_ids=summary_token_ids,
                side_input=main_output.global_summary.states,
                token2side_input_att_mask=modeling.get_cross_block_att(
                    block_ids,
                    block_pos,
                    main_output.global_summary.block_ids,
                    main_output.global_summary.block_pos,
                    cross_block_attention_mode="doc"),
                training=True)
            language_model_loss_fn = losses.LanguageModelLoss(
                decoder.get_token_embedding_table(),
                hidden_size=model_config.hidden_size)
            language_model_loss = language_model_loss_fn(
                summary_token_logits,
                summary_labels,
                padding_token_id=padding_token_id).loss
        else:
            language_model_loss = None

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = checkpoint_utils.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                         init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            host_inputs = dict()

            span_prediction_loss = losses.BatchSpanCrossEntropyLoss()

            qa_loss = span_prediction_loss(
                logits=span_logits,
                annotation_begins=features["answer_annotation_begins"],
                annotation_ends=features["answer_annotation_ends"],
                annotation_labels=features["answer_annotation_labels"],
                block_ids=block_ids,
                num_replicas=num_replicas_concat,
                eps=1e-5)
            host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0)

            if language_model_loss is not None:
                total_loss = (
                    1.0 / (1.0 + summary_loss_weight) * qa_loss +
                    summary_loss_weight /
                    (1.0 + summary_loss_weight) * language_model_loss)
                host_inputs["train_metrics/summary_lm_loss"] = tf.expand_dims(
                    language_model_loss, 0)
            else:
                total_loss = qa_loss

            # Add regularization losses.
            if model.losses:
                total_loss += tf.math.add_n(model.losses)

            train_op = optimization.create_optimizer(total_loss,
                                                     learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps,
                                                     use_tpu,
                                                     optimizer,
                                                     poly_power,
                                                     start_warmup_step,
                                                     learning_rate_schedule,
                                                     reduce_loss_sum=True)

            host_inputs.update({
                "global_step":
                tf.expand_dims(tf.train.get_or_create_global_step(), 0),
                "train_metrics/loss":
                tf.expand_dims(total_loss, 0),
            })

            host_call = (functools.partial(record_summary_host_fn,
                                           metrics_dir=os.path.join(
                                               FLAGS.output_dir,
                                               "train_metrics")), host_inputs)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                host_call=host_call)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            begin_logits_values, begin_logits_indices = tf.math.top_k(
                span_logits[:, :, 0],
                k=nbest_logits_for_eval,
            )
            end_logits_values, end_logits_indices = tf.math.top_k(
                span_logits[:, :, 1],
                k=nbest_logits_for_eval,
            )

            predictions = {
                "block_ids": tf.identity(block_ids),
                "begin_logits_values": begin_logits_values,
                "begin_logits_indices": begin_logits_indices,
                "end_logits_values": end_logits_values,
                "end_logits_indices": end_logits_indices,
                "token_ids": tf.identity(token_ids),
            }
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and PREDICT modes is supported: %s" %
                             (mode))

        return output_spec
Exemplo n.º 2
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    logging.info("*** Model: Params ***")
    for name in sorted(params.keys()):
      logging.info("  %s = %s", name, params[name])
    logging.info("*** Model: Features ***")
    for name in sorted(features.keys()):
      logging.info("  name = %s, shape = %s", name, features[name].shape)

    model = modeling.ReadItTwiceBertModel(
        config=model_config, use_one_hot_embeddings=use_one_hot_embeddings)

    span_prediction_layer = modeling.SpanPredictionHead(
        intermediate_size=model_config.intermediate_size,
        dropout_rate=model_config.hidden_dropout_prob)

    # [batch_size, main_seq_length]
    token_ids = features["token_ids"]
    main_seq_length = tf.shape(token_ids)[1]
    block_ids = features["block_ids"]
    block_pos = features["block_pos"]
    answer_type = features["answer_type"]
    supporting_fact = features["is_supporting_fact"]

    annotation_begins = features.get("entity_annotation_begins")
    annotation_ends = features.get("entity_annotation_ends")
    annotation_labels = features.get("entity_annotation_labels")

    # Do not attend padding tokens
    # [batch_size, main_seq_length, main_seq_length]
    att_mask = tf.tile(
        tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1),
        [1, main_seq_length, 1])
    att_mask = tf.cast(att_mask, dtype=tf.int32)

    main_output = model(
        token_ids=token_ids,
        training=(mode == tf.estimator.ModeKeys.TRAIN),
        block_ids=block_ids,
        block_pos=block_pos,
        att_mask=att_mask,
        annotation_begins=annotation_begins,
        annotation_ends=annotation_ends,
        annotation_labels=annotation_labels,
        enable_side_inputs=enable_side_inputs,
        num_replicas_concat=num_replicas_concat,
        cross_block_attention_mode=cross_block_attention_mode)

    span_logits = span_prediction_layer(
        hidden_states=main_output.final_hidden_states,
        token_ids=token_ids,
        padding_token_id=padding_token_id,
        ignore_prefix_length=features["prefix_length"],
        training=(mode == tf.estimator.ModeKeys.TRAIN))

    # The "pooler" converts the encoded sequence tensor of shape
    # [batch_size, seq_length, hidden_size] to a tensor of shape
    # [batch_size, hidden_size]. This is necessary for segment-level
    # (or segment-pair-level) classification tasks where we need a fixed
    # dimensional representation of the segment.
    with tf.variable_scope("pooler"):
      # We "pool" the model by simply taking the hidden state corresponding
      # to the first token. We assume that this has been pre-trained
      first_token_tensor = tf.squeeze(
          main_output.final_hidden_states[:, 0:1, :], axis=1)
      pooled_output = tf.layers.dense(
          first_token_tensor,
          model_config.hidden_size,
          activation=tf.tanh,
          kernel_initializer=tf.truncated_normal_initializer(
              stddev=model_config.initializer_range))

    yesno_logits = yesno_model(pooled_output)
    supporting_fact_logits = supporting_fact_model(pooled_output)

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = checkpoint_utils.get_assignment_map_from_checkpoint(
          tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                   init_string)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      host_inputs = dict()

      span_prediction_loss = losses.BatchSpanCrossEntropyLoss()

      total_loss = 0
      qa_loss = span_prediction_loss(
          logits=span_logits,
          annotation_begins=features["answer_annotation_begins"],
          annotation_ends=features["answer_annotation_ends"],
          annotation_labels=features["answer_annotation_labels"],
          block_ids=block_ids,
          num_replicas=num_replicas_concat,
          eps=1e-5)
      host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0)
      total_loss += qa_loss

      # example_mask = tf.cast(tf.not_equal(block_ids, 0), tf.float32)
      # yesno_loss = compute_pooled_loss(yesno_logits, answer_type, 3,
      #                                  example_mask)
      # supporting_fact_loss = compute_supporting_facts_loss(
      #     supporting_fact_logits, supporting_fact, example_mask)
      hotpot_qa_loss = hotpot_qa_losses.BatchSpanCrossEntropyLoss()
      yesno_loss, supporting_fact_loss = hotpot_qa_loss(
          yesno_logits,
          answer_type,
          supporting_fact_logits,
          supporting_fact,
          block_ids,
          eps=1e-5)

      host_inputs["train_metrics/yesno_loss"] = tf.expand_dims(yesno_loss, 0)
      total_loss += yesno_loss

      host_inputs["train_metrics/supporting_fact_loss"] = tf.expand_dims(
          supporting_fact_loss, 0)
      total_loss += supporting_fact_loss

      # Add regularization losses.
      if model.losses:
        total_loss += tf.math.add_n(model.losses)

      train_op = optimization.create_optimizer(
          total_loss,
          learning_rate,
          num_train_steps,
          num_warmup_steps,
          use_tpu,
          optimizer,
          poly_power,
          start_warmup_step,
          learning_rate_schedule,
          reduce_loss_sum=True)

      host_inputs.update({
          "global_step":
              tf.expand_dims(tf.train.get_or_create_global_step(), 0),
          "train_metrics/loss":
              tf.expand_dims(total_loss, 0),
      })

      host_call = (functools.partial(
          record_summary_host_fn,
          metrics_dir=os.path.join(FLAGS.output_dir,
                                   "train_metrics")), host_inputs)

      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn,
          host_call=host_call)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      begin_logits_values, begin_logits_indices = tf.math.top_k(
          span_logits[:, :, 0],
          k=nbest_logits_for_eval,
      )
      end_logits_values, end_logits_indices = tf.math.top_k(
          span_logits[:, :, 1],
          k=nbest_logits_for_eval,
      )

      predictions = {
          "block_ids": tf.identity(block_ids),
          "begin_logits_values": begin_logits_values,
          "begin_logits_indices": begin_logits_indices,
          "end_logits_values": end_logits_values,
          "end_logits_indices": end_logits_indices,
          "token_ids": tf.identity(token_ids),
          "answer_type": answer_type,
          "yesno_logits": yesno_logits,
          "supporting_fact_logits": supporting_fact_logits,
          "is_supporting_fact": supporting_fact,
      }
      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and PREDICT modes is supported: %s" % mode)

    return output_spec
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    logging.info("*** Model: Params ***")
    for name in sorted(params.keys()):
      logging.info("  %s = %s", name, params[name])
    logging.info("*** Model: Features ***")
    for name in sorted(features.keys()):
      logging.info("  name = %s, shape = %s", name, features[name].shape)

    model = modeling.ReadItTwiceBertModel(
        config=model_config, use_one_hot_embeddings=use_one_hot_embeddings)

    # [batch_size, main_seq_length]
    token_ids = features["token_ids"]
    batch_size = tf.shape(token_ids)[0]
    main_seq_length = tf.shape(token_ids)[1]
    block_ids = features["block_ids"]
    block_pos = features["block_pos"]

    annotation_begins = features.get("annotation_begins")
    annotation_ends = features.get("annotation_ends")
    annotation_labels = features.get("annotation_labels")

    # Do not attend padding tokens
    # [batch_size, main_seq_length, main_seq_length]
    att_mask = tf.tile(
        tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1),
        [1, main_seq_length, 1])
    att_mask = tf.cast(att_mask, dtype=tf.int32)

    main_output = model(
        token_ids=token_ids,
        training=(mode == tf.estimator.ModeKeys.TRAIN),
        block_ids=block_ids,
        block_pos=block_pos,
        att_mask=att_mask,
        annotation_begins=annotation_begins,
        annotation_ends=annotation_ends,
        annotation_labels=annotation_labels,
        enable_side_inputs=enable_side_inputs,
        num_replicas_concat=num_replicas_concat,
        cross_block_attention_mode=cross_block_attention_mode)

    mlm_loss_fn = losses.LanguageModelLoss(
        model.get_token_embedding_table(),
        hidden_size=model_config.hidden_size,
        name="mlm_loss")
    mlm_loss_output = mlm_loss_fn(
        input_tensor=main_output.final_hidden_states,
        label_ids=features["masked_lm_ids"],
        positions=features["masked_lm_positions"],
        label_weights=features["masked_lm_weights"],
        mlm_is_entity_mask=features.get("mlm_is_entity_mask"),
        mlm_is_not_entity_mask=features.get("mlm_is_not_entity_mask"),
        padding_token_id=padding_token_id)
    mlm_loss = mlm_loss_output.loss

    loss_to_log = dict(mlm_loss=tf.expand_dims(mlm_loss, 0))
    loss_weight_denominator = 1.0 + sum(extra_loss.values())
    total_loss = mlm_loss * (1.0 / loss_weight_denominator)
    for loss_name, loss_weight in extra_loss.items():
      logging.info("EXTRA LOSS: %s with weight %.2f", loss_name,
                   loss_weight / loss_weight_denominator)

      if model_config.summary_mode == "entity":
        # entity label "1" corresponds to unknown entity
        # there is no need to compute coreferense resolution loss
        # for these unknown entities.
        labels_weight = tf.cast(
            tf.logical_and(
                tf.not_equal(
                    tf.expand_dims(main_output.local_summary.labels, 1), 1),
                tf.not_equal(
                    tf.expand_dims(main_output.global_summary.labels, 0), 1)),
            tf.float32)
      else:
        labels_weight = None

      if loss_name == "sdp":
        loss_fn = losses.BatchCoreferenceResolutionLoss(
            apply_linear_layer=False)
        loss_value = loss_fn(
            main_output.local_summary.states,
            main_output.local_summary.labels,
            main_output.global_summary.states,
            main_output.global_summary.labels,
            labels_weight=labels_weight)
      elif loss_name == "sdp_linear":
        loss_fn = losses.BatchCoreferenceResolutionLoss(apply_linear_layer=True)
        loss_value = loss_fn(
            main_output.local_summary.states,
            main_output.local_summary.labels,
            main_output.global_summary.states,
            main_output.global_summary.labels,
            labels_weight=labels_weight)
      elif loss_name == "spp_linear":
        loss_fn = losses.BatchCoreferenceResolutionLoss(apply_linear_layer=True)
        # Positive examples are blocks which go one after another in the
        # original document.
        labels_mask = tf.less_equal(
            tf.abs(
                tf.expand_dims(main_output.local_summary.block_pos, 1) -
                tf.expand_dims(main_output.global_summary.block_pos, 0)), 1)
        loss_value = loss_fn(
            main_output.local_summary.states,
            main_output.local_summary.labels,
            main_output.global_summary.states,
            main_output.global_summary.labels,
            labels_mask=labels_mask,
            labels_weight=labels_weight)
      elif loss_name == "lm":
        token_labels = tf.roll(token_ids, shift=-1, axis=1)
        # [batch_size, global_batch_size]
        token2side_input_att_mask = modeling.get_cross_block_att(
            block_ids,
            block_pos,
            main_output.global_summary.block_ids,
            main_output.global_summary.block_pos,
            cross_block_attention_mode=cross_block_attention_mode,
            cast_to_int32=False)
        # We want to exclude the summary of the block itself
        # from decoder side input. As a proxy for this, we use block_ids AND
        # block_pos.
        samples_are_the_same = tf.logical_and(
            tf.equal(
                tf.expand_dims(block_ids, 1),
                tf.expand_dims(main_output.global_summary.block_ids, 0)),
            tf.equal(
                tf.expand_dims(block_pos, 1),
                tf.expand_dims(main_output.global_summary.block_pos, 0)))
        token2side_input_att_mask = tf.stop_gradient(
            tf.cast(
                tf.logical_and(token2side_input_att_mask,
                               tf.logical_not(samples_are_the_same)),
                dtype=tf.int32))

        decoder = modeling.ReadItTwiceDecoderModel(
            config=model_config,
            num_layers_override=summary_num_layers,
            num_cross_attention_heads=summary_num_cross_attention_heads,
            enable_default_side_input=summary_enable_default_side_input,
            use_one_hot_embeddings=use_one_hot_embeddings)
        summary_token_logits = decoder(
            token_ids=token_ids,
            side_input=main_output.global_summary.states,
            token2side_input_att_mask=token2side_input_att_mask,
            training=True)
        language_model_loss_fn = losses.LanguageModelLoss(
            decoder.get_token_embedding_table(),
            hidden_size=model_config.hidden_size)

        # We don't penalize the first and last 32 tokens, so the model does not
        # have incentive to memoize tokens at the border of blocks.
        labels_weights = tf.concat([
            tf.zeros([batch_size, 32], dtype=tf.bool),
            tf.ones([batch_size, main_seq_length - 32 * 2], dtype=tf.bool),
            tf.zeros([batch_size, 32], dtype=tf.bool)
        ],
                                   axis=1)
        labels_weights = tf.logical_and(
            labels_weights, tf.not_equal(token_labels, padding_token_id))
        labels_weights = tf.stop_gradient(
            tf.cast(labels_weights, dtype=tf.float32))

        loss_value = language_model_loss_fn(
            summary_token_logits, token_labels,
            label_weights=labels_weights).loss
      else:
        raise ValueError("Unknown extra loss: {}".format(loss_name))

      loss_to_log[loss_name] = tf.expand_dims(loss_value, 0)
      total_loss += loss_value * (loss_weight / loss_weight_denominator)

    if model.losses:
      total_loss += tf.math.add_n(model.losses)

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = checkpoint_utils.get_assignment_map_from_checkpoint(
          tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                   init_string)

    metric_fn_tensors = dict(
        mlm_loss_per_sample=mlm_loss_output.mlm_loss_per_sample,
        mlm_accuracy_per_sample=mlm_loss_output.mlm_accuracy_per_sample,
        mlm_weight_per_sample=mlm_loss_output.mlm_weight_per_sample,
        mlm_loss_per_entity_sample=mlm_loss_output.mlm_loss_per_entity_sample,
        mlm_accuracy_per_entity_sample=mlm_loss_output
        .mlm_accuracy_per_entity_sample,
        mlm_weight_per_entity_sample=mlm_loss_output
        .mlm_weight_per_entity_sample,
        mlm_loss_per_non_entity_sample=mlm_loss_output
        .mlm_loss_per_non_entity_sample,
        mlm_accuracy_per_non_entity_sample=mlm_loss_output
        .mlm_accuracy_per_non_entity_sample,
        mlm_weight_per_non_entity_sample=mlm_loss_output
        .mlm_weight_per_non_entity_sample,
        block_ids=block_ids)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu,
          optimizer, poly_power, start_warmup_step, learning_rate_schedule)

      metric_fn_tensors.update({
          "global_step":
              tf.expand_dims(tf.train.get_or_create_global_step(), 0),
          "loss":
              tf.expand_dims(total_loss, 0),
      })
      metric_fn_tensors.update(loss_to_log)

      host_call = (functools.partial(
          record_summary_host_fn,
          metrics_dir=os.path.join(FLAGS.output_dir, "train_metrics"),
          metrics_name=metrics_name or "train_metrics"), metric_fn_tensors)

      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn,
          host_call=host_call)

    elif mode == tf.estimator.ModeKeys.EVAL:

      eval_metrics = (functools.partial(
          metric_utils.masked_lm_metrics,
          is_train=False,
          metrics_name=metrics_name or "eval_metrics"), metric_fn_tensors)
      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported: %s" % mode)

    return output_spec