def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] if "next_sentence_labels" in features: next_sentence_labels = features["next_sentence_labels"] else: next_sentence_labels = get_dummy_next_sentence_labels(input_ids) if mode == tf.estimator.ModeKeys.PREDICT: tf.random.set_seed(0) seed = 0 print("Seed as zero") else: seed = None tf_logging.info("Doing dynamic masking (random)") special_tokens = [LABEL_UNK, LABEL_0, LABEL_1, LABEL_2] masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = random_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID, seed, special_tokens) masked_input_ids, masked_lm_positions_label, masked_label_ids_label, is_test_inst \ = get_label_indices(masked_input_ids) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = model_class( config=model_config, is_training=is_training, input_ids=masked_input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) with tf.compat.v1.variable_scope("label_token"): (masked_lm_loss_label, masked_lm_example_loss_label, masked_lm_log_probs_label) = get_masked_lm_output_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions_label, masked_label_ids_label, is_test_inst) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( model_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss + masked_lm_loss_label * model_config.ratio tvars = tf.compat.v1.trainable_variables() initialized_variable_names, initialized_variable_names2, init_fn\ = align_checkpoint_for_lm(tvars, train_config.checkpoint_type, train_config.init_checkpoint, train_config.second_init_checkpoint, ) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names, initialized_variable_names2) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( total_loss, train_config) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[OomReportingHook()], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, masked_lm_example_loss_label, masked_lm_log_probs_label, masked_label_ids_label, is_test_inst ]) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "masked_input_ids": masked_input_ids, "masked_lm_ids": masked_lm_ids, "masked_lm_example_loss": masked_lm_example_loss, "masked_lm_positions": masked_lm_positions, "masked_lm_example_loss_label": masked_lm_example_loss_label, "masked_lm_log_probs_label": masked_lm_log_probs_label, "masked_label_ids_label": masked_label_ids_label, "is_test_inst": is_test_inst, } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] d_input_ids = features["d_input_ids"] d_input_mask = features["d_input_mask"] d_location_ids = features["d_location_ids"] next_sentence_labels = features["next_sentence_labels"] if dict_run_config.prediction_op == "loss": seed = 0 else: seed = None if dict_run_config.prediction_op == "loss_fixed_mask" or train_config.fixed_mask: masked_input_ids = input_ids masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = tf.ones_like(masked_lm_positions, dtype=tf.float32) else: masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = random_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID, seed) if dict_run_config.use_d_segment_ids: d_segment_ids = features["d_segment_ids"] else: d_segment_ids = None is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = model_class( config=bert_config, d_config=dbert_config, is_training=is_training, input_ids=masked_input_ids, input_mask=input_mask, d_input_ids=d_input_ids, d_input_mask=d_input_mask, d_location_ids=d_location_ids, use_target_pos_emb=dict_run_config.use_target_pos_emb, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, d_segment_ids=d_segment_ids, pool_dict_output=dict_run_config.pool_dict_output, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( bert_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss if dict_run_config.train_op == "entry_prediction": score_label = features["useful_entry"] # [batch, 1] score_label = tf.reshape(score_label, [-1]) entry_logits = bert_common.dense(2, bert_common.create_initializer(bert_config.initializer_range))\ (model.get_dict_pooled_output()) print("entry_logits: ", entry_logits.shape) losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=entry_logits, labels=score_label) loss = tf.reduce_mean(losses) total_loss = loss if dict_run_config.train_op == "lookup": lookup_idx = features["lookup_idx"] lookup_loss, lookup_example_loss, lookup_score = \ sequence_index_prediction(bert_config, lookup_idx, model.get_sequence_output()) total_loss += lookup_loss tvars = tf.compat.v1.trainable_variables() init_vars = {} scaffold_fn = None if train_config.init_checkpoint: if dict_run_config.is_bert_checkpoint: map1, map2, init_vars = get_bert_assignment_map_for_dict(tvars, train_config.init_checkpoint) def load_fn(): tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, map1) tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, map2) else: map1, init_vars = get_assignment_map_as_is(tvars, train_config.init_checkpoint) def load_fn(): tf.compat.v1.train.init_from_checkpoint(train_config.init_checkpoint, map1) if train_config.use_tpu: def tpu_scaffold(): load_fn() return tf.compat.v1.train.Scaffold() scaffold_fn = tpu_scaffold else: load_fn() logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in init_vars: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) logging.info("Total parameters : %d" % get_param_num()) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: if train_config.gradient_accumulation == 1: train_op = optimization.create_optimizer_from_config(total_loss, train_config) else: logging.info("Using gradient accumulation : %d" % train_config.gradient_accumulation) train_op = get_accumulated_optimizer_from_config(total_loss, train_config, tvars, train_config.gradient_accumulation) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels ]) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: if dict_run_config.prediction_op == "gradient": logging.info("Fetching gradient") gradient = get_gradients(model, masked_lm_log_probs, train_config.max_predictions_per_seq, bert_config.vocab_size) predictions = { "masked_input_ids": masked_input_ids, #"input_ids": input_ids, "d_input_ids": d_input_ids, "masked_lm_positions": masked_lm_positions, "gradients": gradient, } elif dict_run_config.prediction_op == "loss" or dict_run_config.prediction_op == "loss_fixed_mask": logging.info("Fetching loss") predictions = { "masked_lm_example_loss": masked_lm_example_loss, } else: raise Exception("prediction target not specified") output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] if "next_sentence_labels" in features: next_sentence_labels = features["next_sentence_labels"] else: next_sentence_labels = get_dummy_next_sentence_labels(input_ids) tlm_prefix = "target_task" with tf.compat.v1.variable_scope(tlm_prefix): priority_score = tf.stop_gradient(priority_model(features)) priority_score = priority_score * target_model_config.amp masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights\ = biased_masking(input_ids, input_mask, priority_score, target_model_config.alpha, train_config.max_predictions_per_seq, MASK_ID) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = model_class( config=bert_config, is_training=is_training, input_ids=masked_input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( bert_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss + next_sentence_loss all_vars = tf.compat.v1.all_variables() tf_logging.info("We assume priority model is from v2") if train_config.checkpoint_type == "v2": assignment_map, initialized_variable_names = assignment_map_v2_to_v2( all_vars, train_config.init_checkpoint) assignment_map2, initialized_variable_names2 = get_assignment_map_remap_from_v2( all_vars, tlm_prefix, train_config.second_init_checkpoint) else: assignment_map, assignment_map2, initialized_variable_names \ = get_tlm_assignment_map_v2(all_vars, tlm_prefix, train_config.init_checkpoint, train_config.second_init_checkpoint) initialized_variable_names2 = None def init_fn(): if train_config.init_checkpoint: tf.compat.v1.train.init_from_checkpoint( train_config.init_checkpoint, assignment_map) if train_config.second_init_checkpoint: tf.compat.v1.train.init_from_checkpoint( train_config.second_init_checkpoint, assignment_map2) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) tvars = [v for v in all_vars if not v.name.startswith(tlm_prefix)] log_var_assignments(tvars, initialized_variable_names, initialized_variable_names2) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( total_loss, train_config, tvars) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels ]) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "masked_input_ids": masked_input_ids, "priority_score": priority_score, "lm_loss1": features["loss1"], "lm_loss2": features["loss2"], } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf_logging.info("*** Features ***") for name in sorted(features.keys()): tf_logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] next_sentence_labels = get_dummy_next_sentence_labels(input_ids) batch_size, seq_length = get_batch_and_seq_length(input_ids, 2) n_trial = seq_length - 20 masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = one_by_one_masking(input_ids, input_mask, MASK_ID, n_trial) num_classes = train_config.num_classes n_repeat = num_classes * n_trial # [ num_classes * n_trial * batch_size, seq_length] repeat_masked_input_ids = tf.tile(masked_input_ids, [num_classes, 1]) repeat_input_mask = tf.tile(input_mask, [n_repeat, 1]) repeat_segment_ids = tf.tile(segment_ids, [n_repeat, 1]) masked_lm_positions = tf.tile(masked_lm_positions, [num_classes, 1]) masked_lm_ids = tf.tile(masked_lm_ids, [num_classes, 1]) masked_lm_weights = tf.tile(masked_lm_weights, [num_classes, 1]) next_sentence_labels = tf.tile(next_sentence_labels, [n_repeat, 1]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) virtual_labels_ids = tf.tile(tf.expand_dims(tf.range(num_classes), 0), [1, batch_size * n_trial]) virtual_labels_ids = tf.reshape(virtual_labels_ids, [-1, 1]) print("repeat_masked_input_ids", repeat_masked_input_ids.shape) print("repeat_input_mask", repeat_input_mask.shape) print("virtual_labels_ids", virtual_labels_ids.shape) model = BertModelWithLabelInner( config=model_config, is_training=is_training, input_ids=repeat_masked_input_ids, input_mask=repeat_input_mask, token_type_ids=repeat_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, label_ids=virtual_labels_ids, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( model_config, model.get_pooled_output(), next_sentence_labels) total_loss = masked_lm_loss # loss = -log(prob) # TODO compare log prob of each label per_case_loss = tf.reshape(masked_lm_example_loss, [num_classes, -1, batch_size]) per_label_loss = tf.reduce_sum(per_case_loss, axis=1) bias = tf.zeros([3, 1]) per_label_score = tf.transpose(-per_label_loss + bias, [1, 0]) tvars = tf.compat.v1.trainable_variables() initialized_variable_names, initialized_variable_names2, init_fn\ = align_checkpoint_for_lm(tvars, train_config.checkpoint_type, train_config.init_checkpoint, train_config.second_init_checkpoint, ) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names, initialized_variable_names2) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( total_loss, train_config) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[OomReportingHook()], scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn_lm, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, ]) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = {"input_ids": input_ids, "logits": per_label_score} output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec