예제 #1
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        config = util.initialize_from_env(use_tpu=FLAGS.use_tpu,
                                          config_params=FLAGS.config_params,
                                          config_file=FLAGS.config_filename)

        input_ids = features["flattened_input_ids"]
        input_mask = features["flattened_input_mask"]
        text_len = features["text_len"]
        speaker_ids = features["speaker_ids"]
        genre = features["genre"]
        gold_starts = features["span_starts"]
        gold_ends = features["span_ends"]
        cluster_ids = features["cluster_ids"]
        sentence_map = features["sentence_map"]
        # span_mention = features["span_mention"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = util.get_model(config, model_sign="corefqa")

        if FLAGS.use_tpu:
            tf.logging.info(
                "****************************** Training on TPU ******************************"
            )

            def tpu_scaffold():
                return tf.train.Scaffold()

            scaffold_fn = tpu_scaffold
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf.logging.info(
                "****************************** tf.estimator.ModeKeys.TRAIN ******************************"
            )
            tf.logging.info("********* Features *********")
            for name in sorted(features.keys()):
                tf.logging.info("  name = %s, shape = %s" %
                                (name, features[name].shape))

            total_loss, topk_span_starts, topk_span_ends, top_antecedent_scores = model.get_predictions_and_loss(
                input_ids, input_mask, text_len, speaker_ids, genre,
                is_training, gold_starts, gold_ends, cluster_ids,
                sentence_map)  # , span_mention)

            if config["tpu"]:
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=config['learning_rate'],
                    beta1=0.9,
                    beta2=0.999,
                    epsilon=1e-08)
                optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
                train_op = optimizer.minimize(total_loss,
                                              tf.train.get_global_step())
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=tf.estimator.ModeKeys.TRAIN,
                    loss=total_loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                optimizer = RAdam(learning_rate=config['learning_rate'],
                                  epsilon=1e-8,
                                  beta1=0.9,
                                  beta2=0.999)
                train_op = optimizer.minimize(total_loss,
                                              tf.train.get_global_step())

                training_logging_hook = tf.train.LoggingTensorHook(
                    {"loss": total_loss}, every_n_iter=1)
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=tf.estimator.ModeKeys.TRAIN,
                    loss=total_loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn,
                    training_hooks=[training_logging_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            tf.logging.info(
                "****************************** tf.estimator.ModeKeys.EVAL ******************************"
            )
            tf.logging.info(
                "@@@@@ MERELY support tf.estimator.ModeKeys.PREDICT ! @@@@@")
            tf.logging.info(
                "@@@@@ YOU can EVAL your checkpoints after the training process. @@@@@"
            )
            tf.logging.info(
                "****************************** tf.estimator.ModeKeys.EVAL ******************************"
            )

        elif mode == tf.estimator.ModeKeys.PREDICT:
            tf.logging.info(
                "****************************** tf.estimator.ModeKeys.PREDICT ******************************"
            )
            total_loss, topk_span_starts, topk_span_ends, top_antecedent_scores = model.get_predictions_and_loss(
                input_ids, input_mask, text_len, speaker_ids, genre,
                is_training, gold_starts, gold_ends, cluster_ids,
                sentence_map)  #, span_mention)
            top_antecedent = tf.math.argmax(top_antecedent_scores, axis=-1)
            predictions = {
                "total_loss": total_loss,
                "topk_span_starts": topk_span_starts,
                "topk_span_ends": topk_span_ends,
                "top_antecedent_scores": top_antecedent_scores,
                "top_antecedent": top_antecedent,
                "cluster_ids": cluster_ids,
                "gold_starts": gold_starts,
                "gold_ends": gold_ends
            }

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Please check the the mode ! ")
        return output_spec
예제 #2
0
    def mention_proposal_model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        input_ids = features["flattened_input_ids"]
        input_mask = features["flattened_input_mask"]
        text_len = features["text_len"]
        speaker_ids = features["speaker_ids"]
        gold_starts = features["span_starts"]
        gold_ends = features["span_ends"]
        cluster_ids = features["cluster_ids"]
        sentence_map = features["sentence_map"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = util.get_model(config, model_sign="mention_proposal")

        if config.use_tpu:

            def tpu_scaffold():
                return tf.train.Scaffold()

            scaffold_fn = tpu_scaffold
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf.logging.info(
                "****************************** tf.estimator.ModeKeys.TRAIN ******************************"
            )
            tf.logging.info("********* Features *********")
            for name in sorted(features.keys()):
                tf.logging.info("  name = %s, shape = %s" %
                                (name, features[name].shape))

            instance = (input_ids, input_mask, sentence_map, text_len,
                        speaker_ids, gold_starts, gold_ends, cluster_ids)
            total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(
                instance, is_training)
            gold_start_sequence_labels, gold_end_sequence_labels, gold_span_sequence_labels = model.get_gold_mention_sequence_labels_from_pad_index(
                gold_starts, gold_ends, text_len)

            if config.use_tpu:
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=config.learning_rate,
                    beta1=0.9,
                    beta2=0.999,
                    epsilon=1e-08)
                optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
                train_op = optimizer.minimize(total_loss,
                                              tf.train.get_global_step())
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                optimizer = RAdam(learning_rate=config.learning_rate,
                                  epsilon=1e-8,
                                  beta1=0.9,
                                  beta2=0.999)
                train_op = optimizer.minimize(total_loss,
                                              tf.train.get_global_step())

                train_logging_hook = tf.train.LoggingTensorHook(
                    {"loss": total_loss}, every_n_iter=1)
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn,
                    training_hooks=[train_logging_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            tf.logging.info(
                "****************************** tf.estimator.ModeKeys.EVAL ******************************"
            )

            instance = (input_ids, input_mask, sentence_map, text_len,
                        speaker_ids, gold_starts, gold_ends, cluster_ids)
            total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(
                instance, is_training)
            total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(
                instance, is_training)
            gold_start_sequence_labels, gold_end_sequence_labels, gold_span_sequence_labels = model.get_gold_mention_sequence_labels_from_pad_index(
                gold_starts, gold_ends, text_len)

            def metric_fn(start_scores, end_scores, span_scores,
                          gold_span_label):
                start_scores = tf.reshape(start_scores,
                                          [-1, config.window_size])
                end_scores = tf.reshape(end_scores, [-1, config.window_size])
                start_scores = tf.tile(tf.expand_dims(start_scores, 2),
                                       [1, 1, config.window_size])
                end_scores = tf.tile(tf.expand_dims(end_scores, 2),
                                     [1, 1, config.window_size])
                sce_span_scores = (start_scores + end_scores + span_scores) / 3
                pred_span_label = tf.cast(
                    tf.reshape(
                        tf.math.greater_equal(sce_span_scores,
                                              config.mention_threshold), [-1]),
                    tf.bool)

                gold_span_label = tf.cast(
                    tf.reshape(gold_span_sequence_labels, [-1]), tf.bool)

                return {
                    "precision":
                    tf.compat.v1.metrics.precision(gold_span_label,
                                                   pred_span_label),
                    "recall":
                    tf.compat.v1.metrics.recall(gold_span_label,
                                                pred_span_label)
                }

            eval_metrics = (metric_fn, [start_scores, end_scores, span_scores])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.EVAL,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            tf.logging.info(
                "****************************** tf.estimator.ModeKeys.PREDICT ******************************"
            )

            instance = (input_ids, input_mask, sentence_map, text_len,
                        speaker_ids, gold_starts, gold_ends, cluster_ids)
            total_loss, start_scores, end_scores, span_scores = model.get_mention_proposal_and_loss(
                instance, is_training)
            gold_start_sequence_labels, gold_end_sequence_labels, gold_span_sequence_labels = model.get_gold_mention_sequence_labels_from_pad_index(
                gold_starts, gold_ends, text_len)
            predictions = {
                "total_loss": total_loss,
                "start_scores": start_scores,
                "start_gold": gold_starts,
                "end_gold": gold_ends,
                "end_scores": end_scores,
                "span_scores": span_scores
            }
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Please check the the mode ! ")

        return output_spec