Ejemplo n.º 1
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        query = features["query"]
        doc = features["doc"]
        doc_mask = features["doc_mask"]
        data_ids = features["data_id"]

        segment_len = max_seq_length - query_len - 3
        step_size = model_config.step_size
        input_ids, input_mask, segment_ids, n_segments = \
            iterate_over(query, doc, doc_mask, total_doc_len, segment_len, step_size)
        if mode == tf.estimator.ModeKeys.PREDICT:
            label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32)
        else:
            label_ids = features["label_ids"]
            label_ids = tf.reshape(label_ids, [-1])

        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        if "feed_features" in special_flags:
            model = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
                features=features,
            )
        else:
            model = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
        if "new_pooling" in special_flags:
            pooled = mimic_pooling(model.get_sequence_output(),
                                   model_config.hidden_size,
                                   model_config.initializer_range)
        else:
            pooled = model.get_pooled_output()

        if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits:
            tf_logging.info("Use old version of logistic regression")
            if is_training:
                pooled = dropout(pooled, 0.1)
            logits = tf.keras.layers.Dense(train_config.num_classes,
                                           name="cls_dense")(pooled)
        else:
            tf_logging.info("Use fixed version of logistic regression")
            output_weights = tf.compat.v1.get_variable(
                "output_weights",
                [train_config.num_classes, model_config.hidden_size],
                initializer=tf.compat.v1.truncated_normal_initializer(
                    stddev=0.02))

            output_bias = tf.compat.v1.get_variable(
                "output_bias", [train_config.num_classes],
                initializer=tf.compat.v1.zeros_initializer())

            if is_training:
                pooled = dropout(pooled, 0.1)

            logits = tf.matmul(pooled, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)

        if "bias_loss" in special_flags:
            tf_logging.info("Using special_flags : bias_loss")
            loss_arr = reweight_zero(label_ids, loss_arr)

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss,
                                                   train_config.learning_rate,
                                                   train_config.use_tpu)
            else:
                if "ask_tvar" in special_flags:
                    tvars = model.get_trainable_vars()
                else:
                    tvars = None
                train_op = optimization.create_optimizer_from_config(
                    loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "logits": logits,
                "doc": doc,
                "data_ids": data_ids,
            }

            useful_inputs = ["data_id", "input_ids2", "data_ids"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]

            if override_prediction_fn is not None:
                predictions = override_prediction_fn(predictions, model)

            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        q_input_ids_1 = features["q_input_ids_1"]
        q_input_mask_1 = features["q_input_mask_1"]
        d_input_ids_1 = features["d_input_ids_1"]
        d_input_mask_1 = features["d_input_mask_1"]

        q_input_ids_2 = features["q_input_ids_2"]
        q_input_mask_2 = features["q_input_mask_2"]
        d_input_ids_2 = features["d_input_ids_2"]
        d_input_mask_2 = features["d_input_mask_2"]

        q_input_ids = tf.stack([q_input_ids_1, q_input_ids_2], axis=0)
        q_input_mask = tf.stack([q_input_mask_1, q_input_mask_2], axis=0)
        q_segment_ids = tf.zeros_like(q_input_ids, tf.int32)

        d_input_ids = tf.stack([d_input_ids_1, d_input_ids_2], axis=0)
        d_input_mask = tf.stack([d_input_mask_1, d_input_mask_2], axis=0)
        d_segment_ids = tf.zeros_like(d_input_ids, tf.int32)

        label_ids = features["label_ids"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        with tf.compat.v1.variable_scope("query"):
            model_q = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=q_input_ids,
                input_mask=q_input_mask,
                token_type_ids=q_segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )

        with tf.compat.v1.variable_scope("document"):
            model_d = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=d_input_ids,
                input_mask=d_input_mask,
                token_type_ids=d_segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
        pooled_q = model_q.get_pooled_output()
        pooled_d = model_d.get_pooled_output()

        logits = tf.matmul(pooled_q, pooled_d, transpose_b=True)
        y = tf.cast(label_ids, tf.float32) * 2 - 1
        losses = tf.maximum(1.0 - logits * y, 0)
        loss = tf.reduce_mean(losses)

        pred = tf.cast(logits > 0, tf.int32)

        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss,
                                                   train_config.learning_rate,
                                                   train_config.use_tpu)
            else:
                train_op = optimization.create_optimizer_from_config(
                    loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [pred, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "q_input_ids": q_input_ids,
                "d_input_ids": d_input_ids,
                "score": logits
            }

            useful_inputs = ["data_id", "input_ids2", "data_ids"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 3
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        q_input_ids = features["q_input_ids"]
        q_input_mask = features["q_input_mask"]
        d_input_ids = features["d_input_ids"]
        d_input_mask = features["d_input_mask"]

        input_shape = get_shape_list(q_input_ids, expected_rank=2)
        batch_size = input_shape[0]

        doc_length = model_config.max_doc_length
        num_docs = model_config.num_docs

        d_input_ids_unpacked = tf.reshape(d_input_ids,
                                          [-1, num_docs, doc_length])
        d_input_mask_unpacked = tf.reshape(d_input_mask,
                                           [-1, num_docs, doc_length])

        d_input_ids_flat = tf.reshape(d_input_ids_unpacked, [-1, doc_length])
        d_input_mask_flat = tf.reshape(d_input_mask_unpacked, [-1, doc_length])

        q_segment_ids = tf.zeros_like(q_input_ids, tf.int32)
        d_segment_ids = tf.zeros_like(d_input_ids_flat, tf.int32)

        label_ids = features["label_ids"]
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        with tf.compat.v1.variable_scope(dual_model_prefix1):
            q_model_config = copy.deepcopy(model_config)
            q_model_config.max_seq_length = model_config.max_sent_length
            model_q = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=q_input_ids,
                input_mask=q_input_mask,
                token_type_ids=q_segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )

        with tf.compat.v1.variable_scope(dual_model_prefix2):
            d_model_config = copy.deepcopy(model_config)
            d_model_config.max_seq_length = model_config.max_doc_length
            model_d = model_class(
                config=model_config,
                is_training=is_training,
                input_ids=d_input_ids_flat,
                input_mask=d_input_mask_flat,
                token_type_ids=d_segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
        pooled_q = model_q.get_pooled_output()  # [batch, vector_size]
        pooled_d_flat = model_d.get_pooled_output(
        )  # [batch, num_window, vector_size]

        pooled_d = tf.reshape(pooled_d_flat, [batch_size, num_docs, -1])
        pooled_q_t = tf.expand_dims(pooled_q, 1)
        pooled_d_t = tf.transpose(pooled_d, [0, 2, 1])
        all_logits = tf.matmul(pooled_q_t,
                               pooled_d_t)  # [batch, 1, num_window]
        if "hinge_all" in special_flags:
            apply_loss_modeing = hinge_all
        elif "sigmoid_all" in special_flags:
            apply_loss_modeing = sigmoid_all
        else:
            apply_loss_modeing = hinge_max
        logits, loss = apply_loss_modeing(all_logits, label_ids)
        pred = tf.cast(logits > 0, tf.int32)

        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss,
                                                   train_config.learning_rate,
                                                   train_config.use_tpu)
            else:
                train_op = optimization.create_optimizer_from_config(
                    loss, train_config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [pred, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {
                "q_input_ids": q_input_ids,
                "d_input_ids": d_input_ids,
                "logits": logits
            }

            useful_inputs = ["data_id", "input_ids2", "data_ids"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 4
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        vectors = features["vectors"]  # [batch_size, max_unit, num_hidden]
        valid_mask = features["valid_mask"]
        label_ids = features["label_ids"]
        vectors = tf.reshape(vectors, [
            -1, model_config.num_window, model_config.max_sequence,
            model_config.hidden_size
        ])
        valid_mask = tf.reshape(
            valid_mask,
            [-1, model_config.num_window, model_config.max_sequence])
        label_ids = tf.reshape(label_ids, [-1])

        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = MultiEvidenceCombiner(config=model_config,
                                      is_training=is_training,
                                      vectors=vectors,
                                      valid_mask=valid_mask,
                                      scope=None)
        pooled = model.pooled_output
        if is_training:
            pooled = dropout(pooled, 0.1)

        logits = tf.keras.layers.Dense(config.num_classes,
                                       name="cls_dense")(pooled)
        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)

        if "bias_loss" in special_flags:
            tf_logging.info("Using special_flags : bias_loss")
            loss_arr = reweight_zero(label_ids, loss_arr)

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn, config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss, config.learning_rate,
                                                   config.use_tpu)
            else:
                if "ask_tvar" in special_flags:
                    tvars = model.get_trainable_vars()
                else:
                    tvars = None
                train_op = optimization.create_optimizer_from_config(
                    loss, config, tvars)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            predictions = {"logits": logits, "label_ids": label_ids}
            if override_prediction_fn is not None:
                predictions = override_prediction_fn(predictions, model)

            useful_inputs = ["data_id", "input_ids2"]
            for input_name in useful_inputs:
                if input_name in features:
                    predictions[input_name] = features[input_name]
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 5
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""
        tf_logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf_logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        label_ids = tf.reshape(label_ids, [-1])

        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        if "feed_features" in special_flags:
            model = model_class(
                config=bert_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
                features=features,
            )
        else:
            model = model_class(
                config=bert_config,
                is_training=is_training,
                input_ids=input_ids,
                input_mask=input_mask,
                token_type_ids=segment_ids,
                use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            )
        if "new_pooling" in special_flags:
            pooled = mimic_pooling(model.get_sequence_output(),
                                   bert_config.hidden_size,
                                   bert_config.initializer_range)
        else:
            pooled = model.get_pooled_output()

        if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits:
            tf_logging.info("Use old version of logistic regression")
            logits = tf.keras.layers.Dense(train_config.num_classes,
                                           name="cls_dense")(pooled)
        else:
            tf_logging.info("Use fixed version of logistic regression")
            output_weights = tf.compat.v1.get_variable(
                "output_weights", [3, bert_config.hidden_size],
                initializer=tf.compat.v1.truncated_normal_initializer(
                    stddev=0.02))

            output_bias = tf.compat.v1.get_variable(
                "output_bias", [3],
                initializer=tf.compat.v1.zeros_initializer())

            if is_training:
                pooled = dropout(pooled, 0.1)

            logits = tf.matmul(pooled, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

        loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label_ids)

        if "bias_loss" in special_flags:
            tf_logging.info("Using special_flags : bias_loss")
            loss_arr = reweight_zero(label_ids, loss_arr)

        loss = tf.reduce_mean(input_tensor=loss_arr)
        tvars = tf.compat.v1.trainable_variables()

        initialized_variable_names = {}

        scaffold_fn = None
        if train_config.init_checkpoint:
            initialized_variable_names, init_fn = get_init_fn(
                train_config, tvars)
            scaffold_fn = get_tpu_scaffold_or_init(init_fn,
                                                   train_config.use_tpu)
        log_var_assignments(tvars, initialized_variable_names)

        TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if "simple_optimizer" in special_flags:
                tf_logging.info("using simple optimizer")
                train_op = create_simple_optimizer(loss,
                                                   train_config.learning_rate,
                                                   train_config.use_tpu)
            else:
                train_op = optimization.create_optimizer_from_config(
                    loss, train_config)
            output_spec = TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = (classification_metric_fn,
                            [logits, label_ids, is_real_example])
            output_spec = TPUEstimatorSpec(mode=model,
                                           loss=loss,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
        else:
            probs = tf.nn.softmax(logits, axis=-1)
            gradient_list = tf.gradients(probs[:, 1], model.embedding_output)
            print(len(gradient_list))
            gradient = gradient_list[0]
            print(gradient.shape)
            gradient = tf.reduce_sum(gradient, axis=2)
            predictions = {
                "input_ids": input_ids,
                "gradient": gradient,
                "labels": label_ids,
                "logits": logits
            }
            output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)

        return output_spec
Ejemplo n.º 6
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""
    tf_logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf_logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    if mode == tf.estimator.ModeKeys.PREDICT:
        label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32)
    else:
        label_ids = features["label_ids"]
        label_ids = tf.reshape(label_ids, [-1])

    if "is_real_example" in features:
        is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
    else:
        is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if "feed_features" in special_flags:
        model = model_class(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
            features=features,
        )
    else:
        model = model_class(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=train_config.use_one_hot_embeddings,
        )
    if "new_pooling" in special_flags:
        pooled = mimic_pooling(model.get_sequence_output(), bert_config.hidden_size, bert_config.initializer_range)
    else:
        pooled = model.get_pooled_output()

    if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits:
        tf_logging.info("Use old version of logistic regression")
        logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled)
    else:
        tf_logging.info("Use fixed version of logistic regression")
        output_weights = tf.compat.v1.get_variable(
            "output_weights", [3, bert_config.hidden_size],
            initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02)
        )

        output_bias = tf.compat.v1.get_variable(
            "output_bias", [3],
            initializer=tf.compat.v1.zeros_initializer()
        )

        if is_training:
            pooled = dropout(pooled, 0.1)

        logits = tf.matmul(pooled, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

    # TODO given topic_ids, reorder logits to [num_group, num_max_items]
    #
    ndcg_rankin_loss = make_loss_fn(RankingLossKey.APPROX_NDCG_LOSS)
    one_hot_labels = tf.one_hot(label_ids, 3)
    # logits ; [batch_size, num_classes]
    logits_t = tf.transpose(logits, [1,0])
    one_hot_labels_t = tf.transpose(one_hot_labels, [1,0])
    fake_logit = logits_t * 1e-5 + one_hot_labels_t
    loss_arr = ndcg_rankin_loss(fake_logit, one_hot_labels_t, {})
    loss = loss_arr
    tvars = tf.compat.v1.trainable_variables()

    initialized_variable_names = {}

    if train_config.checkpoint_type == "bert":
        assignment_fn = tlm.training.assignment_map.get_bert_assignment_map
    elif train_config.checkpoint_type == "v2":
        assignment_fn = tlm.training.assignment_map.assignment_map_v2_to_v2
    elif train_config.checkpoint_type == "bert_nli":
        assignment_fn = tlm.training.assignment_map.get_bert_nli_assignment_map
    elif train_config.checkpoint_type == "attention_bert":
        assignment_fn = tlm.training.assignment_map.bert_assignment_only_attention
    elif train_config.checkpoint_type == "attention_bert_v2":
        assignment_fn = tlm.training.assignment_map.assignment_map_v2_to_v2_only_attention
    elif train_config.checkpoint_type == "wo_attention_bert":
        assignment_fn = tlm.training.assignment_map.bert_assignment_wo_attention
    elif train_config.checkpoint_type == "as_is":
        assignment_fn = tlm.training.assignment_map.get_assignment_map_as_is
    else:
        if not train_config.init_checkpoint:
            pass
        else:
            raise Exception("init_checkpoint exists, but checkpoint_type is not specified")

    scaffold_fn = None
    if train_config.init_checkpoint:
      assignment_map, initialized_variable_names = assignment_fn(tvars, train_config.init_checkpoint)

      def init_fn():
        tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, assignment_map)
      scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu)
    log_var_assignments(tvars, initialized_variable_names)

    TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec
    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        if "simple_optimizer" in special_flags:
            tf_logging.info("using simple optimizer")
            train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu)
        else:
            train_op = optimization.create_optimizer_from_config(loss, train_config)
        output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn)

    elif mode == tf.estimator.ModeKeys.EVAL:
        eval_metrics = (classification_metric_fn, [
            logits, label_ids, is_real_example
        ])
        output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn)
    else:
        predictions = {
                "input_ids": input_ids,
                "logits": logits
        }
        output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions=predictions,
                scaffold_fn=scaffold_fn)


    return output_spec