def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, probabilities, logits, predictions) = \
            create_model(
                albert_config,
                is_training,
                input_ids,
                input_mask,
                segment_ids,
                label_ids,
                num_labels,
                use_one_hot_embeddings,
                task_name,
                hub_module
            )

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu,
                                                     optimizer)

            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       loss=total_loss,
                                                       train_op=train_op,
                                                       scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            if task_name not in ["sts-b", "cola", "nlpcc_dbqa"]:

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)
                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)
                    return {
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }
            elif task_name == "sts-b":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Pearson correlations for STS-B."""
                    # Display labels and predictions
                    concat1 = contrib_metrics.streaming_concat(logits)
                    concat2 = contrib_metrics.streaming_concat(label_ids)

                    # Compute Pearson correlation
                    pearson = contrib_metrics.streaming_pearson_correlation(
                        logits, label_ids, weights=is_real_example)

                    # Compute MSE
                    # mse = tf.metrics.mean(per_example_loss)
                    mse = tf.metrics.mean_squared_error(
                        label_ids, logits, weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "pred": concat1,
                        "label_ids": concat2,
                        "pearson": pearson,
                        "MSE": mse,
                        "eval_loss": loss,
                    }
            elif task_name == "cola":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Matthew's correlations for STS-B."""
                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
                    tp, tp_op = tf.metrics.true_positives(
                        predictions, label_ids, weights=is_real_example)
                    tn, tn_op = tf.metrics.true_negatives(
                        predictions, label_ids, weights=is_real_example)
                    fp, fp_op = tf.metrics.false_positives(
                        predictions, label_ids, weights=is_real_example)
                    fn, fn_op = tf.metrics.false_negatives(
                        predictions, label_ids, weights=is_real_example)

                    # Compute Matthew's correlation
                    mcc = tf.div_no_nan(
                        tp * tn - fp * fn,
                        tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn),
                               0.5))

                    # Compute accuracy
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "matthew_corr":
                        (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)),
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }

            elif task_name == "nlpcc_dbqa":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):

                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    precision, precision_update_op = tf.metrics.precision(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example,
                        name="precision")
                    recall, recall_update_op = tf.metrics.recall(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example,
                        name='recall')

                    f1_score, f1_update_op = tf.metrics.mean(
                        (2 * (precision + 1e-7) *
                         (recall + 1e-7)) / (precision + recall + 2e-7),
                        name='f1_score')

                    # Compute accuracy
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "precision": (precision, precision_update_op),
                        "recall": (recall, recall_update_op),
                        "f1_score": (f1_score, f1_update_op),
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }

            eval_metrics = (metric_fn, [
                per_example_loss, label_ids, logits, is_real_example
            ])
            output_spec = contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)

        else:
            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       predictions={
                                                           "probabilities":
                                                           probabilities,
                                                           "predictions":
                                                           predictions
                                                       },
                                                       scaffold_fn=scaffold_fn)
        return output_spec
Esempio n. 2
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"]
    masked_lm_weights = features["masked_lm_weights"]
    # Note: We keep this feature name `next_sentence_labels` to be compatible
    # with the original data created by lanzhzh@. However, in the ALBERT case
    # it does represent sentence_order_labels.
    sentence_order_labels = features["next_sentence_labels"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    model = modeling.AlbertModel(
        config=albert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

    (masked_lm_loss, masked_lm_example_loss,
     masked_lm_log_probs) = get_masked_lm_output(albert_config,
                                                 model.get_sequence_output(),
                                                 model.get_embedding_table(),
                                                 masked_lm_positions,
                                                 masked_lm_ids,
                                                 masked_lm_weights)

    (sentence_order_loss, sentence_order_example_loss,
     sentence_order_log_probs) = get_sentence_order_output(
         albert_config, model.get_pooled_output(), sentence_order_labels)

    total_loss = masked_lm_loss + sentence_order_loss

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      tf.logging.info("number of hidden group %d to initialize",
                      albert_config.num_hidden_groups)
      num_of_initialize_group = 1
      if FLAGS.init_from_group0:
        num_of_initialize_group = albert_config.num_hidden_groups
        if albert_config.net_structure_type > 0:
          num_of_initialize_group = albert_config.num_hidden_layers
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(
              tvars, init_checkpoint, num_of_initialize_group)
      if use_tpu:

        def tpu_scaffold():
          for gid in range(num_of_initialize_group):
            tf.logging.info("initialize the %dth layer", gid)
            tf.logging.info(assignment_map[gid])
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid])
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        for gid in range(num_of_initialize_group):
          tf.logging.info("initialize the %dth layer", gid)
          tf.logging.info(assignment_map[gid])
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid])

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps,
          use_tpu, optimizer, poly_power, start_warmup_step)

      output_spec = contrib_tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn)
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(*args):
        """Computes the loss and accuracy of the model."""
        (masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
         masked_lm_weights, sentence_order_example_loss,
         sentence_order_log_probs, sentence_order_labels) = args[:7]


        masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
                                         [-1, masked_lm_log_probs.shape[-1]])
        masked_lm_predictions = tf.argmax(
            masked_lm_log_probs, axis=-1, output_type=tf.int32)
        masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
        masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
        masked_lm_accuracy = tf.metrics.accuracy(
            labels=masked_lm_ids,
            predictions=masked_lm_predictions,
            weights=masked_lm_weights)
        masked_lm_mean_loss = tf.metrics.mean(
            values=masked_lm_example_loss, weights=masked_lm_weights)

        metrics = {
            "masked_lm_accuracy": masked_lm_accuracy,
            "masked_lm_loss": masked_lm_mean_loss,
        }

        sentence_order_log_probs = tf.reshape(
            sentence_order_log_probs, [-1, sentence_order_log_probs.shape[-1]])
        sentence_order_predictions = tf.argmax(
            sentence_order_log_probs, axis=-1, output_type=tf.int32)
        sentence_order_labels = tf.reshape(sentence_order_labels, [-1])
        sentence_order_accuracy = tf.metrics.accuracy(
            labels=sentence_order_labels,
            predictions=sentence_order_predictions)
        sentence_order_mean_loss = tf.metrics.mean(
            values=sentence_order_example_loss)
        metrics.update({
            "sentence_order_accuracy": sentence_order_accuracy,
            "sentence_order_loss": sentence_order_mean_loss
        })
        return metrics

      metric_values = [
          masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
          masked_lm_weights, sentence_order_example_loss,
          sentence_order_log_probs, sentence_order_labels
      ]

      eval_metrics = (metric_fn, metric_values)

      output_spec = contrib_tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

    return output_spec