Exemplo n.º 1
0
    def __init__(self,
                 num_classes,
                 features,
                 rep,
                 is_training,
                 loss_weighting=None):
        self.num_classes = num_classes
        self.label_ids = features["label_ids"]
        self.label_ids = tf.reshape(self.label_ids, [-1])
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(self.label_ids),
                                      dtype=tf.float32)
        self.is_real_example = is_real_example

        if is_training:
            rep = dropout(rep, 0.1)
        logits = tf.keras.layers.Dense(self.num_classes, name="cls_dense")(rep)

        self.logits = logits
        self.loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.label_ids)
        if loss_weighting is not None:
            print("Special flags : ", "bias_loss")
            self.loss_arr = loss_weighting(self.label_ids, self.loss_arr)

        self.loss = tf.reduce_mean(input_tensor=self.loss_arr)
        self.preds = tf.cast(tf.argmax(logits, axis=-1), dtype=tf.int32)
Exemplo n.º 2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_ranking")
        """The `model_fn` for TPUEstimator."""
        log_features(features)

        input_ids, input_mask, segment_ids = combine_paired_input_features(features)
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        # Updated

        model = BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        pooled_output = model.get_pooled_output()
        if is_training:
            pooled_output = dropout(pooled_output, 0.1)

        loss, losses, y_pred = apply_loss_modeling(modeling_opt, pooled_output, features)


        assignment_fn = assignment_map.get_bert_assignment_map
        scaffold_fn = checkpoint_init(assignment_fn, train_config)

        optimizer_factory = lambda x: create_optimizer_from_config(x, train_config)
        input_ids1 = tf.identity(features["input_ids1"])
        input_ids2 = tf.identity(features["input_ids2"])
        prediction = {
            "input_ids1": input_ids1,
            "input_ids2": input_ids2
        }
        return ranking_estimator_spec(mode, loss, losses, y_pred, scaffold_fn, optimizer_factory, prediction)
Exemplo n.º 3
0
    def call(self, pooled_output, label_ids):
        if self.is_training:
            pooled_output = dropout(pooled_output, 0.1)

        self.pooled_output = pooled_output
        #self.logits = tf.layers.dense(pooled_output, self.num_classes, name="cls_dense")
        output_weights = tf1.get_variable(
            "output_weights", [3, self.hidden_size],
            initializer=tf1.truncated_normal_initializer(stddev=0.02)
        )

        output_bias = tf1.get_variable(
            "output_bias", [3],
            initializer=tf1.zeros_initializer()
        )


        logits = tf.matmul(pooled_output, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        self.logits = logits

        preds = tf.cast(tf.argmax(self.logits, axis=-1), tf.int32)
        labels = tf.one_hot(label_ids, self.num_classes)
        # self.loss_arr = tf.nn.softmax_cross_entropy_with_logits_v2(
        #     logits=self.logits,
        #     labels=labels)
        self.loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits,
            labels=label_ids)

        self.loss = tf.reduce_mean(self.loss_arr)
        self.acc = tf_module.accuracy(self.logits, label_ids)
Exemplo n.º 4
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_ranking")
        log_features(features)

        input_ids, input_mask, segment_ids = combine_paired_input_features(
            features)
        batch_size, _ = get_shape_list(
            input_mask)  # This is not real batch_size, 2 * real_batch_size
        use_context = tf.ones([batch_size, 1], tf.int32)


        stacked_input_ids, stacked_input_mask, stacked_segment_ids, \
            = split_and_append_sep(input_ids, input_mask, segment_ids,
                                   config.total_sequence_length, config.window_size, CLS_ID, EOW_ID)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        with tf.compat.v1.variable_scope("sero"):
            model = model_class(config, is_training,
                                train_config.use_one_hot_embeddings)
            sequence_output_3d = model.network_stacked(stacked_input_ids,
                                                       stacked_input_mask,
                                                       stacked_segment_ids,
                                                       use_context)

        pooled_output = model.get_pooled_output()

        if is_training:
            pooled_output = dropout(pooled_output, 0.1)

        loss, losses, y_pred = apply_loss_modeling(config.loss, pooled_output,
                                                   features)

        assignment_fn = get_assignment_map_from_checkpoint_type(
            train_config.checkpoint_type, config.lower_layers)
        scaffold_fn = checkpoint_init(assignment_fn, train_config)
        prediction = {
            "stacked_input_ids": stacked_input_ids,
            "stacked_input_mask": stacked_input_mask,
            "stacked_segment_ids": stacked_segment_ids,
        }

        if train_config.gradient_accumulation != 1:
            optimizer_factory = lambda x: grad_accumulation.get_accumulated_optimizer_from_config(
                x, train_config, tf.compat.v1.trainable_variables(),
                train_config.gradient_accumulation)
        else:
            optimizer_factory = lambda x: create_optimizer_from_config(
                x, train_config)
        return ranking_estimator_spec(mode, loss, losses, y_pred, scaffold_fn,
                                      optimizer_factory, prediction)
Exemplo n.º 5
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_sero_classification")
        """The `model_fn` for TPUEstimator."""
        log_features(features)
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        # Updated
        model = BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        pooled_output = model.get_pooled_output()
        if is_training:
            pooled_output = dropout(pooled_output, 0.1)

        logits = get_prediction_structure(modeling_opt, pooled_output)
        loss = 0

        tvars = tf.compat.v1.trainable_variables()
        assignment_fn = assignment_map.get_bert_assignment_map
        initialized_variable_names, init_fn = get_init_fn(tvars, train_config.init_checkpoint, assignment_fn)
        scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        predictions = None
        if modeling_opt == "multi_label_hinge":
            predictions = {
                "input_ids":input_ids,
                "logits":logits,
            }
        else:
            predictions = {
                "input_ids": input_ids,
                "logits": logits,
            }
            useful_inputs = ["data_id", "input_ids2", "data_ids"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]
        output_spec = rank_predict_estimator_spec(logits, mode, scaffold_fn, predictions)
        return output_spec
Exemplo n.º 6
0
 def apply(self, prev_output, batch_size, seq_length, attention_mask):
     layer_input = prev_output
     attention_output = self.self_attention.call(
         layer_input,
         attention_mask,
         batch_size,
         seq_length,
     )
     with tf.compat.v1.variable_scope("intermediate"):
         intermediate_output = self.intermediate_ff(attention_output)
     with tf.compat.v1.variable_scope("output"):
         layer_output = self.output_ff(intermediate_output)
         layer_output = bc.dropout(layer_output,
                                   self.config.hidden_dropout_prob)
         layer_output = bc.layer_norm(layer_output + attention_output)
     return intermediate_output, layer_output
def self_attention_with_add(layer_input, attention_mask, config, batch_size,
                            seq_length, hidden_size, initializer, values,
                            add_locations):
    attention_head_size = int(hidden_size / config.num_attention_heads)
    with tf.compat.v1.variable_scope("attention"):
        attention_heads = []
        with tf.compat.v1.variable_scope("self"):
            attention_head = bc.attention_layer(
                from_tensor=layer_input,
                to_tensor=layer_input,
                attention_mask=attention_mask,
                num_attention_heads=config.num_attention_heads,
                size_per_head=attention_head_size,
                attention_probs_dropout_prob=config.
                attention_probs_dropout_prob,
                initializer_range=config.initializer_range,
                do_return_2d_tensor=True,
                batch_size=batch_size,
                from_seq_length=seq_length,
                to_seq_length=seq_length)
            attention_heads.append(attention_head)

        attention_output = None
        if len(attention_heads) == 1:
            attention_output = attention_heads[0]
        else:
            # In the case where we have other sequences, we just concatenate
            # them to the self-attention head before the projection.
            attention_output = tf.concat(attention_heads, axis=-1)

        # [batch*seq_length, hidden_dim] , [batch, n_locations]
        attention_output = tf.tensor_scatter_nd_add(attention_output,
                                                    add_locations, values)

        # Run a linear projection of `hidden_size` then add a residual
        # with `layer_input`.
        with tf.compat.v1.variable_scope("output"):
            attention_output = bc.dense(hidden_size,
                                        initializer)(attention_output)
            attention_output = bc.dropout(attention_output,
                                          config.hidden_dropout_prob)
            attention_output = bc.layer_norm(attention_output + layer_input)
    return attention_output
Exemplo n.º 8
0
    def __call__(self, inputs):
        from_tensor, to_tensor_list, attention_mask = inputs

        attention_output = attention_layer(
            from_tensor=from_tensor,
            to_tensor_list=to_tensor_list,
            query_ff=self.sub_layers['query'],
            key_ff=self.sub_layers['key'],
            value_ff=self.sub_layers['value'],
            attention_mask=attention_mask,
            num_attention_heads=self.num_attention_heads,
            size_per_head=self.attention_head_size,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
        )

        attention_output = self.sub_layers['output'](attention_output)
        attention_output = bc.dropout(attention_output,
                                      self.hidden_dropout_prob)
        attention_output = bc.layer_norm(attention_output + from_tensor.matrix)
        return attention_output
Exemplo n.º 9
0
    def call(self, layer_input, attention_mask, batch_size, seq_length):
        attention_heads = []
        with tf.compat.v1.variable_scope("attention"):
            with tf.compat.v1.variable_scope("self"):
                attention_head = bc.attention_layer2(
                    from_tensor=layer_input,
                    to_tensor=layer_input,
                    query_ff=self.query_layer,
                    key_ff=self.key_layer,
                    value_ff=self.value_layer,
                    attention_mask=attention_mask,
                    num_attention_heads=self.num_attention_heads,
                    size_per_head=self.attention_head_size,
                    attention_probs_dropout_prob=self.
                    attention_probs_dropout_prob,
                    do_return_2d_tensor=True,
                    batch_size=batch_size,
                    from_seq_length=seq_length,
                    to_seq_length=seq_length)
                attention_heads.append(attention_head)

            attention_output = None
            if len(attention_heads) == 1:
                attention_output = attention_heads[0]
            else:
                # In the case where we have other sequences, we just concatenate
                # them to the self-attention head before the projection.
                attention_output = tf.concat(attention_heads, axis=-1)

            # Run a linear projection of `hidden_size` then add a residual
            # with `layer_input`.
            with tf.compat.v1.variable_scope("output"):
                attention_output = self.output_layer(attention_output)
                attention_output = bc.dropout(attention_output,
                                              self.hidden_dropout_prob)
                attention_output = bc.layer_norm(attention_output +
                                                 layer_input)
        return attention_output
    def forward_layer_with_added(self, prev_output, added_value, locations):
        hidden_size = self.config.hidden_size
        layer_input = prev_output
        attention_output = self_attention_with_add(
            layer_input, self.attention_mask, self.config, self.batch_size,
            self.seq_length, hidden_size, self.initializer, added_value,
            locations)

        with tf.compat.v1.variable_scope("intermediate"):
            intermediate_output = bc.dense(
                self.config.intermediate_size,
                self.initializer,
                activation=bc.get_activation(
                    self.config.hidden_act))(attention_output)

        with tf.compat.v1.variable_scope("output"):
            layer_output = bc.dense(hidden_size,
                                    self.initializer)(intermediate_output)
            layer_output = bc.dropout(layer_output,
                                      self.config.hidden_dropout_prob)
            layer_output = bc.layer_norm(layer_output + attention_output)
            prev_output = layer_output
        return intermediate_output, layer_output
Exemplo n.º 11
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" % (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        if mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.ones([input_ids.shape[0]], dtype=tf.float32)
        else:
            label_ids = features["label_ids"]
            label_ids = tf.reshape(label_ids, [-1])
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = BertModel(
            config=model_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        pooled = model.get_pooled_output()
        if is_training:
            pooled = dropout(pooled, 0.1)
        logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled)
        scale = model_config.scale

        label_ids = scale * label_ids

        weight = tf.abs(label_ids)
        loss_arr = tf.keras.losses.MAE(y_true=label_ids, y_pred=logits)
        loss_arr = loss_arr * weight

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None

        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)
        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec

        def metric_fn(logits, label, is_real_example):
            mae = tf.compat.v1.metrics.mean_absolute_error(
                labels=label, predictions=logits, weights=is_real_example)

            return {
                "mae": mae
            }

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            tvars = None
            train_op = optimization.create_optimizer_from_config(loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn, [
                logits, label_ids, is_real_example
            ])
            output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
        else:
            predictions = {
                    "input_ids": input_ids,
                    "logits": logits,
            }
            if "data_id" in features:
                predictions['data_id'] = features['data_id']
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    scaffold_fn=scaffold_fn)
        return output_spec
Exemplo n.º 12
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        if mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32)
        else:
            label_ids = features["label_ids"]
            label_ids = tf.reshape(label_ids, [-1])
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        input_ids2 = features["input_ids2"]
        input_mask2 = features["input_mask2"]
        segment_ids2 = features["segment_ids2"]
        with tf.compat.v1.variable_scope(dual_model_prefix1):
            model_1 = BertModel(
                config=model_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
            pooled = model_1.get_pooled_output()
            if is_training:
                pooled = dropout(pooled, 0.1)
            logits = tf.keras.layers.Dense(train_config.num_classes,
                                           name="cls_dense")(pooled)
        with tf.compat.v1.variable_scope(dual_model_prefix2):
            model_2 = BertModel(
                config=model_config,
                is_training=is_training,
                input_ids=input_ids2,
                input_mask=input_mask2,
                token_type_ids=segment_ids2,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
            pooled = model_2.get_pooled_output()
            if is_training:
                pooled = dropout(pooled, 0.1)
            conf_probs = tf.keras.layers.Dense(
                train_config.num_classes,
                name="cls_dense",
                activation=tf.keras.activations.softmax)(pooled)

            confidence = conf_probs[:, 1]
        confidence_loss = 1 - confidence

        cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)

        k = model_config.k
        alpha = model_config.alpha
        loss_arr = cls_loss * confidence + confidence_loss * k

        loss_arr = apply_weighted_loss(loss_arr, label_ids, alpha)

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)
        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec

        def metric_fn(log_probs, label, is_real_example, confidence):
            r = classification_metric_fn(log_probs, label, is_real_example)
            r['confidence'] = tf.compat.v1.metrics.mean(confidence)
            return r

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            tvars = None
            train_op = optimization.create_optimizer_from_config(
                loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn,
                            [logits, label_ids, is_real_example, confidence])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "logits": logits,
                "confidence": confidence,
            }
            if "data_id" in features:
                predictions['data_id'] = features['data_id']
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        return output_spec
Exemplo n.º 13
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        tf_logging.info("model_fn_pooling_long_things")
        log_features(features)
        input_ids = features["input_ids"]  # [batch_size, seq_length]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        label_ids = tf.reshape(label_ids, [-1])

        batch_size, _ = get_shape_list(
            input_mask)  # This is not real batch_size, 2 * real_batch_size
        use_context = tf.ones([batch_size, 1], tf.int32)
        total_sequence_length = config.total_sequence_length
        stacked_input_ids, stacked_input_mask, stacked_segment_ids, \
            = split_and_append_sep2(input_ids[:, :total_sequence_length],
                                    input_mask[:, :total_sequence_length],
                                    segment_ids[:, :total_sequence_length],
                                   total_sequence_length, config.window_size, CLS_ID, EOW_ID)
        if "focus_mask" in features:
            focus_mask = features["focus_mask"]
            _, stacked_focus_mask, _, \
                = split_and_append_sep2(input_ids[:, :total_sequence_length],
                                        focus_mask[:, :total_sequence_length],
                                        segment_ids[:, :total_sequence_length],
                                        total_sequence_length, config.window_size, CLS_ID, EOW_ID)
            features["focus_mask"] = r3to2(stacked_focus_mask)

        batch_size, num_seg, seq_len = get_shape_list2(stacked_input_ids)

        input_ids_2d = r3to2(stacked_input_ids)
        input_mask_2d = r3to2(stacked_input_mask)
        segment_ids_2d = r3to2(stacked_segment_ids)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        if "feed_features" in special_flags:
            model = model_class(
                config=config,
                is_training=is_training,
                input_ids=input_ids_2d,
                input_mask=input_mask_2d,
                token_type_ids=segment_ids_2d,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
                features=features,
            )
        else:
            model = model_class(
                config=config,
                is_training=is_training,
                input_ids=input_ids_2d,
                input_mask=input_mask_2d,
                token_type_ids=segment_ids_2d,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )

        sequence_output_2d = model.get_sequence_output()
        pooled_output = model.get_pooled_output()

        if is_training:
            pooled_output = dropout(pooled_output, 0.1)

        pooled_output_3d = tf.reshape(pooled_output, [batch_size, num_seg, -1])
        sequence_output_3d = tf.reshape(sequence_output_2d,
                                        [batch_size, num_seg, seq_len, -1])
        logits = pooling_modeling(config.option_name, train_config.num_classes,
                                  pooled_output_3d, sequence_output_3d)

        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)
        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = classification_model_fn.get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer_from_config(
                loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=None,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            predictions = {"input_ids": input_ids, "logits": logits}
            if "data_id" in features:
                predictions['data_id'] = features['data_id']
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           predictions=predictions,
                                           scaffold_fn=scaffold_fn)
        return output_spec
Exemplo n.º 14
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        input_shape = get_shape_list2(input_ids)
        batch_size, seq_length = input_shape

        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones([batch_size, 1], dtype=tf.float32)
        label_ids = tf.reshape(
            label_ids, [batch_size, seq_length, train_config.num_classes])
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model = BertModel(
            config=model_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        seq_out = model.get_sequence_output()
        if is_training:
            seq_out = dropout(seq_out, 0.1)
        logits = tf.keras.layers.Dense(train_config.num_classes,
                                       name="cls_dense")(seq_out)

        probs = tf.math.sigmoid(logits)

        eps = 1e-10
        label_logs = tf.math.log(label_ids + eps)
        #scale = model_config.scale
        #label_ids = scale * label_ids

        is_valid_mask = tf.cast(segment_ids, tf.float32)
        #loss_arr = tf.keras.losses.MAE(y_true=label_ids, y_pred=probs)
        loss_arr = tf.keras.losses.MAE(y_true=label_logs, y_pred=logits)
        loss_arr = loss_arr * is_valid_mask

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None

        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)
        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec

        def metric_fn(probs, label, is_real_example):
            cut = math.exp(-10)
            pred_binary = probs > cut
            label_binary_all = label > cut

            pred_binary = pred_binary[:, :, 0]
            label_binary_1 = label_binary_all[:, :, 1]
            label_binary_0 = label_binary_all[:, :, 0]

            precision = tf.compat.v1.metrics.precision(predictions=pred_binary,
                                                       labels=label_binary_0)
            recall = tf.compat.v1.metrics.recall(predictions=pred_binary,
                                                 labels=label_binary_0)
            true_rate_1 = tf.compat.v1.metrics.mean(label_binary_1)
            true_rate_0 = tf.compat.v1.metrics.mean(label_binary_0)
            mae = tf.compat.v1.metrics.mean_absolute_error(
                labels=label, predictions=probs, weights=is_real_example)

            return {
                "mae": mae,
                "precision": precision,
                "recall": recall,
                "true_rate_1": true_rate_1,
                "true_rate_0": true_rate_0,
            }

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            tvars = None
            train_op = optimization.create_optimizer_from_config(
                loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (metric_fn, [probs, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "logits": logits,
                "label_ids": label_ids,
            }
            if "data_id" in features:
                predictions['data_id'] = features['data_id']
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        return output_spec
Exemplo n.º 15
0
def attention_layer_w_ext(from_tensor,
                                        to_tensor,
                                        attention_mask=None,
                                        num_attention_heads=1,
                                        size_per_head=512,
                                        ext_slice=None, # [Num_tokens, n_items, hidden_dim]
                                        query_act=None,
                                        key_act=None,
                                        value_act=None,
                                        attention_probs_dropout_prob=0.0,
                                        initializer_range=0.02,
                                        do_return_2d_tensor=False,
                                        batch_size=None,
                                        from_seq_length=None,
                                        to_seq_length=None):
    """Performs multi-headed attention from `from_tensor` to `to_tensor`.

    This is an implementation of multi-headed attention based on "Attention
    is all you Need". If `from_tensor` and `to_tensor` are the same, then
    this is self-attention. Each timestep in `from_tensor` attends to the
    corresponding sequence in `to_tensor`, and returns a fixed-with vector.

    This function first projects `from_tensor` into a "query" tensor and
    `to_tensor` into "key" and "value" tensors. These are (effectively) a list
    of tensors of length `num_attention_heads`, where each tensor is of shape
    [batch_size, seq_length, size_per_head].

    Then, the query and key tensors are dot-producted and scaled. These are
    softmaxed to obtain attention probabilities. The value tensors are then
    interpolated by these probabilities, then concatenated back to a single
    tensor and returned.

    In practice, the multi-headed attention are done with transposes and
    reshapes rather than actual separate tensors.

    Args:
        from_tensor: float Tensor of shape [batch_size, from_seq_length,
            from_width].
        to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
        attention_mask: (optional) int32 Tensor of shape [batch_size,
            from_seq_length, to_seq_length]. The values should be 1 or 0. The
            attention scores will effectively be set to -infinity for any positions in
            the mask that are 0, and will be unchanged for positions that are 1.
        num_attention_heads: int. Number of attention heads.
        size_per_head: int. Size of each attention head.
        query_act: (optional) Activation function for the query transform.
        key_act: (optional) Activation function for the key transform.
        value_act: (optional) Activation function for the value transform.
        attention_probs_dropout_prob: (optional) float. Dropout probability of the
            attention probabilities.
        initializer_range: float. Range of the weight initializer.
        do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
            * from_seq_length, num_attention_heads * size_per_head]. If False, the
            output will be of shape [batch_size, from_seq_length, num_attention_heads
            * size_per_head].
        batch_size: (Optional) int. If the input is 2D, this might be the batch size
            of the 3D version of the `from_tensor` and `to_tensor`.
        from_seq_length: (Optional) If the input is 2D, this might be the seq length
            of the 3D version of the `from_tensor`.
        to_seq_length: (Optional) If the input is 2D, this might be the seq length
            of the 3D version of the `to_tensor`.

    Returns:
        float Tensor of shape [batch_size, from_seq_length,
            num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
            true, this will be of shape [batch_size * from_seq_length,
            num_attention_heads * size_per_head]).

    Raises:
        ValueError: Any of the arguments or tensor shapes are invalid.
    """

    def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                                                     seq_length, width):
        output_tensor = tf.reshape(
                input_tensor, [batch_size, seq_length, num_attention_heads, width])

        output_tensor = tf.transpose(a=output_tensor, perm=[0, 2, 1, 3])
        return output_tensor

    from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
    to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])

    if len(from_shape) != len(to_shape):
        raise ValueError(
                "The rank of `from_tensor` must match the rank of `to_tensor`.")

    if len(from_shape) == 3:
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]
        to_seq_length = to_shape[1]
    elif len(from_shape) == 2:
        if (batch_size is None or from_seq_length is None or to_seq_length is None):
            raise ValueError(
                    "When passing in rank 2 tensors to attention_layer, the values "
                    "for `batch_size`, `from_seq_length`, and `to_seq_length` "
                    "must all be specified.")

    # Scalar dimensions referenced here:
    #     B = batch size (number of sequences)
    #     F = `from_tensor` sequence length
    #     T = `to_tensor` sequence length
    #     N = `num_attention_heads`
    #     H = `size_per_head`

    from_tensor_2d = reshape_to_matrix(from_tensor)
    to_tensor_2d = reshape_to_matrix(to_tensor)

    def get_ext_slice(idx):
        return ext_slice[:, idx, :]

    print("from_tensor_2d ", from_tensor_2d.shape)

    query_in = from_tensor_2d + get_ext_slice(EXT_QUERY_IN)
    query_in = from_tensor_2d

    # `query_layer` = [B*F, N*H]
    query_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=query_act,
            name="query",
            kernel_initializer=create_initializer(initializer_range))(query_in)

    query_layer = query_layer + get_ext_slice(EXT_QUERY_OUT)

    key_in = to_tensor_2d
    key_in = to_tensor_2d + get_ext_slice(EXT_KEY_IN)
    # `key_layer` = [B*T, N*H]
    key_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=key_act,
            name="key",
            kernel_initializer=create_initializer(initializer_range))(key_in)

    key_layer = key_layer + get_ext_slice(EXT_KEY_OUT)

    value_in = to_tensor_2d
    value_in = to_tensor_2d + get_ext_slice(EXT_VALUE_IN)
    # `value_layer` = [B*T, N*H]
    value_layer = tf.keras.layers.Dense(
            num_attention_heads * size_per_head,
            activation=value_act,
            name="value",
            kernel_initializer=create_initializer(initializer_range))(value_in)

    value_layer = value_layer + get_ext_slice(EXT_VALUE_OUT)

    # `query_layer` = [B, N, F, H]
    query_layer = transpose_for_scores(query_layer, batch_size,
                                                     num_attention_heads, from_seq_length,
                                                     size_per_head)

    # `key_layer` = [B, N, T, H]
    key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
                                                                     to_seq_length, size_per_head)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    # `attention_scores` = [B, N, F, T]
    attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
    attention_scores = tf.multiply(attention_scores,
                                                                 1.0 / math.sqrt(float(size_per_head)))

    if attention_mask is not None:
        # `attention_mask` = [B, 1, F, T]
        attention_mask = tf.expand_dims(attention_mask, axis=[1])

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        attention_scores += adder


    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = tf.nn.softmax(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = dropout(attention_probs, attention_probs_dropout_prob)

    # `value_layer` = [B, T, N, H]
    value_layer = tf.reshape(
            value_layer,
            [batch_size, to_seq_length, num_attention_heads, size_per_head])

    # `value_layer` = [B, N, T, H]
    value_layer = tf.transpose(a=value_layer, perm=[0, 2, 1, 3])

    # `context_layer` = [B, N, F, H]
    context_layer = tf.matmul(attention_probs, value_layer)

    # `context_layer` = [B, F, N, H]
    context_layer = tf.transpose(a=context_layer, perm=[0, 2, 1, 3])

    if do_return_2d_tensor:
        # `context_layer` = [B*F, N*V]
        context_layer = tf.reshape(
                context_layer,
                [batch_size * from_seq_length, num_attention_heads * size_per_head])
    else:
        # `context_layer` = [B, F, N*V]
        context_layer = tf.reshape(
                context_layer,
                [batch_size, from_seq_length, num_attention_heads * size_per_head])

    return context_layer
Exemplo n.º 16
0
 def __call__(self, inputs):
     intermediate_output = self.intermediate_ff(inputs)
     layer_output = self.output_ff(intermediate_output)
     layer_output = bc.dropout(layer_output, self.hidden_dropout_prob)
     layer_output = bc.layer_norm(layer_output + inputs)
     return layer_output
Exemplo n.º 17
0
def transformer_model(input_tensor,
                    attention_mask=None,
                    input_mask=None,
                    hidden_size=768,
                    num_hidden_layers=12,
                    num_attention_heads=12,
                    mr_num_route=10,
                    intermediate_size=3072,
                    intermediate_act_fn=gelu,
                    hidden_dropout_prob=0.1,
                    attention_probs_dropout_prob=0.1,
                    initializer_range=0.02,
                    is_training=True,
                    do_return_all_layers=False):
    """Multi-headed, multi-layer Transformer from "Attention is All You Need".

    This is almost an exact implementation of the original Transformer encoder.

    See the original paper:
    https://arxiv.org/abs/1706.03762

    Also see:
    https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py

    Args:
        input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
        attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
            seq_length], with 1 for positions that can be attended to and 0 in
            positions that should not be.
        hidden_size: int. Hidden size of the Transformer.
        num_hidden_layers: int. Number of layers (blocks) in the Transformer.
        num_attention_heads: int. Number of attention heads in the Transformer.
        intermediate_size: int. The size of the "intermediate" (a.k.a., feed
            forward) layer.
        intermediate_act_fn: function. The non-linear activation function to apply
            to the output of the intermediate/feed-forward layer.
        hidden_dropout_prob: float. Dropout probability for the hidden layers.
        attention_probs_dropout_prob: float. Dropout probability of the attention
            probabilities.
        initializer_range: float. Range of the initializer (stddev of truncated
            normal).
        do_return_all_layers: Whether to also return all layers or just the final
            layer.

    Returns:
        float Tensor of shape [batch_size, seq_length, hidden_size], the final
        hidden layer of the Transformer.

    Raises:
        ValueError: A Tensor shape or parameter is invalid.
    """
    if hidden_size % num_attention_heads != 0:
        raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))

    attention_head_size = int(hidden_size / num_attention_heads)
    input_shape = get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    input_width = input_shape[2]

    initializer = create_initializer(initializer_range)

    ext_tensor = tf.compat.v1.get_variable("ext_tensor",
                                 shape=[num_hidden_layers, mr_num_route, EXT_SIZE ,hidden_size],
                                 initializer=initializer,
                                 )
    ext_tensor_inter = tf.compat.v1.get_variable("ext_tensor_inter",
                                       shape=[num_hidden_layers, mr_num_route, intermediate_size],
                                       initializer=initializer,
                                           )
    # The Transformer performs sum residuals on all layers so the input needs
    # to be the same as the hidden size.
    if input_width != hidden_size:
        raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
                                         (input_width, hidden_size))

    # We keep the representation as a 2D tensor to avoid re-shaping it back and
    # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
    # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
    # help the optimizer.
    prev_output = reshape_to_matrix(input_tensor)

    def is_mr_layer(layer_idx):
        if layer_idx > 1:
            return True
        else:
            return False

    all_layer_outputs = []
    for layer_idx in range(num_hidden_layers):
        if not is_mr_layer(layer_idx):
            with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                layer_input = prev_output

                with tf.compat.v1.variable_scope("attention"):
                    attention_heads = []
                    with tf.compat.v1.variable_scope("self"):
                        attention_head = attention_layer(
                                from_tensor=layer_input,
                                to_tensor=layer_input,
                                attention_mask=attention_mask,
                                num_attention_heads=num_attention_heads,
                                size_per_head=attention_head_size,
                                attention_probs_dropout_prob=attention_probs_dropout_prob,
                                initializer_range=initializer_range,
                                do_return_2d_tensor=True,
                                batch_size=batch_size,
                                from_seq_length=seq_length,
                                to_seq_length=seq_length)
                        attention_heads.append(attention_head)

                    attention_output = None
                    if len(attention_heads) == 1:
                        attention_output = attention_heads[0]
                    else:
                        # In the case where we have other sequences, we just concatenate
                        # them to the self-attention head before the projection.
                        attention_output = tf.concat(attention_heads, axis=-1)

                    # Run a linear projection of `hidden_size` then add a residual
                    # with `layer_input`.
                    with tf.compat.v1.variable_scope("output"):
                        attention_output = dense(hidden_size, initializer)(attention_output)
                        attention_output = dropout(attention_output, hidden_dropout_prob)
                        attention_output = layer_norm(attention_output + layer_input)

                # The activation is only applied to the "intermediate" hidden layer.
                with tf.compat.v1.variable_scope("intermediate"):
                    intermediate_output = dense(intermediate_size, initializer,
                                                activation=intermediate_act_fn)(attention_output)

                # Down-project back to `hidden_size` then add the residual.
                with tf.compat.v1.variable_scope("output"):
                    layer_output = dense(hidden_size, initializer)(intermediate_output)
                    layer_output = dropout(layer_output, hidden_dropout_prob)
                    layer_output = layer_norm(layer_output + attention_output)
                    prev_output = layer_output
                    all_layer_outputs.append(layer_output)

                with tf.compat.v1.variable_scope("mr_key"):
                    key_output = tf.keras.layers.Dense(
                        mr_num_route,
                        kernel_initializer=create_initializer(initializer_range))(intermediate_output)
                    key_output = dropout(key_output, hidden_dropout_prob)

                    if is_training:
                        key = tf.random.categorical(key_output, 1) # [batch_size, 1]
                        key = tf.reshape(key, [-1])
                    else:
                        key = tf.math.argmax(input=key_output, axis=1)

        else: # Case MR layer
            with tf.compat.v1.variable_scope("layer_%d" % layer_idx):
                layer_input = prev_output
                ext_slice = tf.gather(ext_tensor[layer_idx], key)
                ext_interm_slice = tf.gather(ext_tensor_inter[layer_idx], key)
                print("ext_slice (batch*seq, ", ext_slice.shape)
                with tf.compat.v1.variable_scope("attention"):
                    attention_heads = []
                    with tf.compat.v1.variable_scope("self"):
                        attention_head = attention_layer_w_ext(
                            from_tensor=layer_input,
                            to_tensor=layer_input,
                            attention_mask=attention_mask,
                            ext_slice=ext_slice,
                            num_attention_heads=num_attention_heads,
                            size_per_head=attention_head_size,
                            attention_probs_dropout_prob=attention_probs_dropout_prob,
                            initializer_range=initializer_range,
                            do_return_2d_tensor=True,
                            batch_size=batch_size,
                            from_seq_length=seq_length,
                            to_seq_length=seq_length)
                        attention_head = attention_head + ext_slice[:,EXT_ATT_OUT,:]
                        attention_heads.append(attention_head)

                    attention_output = None
                    if len(attention_heads) == 1:
                        attention_output = attention_heads[0]
                    else:
                        # In the case where we have other sequences, we just concatenate
                        # them to the self-attention head before the projection.
                        attention_output = tf.concat(attention_heads, axis=-1)

                    # Run a linear projection of `hidden_size` then add a residual
                    # with `layer_input`.
                    with tf.compat.v1.variable_scope("output"):
                        attention_output = dense(hidden_size, initializer)(attention_output)
                        attention_output = dropout(attention_output, hidden_dropout_prob)
                        attention_output = attention_output + ext_slice[:,EXT_ATT_PROJ,:]
                        attention_output = layer_norm(attention_output + layer_input)

                # The activation is only applied to the "intermediate" hidden layer.
                with tf.compat.v1.variable_scope("intermediate"):
                    intermediate_output = dense(intermediate_size, initializer,
                                                activation=intermediate_act_fn)(attention_output)
                    intermediate_output = ext_interm_slice + intermediate_output
                # Down-project back to `hidden_size` then add the residual.
                with tf.compat.v1.variable_scope("output"):
                    layer_output = dense(hidden_size, initializer)(intermediate_output)
                    layer_output = layer_output + ext_slice[:, EXT_LAYER_OUT,:]
                    layer_output = dropout(layer_output, hidden_dropout_prob)
                    layer_output = layer_norm(layer_output + attention_output)
                    prev_output = layer_output
                    all_layer_outputs.append(layer_output)

    if do_return_all_layers:
        final_outputs = []
        for layer_output in all_layer_outputs:
            final_output = reshape_from_matrix(layer_output, input_shape)
            final_outputs.append(final_output)
        return final_outputs, key
    else:
        final_output = reshape_from_matrix(prev_output, input_shape)
        return final_output, key
Exemplo n.º 18
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        if mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32)
        else:
            label_ids = features["label_ids"]
            label_ids = tf.reshape(label_ids, [-1])
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        domain_ids = features["domain_ids"]
        domain_ids = tf.reshape(domain_ids, [-1])

        is_valid_label = features["is_valid_label"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        model_1 = BertModel(
            config=model_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        pooled = model_1.get_pooled_output()
        if is_training:
            pooled = dropout(pooled, 0.1)

        logits = tf.keras.layers.Dense(train_config.num_classes,
                                       name="cls_dense")(pooled)
        pred_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)
        num_domain = 2
        pooled_for_domain = grad_reverse(pooled)
        domain_logits = tf.keras.layers.Dense(
            num_domain, name="domain_dense")(pooled_for_domain)
        domain_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=domain_logits, labels=domain_ids)

        pred_loss = tf.reduce_mean(pred_losses *
                                   tf.cast(is_valid_label, tf.float32))
        domain_loss = tf.reduce_mean(domain_losses)

        tf.compat.v1.summary.scalar('domain_loss', domain_loss)
        tf.compat.v1.summary.scalar('pred_loss', pred_loss)
        alpha = model_config.alpha
        loss = pred_loss + alpha * domain_loss
        tvars = tf.compat.v1.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)
        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            tvars = None
            train_op = optimization.create_optimizer_from_config(
                loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "input_ids": input_ids,
                "logits": logits,
            }
            if "data_id" in features:
                predictions['data_id'] = features['data_id']
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        return output_spec
Exemplo n.º 19
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        vectors = features["vectors"]  # [batch_size, max_unit, num_hidden]
        valid_mask = features["valid_mask"]
        label_ids = features["label_ids"]
        vectors = tf.reshape(vectors, [
            -1, model_config.num_window, model_config.max_sequence,
            model_config.hidden_size
        ])
        valid_mask = tf.reshape(
            valid_mask,
            [-1, model_config.num_window, model_config.max_sequence])
        label_ids = tf.reshape(label_ids, [-1])

        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = MultiEvidenceCombiner(config=model_config,
                                      is_training=is_training,
                                      vectors=vectors,
                                      valid_mask=valid_mask,
                                      scope=None)
        pooled = model.pooled_output
        if is_training:
            pooled = dropout(pooled, 0.1)

        logits = tf.keras.layers.Dense(config.num_classes,
                                       name="cls_dense")(pooled)
        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)

        if "bias_loss" in special_flags:
            tf_logging.info("Using special_flags : bias_loss")
            loss_arr = reweight_zero(label_ids, loss_arr)

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn, config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss, config.learning_rate,
                                                   config.use_tpu)
            else:
                if "ask_tvar" in special_flags:
                    tvars = model.get_trainable_vars()
                else:
                    tvars = None
                train_op = optimization.create_optimizer_from_config(
                    loss, config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {"logits": logits, "label_ids": label_ids}
            if override_prediction_fn is not None:
                predictions = override_prediction_fn(predictions, model)

            useful_inputs = ["data_id", "input_ids2"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Exemplo n.º 20
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        label_ids = tf.reshape(label_ids, [-1])

        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        if "feed_features" in special_flags:
            model = model_class(
                config=bert_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
                features=features,
            )
        else:
            model = model_class(
                config=bert_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
        if "new_pooling" in special_flags:
            pooled = mimic_pooling(model.get_sequence_output(),
                                   bert_config.hidden_size,
                                   bert_config.initializer_range)
        else:
            pooled = model.get_pooled_output()

        if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits:
            tf_logging.info("Use old version of logistic regression")
            logits = tf.keras.layers.Dense(train_config.num_classes,
                                           name="cls_dense")(pooled)
        else:
            tf_logging.info("Use fixed version of logistic regression")
            output_weights = tf.compat.v1.get_variable(
                "output_weights", [3, bert_config.hidden_size],
                initializer=tf.compat.v1.truncated_normal_initializer(
                    stddev=0.02))

            output_bias = tf.compat.v1.get_variable(
                "output_bias", [3],
                initializer=tf.compat.v1.zeros_initializer())

            if is_training:
                pooled = dropout(pooled, 0.1)

            logits = tf.matmul(pooled, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)

        if "bias_loss" in special_flags:
            tf_logging.info("Using special_flags : bias_loss")
            loss_arr = reweight_zero(label_ids, loss_arr)

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss,
                                                   train_config.learning_rate,
                                                   train_config.use_tpu)
            else:
                train_op = optimization.create_optimizer_from_config(
                    loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            probs = tf.nn.softmax(logits, axis=-1)
            gradient_list = tf.gradients(probs[:, 1], model.embedding_output)
            print(len(gradient_list))
            gradient = gradient_list[0]
            print(gradient.shape)
            gradient = tf.reduce_sum(gradient, axis=2)
            predictions = {
                "input_ids": input_ids,
                "gradient": gradient,
                "labels": label_ids,
                "logits": logits
            }
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Exemplo n.º 21
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        label_ids = tf.reshape(label_ids, [-1])
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model_1 = BertModel(
            config=model_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
        pooled = model_1.get_pooled_output()
        if is_training:
            pooled = dropout(pooled, 0.1)
        logits = tf.keras.layers.Dense(train_config.num_classes,
                                       name="cls_dense")(pooled)
        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)
        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)
        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec

        global_step = tf.compat.v1.train.get_or_create_global_step()
        init_lr = train_config.learning_rate
        num_warmup_steps = train_config.num_warmup_steps
        num_train_steps = train_config.num_train_steps

        learning_rate2_const = tf.constant(value=init_lr,
                                           shape=[],
                                           dtype=tf.float32)
        learning_rate2_decayed = tf.compat.v1.train.polynomial_decay(
            learning_rate2_const,
            global_step,
            num_train_steps,
            end_learning_rate=0.0,
            power=1.0,
            cycle=False)
        if num_warmup_steps:
            global_steps_int = tf.cast(global_step, tf.int32)
            warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

            global_steps_float = tf.cast(global_steps_int, tf.float32)
            warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

            warmup_percent_done = global_steps_float / warmup_steps_float
            warmup_learning_rate = init_lr * warmup_percent_done

            is_warmup = tf.cast(global_steps_int < warmup_steps_int,
                                tf.float32)
            learning_rate = ((1.0 - is_warmup) * learning_rate2_decayed +
                             is_warmup * warmup_learning_rate)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            tvars = None
            train_op = optimization.create_optimizer_from_config(
                loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:

            def reform_scala(t):
                return tf.reshape(t, [1])

            predictions = {
                "input_ids": input_ids,
                "label_ids": label_ids,
                "logits": logits,
                "learning_rate2_const": reform_scala(learning_rate2_const),
                "warmup_percent_done": reform_scala(warmup_percent_done),
                "warmup_learning_rate": reform_scala(warmup_learning_rate),
                "learning_rate": reform_scala(learning_rate),
                "learning_rate2_decayed": reform_scala(learning_rate2_decayed),
            }
            if "data_id" in features:
                predictions['data_id'] = features['data_id']
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        return output_spec
Exemplo n.º 22
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        query = features["query"]
        doc = features["doc"]
        doc_mask = features["doc_mask"]
        data_ids = features["data_id"]

        segment_len = max_seq_length - query_len - 3
        step_size = model_config.step_size
        input_ids, input_mask, segment_ids, n_segments = \
            iterate_over(query, doc, doc_mask, total_doc_len, segment_len, step_size)
        if mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32)
        else:
            label_ids = features["label_ids"]
            label_ids = tf.reshape(label_ids, [-1])

        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        if "feed_features" in special_flags:
            model = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
                features=features,
            )
        else:
            model = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
        if "new_pooling" in special_flags:
            pooled = mimic_pooling(model.get_sequence_output(),
                                   model_config.hidden_size,
                                   model_config.initializer_range)
        else:
            pooled = model.get_pooled_output()

        if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits:
            tf_logging.info("Use old version of logistic regression")
            if is_training:
                pooled = dropout(pooled, 0.1)
            logits = tf.keras.layers.Dense(train_config.num_classes,
                                           name="cls_dense")(pooled)
        else:
            tf_logging.info("Use fixed version of logistic regression")
            output_weights = tf.compat.v1.get_variable(
                "output_weights",
                [train_config.num_classes, model_config.hidden_size],
                initializer=tf.compat.v1.truncated_normal_initializer(
                    stddev=0.02))

            output_bias = tf.compat.v1.get_variable(
                "output_bias", [train_config.num_classes],
                initializer=tf.compat.v1.zeros_initializer())

            if is_training:
                pooled = dropout(pooled, 0.1)

            logits = tf.matmul(pooled, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)

        if "bias_loss" in special_flags:
            tf_logging.info("Using special_flags : bias_loss")
            loss_arr = reweight_zero(label_ids, loss_arr)

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss,
                                                   train_config.learning_rate,
                                                   train_config.use_tpu)
            else:
                if "ask_tvar" in special_flags:
                    tvars = model.get_trainable_vars()
                else:
                    tvars = None
                train_op = optimization.create_optimizer_from_config(
                    loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "logits": logits,
                "doc": doc,
                "data_ids": data_ids,
            }

            useful_inputs = ["data_id", "input_ids2", "data_ids"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]

            if override_prediction_fn is not None:
                predictions = override_prediction_fn(predictions, model)

            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Exemplo n.º 23
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""
    tf_logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf_logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    if mode == tf.estimator.ModeKeys.PREDICT:
        label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32)
    else:
        label_ids = features["label_ids"]
        label_ids = tf.reshape(label_ids, [-1])

    if "is_real_example" in features:
        is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
    else:
        is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if "feed_features" in special_flags:
        model = model_class(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            features=features,
        )
    else:
        model = model_class(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
    if "new_pooling" in special_flags:
        pooled = mimic_pooling(model.get_sequence_output(), bert_config.hidden_size, bert_config.initializer_range)
    else:
        pooled = model.get_pooled_output()

    if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits:
        tf_logging.info("Use old version of logistic regression")
        logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled)
    else:
        tf_logging.info("Use fixed version of logistic regression")
        output_weights = tf.compat.v1.get_variable(
            "output_weights", [3, bert_config.hidden_size],
            initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02)
        )

        output_bias = tf.compat.v1.get_variable(
            "output_bias", [3],
            initializer=tf.compat.v1.zeros_initializer()
        )

        if is_training:
            pooled = dropout(pooled, 0.1)

        logits = tf.matmul(pooled, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

    # TODO given topic_ids, reorder logits to [num_group, num_max_items]
    #
    ndcg_rankin_loss = make_loss_fn(RankingLossKey.APPROX_NDCG_LOSS)
    one_hot_labels = tf.one_hot(label_ids, 3)
    # logits ; [batch_size, num_classes]
    logits_t = tf.transpose(logits, [1,0])
    one_hot_labels_t = tf.transpose(one_hot_labels, [1,0])
    fake_logit = logits_t * 1e-5 + one_hot_labels_t
    loss_arr = ndcg_rankin_loss(fake_logit, one_hot_labels_t, {})
    loss = loss_arr
    tvars = tf.compat.v1.trainable_variables()

    initialized_variable_names = {}

    if train_config.checkpoint_type == "bert":
        assignment_fn = tlm.training.assignment_map.get_bert_assignment_map
    elif train_config.checkpoint_type == "v2":
        assignment_fn = tlm.training.assignment_map.assignment_map_v2_to_v2
    elif train_config.checkpoint_type == "bert_nli":
        assignment_fn = tlm.training.assignment_map.get_bert_nli_assignment_map
    elif train_config.checkpoint_type == "attention_bert":
        assignment_fn = tlm.training.assignment_map.bert_assignment_only_attention
    elif train_config.checkpoint_type == "attention_bert_v2":
        assignment_fn = tlm.training.assignment_map.assignment_map_v2_to_v2_only_attention
    elif train_config.checkpoint_type == "wo_attention_bert":
        assignment_fn = tlm.training.assignment_map.bert_assignment_wo_attention
    elif train_config.checkpoint_type == "as_is":
        assignment_fn = tlm.training.assignment_map.get_assignment_map_as_is
    else:
        if not train_config.init_checkpoint:
            pass
        else:
            raise Exception("init_checkpoint exists, but checkpoint_type is not specified")

    scaffold_fn = None
    if train_config.init_checkpoint:
      assignment_map, initialized_variable_names = assignment_fn(tvars, train_config.init_checkpoint)

      def init_fn():
        tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, assignment_map)
      scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
    log_var_assignments(tvars, initialized_variable_names)

    TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        if "simple_optimizer" in special_flags:
            tf_logging.info("using simple optimizer")
            train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu)
        else:
            train_op = optimization.create_optimizer_from_config(loss, train_config)
        output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn)

    elif mode == tf.estimator.ModeKeys.EVAL:
        eval_metrics = (classification_metric_fn, [
            logits, label_ids, is_real_example
        ])
        output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
    else:
        predictions = {
                "input_ids": input_ids,
                "logits": logits
        }
        output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions=predictions,
                scaffold_fn=scaffold_fn)


    return output_spec