def _get_attention_loss(self, teacher, student, bert_config, student_config, weights): teacher_attention_scores = teacher.get_attention_scores() teacher_attention_scores = [tf.stop_gradient(value) for value in teacher_attention_scores] student_attention_scores = student.get_attention_scores() num_teacher_hidden_layers = bert_config.num_hidden_layers num_student_hidden_layers = student_config.num_hidden_layers num_projections = \ int(num_teacher_hidden_layers / num_student_hidden_layers) attention_losses = [] for i in range(num_student_hidden_layers): attention_losses.append(tf.losses.mean_squared_error( teacher_attention_scores[ num_projections * i + num_projections - 1], student_attention_scores[i], weights=tf.reshape(weights, [-1, 1, 1, 1])),) attention_loss = tf.add_n(attention_losses) return attention_loss
def _get_hidden_loss(self, teacher, student, bert_config, student_config, weights): teacher_hidden_layers = teacher.all_encoder_layers teacher_hidden_layers = [tf.stop_gradient(value) for value in teacher_hidden_layers] student_hidden_layers = student.all_encoder_layers num_teacher_hidden_layers = bert_config.num_hidden_layers num_student_hidden_layers = student_config.num_hidden_layers num_projections = int( num_teacher_hidden_layers / num_student_hidden_layers) with tf.variable_scope('hidden_loss'): hidden_losses = [] for i in range(num_student_hidden_layers): hidden_losses.append(tf.losses.mean_squared_error( teacher_hidden_layers[ num_projections * i + num_projections - 1], tf.layers.dense( student_hidden_layers[i], bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range)), weights=tf.reshape(weights, [-1, 1, 1]))) hidden_loss = tf.add_n(hidden_losses) return hidden_loss
def __init__(self, bert_config, is_training, encoder, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels, sample_weight=None, scope_lm='cls/predictions', scope_cls='cls/seq_relationship', trainable=True, use_nsp_loss=True, **kwargs): super(BERTDecoder, self).__init__(**kwargs) def gather_indexes(sequence_tensor, positions): sequence_shape = util.get_shape_list(sequence_tensor, 3) batch_size = sequence_shape[0] seq_length = sequence_shape[1] width = sequence_shape[2] flat_offsets = tf.reshape( tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) flat_positions = tf.reshape(positions + flat_offsets, [-1]) flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width]) output_tensor = tf.gather(flat_sequence_tensor, flat_positions) return output_tensor scalar_losses = [] # masked language modeling input_tensor = gather_indexes(encoder.get_sequence_output(), masked_lm_positions) with tf.variable_scope(scope_lm): with tf.variable_scope('transform'): input_tensor = tf.layers.dense( input_tensor, units=bert_config.hidden_size, activation=util.get_activation(bert_config.hidden_act), kernel_initializer=util.create_initializer( bert_config.initializer_range)) input_tensor = util.layer_norm(input_tensor) output_bias = tf.get_variable('output_bias', shape=[bert_config.vocab_size], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(input_tensor, encoder.get_embedding_table(), transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs') log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(masked_lm_ids, [-1]) if sample_weight is not None: sample_weight = tf.expand_dims(tf.cast(sample_weight, dtype=tf.float32), axis=-1) masked_lm_weights *= sample_weight label_weights = tf.reshape(masked_lm_weights, [-1]) one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) per_example_loss = label_weights * per_example_loss numerator = tf.reduce_sum(per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator scalar_losses.append(loss) self.losses['MLM_losses'] = per_example_loss self.preds['MLM_preds'] = tf.argmax(probs, axis=-1) # next sentence prediction with tf.variable_scope(scope_cls): output_weights = tf.get_variable( 'output_weights', shape=[2, bert_config.hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable('output_bias', shape=[2], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(encoder.get_pooled_output(), output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1, name='probs') log_probs = tf.nn.log_softmax(logits, axis=-1) labels = tf.reshape(next_sentence_labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) if sample_weight is not None: per_example_loss = (tf.cast(sample_weight, dtype=tf.float32) * per_example_loss) loss = tf.reduce_mean(per_example_loss) if use_nsp_loss: scalar_losses.append(loss) self.losses['NSP_losses'] = per_example_loss self.probs['NSP_probs'] = probs self.preds['NSP_preds'] = tf.argmax(probs, axis=-1) self.total_loss = tf.add_n(scalar_losses)
def __init__(self, bert_config, is_training, input_ids, input_mask, segment_ids, sample_weight=None, scope='bert', dtype=tf.float32, drop_pooler=False, cls_model='self-attention', label_size=2, speed=0.1, ignore_cls='0', **kwargs): super(FastBERTCLSDistillor, self).__init__() if not ignore_cls: ignore_cls = [] if isinstance(ignore_cls, str): ignore_cls = ignore_cls.replace(' ', '').split(',') ignore_cls = list(map(int, ignore_cls)) elif isinstance(ignore_cls, list): ignore_cls = list(map(int, ignore_cls)) else: raise ValueError( '`ignore_cls` should be a list of child-classifier ids or ' 'a string seperated with commas.') if not speed: raise ValueError( '`speed` should be a float number between `0` and `1`.') bert_config = copy.deepcopy(bert_config) bert_config.hidden_dropout_prob = 0.0 bert_config.attention_probs_dropout_prob = 0.0 input_shape = util.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] max_seq_length = input_shape[1] with tf.variable_scope(scope): with tf.variable_scope('embeddings'): (self.embedding_output, self.embedding_table) = \ self.embedding_lookup( input_ids=input_ids, vocab_size=bert_config.vocab_size, batch_size=batch_size, max_seq_length=max_seq_length, embedding_size=bert_config.hidden_size, initializer_range=bert_config.initializer_range, word_embedding_name='word_embeddings', dtype=dtype, trainable=False, tilda_embeddings=None) # Add positional embeddings and token type embeddings # layer normalize and perform dropout. self.embedding_output = self.embedding_postprocessor( input_tensor=self.embedding_output, batch_size=batch_size, max_seq_length=max_seq_length, hidden_size=bert_config.hidden_size, use_token_type=True, segment_ids=segment_ids, token_type_vocab_size=bert_config.type_vocab_size, token_type_embedding_name='token_type_embeddings', use_position_embeddings=True, position_embedding_name='position_embeddings', initializer_range=bert_config.initializer_range, max_position_embeddings=\ bert_config.max_position_embeddings, dropout_prob=bert_config.hidden_dropout_prob, dtype=dtype, trainable=False) with tf.variable_scope('encoder'): attention_mask = self.create_attention_mask_from_input_mask( input_mask, batch_size, max_seq_length, dtype=dtype) # stacked transformers (self.all_encoder_layers, self.all_cls_layers) = \ self.dynamic_transformer_model( is_training, input_tensor=self.embedding_output, input_mask=input_mask, batch_size=batch_size, max_seq_length=max_seq_length, label_size=label_size, attention_mask=attention_mask, hidden_size=bert_config.hidden_size, num_hidden_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, intermediate_act_fn=util.get_activation( bert_config.hidden_act), hidden_dropout_prob=bert_config.hidden_dropout_prob, attention_probs_dropout_prob=\ bert_config.attention_probs_dropout_prob, initializer_range=bert_config.initializer_range, dtype=dtype, cls_model=cls_model, speed=speed, ignore_cls=ignore_cls) self.sequence_output = self.all_encoder_layers[-1] with tf.variable_scope('pooler'): first_token_tensor = self.sequence_output[:, 0, :] # trick: ignore the fully connected layer if drop_pooler: self.pooled_output = first_token_tensor else: self.pooled_output = tf.layers.dense( first_token_tensor, bert_config.hidden_size, activation=tf.tanh, kernel_initializer=util.create_initializer( bert_config.initializer_range), trainable=False) # teacher classifier if bert_config.num_hidden_layers not in ignore_cls: with tf.variable_scope('cls/seq_relationship'): output_weights = tf.get_variable( 'output_weights', shape=[label_size, bert_config.hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=False) output_bias = tf.get_variable( 'output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=False) logits = tf.matmul(self.pooled_output, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probs = tf.nn.softmax(logits, axis=-1) # distillation if is_training: losses = [] for cls_probs in self.all_cls_layers.values(): # KL-Divergence per_example_loss = tf.reduce_sum( cls_probs * (tf.log(cls_probs) - tf.log(probs)), axis=-1) if sample_weight is not None: per_example_loss *= tf.cast(sample_weight, dtype=tf.float32) loss = tf.reduce_mean(per_example_loss) losses.append(loss) distill_loss = tf.add_n(losses) self.total_loss = distill_loss self.losses['losses'] = distill_loss else: if bert_config.num_hidden_layers not in ignore_cls: self.all_cls_layers[bert_config.num_hidden_layers] = probs self.probs['probs'] = tf.concat(list(self.all_cls_layers.values()), axis=0, name='probs')
def __init__(self, tiny_bert_config, bert_config, is_training, input_ids, input_mask, segment_ids, label_ids=None, sample_weight=None, scope='bert', name='', dtype=tf.float32, use_tilda_embedding=False, drop_pooler=False, label_size=2, trainable=True, **kwargs): super(TinyBERTCLSDistillor, self).__init__() def _get_logits(pooled_output, hidden_size, scope, trainable): with tf.variable_scope(scope): output_weights = tf.get_variable( 'output_weights', shape=[label_size, hidden_size], initializer=util.create_initializer( bert_config.initializer_range), trainable=trainable) output_bias = tf.get_variable( 'output_bias', shape=[label_size], initializer=tf.zeros_initializer(), trainable=trainable) logits = tf.matmul(pooled_output, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) return logits student = BERTEncoder(bert_config=tiny_bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope='tiny/bert', use_tilda_embedding=use_tilda_embedding, drop_pooler=drop_pooler, trainable=True, **kwargs) student_logits = _get_logits(student.get_pooled_output(), tiny_bert_config.hidden_size, 'tiny/cls/seq_relationship', True) if is_training: teacher = BERTEncoder(bert_config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope=scope, use_tilda_embedding=False, drop_pooler=drop_pooler, trainable=False, **kwargs) teacher_logits = _get_logits(teacher.get_pooled_output(), bert_config.hidden_size, 'cls/seq_relationship', False) weights = 1.0 if sample_weight is not None: weights = tf.cast(sample_weight, dtype=tf.float32) # embedding loss teacher_embedding = teacher.get_embedding_output() student_embedding = student.get_embedding_output() with tf.variable_scope('embedding_loss'): linear_trans = tf.layers.dense( student_embedding, bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range)) embedding_loss = tf.losses.mean_squared_error( linear_trans, teacher_embedding, weights=tf.reshape(weights, [-1, 1, 1])) # attention loss teacher_attention_scores = teacher.get_attention_scores() student_attention_scores = student.get_attention_scores() num_teacher_hidden_layers = bert_config.num_hidden_layers num_student_hidden_layers = tiny_bert_config.num_hidden_layers num_projections = \ int(num_teacher_hidden_layers / num_student_hidden_layers) attention_losses = [] for i in range(num_student_hidden_layers): attention_losses.append( tf.losses.mean_squared_error( teacher_attention_scores[num_projections * i + num_projections - 1], student_attention_scores[i], weights=tf.reshape(weights, [-1, 1, 1, 1])), ) attention_loss = tf.add_n(attention_losses) # hidden loss teacher_hidden_layers = teacher.all_encoder_layers student_hidden_layers = student.all_encoder_layers num_teacher_hidden_layers = bert_config.num_hidden_layers num_student_hidden_layers = tiny_bert_config.num_hidden_layers num_projections = int(num_teacher_hidden_layers / num_student_hidden_layers) with tf.variable_scope('hidden_loss'): hidden_losses = [] for i in range(num_student_hidden_layers): hidden_losses.append( tf.losses.mean_squared_error( teacher_hidden_layers[num_projections * i + num_projections - 1], tf.layers.dense( student_hidden_layers[i], bert_config.hidden_size, kernel_initializer=util.create_initializer( bert_config.initializer_range)), weights=tf.reshape(weights, [-1, 1, 1]))) hidden_loss = tf.add_n(hidden_losses) # prediction loss teacher_probs = tf.nn.softmax(teacher_logits, axis=-1) student_log_probs = tf.nn.log_softmax(student_logits, axis=-1) pred_loss = ( -tf.reduce_sum(teacher_probs * student_log_probs, axis=-1) * tf.reshape(weights, [-1, 1])) pred_loss = tf.reduce_mean(pred_loss) # sum up distill_loss = (embedding_loss + attention_loss + hidden_loss + pred_loss) self.total_loss = distill_loss self.losses['distill'] = tf.reshape(distill_loss, [1]) else: student_probs = tf.nn.softmax(student_logits, axis=-1, name='probs') self.probs[name] = student_probs