def _get_pred_loss(self, teacher_logits, student_logits, weights): teacher_probs = tf.nn.softmax(teacher_logits, axis=-1) teacher_probs = tf.stop_gradient(teacher_probs) student_log_probs = tf.nn.log_softmax(student_logits, axis=-1) pred_loss = ( - tf.reduce_sum(teacher_probs * student_log_probs, axis=-1) * tf.reshape(weights, [-1, 1])) pred_loss = tf.reduce_mean(pred_loss) return pred_loss
def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None): '''cache hidden states into memory.''' if mem_len is None or mem_len == 0: return None else: if reuse_len is not None and reuse_len > 0: curr_out = curr_out[:reuse_len] if prev_mem is None: new_mem = curr_out[-mem_len:] else: new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:] return tf.stop_gradient(new_mem)
def _get_embedding_loss(self, teacher, student, bert_config, weights): teacher_embedding = teacher.get_embedding_output() teacher_embedding = tf.stop_gradient(teacher_embedding) student_embedding = student.get_embedding_output() with tf.variable_scope('embedding_loss'): linear_trans = tf.layers.dense( student_embedding, bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range)) embedding_loss = tf.losses.mean_squared_error( linear_trans, teacher_embedding, weights=tf.reshape(weights, [-1, 1, 1])) return embedding_loss
def _get_attention_loss(self, teacher, student, bert_config, student_config, weights): teacher_attention_scores = teacher.get_attention_scores() teacher_attention_scores = [tf.stop_gradient(value) for value in teacher_attention_scores] student_attention_scores = student.get_attention_scores() num_teacher_hidden_layers = bert_config.num_hidden_layers num_student_hidden_layers = student_config.num_hidden_layers num_projections = \ int(num_teacher_hidden_layers / num_student_hidden_layers) attention_losses = [] for i in range(num_student_hidden_layers): attention_losses.append(tf.losses.mean_squared_error( teacher_attention_scores[ num_projections * i + num_projections - 1], student_attention_scores[i], weights=tf.reshape(weights, [-1, 1, 1, 1])),) attention_loss = tf.add_n(attention_losses) return attention_loss
def _get_fake_data(self, inputs, mlm_logits): '''Sample from the generator to create corrupted input.''' inputs = unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self.bert_config.vocab_size, dtype=tf.float32) if self.config.disallow_correct else None sampled_tokens = tf.stop_gradient(sample_from_softmax( mlm_logits / self.config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple('FakedData', [ 'inputs', 'is_fake_tokens', 'sampled_tokens']) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def _get_hidden_loss(self, teacher, student, bert_config, student_config, weights): teacher_hidden_layers = teacher.all_encoder_layers teacher_hidden_layers = [tf.stop_gradient(value) for value in teacher_hidden_layers] student_hidden_layers = student.all_encoder_layers num_teacher_hidden_layers = bert_config.num_hidden_layers num_student_hidden_layers = student_config.num_hidden_layers num_projections = int( num_teacher_hidden_layers / num_student_hidden_layers) with tf.variable_scope('hidden_loss'): hidden_losses = [] for i in range(num_student_hidden_layers): hidden_losses.append(tf.losses.mean_squared_error( teacher_hidden_layers[ num_projections * i + num_projections - 1], tf.layers.dense( student_hidden_layers[i], bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range)), weights=tf.reshape(weights, [-1, 1, 1]))) hidden_loss = tf.add_n(hidden_losses) return hidden_loss
def __init__(self, is_training, input_tensor, is_supervised, is_expanded, label_ids, label_size=2, sample_weight=None, scope='cls/seq_relationship', hidden_dropout_prob=0.1, initializer_range=0.02, trainable=True, global_step=None, num_train_steps=None, uda_softmax_temp=-1, uda_confidence_thresh=-1, tsa_schedule='linear', **kwargs): super().__init__(**kwargs) is_supervised = tf.cast(is_supervised, tf.float32) is_expanded = tf.cast(is_expanded, tf.float32) hidden_size = input_tensor.shape.as_list()[-1] with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer(initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) output_layer = util.dropout( input_tensor, hidden_dropout_prob if is_training else 0.0) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) with tf.variable_scope('sup_loss'): # reshape sup_ori_log_probs = tf.boolean_mask(log_probs, mask=(1.0 - is_expanded), axis=0) sup_log_probs = tf.boolean_mask(sup_ori_log_probs, mask=is_supervised, axis=0) sup_label_ids = tf.boolean_mask(label_ids, mask=is_supervised, axis=0) self.preds['preds'] = tf.argmax(sup_ori_log_probs, axis=-1) one_hot_labels = tf.one_hot(sup_label_ids, depth=label_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum( one_hot_labels * sup_log_probs, axis=-1) loss_mask = tf.ones_like(per_example_loss, dtype=tf.float32) correct_label_probs = tf.reduce_sum(one_hot_labels * tf.exp(sup_log_probs), axis=-1) if is_training and tsa_schedule: tsa_start = 1.0 / label_size tsa_threshold = get_tsa_threshold(tsa_schedule, global_step, num_train_steps, tsa_start, end=1) larger_than_threshold = tf.greater(correct_label_probs, tsa_threshold) loss_mask = loss_mask * ( 1 - tf.cast(larger_than_threshold, tf.float32)) loss_mask = tf.stop_gradient(loss_mask) per_example_loss = per_example_loss * loss_mask if sample_weight is not None: sup_sample_weight = tf.boolean_mask(sample_weight, mask=is_supervised, axis=0) per_example_loss *= tf.cast(sup_sample_weight, dtype=tf.float32) sup_loss = (tf.reduce_sum(per_example_loss) / tf.maximum(tf.reduce_sum(loss_mask), 1)) self.losses['supervised'] = per_example_loss with tf.variable_scope('unsup_loss'): # reshape ori_log_probs = tf.boolean_mask(sup_ori_log_probs, mask=(1.0 - is_supervised), axis=0) aug_log_probs = tf.boolean_mask(log_probs, mask=is_expanded, axis=0) sup_ori_logits = tf.boolean_mask(logits, mask=(1.0 - is_expanded), axis=0) ori_logits = tf.boolean_mask(sup_ori_logits, mask=(1.0 - is_supervised), axis=0) unsup_loss_mask = 1 if uda_softmax_temp != -1: tgt_ori_log_probs = tf.nn.log_softmax(ori_logits / uda_softmax_temp, axis=-1) tgt_ori_log_probs = tf.stop_gradient(tgt_ori_log_probs) else: tgt_ori_log_probs = tf.stop_gradient(ori_log_probs) if uda_confidence_thresh != -1: largest_prob = tf.reduce_max(tf.exp(ori_log_probs), axis=-1) unsup_loss_mask = tf.cast( tf.greater(largest_prob, uda_confidence_thresh), tf.float32) unsup_loss_mask = tf.stop_gradient(unsup_loss_mask) per_example_loss = kl_for_log_probs( tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask if sample_weight is not None: unsup_sample_weight = tf.boolean_mask(sample_weight, mask=(1.0 - is_supervised), axis=0) per_example_loss *= tf.cast(unsup_sample_weight, dtype=tf.float32) unsup_loss = tf.reduce_mean(per_example_loss) self.losses['unsupervised'] = per_example_loss self.total_loss = sup_loss + unsup_loss
def __init__(self, bert_config, is_training, input_ids, input_mask=None, token_type_ids=None, use_one_hot_embeddings=True, scope=None, embedding_size=None, input_embeddings=None, input_reprs=None, update_embeddings=True, untied_embeddings=False): '''Constructor for BertModel. Args: bert_config: `BertConfig` instance. is_training: bool. true for training model, false for eval model. Controls whether dropout will be applied. input_ids: int32 Tensor of shape [batch_size, seq_length]. input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. use_one_hot_embeddings: (optional) bool. Whether to use one-hot word embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, it is much faster if this is True, on the CPU or GPU, it is faster if this is False. scope: (optional) variable scope. Defaults to 'electra'. Raises: ValueError: The config is invalid or one of the input tensor shapes is invalid. ''' bert_config = copy.deepcopy(bert_config) if not is_training: bert_config.hidden_dropout_prob = 0.0 bert_config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(token_type_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) assert token_type_ids is not None if input_reprs is None: with tf.variable_scope( ((scope if untied_embeddings else 'electra') + '/embeddings'), reuse=tf.AUTO_REUSE): # Perform embedding lookup on the word ids if embedding_size is None: embedding_size = bert_config.hidden_size (token_embeddings, self.embedding_table) = \ embedding_lookup( input_ids=input_ids, vocab_size=bert_config.vocab_size, embedding_size=embedding_size, initializer_range=bert_config.initializer_range, word_embedding_name='word_embeddings', use_one_hot_embeddings=use_one_hot_embeddings) with tf.variable_scope( ((scope if untied_embeddings else 'electra') + '/embeddings'), reuse=tf.AUTO_REUSE): # Add positional embeddings and token type embeddings, then # layer normalize and perform dropout. self.embedding_output = embedding_postprocessor( input_tensor=token_embeddings, use_token_type=True, token_type_ids=token_type_ids, token_type_vocab_size=bert_config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=bert_config.initializer_range, max_position_embeddings=\ bert_config.max_position_embeddings, dropout_prob=bert_config.hidden_dropout_prob) else: self.embedding_output = input_reprs if not update_embeddings: self.embedding_output = tf.stop_gradient(self.embedding_output) with tf.variable_scope(scope, default_name='electra'): if self.embedding_output.shape[-1] != bert_config.hidden_size: self.embedding_output = tf.layers.dense( self.embedding_output, bert_config.hidden_size, name='embeddings_project') with tf.variable_scope('encoder'): # This converts a 2D mask of shape [batch_size, seq_length] # to a 3D mask of shape [batch_size, seq_length, seq_length] # which is used for the attention scores. attention_mask = create_attention_mask_from_input_mask( token_type_ids, input_mask) # Run the stacked transformer. Output shapes # attn_maps: # [n_layers, batch_size, n_heads, seq_length, seq_length] (self.all_layer_outputs, self.attn_maps) = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=bert_config.hidden_size, num_hidden_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, intermediate_act_fn=util.get_activation( bert_config.hidden_act), hidden_dropout_prob=bert_config.hidden_dropout_prob, attention_probs_dropout_prob=bert_config. attention_probs_dropout_prob, initializer_range=bert_config.initializer_range, do_return_all_layers=True) self.sequence_output = self.all_layer_outputs[-1] self.pooled_output = self.sequence_output[:, 0]
def _get_hidden_emd(self, teacher, student, teacher_weight, student_weight, bert_config, sample_weight, emd_temporature): teacher_hidden_layers = teacher.all_encoder_layers teacher_hidden_layers = [ tf.stop_gradient(value) for value in teacher_hidden_layers ] student_hidden_layers = student.all_encoder_layers M = len(teacher_hidden_layers) N = len(student_hidden_layers) with tf.variable_scope('hidden_emd'): flow = tf.get_variable('flow', shape=[M, N], initializer=tf.constant_initializer(1 / M / N), trainable=False) # MSE rows = [] for m in range(M): cols = [] for n in range(N): linear_trans = tf.layers.dense( student_hidden_layers[n], bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range)) mse = tf.losses.mean_squared_error( teacher_hidden_layers[m], linear_trans, weights=tf.reshape(sample_weight, [-1, 1, 1])) col = tf.reshape(mse, [1, 1]) cols.append(col) row = tf.concat(cols, axis=1) rows.append(row) distance = tf.concat(rows, axis=0) # cost attention mechanism teacher_cost = (tf.reduce_sum(flow, axis=1) * tf.reduce_sum(distance, axis=1) / (teacher_weight + 1e-6)) student_cost = (tf.reduce_sum(flow, axis=0) * tf.reduce_sum(distance, axis=0) / (student_weight + 1e-6)) # new weights new_teacher_weight = tf.where( teacher_cost > 1e-12, tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6), teacher_weight) new_student_weight = tf.where( student_cost > 1e-12, tf.reduce_sum(student_cost) / (student_cost + 1e-6), student_weight) new_teacher_weight = tf.nn.softmax(new_teacher_weight / emd_temporature) new_student_weight = tf.nn.softmax(new_student_weight / emd_temporature) self.hidden_flow = flow self.hidden_distance = distance hidden_emd = tf.reduce_sum(flow * distance) return hidden_emd, new_teacher_weight, new_student_weight
def _get_attention_emd(self, teacher, student, teacher_weight, student_weight, sample_weight, emd_temporature): teacher_attention_scores = teacher.get_attention_scores() teacher_attention_scores = [ tf.stop_gradient(value) for value in teacher_attention_scores ] student_attention_scores = student.get_attention_scores() M = len(teacher_attention_scores) N = len(student_attention_scores) with tf.variable_scope('attention_emd'): flow = tf.get_variable('flow', shape=[M, N], initializer=tf.constant_initializer(1 / M / N), trainable=False) # MSE rows = [] for m in range(M): cols = [] for n in range(N): teacher_matrix = tf.where( teacher_attention_scores[m] < -1e2, tf.zeros_like(teacher_attention_scores[m]), teacher_attention_scores[m]) student_matrix = tf.where( student_attention_scores[n] < -1e2, tf.zeros_like(student_attention_scores[n]), student_attention_scores[n]) mse = tf.losses.mean_squared_error(teacher_matrix, student_matrix, weights=tf.reshape( sample_weight, [-1, 1, 1, 1])) col = tf.reshape(mse, [1, 1]) cols.append(col) row = tf.concat(cols, axis=1) rows.append(row) distance = tf.concat(rows, axis=0) # cost attention mechanism teacher_cost = (tf.reduce_sum(flow, axis=1) * tf.reduce_sum(distance, axis=1) / (teacher_weight + 1e-6)) student_cost = (tf.reduce_sum(flow, axis=0) * tf.reduce_sum(distance, axis=0) / (student_weight + 1e-6)) # new weights new_teacher_weight = tf.where( teacher_cost > 1e-12, tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6), teacher_weight) new_student_weight = tf.where( student_cost > 1e-12, tf.reduce_sum(student_cost) / (student_cost + 1e-6), student_weight) new_teacher_weight = tf.nn.softmax(new_teacher_weight / emd_temporature) new_student_weight = tf.nn.softmax(new_student_weight / emd_temporature) self.attention_flow = flow self.attention_distance = distance attention_emd = tf.reduce_sum(flow * distance) return attention_emd, new_teacher_weight, new_student_weight