def sub_metric_fn( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, ): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape(masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax(input=masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.compat.v1.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.compat.v1.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, }
def scatter_multiple(input_ids, indice, update_vals): batch_size = get_shape_list2(input_ids)[0] seq_length = get_shape_list2(input_ids)[1] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) indices = tf.reshape(indice + flat_offsets, [-1, 1]) tensor = tf.reshape(input_ids, [batch_size * seq_length]) updates = tf.reshape(update_vals, [-1]) flat_output = tf.tensor_scatter_nd_update(tensor, indices, updates) return tf.reshape(flat_output, [batch_size, seq_length])
def delete_tokens(input_ids, n_trial, shift): delete_location = [] n_block_size = 1 for i in range(n_trial): st = shift + i * n_block_size ed = shift + (i + 1) * n_block_size row = [] for j in range(st, ed): row.append(j) delete_location.append(row) print(delete_location) batch_size, _ = get_shape_list2(input_ids) # [n_trial, 1] delete_location = tf.constant(delete_location, tf.int32) # [1, n_trial, 1] delete_location = tf.expand_dims(delete_location, 0) # [batch_size, n_trial, 1] delete_location = tf.tile(delete_location, [batch_size, 1, 1]) # [n_trial, batch, 1] delete_location = tf.transpose(delete_location, [1, 0, 2]) # [n_trial * batch, 1] delete_location = tf.reshape(delete_location, [batch_size * n_trial, -1]) n_input_ids = tf.tile(input_ids, [n_trial, 1]) masked_input_ids = scatter_with_batch(n_input_ids, delete_location, MASK_ID) return masked_input_ids
def compute_unreduced_loss(labels, logits): """See `_RankingLoss`.""" alpha = 10.0 is_valid = utils.is_label_valid(labels) labels = tf.compat.v1.where(is_valid, labels, tf.zeros_like(labels)) logits = tf.compat.v1.where( is_valid, logits, -1e3 * tf.ones_like(logits) + tf.reduce_min(input_tensor=logits, axis=-1, keepdims=True)) label_sum = tf.reduce_sum(input_tensor=labels, axis=1, keepdims=True) nonzero_mask = tf.greater(tf.reshape(label_sum, [-1]), 0.0) labels = tf.compat.v1.where(nonzero_mask, labels, _EPSILON * tf.ones_like(labels)) gains = tf.pow(2., tf.cast(labels, dtype=tf.float32)) - 1. ranks = utils.approx_ranks(logits, alpha=alpha) discounts = 1. / tf.math.log1p(ranks) dcg = tf.reduce_sum(input_tensor=gains * discounts, axis=-1, keepdims=True) cost = -dcg * utils.inverse_max_dcg(labels) return cost, tf.reshape(tf.cast(nonzero_mask, dtype=tf.float32), [-1, 1])
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = model_class( config=model_config, is_training=is_training, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) logits = model.get_logits() loss = model.get_loss(label_ids) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn(train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config(loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [ logits, label_ids, is_real_example ]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "label_ids": label_ids, "logits": logits, } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def one_by_one_masking(input_ids, input_masks, mask_token, n_trial): batch_size, seq_length = get_batch_and_seq_length(input_ids, 2) loc_dummy = tf.cast(tf.range(0, seq_length), tf.float32) loc_dummy = tf.tile(tf.expand_dims(loc_dummy, 0), [batch_size, 1]) loc_dummy = remove_special_mask(input_ids, input_masks, loc_dummy) indices = tf.argsort(loc_dummy, axis=-1, direction='ASCENDING', stable=False, name=None) # [25, batch, 20] n_input_ids = tf.tile(input_ids, [n_trial, 1]) lm_locations = tf.reshape(indices[:, :n_trial], [-1, 1]) masked_lm_positions = lm_locations # [ batch*n_trial, max_predictions) masked_lm_ids = gather_index2d(n_input_ids, masked_lm_positions) masked_lm_weights = tf.ones_like(masked_lm_positions, dtype=tf.float32) masked_input_ids = scatter_with_batch(n_input_ids, masked_lm_positions, mask_token) return masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights
def classification_metric_fn(pred, label, is_real_example): """Computes the loss and accuracy of the model.""" label = tf.reshape(label, [-1]) accuracy = tf.compat.v1.metrics.accuracy(labels=label, predictions=pred, weights=is_real_example) precision = tf.compat.v1.metrics.precision(labels=label, predictions=pred, weights=is_real_example) recall = tf.compat.v1.metrics.recall(labels=label, predictions=pred, weights=is_real_example) return { "accuracy": accuracy, "precision": precision, "recall": recall, }
def get_label_indices(input_ids): test_label = [LABEL_0, LABEL_1, LABEL_2] test_label_mask = tf.cast(tf.zeros_like(input_ids), tf.bool) for token in test_label: test_label_mask = tf.logical_or(tf.equal(input_ids, token), test_label_mask) _, masked_lm_positions = tf.math.top_k(tf.cast(test_label_mask, tf.float32), k=1, sorted=False, name="masking_top_k") is_test_inst_bool = tf.reduce_any(test_label_mask, axis=1) is_test_inst = tf.cast(tf.reduce_any(test_label_mask, axis=1), tf.float32) masked_label_ids = gather_index2d(input_ids, masked_lm_positions) is_test_inst_int = tf.cast(is_test_inst, tf.int32) not_is_test_inst_int = tf.cast(tf.logical_not(is_test_inst_bool), tf.int32) scatter_vals = LABEL_UNK * is_test_inst_int\ + tf.reshape(masked_label_ids, [-1]) * not_is_test_inst_int masked_input_ids = scatter_multiple(input_ids, masked_lm_positions, scatter_vals) return masked_input_ids, masked_lm_positions, masked_label_ids, is_test_inst
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] if mode == tf.estimator.ModeKeys.PREDICT: label_ids = tf.ones([input_ids.shape[0]], dtype=tf.float32) else: label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = BertModel( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled = model.get_pooled_output() if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) scale = model_config.scale label_ids = scale * label_ids weight = tf.abs(label_ids) loss_arr = tf.keras.losses.MAE(y_true=label_ids, y_pred=logits) loss_arr = loss_arr * weight loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn(train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec def metric_fn(logits, label, is_real_example): mae = tf.compat.v1.metrics.mean_absolute_error( labels=label, predictions=logits, weights=is_real_example) return { "mae": mae } output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config(loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn, [ logits, label_ids, is_real_example ]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "logits": logits, } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) q_input_ids = features["q_input_ids"] q_input_mask = features["q_input_mask"] d_input_ids = features["d_input_ids"] d_input_mask = features["d_input_mask"] input_shape = get_shape_list(q_input_ids, expected_rank=2) batch_size = input_shape[0] doc_length = model_config.max_doc_length num_docs = model_config.num_docs d_input_ids_unpacked = tf.reshape(d_input_ids, [-1, num_docs, doc_length]) d_input_mask_unpacked = tf.reshape(d_input_mask, [-1, num_docs, doc_length]) d_input_ids_flat = tf.reshape(d_input_ids_unpacked, [-1, doc_length]) d_input_mask_flat = tf.reshape(d_input_mask_unpacked, [-1, doc_length]) q_segment_ids = tf.zeros_like(q_input_ids, tf.int32) d_segment_ids = tf.zeros_like(d_input_ids_flat, tf.int32) label_ids = features["label_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) with tf.compat.v1.variable_scope(dual_model_prefix1): q_model_config = copy.deepcopy(model_config) q_model_config.max_seq_length = model_config.max_sent_length model_q = model_class( config=model_config, is_training=is_training, input_ids=q_input_ids, input_mask=q_input_mask, token_type_ids=q_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) with tf.compat.v1.variable_scope(dual_model_prefix2): d_model_config = copy.deepcopy(model_config) d_model_config.max_seq_length = model_config.max_doc_length model_d = model_class( config=model_config, is_training=is_training, input_ids=d_input_ids_flat, input_mask=d_input_mask_flat, token_type_ids=d_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled_q = model_q.get_pooled_output() # [batch, vector_size] pooled_d_flat = model_d.get_pooled_output( ) # [batch, num_window, vector_size] pooled_d = tf.reshape(pooled_d_flat, [batch_size, num_docs, -1]) pooled_q_t = tf.expand_dims(pooled_q, 1) pooled_d_t = tf.transpose(pooled_d, [0, 2, 1]) all_logits = tf.matmul(pooled_q_t, pooled_d_t) # [batch, 1, num_window] if "hinge_all" in special_flags: apply_loss_modeing = hinge_all elif "sigmoid_all" in special_flags: apply_loss_modeing = sigmoid_all else: apply_loss_modeing = hinge_max logits, loss = apply_loss_modeing(all_logits, label_ids) pred = tf.cast(logits > 0, tf.int32) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [pred, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "q_input_ids": q_input_ids, "d_input_ids": d_input_ids, "logits": logits } useful_inputs = ["data_id", "input_ids2", "data_ids"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] if mode == tf.estimator.ModeKeys.PREDICT: label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32) else: label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) input_ids2 = features["input_ids2"] input_mask2 = features["input_mask2"] segment_ids2 = features["segment_ids2"] with tf.compat.v1.variable_scope(dual_model_prefix1): model_1 = BertModel( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled = model_1.get_pooled_output() if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) with tf.compat.v1.variable_scope(dual_model_prefix2): model_2 = BertModel( config=model_config, is_training=is_training, input_ids=input_ids2, input_mask=input_mask2, token_type_ids=segment_ids2, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled = model_2.get_pooled_output() if is_training: pooled = dropout(pooled, 0.1) conf_probs = tf.keras.layers.Dense( train_config.num_classes, name="cls_dense", activation=tf.keras.activations.softmax)(pooled) confidence = conf_probs[:, 1] confidence_loss = 1 - confidence cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) k = model_config.k alpha = model_config.alpha loss_arr = cls_loss * confidence + confidence_loss * k loss_arr = apply_weighted_loss(loss_arr, label_ids, alpha) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec def metric_fn(log_probs, label, is_real_example, confidence): r = classification_metric_fn(log_probs, label, is_real_example) r['confidence'] = tf.compat.v1.metrics.mean(confidence) return r output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn, [logits, label_ids, is_real_example, confidence]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "logits": logits, "confidence": confidence, } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] next_sentence_labels = get_dummy_next_sentence_labels(input_ids) batch_size, seq_length = get_batch_and_seq_length(input_ids, 2) n_trial = seq_length - 20 masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = one_by_one_masking(input_ids, input_mask, MASK_ID, n_trial) num_classes = train_config.num_classes n_repeat = num_classes * n_trial # [ num_classes * n_trial * batch_size, seq_length] repeat_masked_input_ids = tf.tile(masked_input_ids, [num_classes, 1]) repeat_input_mask = tf.tile(input_mask, [n_repeat, 1]) repeat_segment_ids = tf.tile(segment_ids, [n_repeat, 1]) masked_lm_positions = tf.tile(masked_lm_positions, [num_classes, 1]) masked_lm_ids = tf.tile(masked_lm_ids, [num_classes, 1]) masked_lm_weights = tf.tile(masked_lm_weights, [num_classes, 1]) next_sentence_labels = tf.tile(next_sentence_labels, [n_repeat, 1]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) virtual_labels_ids = tf.tile(tf.expand_dims(tf.range(num_classes), 0), [1, batch_size * n_trial]) virtual_labels_ids = tf.reshape(virtual_labels_ids, [-1, 1]) print("repeat_masked_input_ids", repeat_masked_input_ids.shape) print("repeat_input_mask", repeat_input_mask.shape) print("virtual_labels_ids", virtual_labels_ids.shape) model = BertModelWithLabelInner( config=model_config, is_training=is_training, input_ids=repeat_masked_input_ids, input_mask=repeat_input_mask, token_type_ids=repeat_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, label_ids=virtual_labels_ids, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( model_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss # loss = -log(prob) # TODO compare log prob of each label per_case_loss = tf.reshape(masked_lm_example_loss, [num_classes, -1, batch_size]) per_label_loss = tf.reduce_sum(per_case_loss, axis=1) bias = tf.zeros([3, 1]) per_label_score = tf.transpose(-per_label_loss + bias, [1, 0]) tvars = tf.compat.v1.trainable_variables() initialized_variable_names, initialized_variable_names2, init_fn\ = align_checkpoint_for_lm(tvars, train_config.checkpoint_type, train_config.init_checkpoint, train_config.second_init_checkpoint, ) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names, initialized_variable_names2) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( total_loss, train_config) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[OomReportingHook()], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn_lm, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, ]) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = {"input_ids": input_ids, "logits": per_label_score} output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def reform_scala(t): return tf.reshape(t, [1])
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] if mode == tf.estimator.ModeKeys.PREDICT: label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32) else: label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) domain_ids = features["domain_ids"] domain_ids = tf.reshape(domain_ids, [-1]) is_valid_label = features["is_valid_label"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) model_1 = BertModel( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled = model_1.get_pooled_output() if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) pred_losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) num_domain = 2 pooled_for_domain = grad_reverse(pooled) domain_logits = tf.keras.layers.Dense( num_domain, name="domain_dense")(pooled_for_domain) domain_losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=domain_logits, labels=domain_ids) pred_loss = tf.reduce_mean(pred_losses * tf.cast(is_valid_label, tf.float32)) domain_loss = tf.reduce_mean(domain_losses) tf.compat.v1.summary.scalar('domain_loss', domain_loss) tf.compat.v1.summary.scalar('pred_loss', pred_loss) alpha = model_config.alpha loss = pred_loss + alpha * domain_loss tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "logits": logits, } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = model_class( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) probs = model.get_prob() logits = probs epsilon = 1e-12 # probs_ad = tf.clip_by_value(probs, epsilon, 1.0 - epsilon) # logits1 = tf.math.log(probs_ad) # logits0 = tf.zeros_like(logits1) # logits = tf.stack([logits0, logits1], axis=1) # prob2 = tf.nn.softmax(logits, axis=1) # prob_err = prob2[:, 1] - probs y_true = tf.cast(label_ids, tf.float32) loss_arr = tf.keras.losses.BinaryCrossentropy()(y_true, probs) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn(train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config(loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [ logits, label_ids, is_real_example ]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "label_ids": label_ids, "logits": logits, "score1": model.score1, "score2": model.score2, # "prob_err": prob_err, } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) vectors = features["vectors"] # [batch_size, max_unit, num_hidden] valid_mask = features["valid_mask"] label_ids = features["label_ids"] vectors = tf.reshape(vectors, [ -1, model_config.num_window, model_config.max_sequence, model_config.hidden_size ]) valid_mask = tf.reshape( valid_mask, [-1, model_config.num_window, model_config.max_sequence]) label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = MultiEvidenceCombiner(config=model_config, is_training=is_training, vectors=vectors, valid_mask=valid_mask, scope=None) pooled = model.pooled_output if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(config.num_classes, name="cls_dense")(pooled) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) if "bias_loss" in special_flags: tf_logging.info("Using special_flags : bias_loss") loss_arr = reweight_zero(label_ids, loss_arr) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn(config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, config.learning_rate, config.use_tpu) else: if "ask_tvar" in special_flags: tvars = model.get_trainable_vars() else: tvars = None train_op = optimization.create_optimizer_from_config( loss, config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = {"logits": logits, "label_ids": label_ids} if override_prediction_fn is not None: predictions = override_prediction_fn(predictions, model) useful_inputs = ["data_id", "input_ids2"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "feed_features" in special_flags: model = model_class( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) else: model = model_class( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) if "new_pooling" in special_flags: pooled = mimic_pooling(model.get_sequence_output(), bert_config.hidden_size, bert_config.initializer_range) else: pooled = model.get_pooled_output() if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits: tf_logging.info("Use old version of logistic regression") logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) else: tf_logging.info("Use fixed version of logistic regression") output_weights = tf.compat.v1.get_variable( "output_weights", [3, bert_config.hidden_size], initializer=tf.compat.v1.truncated_normal_initializer( stddev=0.02)) output_bias = tf.compat.v1.get_variable( "output_bias", [3], initializer=tf.compat.v1.zeros_initializer()) if is_training: pooled = dropout(pooled, 0.1) logits = tf.matmul(pooled, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) if "bias_loss" in special_flags: tf_logging.info("Using special_flags : bias_loss") loss_arr = reweight_zero(label_ids, loss_arr) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: train_op = optimization.create_optimizer_from_config( loss, train_config) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: probs = tf.nn.softmax(logits, axis=-1) gradient_list = tf.gradients(probs[:, 1], model.embedding_output) print(len(gradient_list)) gradient = gradient_list[0] print(gradient.shape) gradient = tf.reduce_sum(gradient, axis=2) predictions = { "input_ids": input_ids, "gradient": gradient, "labels": label_ids, "logits": logits } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model_1 = BertModel( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) pooled = model_1.get_pooled_output() if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec global_step = tf.compat.v1.train.get_or_create_global_step() init_lr = train_config.learning_rate num_warmup_steps = train_config.num_warmup_steps num_train_steps = train_config.num_train_steps learning_rate2_const = tf.constant(value=init_lr, shape=[], dtype=tf.float32) learning_rate2_decayed = tf.compat.v1.train.polynomial_decay( learning_rate2_const, global_step, num_train_steps, end_learning_rate=0.0, power=1.0, cycle=False) if num_warmup_steps: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ((1.0 - is_warmup) * learning_rate2_decayed + is_warmup * warmup_learning_rate) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: def reform_scala(t): return tf.reshape(t, [1]) predictions = { "input_ids": input_ids, "label_ids": label_ids, "logits": logits, "learning_rate2_const": reform_scala(learning_rate2_const), "warmup_percent_done": reform_scala(warmup_percent_done), "warmup_learning_rate": reform_scala(warmup_learning_rate), "learning_rate": reform_scala(learning_rate), "learning_rate2_decayed": reform_scala(learning_rate2_decayed), } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) query = features["query"] doc = features["doc"] doc_mask = features["doc_mask"] data_ids = features["data_id"] segment_len = max_seq_length - query_len - 3 step_size = model_config.step_size input_ids, input_mask, segment_ids, n_segments = \ iterate_over(query, doc, doc_mask, total_doc_len, segment_len, step_size) if mode == tf.estimator.ModeKeys.PREDICT: label_ids = tf.ones([input_ids.shape[0]], dtype=tf.int32) else: label_ids = features["label_ids"] label_ids = tf.reshape(label_ids, [-1]) if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) if "feed_features" in special_flags: model = model_class( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, features=features, ) else: model = model_class( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) if "new_pooling" in special_flags: pooled = mimic_pooling(model.get_sequence_output(), model_config.hidden_size, model_config.initializer_range) else: pooled = model.get_pooled_output() if train_config.checkpoint_type != "bert_nli" and train_config.use_old_logits: tf_logging.info("Use old version of logistic regression") if is_training: pooled = dropout(pooled, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(pooled) else: tf_logging.info("Use fixed version of logistic regression") output_weights = tf.compat.v1.get_variable( "output_weights", [train_config.num_classes, model_config.hidden_size], initializer=tf.compat.v1.truncated_normal_initializer( stddev=0.02)) output_bias = tf.compat.v1.get_variable( "output_bias", [train_config.num_classes], initializer=tf.compat.v1.zeros_initializer()) if is_training: pooled = dropout(pooled, 0.1) logits = tf.matmul(pooled, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) loss_arr = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=label_ids) if "bias_loss" in special_flags: tf_logging.info("Using special_flags : bias_loss") loss_arr = reweight_zero(label_ids, loss_arr) loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if "simple_optimizer" in special_flags: tf_logging.info("using simple optimizer") train_op = create_simple_optimizer(loss, train_config.learning_rate, train_config.use_tpu) else: if "ask_tvar" in special_flags: tvars = model.get_trainable_vars() else: tvars = None train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (classification_metric_fn, [logits, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=model, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "logits": logits, "doc": doc, "data_ids": data_ids, } useful_inputs = ["data_id", "input_ids2", "data_ids"] for input_name in useful_inputs: if input_name in features: predictions[input_name] = features[input_name] if override_prediction_fn is not None: predictions = override_prediction_fn(predictions, model) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] input_shape = get_shape_list2(input_ids) batch_size, seq_length = input_shape if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones([batch_size, 1], dtype=tf.float32) label_ids = tf.reshape( label_ids, [batch_size, seq_length, train_config.num_classes]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = BertModel( config=model_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) seq_out = model.get_sequence_output() if is_training: seq_out = dropout(seq_out, 0.1) logits = tf.keras.layers.Dense(train_config.num_classes, name="cls_dense")(seq_out) probs = tf.math.sigmoid(logits) eps = 1e-10 label_logs = tf.math.log(label_ids + eps) #scale = model_config.scale #label_ids = scale * label_ids is_valid_mask = tf.cast(segment_ids, tf.float32) #loss_arr = tf.keras.losses.MAE(y_true=label_ids, y_pred=probs) loss_arr = tf.keras.losses.MAE(y_true=label_logs, y_pred=logits) loss_arr = loss_arr * is_valid_mask loss = tf.reduce_mean(input_tensor=loss_arr) tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} scaffold_fn = None if train_config.init_checkpoint: initialized_variable_names, init_fn = get_init_fn( train_config, tvars) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec def metric_fn(probs, label, is_real_example): cut = math.exp(-10) pred_binary = probs > cut label_binary_all = label > cut pred_binary = pred_binary[:, :, 0] label_binary_1 = label_binary_all[:, :, 1] label_binary_0 = label_binary_all[:, :, 0] precision = tf.compat.v1.metrics.precision(predictions=pred_binary, labels=label_binary_0) recall = tf.compat.v1.metrics.recall(predictions=pred_binary, labels=label_binary_0) true_rate_1 = tf.compat.v1.metrics.mean(label_binary_1) true_rate_0 = tf.compat.v1.metrics.mean(label_binary_0) mae = tf.compat.v1.metrics.mean_absolute_error( labels=label, predictions=probs, weights=is_real_example) return { "mae": mae, "precision": precision, "recall": recall, "true_rate_1": true_rate_1, "true_rate_0": true_rate_0, } output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: tvars = None train_op = optimization.create_optimizer_from_config( loss, train_config, tvars) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn, [probs, label_ids, is_real_example]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "logits": logits, "label_ids": label_ids, } if "data_id" in features: predictions['data_id'] = features['data_id'] output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec