class SimpleSharingModel: def __init__( self, config, use_one_hot_embeddings, is_training, masked_input_ids, input_mask, segment_ids, nli_input_ids, nli_input_mask, nli_segment_ids, ): all_input_ids = tf.concat([masked_input_ids, nli_input_ids], axis=0) all_input_mask = tf.concat([input_mask, nli_input_mask], axis=0) all_segment_ids = tf.concat([segment_ids, nli_segment_ids], axis=0) self.batch_size, _ = get_shape_list2(masked_input_ids) self.model = BertModel(config, is_training, all_input_ids, all_input_mask, all_segment_ids, use_one_hot_embeddings) def lm_sequence_output(self): return self.model.get_sequence_output()[:self.batch_size] def get_embedding_table(self): return self.model.get_embedding_table() def get_tt_feature(self): return self.model.get_pooled_output()[self.batch_size:]
class AddLayerSharingModel: def __init__( self, config, use_one_hot_embeddings, is_training, masked_input_ids, input_mask, segment_ids, tt_input_ids, tt_input_mask, tt_segment_ids, ): all_input_ids = tf.concat([masked_input_ids, tt_input_ids], axis=0) all_input_mask = tf.concat([input_mask, tt_input_mask], axis=0) all_segment_ids = tf.concat([segment_ids, tt_segment_ids], axis=0) self.config = config self.lm_batch_size, _ = get_shape_list2(masked_input_ids) self.model = BertModel(config, is_training, all_input_ids, all_input_mask, all_segment_ids, use_one_hot_embeddings) initializer = base.create_initializer(config.initializer_range) self.tt_layer = ForwardLayer(config, initializer) self.tt_input_mask = tt_input_mask seq_output = self.model.get_sequence_output()[self.lm_batch_size:] tt_batch_size, seq_length = get_shape_list2(tt_input_ids) tt_attention_mask = create_attention_mask_from_input_mask2( seq_output, self.tt_input_mask) print('tt_attention_mask', tt_attention_mask.shape) print("seq_output", seq_output.shape) seq_output = self.tt_layer.apply_3d(seq_output, tt_batch_size, seq_length, tt_attention_mask) self.tt_feature = mimic_pooling(seq_output, self.config.hidden_size, self.config.initializer_range) def lm_sequence_output(self): return self.model.get_sequence_output()[:self.lm_batch_size] def get_embedding_table(self): return self.model.get_embedding_table() def get_tt_feature(self): return self.tt_feature
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] next_sentence_labels = features["next_sentence_labels"] n_trial = 25 logging.info("Doing All Masking") masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = planned_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID, n_trial) is_training = (mode == tf.estimator.ModeKeys.TRAIN) repeat_input_mask = tf.tile(input_mask, [n_trial, 1]) repeat_segment_ids = tf.tile(segment_ids, [n_trial, 1]) prefix1 = "MaybeBERT" prefix2 = "MaybeBFN" with tf.compat.v1.variable_scope(prefix1): model = BertModel( config=bert_config, is_training=is_training, input_ids=masked_input_ids, input_mask=repeat_input_mask, token_type_ids=repeat_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss1, masked_lm_log_probs2) = get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) with tf.compat.v1.variable_scope(prefix2): model = BertModel( config=bert_config, is_training=is_training, input_ids=masked_input_ids, input_mask=repeat_input_mask, token_type_ids=repeat_segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss2, masked_lm_log_probs2) = get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) n_mask = train_config.max_predictions_per_seq def reform(t): t = tf.reshape(t, [n_trial, -1, n_mask]) t = tf.transpose(t, [1, 0, 2]) return t grouped_positions = reform(masked_lm_positions) grouped_loss1 = reform(masked_lm_example_loss1) grouped_loss2 = reform(masked_lm_example_loss2) tvars = tf.compat.v1.trainable_variables() scaffold_fn = None initialized_variable_names, init_fn = get_init_fn_for_two_checkpoints( train_config, tvars, train_config.init_checkpoint, prefix1, train_config.second_init_checkpoint, prefix2) if train_config.use_tpu: def tpu_scaffold(): init_fn() return tf.compat.v1.train.Scaffold() scaffold_fn = tpu_scaffold else: init_fn() log_var_assignments(tvars, initialized_variable_names) output_spec = None if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids, "grouped_positions": grouped_positions, "grouped_loss1": grouped_loss1, "grouped_loss2": grouped_loss2, } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=None, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" logging.info("*** Features ***") for name in sorted(features.keys()): logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] next_sentence_labels = features["next_sentence_labels"] seed = 0 threshold = 1e-2 logging.info("Doing All Masking") masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = random_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID, seed) is_training = (mode == tf.estimator.ModeKeys.TRAIN) prefix1 = "MaybeBERT" prefix2 = "MaybeNLI" with tf.compat.v1.variable_scope(prefix1): model = BertModel( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss1, masked_lm_log_probs2) = get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) all_layers1 = model.get_all_encoder_layers() with tf.compat.v1.variable_scope(prefix2): model = BertModel( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) all_layers2 = model.get_all_encoder_layers() preserved_infos = [] for a_layer, b_layer in zip(all_layers1, all_layers2): layer_diff = a_layer - b_layer is_preserved = tf.less(tf.abs(layer_diff), threshold) preserved_infos.append(is_preserved) t = tf.cast(preserved_infos[1], dtype=tf.int32) #[batch_size, seq_len, dims] layer_1_count = tf.reduce_sum(t, axis=2) tvars = tf.compat.v1.trainable_variables() initialized_variable_names, init_fn = get_init_fn_for_two_checkpoints(train_config, tvars, train_config.init_checkpoint, prefix1, train_config.second_init_checkpoint, prefix2) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) output_spec = None if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "input_ids": input_ids, "layer_count": layer_1_count } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=None, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument log_features(features) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] next_sentence_labels = features["next_sentence_labels"] masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = random_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID) is_training = (mode == tf.estimator.ModeKeys.TRAIN) prefix1 = "MaybeBERT" prefix2 = "MaybeBFN" with tf.compat.v1.variable_scope(prefix1): model1 = BertModel( config=bert_config, is_training=is_training, input_ids=masked_input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss1, masked_lm_log_probs1) = get_masked_lm_output( bert_config, model1.get_sequence_output(), model1.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) masked_lm_example_loss1 = tf.reshape(masked_lm_example_loss1, masked_lm_ids.shape) with tf.compat.v1.variable_scope(prefix2): model2 = BertModel( config=bert_config, is_training=is_training, input_ids=masked_input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss2, masked_lm_log_probs2) = get_masked_lm_output( bert_config, model2.get_sequence_output(), model2.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) print(model2.get_sequence_output().shape) masked_lm_example_loss2 = tf.reshape(masked_lm_example_loss2, masked_lm_ids.shape) model = model_class( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) loss_model = IndependentLossModel(bert_config) loss_model.train_modeling(model.get_sequence_output(), masked_lm_positions, masked_lm_weights, tf.stop_gradient(masked_lm_example_loss1), tf.stop_gradient(masked_lm_example_loss2)) total_loss = loss_model.total_loss loss1 = loss_model.loss1 loss2 = loss_model.loss2 per_example_loss1 = loss_model.per_example_loss1 per_example_loss2 = loss_model.per_example_loss2 losses1 = tf.reduce_sum(per_example_loss1, axis=1) losses2 = tf.reduce_sum(per_example_loss2, axis=1) prob1 = loss_model.prob1 prob2 = loss_model.prob2 checkpoint2_1, checkpoint2_2 = train_config.second_init_checkpoint.split( ",") tvars = tf.compat.v1.trainable_variables() initialized_variable_names_1, init_fn_1 = get_init_fn_for_two_checkpoints( train_config, tvars, checkpoint2_1, prefix1, checkpoint2_2, prefix2) assignment_fn = get_bert_assignment_map assignment_map2, initialized_variable_names_2 = assignment_fn( tvars, train_config.init_checkpoint) initialized_variable_names = {} initialized_variable_names.update(initialized_variable_names_1) initialized_variable_names.update(initialized_variable_names_2) def init_fn(): init_fn_1() tf.compat.v1.train.init_from_checkpoint( train_config.init_checkpoint, assignment_map2) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer_from_config( total_loss, train_config) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss1, per_example_loss2): loss1 = tf.compat.v1.metrics.mean(values=per_example_loss1) loss2 = tf.compat.v1.metrics.mean(values=per_example_loss2) return { "loss1": loss1, "loss2": loss2, } eval_metrics = (metric_fn, [losses1, losses2]) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "prob1": prob1, "prob2": prob2, "per_example_loss1": per_example_loss1, "per_example_loss2": per_example_loss2, "input_ids": input_ids, "masked_lm_positions": masked_lm_positions, } output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument tf_logging.info("model_fn_apr_lm") """The `model_fn` for TPUEstimator.""" log_features(features) raw_input_ids = features["input_ids"] # [batch_size, seq_length] raw_input_mask = features["input_mask"] raw_segment_ids = features["segment_ids"] word_tokens = features["word"] word_input_mask = tf.cast(tf.not_equal(word_tokens, 0), tf.int32) word_segment_ids = tf.ones_like(word_tokens, tf.int32) if mode == tf.estimator.ModeKeys.PREDICT: tf.random.set_seed(0) seed = 0 else: seed = None input_ids = tf.concat([word_tokens, raw_input_ids], axis=1) input_mask = tf.concat([word_input_mask, raw_input_mask], axis=1) segment_ids = tf.concat([word_segment_ids, raw_segment_ids], axis=1) is_training = (mode == tf.estimator.ModeKeys.TRAIN) tf_logging.info("Using masked_input_ids") masked_input_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights \ = random_masking(input_ids, input_mask, train_config.max_predictions_per_seq, MASK_ID, seed) model = BertModel( config=config, is_training=is_training, input_ids=masked_input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=train_config.use_one_hot_embeddings, ) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) loss = masked_lm_loss tvars = tf.compat.v1.trainable_variables() assignment_fn = tlm.training.assignment_map.get_bert_assignment_map initialized_variable_names, init_fn = get_init_fn( tvars, train_config.init_checkpoint, assignment_fn) scaffold_fn = get_tpu_scaffold_or_init(init_fn, train_config.use_tpu) log_var_assignments(tvars, initialized_variable_names) TPUEstimatorSpec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec if mode == tf.estimator.ModeKeys.TRAIN: tf_logging.info("Using single lr ") train_op = optimization.create_optimizer_from_config( loss, train_config) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: eval_metrics = (metric_fn_lm, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, ]) output_spec = TPUEstimatorSpec(mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: predictions = { "input_ids": input_ids, "masked_input_ids": masked_input_ids, "masked_lm_ids": masked_lm_ids, "masked_lm_example_loss": masked_lm_example_loss, "masked_lm_positions": masked_lm_positions } output_spec = TPUEstimatorSpec(mode=mode, loss=loss, predictions=predictions, scaffold_fn=scaffold_fn) return output_spec