def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        is_real_example, label_ids = None, None
        if FLAGS.export_dir is None:
            label_ids = features["label_ids"]
            if "is_real_example" in features:
                is_real_example = tf.cast(features["is_real_example"],
                                          dtype=tf.float32)
            else:
                is_real_example = tf.ones(tf.shape(label_ids),
                                          dtype=tf.float32)

        is_training = (mode == tf_estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, input_ids,
                                       input_mask, segment_ids, label_ids,
                                       num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf_estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf_estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                per_example_loss, label_ids, logits, is_real_example
            ])
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            probabilities = tf.identity(probabilities, name="probabilities")
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={"probabilities": probabilities},
                scaffold_fn=scaffold_fn)
        return output_spec
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        if bert_teacher_config is None:
            model = modeling.BertModel(
                config=bert_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=use_one_hot_embeddings,
                use_einsum=use_einsum)

            label_ids = tf.reshape(masked_lm_ids, [-1])
            true_labels = tf.one_hot(label_ids,
                                     depth=bert_config.vocab_size,
                                     dtype=model.get_sequence_output().dtype)
            one_hot_labels = true_labels
        else:
            model = modeling.BertModel(
                config=bert_config,
                is_training=False,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=use_one_hot_embeddings,
                use_einsum=use_einsum)

            with tf.variable_scope("teacher"):
                teacher_model = modeling.BertModel(
                    config=bert_teacher_config,
                    is_training=False,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    token_type_ids=segment_ids,
                    use_one_hot_embeddings=use_one_hot_embeddings,
                    use_einsum=use_einsum)

                label_ids = tf.reshape(masked_lm_ids, [-1])

                true_labels = tf.one_hot(
                    label_ids,
                    depth=bert_config.vocab_size,
                    dtype=model.get_sequence_output().dtype)

                teacher_logits = get_logits(
                    bert_teacher_config,
                    distill_temperature * teacher_model.get_sequence_output(),
                    teacher_model.get_embedding_table(), masked_lm_positions)

                teacher_labels = tf.nn.softmax(teacher_logits, axis=-1)

                if distill_ground_truth_ratio == 1.0:
                    one_hot_labels = true_labels
                else:
                    one_hot_labels = (teacher_labels *
                                      (1 - distill_ground_truth_ratio) +
                                      true_labels * distill_ground_truth_ratio)

                teacher_attentions = teacher_model.get_all_attention_maps()
                student_attentions = model.get_all_attention_maps()

                teacher_hiddens = teacher_model.get_all_encoder_layers()
                student_hiddens = model.get_all_encoder_layers()

        (masked_lm_loss, _, masked_lm_example_loss, masked_lm_log_probs,
         _) = get_masked_lm_output(bert_config, model.get_sequence_output(),
                                   model.get_embedding_table(),
                                   masked_lm_positions,
                                   tf.stop_gradient(one_hot_labels),
                                   true_labels, masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        extra_loss1 = 0.0
        extra_loss2 = 0.0
        extra_loss3 = 0.0
        extra_loss4 = 0.0

        scalars_to_summarize = {}

        def get_layerwise_gate(layer_id):
            steps_per_phase = num_train_steps // bert_config.num_hidden_layers
            layer_wise_gate = distill_util.layer_wise_learning_rate(
                layer_id=layer_id,
                steps_per_phase=steps_per_phase,
                binary=True)
            return layer_wise_gate

        if layer_wise_warmup and hidden_distill_factor != 0.0:
            layer_id = 0
            for teacher_hidden, student_hidden in (zip(teacher_hiddens[1:],
                                                       student_hiddens[1:])):
                with tf.variable_scope("hidden_distill_%d" % layer_id):
                    mse_loss = tf.losses.mean_squared_error(
                        tf.stop_gradient(
                            contrib_layers.layer_norm(inputs=teacher_hidden,
                                                      begin_norm_axis=-1,
                                                      begin_params_axis=-1,
                                                      trainable=False)),
                        contrib_layers.layer_norm(inputs=student_hidden,
                                                  begin_norm_axis=-1,
                                                  begin_params_axis=-1,
                                                  trainable=False))
                    layer_wise_gate = get_layerwise_gate(layer_id)
                    extra_loss1 += layer_wise_gate * mse_loss
                layer_id += 1
            extra_loss1 = extra_loss1 * hidden_distill_factor / layer_id

        if layer_wise_warmup and (beta_distill_factor != 0
                                  and gamma_distill_factor != 0.0):
            layer_id = 0
            for teacher_hidden, student_hidden in (zip(teacher_hiddens[1:],
                                                       student_hiddens[1:])):
                with tf.variable_scope("hidden_distill_%d" % layer_id):
                    teacher_mean = tf.reduce_mean(teacher_hiddens,
                                                  axis=[-1],
                                                  keepdims=True)
                    student_mean = tf.reduce_mean(student_hidden,
                                                  axis=[-1],
                                                  keepdims=True)
                    teacher_variance = tf.reduce_mean(tf.squared_difference(
                        teacher_hiddens, teacher_mean),
                                                      axis=[-1],
                                                      keepdims=True)
                    student_variance = tf.reduce_mean(tf.squared_difference(
                        student_hidden, student_mean),
                                                      axis=[-1],
                                                      keepdims=True)
                    beta_distill_loss = tf.reduce_mean(
                        tf.squared_difference(tf.stop_gradient(teacher_mean),
                                              student_mean))
                    gamma_distill_loss = tf.reduce_mean(
                        tf.abs(
                            tf.stop_gradient(teacher_variance) -
                            student_variance))
                    layer_wise_gate = get_layerwise_gate(layer_id)
                    extra_loss3 += layer_wise_gate * beta_distill_loss
                    extra_loss4 += layer_wise_gate * gamma_distill_loss
                layer_id += 1
            extra_loss3 = extra_loss3 * beta_distill_factor / layer_id
            extra_loss4 = extra_loss4 * gamma_distill_factor / layer_id

        if layer_wise_warmup and attention_distill_factor != 0.0:
            layer_id = 0
            for teacher_attention, student_attention in (zip(
                    teacher_attentions, student_attentions)):
                with tf.variable_scope("attention_distill_%d" % layer_id):
                    teacher_attention_prob = tf.nn.softmax(teacher_attention,
                                                           axis=-1)
                    student_attention_log_prob = tf.nn.log_softmax(
                        student_attention, axis=-1)
                    kl_divergence = -(tf.stop_gradient(teacher_attention_prob)
                                      * student_attention_log_prob)
                    kl_divergence = tf.reduce_mean(
                        tf.reduce_sum(kl_divergence, axis=-1))
                    layer_wise_gate = get_layerwise_gate(layer_id)
                    extra_loss2 += layer_wise_gate * kl_divergence
                layer_id += 1
            extra_loss2 = extra_loss2 * attention_distill_factor / layer_id

        if layer_wise_warmup:
            total_loss = extra_loss1 + extra_loss2 + extra_loss3 + extra_loss4
        else:
            total_loss = masked_lm_loss + next_sentence_loss

        if summary_dir is not None:
            if layer_wise_warmup:
                scalars_to_summarize["feature_map_transfer_loss"] = extra_loss1
                scalars_to_summarize["attention_transfer_loss"] = extra_loss2
                scalars_to_summarize["mean_transfer_loss"] = extra_loss3
                scalars_to_summarize["variance_transfer_loss"] = extra_loss4
            else:
                scalars_to_summarize["masked_lm_loss"] = masked_lm_loss
                scalars_to_summarize["next_sentence_loss"] = next_sentence_loss

                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_accuracy = tf.cast(
                    tf.math.equal(tf.reshape(masked_lm_ids, [-1]),
                                  tf.reshape(masked_lm_predictions, [-1])),
                    tf.float32)
                numerator = tf.reduce_sum(
                    tf.reshape(masked_lm_weights, [-1]) * masked_lm_accuracy)
                denominator = tf.reduce_sum(masked_lm_weights) + 1e-5
                masked_lm_accuracy = numerator / denominator
                scalars_to_summarize["masked_lm_accuracy"] = masked_lm_accuracy

                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_accuracy = tf.reduce_mean(
                    tf.cast(
                        tf.math.equal(
                            tf.reshape(next_sentence_labels, [-1]),
                            tf.reshape(next_sentence_predictions, [-1])),
                        tf.float32))
                scalars_to_summarize[
                    "next_sentence_accuracy"] = next_sentence_accuracy

            scalars_to_summarize[
                "global_step"] = tf.train.get_or_create_global_step()
            scalars_to_summarize["loss"] = total_loss

        host_call = None
        if summary_dir is not None:
            if use_tpu:
                for name in scalars_to_summarize:
                    scalars_to_summarize[name] = tf.reshape(
                        scalars_to_summarize[name], [1])

                def host_call_fn(*args):
                    """Host call function to compute training summaries."""
                    scalars = _list_to_dicts(args,
                                             scalars_to_summarize.keys())[0]
                    for name in scalars:
                        scalars[name] = scalars[name][0]

                    with contrib_summary.create_file_writer(
                            summary_dir, max_queue=1000).as_default():
                        with contrib_summary.always_record_summaries():
                            for name, value in scalars.items():
                                if name not in ["global_step"]:
                                    contrib_summary.scalar(
                                        name,
                                        value,
                                        step=scalars["global_step"])

                    return contrib_summary.all_summary_ops()

                host_call = (host_call_fn,
                             _dicts_to_list([scalars_to_summarize],
                                            scalars_to_summarize.keys()))
            else:
                for name in scalars_to_summarize:
                    tf.summary.scalar(name, scalars_to_summarize[name])

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        teacher_initialized_variable_names = {}
        scaffold_fn = None

        if init_checkpoint:
            if not init_from_teacher:
                # Initializes from the checkpoint for all variables.
                (assignment_map, initialized_variable_names
                 ) = modeling.get_assignment_map_from_checkpoint(
                     tvars, init_checkpoint)
                if use_tpu:

                    def tpu_scaffold():
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      assignment_map)
                        return tf.train.Scaffold()

                    scaffold_fn = tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
            elif bert_teacher_config is not None:
                # Initializes from the pre-trained checkpoint only for teacher model
                # and embeddings for distillation.
                (assignment_map, initialized_variable_names
                 ) = modeling.get_assignment_map_from_checkpoint(
                     tvars, init_checkpoint, init_embedding=True)
                (teacher_assignment_map, teacher_initialized_variable_names
                 ) = modeling.get_assignment_map_from_checkpoint(
                     tvars, init_checkpoint, init_from_teacher=True)
                if use_tpu:

                    def teacher_tpu_scaffold():
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      assignment_map)
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      teacher_assignment_map)
                        return tf.train.Scaffold()

                    scaffold_fn = teacher_tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  teacher_assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        total_size = 0
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            if var.name in teacher_initialized_variable_names:
                init_string = ", *INIT_FROM_TEACHER_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)
            if not var.name.startswith("teacher"):
                total_size += functools.reduce(lambda x, y: x * y,
                                               var.get_shape().as_list())
        tf.logging.info("  total variable parameters: %d", total_size)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if layer_wise_warmup:
                train_op = optimization.create_optimizer(
                    total_loss,
                    learning_rate,
                    num_train_steps,
                    num_warmup_steps,
                    use_tpu,
                    optimizer,
                    end_lr_rate=1.0,
                    use_layer_wise_warmup=True,
                    total_warmup_phases=bert_config.num_hidden_layers)
            else:
                train_op = optimization.create_optimizer(
                    total_loss, learning_rate, num_train_steps,
                    num_warmup_steps, use_tpu, optimizer)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                host_call=host_call)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
Exemple #3
0
    def model_fn(features, labels, mode, params):
        logging.info("*** Features ***")
        for name in sorted(features.keys()):
            logging.info("  name = %s, shape = %s" % (name, features[name].shape))
        input_ids = features["input_ids"]
        mask = features["mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if FLAGS.crf:
            (total_loss, logits,predicts) = create_model(bert_config, is_training, input_ids,
                                                            mask, segment_ids, label_ids,num_labels, 
                                                            use_one_hot_embeddings)

        else:
            (total_loss, logits, predicts) = create_model(bert_config, is_training, input_ids,
                                                            mask, segment_ids, label_ids,num_labels, 
                                                            use_one_hot_embeddings)
        tvars = tf.trainable_variables()
        scaffold_fn = None
        initialized_variable_names=None
        if init_checkpoint:
            (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint)
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
            if use_tpu:
                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()
                scaffold_fn = tpu_scaffold
            else:

                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
        logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        

        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            def metric_fn(label_ids, logits,num_labels,mask):
                predictions = tf.math.argmax(logits, axis=-1, output_type=tf.int32)
                cm = metrics.streaming_confusion_matrix(label_ids, predictions, num_labels-1, weights=mask)
                return {
                    "confusion_matrix":cm
                }
                #
            eval_metrics = (metric_fn, [label_ids, logits, num_labels, mask])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predicts, scaffold_fn=scaffold_fn
            )
        return output_spec