def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) query = features["query"] doc = features["doc"] doc_mask = features["doc_mask"] data_ids = features["data_id"] segment_len = max_seq_length - query_len - 3 step_size = model_config.step_size input_ids, input_mask, segment_ids, n_segments = \ iterate_over(query, doc, doc_mask, total_doc_len, segment_len, step_size) if mode == tf.estimator.ModeKeys.PREDICT: label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32) else: label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "feed_features" in special_flags: model = model_class( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) else: model = model_class( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) if "new_pooling" in special_flags: pooled = mimic_pooling(model.get_sequence_output(), model_config.hidden_size, model_config.initializer_range) else: pooled = model.get_pooled_output() if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits: tf_logging.info("Use old version of logistic regression") if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) else: tf_logging.info("Use fixed version of logistic regression") output_weights = tf.compat.v1.get_variable( "output_weights", [train_config.num_classes, model_config.hidden_size], initializer=tf.compat.v1.truncated_normal_initializer( stddev=0.02)) output_bias = tf.compat.v1.get_variable( "output_bias", [train_config.num_classes], initializer=tf.compat.v1.zeros_initializer()) if is_training: pooled = dropout(pooled, 0.1) logits = tf.matmul(pooled, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) if "bias_loss" in special_flags: tf_logging.info("Using special_flags : bias_loss") loss_arr = reweight_zero(label_ids, loss_arr) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: if "ask_tvar" in special_flags: tvars = model.get_trainable_vars() else: tvars = None train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "logits": logits, "doc": doc, "data_ids": data_ids, } useful_inputs = ["data_id", "input_ids2", "data_ids"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] if override_prediction_fn is not None: predictions = override_prediction_fn(predictions, model) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) q_input_ids_1 = features["q_input_ids_1"] q_input_mask_1 = features["q_input_mask_1"] d_input_ids_1 = features["d_input_ids_1"] d_input_mask_1 = features["d_input_mask_1"] q_input_ids_2 = features["q_input_ids_2"] q_input_mask_2 = features["q_input_mask_2"] d_input_ids_2 = features["d_input_ids_2"] d_input_mask_2 = features["d_input_mask_2"] q_input_ids = tf.stack([q_input_ids_1, q_input_ids_2], axis=0) q_input_mask = tf.stack([q_input_mask_1, q_input_mask_2], axis=0) q_segment_ids = tf.zeros_like(q_input_ids, tf.int32) d_input_ids = tf.stack([d_input_ids_1, d_input_ids_2], axis=0) d_input_mask = tf.stack([d_input_mask_1, d_input_mask_2], axis=0) d_segment_ids = tf.zeros_like(d_input_ids, tf.int32) label_ids = features["label_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) with tf.compat.v1.variable_scope("query"): model_q = model_class( config=model_config, is_training=is_training, input_ids=q_input_ids, input_mask=q_input_mask, token_type_ids=q_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) with tf.compat.v1.variable_scope("document"): model_d = model_class( config=model_config, is_training=is_training, input_ids=d_input_ids, input_mask=d_input_mask, token_type_ids=d_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled_q = model_q.get_pooled_output() pooled_d = model_d.get_pooled_output() logits = tf.matmul(pooled_q, pooled_d, transpose_b=True) y = tf.cast(label_ids, tf.float32) * 2 - 1 losses = tf.maximum(1.0 - logits * y, 0) loss = tf.reduce_mean(losses) pred = tf.cast(logits > 0, tf.int32) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [pred, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "q_input_ids": q_input_ids, "d_input_ids": d_input_ids, "score": logits } useful_inputs = ["data_id", "input_ids2", "data_ids"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) q_input_ids = features["q_input_ids"] q_input_mask = features["q_input_mask"] d_input_ids = features["d_input_ids"] d_input_mask = features["d_input_mask"] input_shape = get_shape_list(q_input_ids, expected_rank=2) batch_size = input_shape[0] doc_length = model_config.max_doc_length num_docs = model_config.num_docs d_input_ids_unpacked = tf.reshape(d_input_ids, [-1, num_docs, doc_length]) d_input_mask_unpacked = tf.reshape(d_input_mask, [-1, num_docs, doc_length]) d_input_ids_flat = tf.reshape(d_input_ids_unpacked, [-1, doc_length]) d_input_mask_flat = tf.reshape(d_input_mask_unpacked, [-1, doc_length]) q_segment_ids = tf.zeros_like(q_input_ids, tf.int32) d_segment_ids = tf.zeros_like(d_input_ids_flat, tf.int32) label_ids = features["label_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) with tf.compat.v1.variable_scope(dual_model_prefix1): q_model_config = copy.deepcopy(model_config) q_model_config.max_seq_length = model_config.max_sent_length model_q = model_class( config=model_config, is_training=is_training, input_ids=q_input_ids, input_mask=q_input_mask, token_type_ids=q_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) with tf.compat.v1.variable_scope(dual_model_prefix2): d_model_config = copy.deepcopy(model_config) d_model_config.max_seq_length = model_config.max_doc_length model_d = model_class( config=model_config, is_training=is_training, input_ids=d_input_ids_flat, input_mask=d_input_mask_flat, token_type_ids=d_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled_q = model_q.get_pooled_output() # [batch, vector_size] pooled_d_flat = model_d.get_pooled_output( ) # [batch, num_window, vector_size] pooled_d = tf.reshape(pooled_d_flat, [batch_size, num_docs, -1]) pooled_q_t = tf.expand_dims(pooled_q, 1) pooled_d_t = tf.transpose(pooled_d, [0, 2, 1]) all_logits = tf.matmul(pooled_q_t, pooled_d_t) # [batch, 1, num_window] if "hinge_all" in special_flags: apply_loss_modeing = hinge_all elif "sigmoid_all" in special_flags: apply_loss_modeing = sigmoid_all else: apply_loss_modeing = hinge_max logits, loss = apply_loss_modeing(all_logits, label_ids) pred = tf.cast(logits > 0, tf.int32) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [pred, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "q_input_ids": q_input_ids, "d_input_ids": d_input_ids, "logits": logits } useful_inputs = ["data_id", "input_ids2", "data_ids"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) vectors = features["vectors"] # [batch_size, max_unit, num_hidden] valid_mask = features["valid_mask"] label_ids = features["label_ids"] vectors = tf.reshape(vectors, [ -1, model_config.num_window, model_config.max_sequence, model_config.hidden_size ]) valid_mask = tf.reshape( valid_mask, [-1, model_config.num_window, model_config.max_sequence]) label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = MultiEvidenceCombiner(config=model_config, is_training=is_training, vectors=vectors, valid_mask=valid_mask, scope=None) pooled = model.pooled_output if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(config.num_classes, name="cls_dense")(pooled) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) if "bias_loss" in special_flags: tf_logging.info("Using special_flags : bias_loss") loss_arr = reweight_zero(label_ids, loss_arr) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn(config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, config.learning_rate, config.use_tpu) else: if "ask_tvar" in special_flags: tvars = model.get_trainable_vars() else: tvars = None train_op = optimization.create_optimizer_from_config( loss, config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = {"logits": logits, "label_ids": label_ids} if override_prediction_fn is not None: predictions = override_prediction_fn(predictions, model) useful_inputs = ["data_id", "input_ids2"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "feed_features" in special_flags: model = model_class( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) else: model = model_class( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) if "new_pooling" in special_flags: pooled = mimic_pooling(model.get_sequence_output(), bert_config.hidden_size, bert_config.initializer_range) else: pooled = model.get_pooled_output() if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits: tf_logging.info("Use old version of logistic regression") logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) else: tf_logging.info("Use fixed version of logistic regression") output_weights = tf.compat.v1.get_variable( "output_weights", [3, bert_config.hidden_size], initializer=tf.compat.v1.truncated_normal_initializer( stddev=0.02)) output_bias = tf.compat.v1.get_variable( "output_bias", [3], initializer=tf.compat.v1.zeros_initializer()) if is_training: pooled = dropout(pooled, 0.1) logits = tf.matmul(pooled, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) if "bias_loss" in special_flags: tf_logging.info("Using special_flags : bias_loss") loss_arr = reweight_zero(label_ids, loss_arr) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config( loss, train_config) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: probs = tf.nn.softmax(logits, axis=-1) gradient_list = tf.gradients(probs[:, 1], model.embedding_output) print(len(gradient_list)) gradient = gradient_list[0] print(gradient.shape) gradient = tf.reduce_sum(gradient, axis=2) predictions = { "input_ids": input_ids, "gradient": gradient, "labels": label_ids, "logits": logits } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] if mode == tf.estimator.ModeKeys.PREDICT: label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32) else: label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "feed_features" in special_flags: model = model_class( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) else: model = model_class( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) if "new_pooling" in special_flags: pooled = mimic_pooling(model.get_sequence_output(), bert_config.hidden_size, bert_config.initializer_range) else: pooled = model.get_pooled_output() if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits: tf_logging.info("Use old version of logistic regression") logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) else: tf_logging.info("Use fixed version of logistic regression") output_weights = tf.compat.v1.get_variable( "output_weights", [3, bert_config.hidden_size], initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02) ) output_bias = tf.compat.v1.get_variable( "output_bias", [3], initializer=tf.compat.v1.zeros_initializer() ) if is_training: pooled = dropout(pooled, 0.1) logits = tf.matmul(pooled, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) # TODO given topic_ids, reorder logits to [num_group, num_max_items] # ndcg_rankin_loss = make_loss_fn(RankingLossKey.APPROX_NDCG_LOSS) one_hot_labels = tf.one_hot(label_ids, 3) # logits ; [batch_size, num_classes] logits_t = tf.transpose(logits, [1,0]) one_hot_labels_t = tf.transpose(one_hot_labels, [1,0]) fake_logit = logits_t * 1e-5 + one_hot_labels_t loss_arr = ndcg_rankin_loss(fake_logit, one_hot_labels_t, {}) loss = loss_arr tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} if train_config.checkpoint_type == "bert": assignment_fn = tlm.training.assignment_map.get_bert_assignment_map elif train_config.checkpoint_type == "v2": assignment_fn = tlm.training.assignment_map.assignment_map_v2_to_v2 elif train_config.checkpoint_type == "bert_nli": assignment_fn = tlm.training.assignment_map.get_bert_nli_assignment_map elif train_config.checkpoint_type == "attention_bert": assignment_fn = tlm.training.assignment_map.bert_assignment_only_attention elif train_config.checkpoint_type == "attention_bert_v2": assignment_fn = tlm.training.assignment_map.assignment_map_v2_to_v2_only_attention elif train_config.checkpoint_type == "wo_attention_bert": assignment_fn = tlm.training.assignment_map.bert_assignment_wo_attention elif train_config.checkpoint_type == "as_is": assignment_fn = tlm.training.assignment_map.get_assignment_map_as_is else: if not train_config.init_checkpoint: pass else: raise Exception("init_checkpoint exists, but checkpoint_type is not specified") scaffold_fn = None if train_config.init_checkpoint: assignment_map, initialized_variable_names = assignment_fn(tvars, train_config.init_checkpoint) def init_fn(): tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, assignment_map) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config(loss, train_config) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [ logits, label_ids, is_real_example ]) output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "logits": logits } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec