Beispiel #1
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""
        del labels, params  # Unused.

        tf.logging.info("*** Features ***")
        for name in sorted(features):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        model = table_bert.create_model(
            features=features,
            mode=mode,
            bert_config=bert_config,
            restrict_attention_mode=restrict_attention_mode,
            restrict_attention_bucket_size=restrict_attention_bucket_size,
            restrict_attention_header_size=restrict_attention_header_size,
            restrict_attention_row_heads_ratio=
            restrict_attention_row_heads_ratio,
            disabled_features=disabled_features,
            disable_position_embeddings=disable_position_embeddings,
            reset_position_index_per_cell=reset_position_index_per_cell,
            proj_value_length=proj_value_length,
            attention_bias_disabled=attention_bias_disabled,
            attention_bias_use_relative_scalar_only=
            attention_bias_use_relative_scalar_only,
        )

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
         masked_lm_predictions) = _get_masked_lm_output(
             bert_config, model.get_sequence_output(),
             model.get_embedding_table(), masked_lm_positions, masked_lm_ids,
             masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = _get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + next_sentence_loss

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            init_tvars = [
                tvar for tvar in tvars
                if "position_embeddings" not in tvar.name
            ]
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 init_tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf_estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf_estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        elif mode == tf_estimator.ModeKeys.PREDICT:
            predictions = {
                "masked_lm_predictions": masked_lm_predictions,
            }
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                predictions=predictions,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Unsupported mode: %s" % mode)
        return output_spec
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""

        del labels, params  # Unused.

        # TODO(thomasmueller) Add support for this.
        if (config.use_out_of_core_negatives and config.use_mined_negatives
                and config.mask_repeated_questions):
            raise ValueError("Unsupported combination of options.")

        tf.logging.info("*** Features ***")
        for name in sorted(features):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        if config.ignore_table_content:
            features["input_mask"] = tf.where(
                features["row_ids"] > 0, tf.zeros_like(features["input_mask"]),
                tf.ones_like(features["input_mask"]))

        features_table_names = [
            "input_ids", "input_mask", "segment_ids", "column_ids", "row_ids",
            "prev_label_ids", "column_ranks", "inv_column_ranks",
            "numeric_relations"
        ]
        features_table = {}
        for name in features_table_names:
            if config.use_mined_negatives:
                # split each feature in half, and concat vertically.
                feature_positive_table, feature_negative_table = tf.split(
                    features[name], num_or_size_splits=2, axis=1)
                features_table[name] = tf.concat(
                    [feature_positive_table, feature_negative_table], axis=0)
            else:
                features_table[name] = features[name]

        tf.logging.info("*** Features table ***")
        for name in sorted(features_table):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        table_model = table_bert.create_model(
            features=features_table,
            disabled_features=config.disabled_features,
            mode=mode,
            bert_config=config.bert_config,
        )

        # Arrange features for query, such that it is assigned with an empty table.
        empty_table_features = tf.zeros_like(
            features["question_input_mask"])[:, :config.max_query_length]
        features_query = {
            "input_ids":
            features["question_input_ids"][:, :config.max_query_length],
            "input_mask":
            features["question_input_mask"][:, :config.max_query_length],
            "segment_ids":
            empty_table_features,
            "column_ids":
            empty_table_features,
            "row_ids":
            empty_table_features,
            "prev_label_ids":
            empty_table_features,
            "column_ranks":
            empty_table_features,
            "inv_column_ranks":
            empty_table_features,
            "numeric_relations":
            empty_table_features,
        }
        query_model = table_bert.create_model(
            features=features_query,
            disabled_features=config.disabled_features,
            mode=mode,
            bert_config=config.bert_config,
        )

        table_hidden_representation = table_model.get_pooled_output()
        query_hidden_representation = query_model.get_pooled_output()
        if config.down_projection_dim > 0:
            table_projection = _get_projection_matrix(
                "table_projection",
                num_columns=table_hidden_representation.shape[1])
            query_projection = _get_projection_matrix(
                "text_projection",
                num_columns=query_hidden_representation.shape[1])

            # <float32>[batch_size * num_tables, down_projection_dim]
            table_rep = _get_type_representation(table_hidden_representation,
                                                 table_projection)
            # <float32>[batch_size, down_projection_dim]
            query_rep = _get_type_representation(query_hidden_representation,
                                                 query_projection)
        else:
            table_rep = table_hidden_representation
            query_rep = query_hidden_representation

        batch_size = tf.shape(query_rep)[0]
        # Identity matrix, as gold logits are on the diagonal.
        labels_single_table = tf.eye(batch_size)
        if config.use_mined_negatives:
            # <int64>[batch_size, batch_size * num_tables]
            labels = tf.concat(
                [labels_single_table,
                 tf.zeros_like(labels_single_table)],
                axis=1)
        else:
            labels = labels_single_table

        # <int64>[batch_size * num_tables]
        table_id_hash = tf.reshape(tf.transpose(features["table_id_hash"]),
                                   shape=[-1])
        # <int64>[batch_size, 1]
        table_id_hash_transposed = features["table_id_hash"][..., :1]

        # <int64>[batch_size, 1]
        question_hash = features["question_hash"]

        # <int64>[1, batch_size]
        question_hash_transposed = tf.transpose(question_hash)
        if config.use_tpu and config.use_out_of_core_negatives:
            data = get_updates_for_use_tpu_with_out_of_core_negatives(
                ModelBuilderData(table_rep, table_id_hash, question_hash,
                                 labels))
            table_rep = data.table_rep
            table_id_hash = data.table_id_hash
            question_hash = data.question_hash
            labels = data.labels
        # <float32>[batch_size, batch_size|global_batch_size * num_tables]
        logits = tf.matmul(query_rep, table_rep, transpose_b=True)
        if config.mask_repeated_tables:
            # Matrix of non-trivial repeated tables
            # <bool>[batch_size, batch_size|global_batch_size * num_tables]
            repeated_tables = tf.math.equal(
                tf.expand_dims(table_id_hash, axis=0),
                table_id_hash_transposed) & tf.math.equal(labels, 0)
            logits = tf.where(repeated_tables,
                              tf.zeros_like(logits) - _INF, logits)

        if config.mask_repeated_questions:
            logits = _apply_reapated_text_masking(
                config,
                question_hash,
                question_hash_transposed,
                labels,
                logits,
            )

        total_loss = tf.losses.softmax_cross_entropy(onehot_labels=labels,
                                                     logits=logits)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        init_checkpoint = config.init_checkpoint
        if init_checkpoint:
            (assignment_maps,
             initialized_variable_names) = _get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint, config.init_from_single_encoder)

            if config.use_tpu:

                def tpu_scaffold():
                    for assignment_map in assignment_maps:
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                for assignment_map in assignment_maps:
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf_estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss,
                config.learning_rate,
                config.num_train_steps,
                config.num_warmup_steps,
                config.use_tpu,
                grad_clipping=config.grad_clipping)

            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf_estimator.ModeKeys.EVAL:
            eval_metrics = (_calculate_eval_metrics_fn,
                            [total_loss, logits, labels])
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            if config.use_mined_negatives:
                table_rep_gold, _ = tf.split(table_rep,
                                             num_or_size_splits=2,
                                             axis=0)
            else:
                table_rep_gold = table_rep

            predictions = {
                "query_rep": query_rep,
                "table_rep": table_rep_gold,
            }
            # Only available when predicting on GPU.
            if "table_id" in features:
                predictions["table_id"] = features["table_id"]
            if "question_id" in features:
                predictions["query_id"] = features["question_id"]
            output_spec = tf_estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        return output_spec
Beispiel #3
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""

        del labels  # Unused.

        tf.logging.info("*** Features ***")
        for name in sorted(features):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        label_ids = features["label_ids"]
        input_mask = features["input_mask"]
        row_ids = features["row_ids"]
        column_ids = features["column_ids"]
        # Table cells only, without question tokens and table headers.
        table_mask = tf.where(row_ids > 0, tf.ones_like(row_ids),
                              tf.zeros_like(row_ids))
        do_model_aggregation = config.num_aggregation_labels > 0
        aggregation_function_id = (tf.squeeze(
            features["aggregation_function_id"], axis=[1])
                                   if do_model_aggregation else None)

        do_model_classification = config.num_classification_labels > 0
        classification_class_index = (tf.squeeze(
            features["classification_class_index"], axis=[1])
                                      if do_model_classification else None)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = table_bert.create_model(
            features=features,
            mode=mode,
            bert_config=config.bert_config,
            disabled_features=config.disabled_features,
            disable_position_embeddings=config.disable_position_embeddings,
            reset_position_index_per_cell=config.reset_position_index_per_cell,
            proj_value_length=config.proj_value_length,
        )

        answer, numeric_values, numeric_values_scale = (
            utils.extract_answer_from_features(
                features=features,
                use_answer_as_supervision=config.use_answer_as_supervision))
        outputs = _get_classification_outputs(
            config=config,
            output_layer=model.get_sequence_output(),
            output_layer_aggregation=model.get_pooled_output(),
            label_ids=label_ids,
            input_mask=input_mask,
            table_mask=table_mask,
            aggregation_function_id=aggregation_function_id,
            answer=answer,
            numeric_values=numeric_values,
            numeric_values_scale=numeric_values_scale,
            is_training=is_training,
            row_ids=row_ids,
            column_ids=column_ids,
            classification_class_index=classification_class_index)
        total_loss = outputs.total_loss

        tvars = tf.trainable_variables()
        if config.reset_output_cls:
            tvars = [
                tvar for tvar in tvars
                if ("output_weights_cls" not in tvar.name
                    and "output_bias_cls" not in tvar.name)
            ]
        initialized_variable_names = set()
        scaffold_fn = None
        init_from_checkpoints = []

        def add_init_checkpoint(init_checkpoint, scope=None):
            if not init_checkpoint:
                return
            (assignment_map, initialized_variables
             ) = modeling.get_assignment_map_from_checkpoint(tvars,
                                                             init_checkpoint,
                                                             scope=scope)
            initialized_variable_names.update(initialized_variables.keys())
            init_from_checkpoints.append((init_checkpoint, assignment_map))

        add_init_checkpoint(config.init_checkpoint)

        if init_from_checkpoints:
            if config.use_tpu:

                def tpu_scaffold():
                    for init_checkpoint, assignment_map in init_from_checkpoints:
                        tf.train.init_from_checkpoint(init_checkpoint,
                                                      assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                for init_checkpoint, assignment_map in init_from_checkpoints:
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)

        fail_if_missing = init_from_checkpoints and params.get(
            "fail_if_missing_variables_in_checkpoint", False)
        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            elif fail_if_missing:
                if "layer_norm" not in var.name and "LayerNorm" not in var.name:
                    tf.logging.fatal("Variable not found in checkpoint: %s",
                                     var.name)
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss,
                config.learning_rate,
                config.num_train_steps,
                config.num_warmup_steps,
                config.use_tpu,
                gradient_accumulation_steps=params.get(
                    "gradient_accumulation_steps", 1),
                grad_clipping=config.grad_clipping)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (_calculate_eval_metrics_fn, [
                total_loss,
                label_ids,
                outputs.logits,
                input_mask,
                aggregation_function_id,
                outputs.logits_aggregation,
                classification_class_index,
                outputs.logits_cls,
            ])
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "probabilities": outputs.probs,
                "input_ids": features["input_ids"],
                "column_ids": features["column_ids"],
                "row_ids": features["row_ids"],
                "segment_ids": features["segment_ids"],
                "question_id_ints": features["question_id_ints"],
            }
            if "question_id" in features:
                # Only available when predicting on GPU.
                predictions["question_id"] = features["question_id"]
                del predictions["question_id_ints"]
            if do_model_aggregation:
                predictions.update({
                    "gold_aggr":
                    features["aggregation_function_id"],
                    "pred_aggr":
                    tf.argmax(
                        outputs.logits_aggregation,
                        axis=-1,
                        output_type=tf.int32,
                    )
                })
            if do_model_classification:
                predictions.update({
                    "gold_cls":
                    features["classification_class_index"],
                    "pred_cls":
                    tf.argmax(
                        outputs.logits_cls,
                        axis=-1,
                        output_type=tf.int32,
                    )
                })
                if config.num_classification_labels == 2:
                    predictions.update({
                        "logits_cls":
                        outputs.logits_cls[:, 1] - outputs.logits_cls[:, 0]
                    })
                else:
                    predictions.update({"logits_cls": outputs.logits_cls})
            if outputs.span_indexes is not None and outputs.span_logits is not None:
                predictions.update({"span_indexes": outputs.span_indexes})
                predictions.update({"span_logits": outputs.span_logits})

            if custom_prediction_keys:
                predictions = {
                    key: predictions[key]
                    for key in custom_prediction_keys
                }
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        return output_spec
Beispiel #4
0
    def model_fn(features, labels, mode, params):
        """The `model_fn` for TPUEstimator."""

        del labels  # Unused.

        tf.logging.info("*** Features ***")
        for name in sorted(features):
            tf.logging.info("  name = %s, shape = %s", name,
                            features[name].shape)

        label_ids = features["label_ids"]
        input_mask = features["input_mask"]
        row_ids = features["row_ids"]
        column_ids = features["column_ids"]
        # Table cells only, without question tokens and table headers.
        table_mask = tf.where(row_ids > 0, tf.ones_like(row_ids),
                              tf.zeros_like(row_ids))
        do_model_aggregation = config.num_aggregation_labels > 0
        aggregation_function_id = (tf.squeeze(
            features["aggregation_function_id"], axis=[1])
                                   if do_model_aggregation else None)

        do_model_classification = config.num_classification_labels > 0
        classification_class_index = (tf.squeeze(
            features["classification_class_index"], axis=[1])
                                      if do_model_classification else None)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = table_bert.create_model(
            features=features,
            mode=mode,
            bert_config=config.bert_config,
            disabled_features=config.disabled_features,
            disable_position_embeddings=config.disable_position_embeddings)

        if config.use_answer_as_supervision:
            answer = tf.squeeze(features["answer"], axis=[1])
            numeric_values = features["numeric_values"]
            numeric_values_scale = features["numeric_values_scale"]
        else:
            answer = None
            numeric_values = None
            numeric_values_scale = None

        (total_loss, logits, logits_aggregation, probabilities,
         logits_cls) = _get_classification_outputs(
             config=config,
             output_layer=model.get_sequence_output(),
             output_layer_aggregation=model.get_pooled_output(),
             label_ids=label_ids,
             input_mask=input_mask,
             table_mask=table_mask,
             aggregation_function_id=aggregation_function_id,
             answer=answer,
             numeric_values=numeric_values,
             numeric_values_scale=numeric_values_scale,
             is_training=is_training,
             row_ids=row_ids,
             column_ids=column_ids,
             classification_class_index=classification_class_index)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        init_checkpoint = config.init_checkpoint
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if config.use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(
                total_loss,
                config.learning_rate,
                config.num_train_steps,
                config.num_warmup_steps,
                config.use_tpu,
                gradient_accumulation_steps=params.get(
                    "gradient_accumulation_steps", 1),
                grad_clipping=config.grad_clipping)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (_calculate_eval_metrics_fn, [
                total_loss, label_ids, logits, input_mask,
                aggregation_function_id, logits_aggregation,
                classification_class_index, logits_cls
            ])
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "embeddings": model.get_sequence_output(),
                "probabilities": probabilities,
                "column_ids": features["column_ids"],
                "row_ids": features["row_ids"],
                "segment_ids": features["segment_ids"],
                "question_id_ints": features["question_id_ints"],
            }
            # TODO Remove once the data has been updated.
            if "question_id" in features:
                # Only available when predicting on GPU.
                predictions["question_id"] = features["question_id"]
            if do_model_aggregation:
                predictions.update({
                    "gold_aggr":
                    features["aggregation_function_id"],
                    "pred_aggr":
                    tf.argmax(logits_aggregation,
                              axis=-1,
                              output_type=tf.int32)
                })
            if do_model_classification:
                predictions.update({
                    "gold_cls":
                    features["classification_class_index"],
                    "pred_cls":
                    tf.argmax(logits_cls, axis=-1, output_type=tf.int32)
                })
                if config.num_classification_labels == 2:
                    predictions.update(
                        {"logits_cls": logits_cls[:, 1] - logits_cls[:, 0]})
                else:
                    predictions.update({"logits_cls": logits_cls})
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        return output_spec