Пример #1
0
 def _get_attention_loss(self, teacher, student,
                         bert_config, student_config, weights):
     teacher_attention_scores = teacher.get_attention_scores()
     teacher_attention_scores = [tf.stop_gradient(value)
                                 for value in teacher_attention_scores]
     student_attention_scores = student.get_attention_scores()
     num_teacher_hidden_layers = bert_config.num_hidden_layers
     num_student_hidden_layers = student_config.num_hidden_layers
     num_projections = \
         int(num_teacher_hidden_layers / num_student_hidden_layers)
     attention_losses = []
     for i in range(num_student_hidden_layers):
         attention_losses.append(tf.losses.mean_squared_error(
             teacher_attention_scores[
                 num_projections * i + num_projections - 1],
             student_attention_scores[i],
             weights=tf.reshape(weights, [-1, 1, 1, 1])),)
     attention_loss = tf.add_n(attention_losses)
     return attention_loss
Пример #2
0
 def _get_hidden_loss(self, teacher, student,
                      bert_config, student_config, weights):
     teacher_hidden_layers = teacher.all_encoder_layers
     teacher_hidden_layers = [tf.stop_gradient(value)
                              for value in teacher_hidden_layers]
     student_hidden_layers = student.all_encoder_layers
     num_teacher_hidden_layers = bert_config.num_hidden_layers
     num_student_hidden_layers = student_config.num_hidden_layers
     num_projections = int(
         num_teacher_hidden_layers / num_student_hidden_layers)
     with tf.variable_scope('hidden_loss'):
         hidden_losses = []
         for i in range(num_student_hidden_layers):
             hidden_losses.append(tf.losses.mean_squared_error(
                 teacher_hidden_layers[
                     num_projections * i + num_projections - 1],
                 tf.layers.dense(
                     student_hidden_layers[i], bert_config.hidden_size,
                     kernel_initializer=util.create_initializer(
                         bert_config.initializer_range)),
                 weights=tf.reshape(weights, [-1, 1, 1])))
         hidden_loss = tf.add_n(hidden_losses)
     return hidden_loss
Пример #3
0
    def __init__(self,
                 bert_config,
                 is_training,
                 encoder,
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 next_sentence_labels,
                 sample_weight=None,
                 scope_lm='cls/predictions',
                 scope_cls='cls/seq_relationship',
                 trainable=True,
                 use_nsp_loss=True,
                 **kwargs):
        super(BERTDecoder, self).__init__(**kwargs)

        def gather_indexes(sequence_tensor, positions):
            sequence_shape = util.get_shape_list(sequence_tensor, 3)
            batch_size = sequence_shape[0]
            seq_length = sequence_shape[1]
            width = sequence_shape[2]

            flat_offsets = tf.reshape(
                tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
            flat_positions = tf.reshape(positions + flat_offsets, [-1])
            flat_sequence_tensor = tf.reshape(sequence_tensor,
                                              [batch_size * seq_length, width])
            output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
            return output_tensor

        scalar_losses = []

        # masked language modeling
        input_tensor = gather_indexes(encoder.get_sequence_output(),
                                      masked_lm_positions)
        with tf.variable_scope(scope_lm):
            with tf.variable_scope('transform'):
                input_tensor = tf.layers.dense(
                    input_tensor,
                    units=bert_config.hidden_size,
                    activation=util.get_activation(bert_config.hidden_act),
                    kernel_initializer=util.create_initializer(
                        bert_config.initializer_range))
                input_tensor = util.layer_norm(input_tensor)
            output_bias = tf.get_variable('output_bias',
                                          shape=[bert_config.vocab_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            logits = tf.matmul(input_tensor,
                               encoder.get_embedding_table(),
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs')
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            label_ids = tf.reshape(masked_lm_ids, [-1])
            if sample_weight is not None:
                sample_weight = tf.expand_dims(tf.cast(sample_weight,
                                                       dtype=tf.float32),
                                               axis=-1)
                masked_lm_weights *= sample_weight
            label_weights = tf.reshape(masked_lm_weights, [-1])
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=bert_config.vocab_size,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels,
                                              axis=[-1])
            per_example_loss = label_weights * per_example_loss

            numerator = tf.reduce_sum(per_example_loss)
            denominator = tf.reduce_sum(label_weights) + 1e-5
            loss = numerator / denominator

            scalar_losses.append(loss)
            self.losses['MLM_losses'] = per_example_loss
            self.preds['MLM_preds'] = tf.argmax(probs, axis=-1)

        # next sentence prediction
        with tf.variable_scope(scope_cls):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[2, bert_config.hidden_size],
                initializer=util.create_initializer(
                    bert_config.initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[2],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            logits = tf.matmul(encoder.get_pooled_output(),
                               output_weights,
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            probs = tf.nn.softmax(logits, axis=-1, name='probs')
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            labels = tf.reshape(next_sentence_labels, [-1])
            one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            if sample_weight is not None:
                per_example_loss = (tf.cast(sample_weight, dtype=tf.float32) *
                                    per_example_loss)
            loss = tf.reduce_mean(per_example_loss)

            if use_nsp_loss:
                scalar_losses.append(loss)
            self.losses['NSP_losses'] = per_example_loss
            self.probs['NSP_probs'] = probs
            self.preds['NSP_preds'] = tf.argmax(probs, axis=-1)

        self.total_loss = tf.add_n(scalar_losses)
Пример #4
0
    def __init__(self,
                 bert_config,
                 is_training,
                 input_ids,
                 input_mask,
                 segment_ids,
                 sample_weight=None,
                 scope='bert',
                 dtype=tf.float32,
                 drop_pooler=False,
                 cls_model='self-attention',
                 label_size=2,
                 speed=0.1,
                 ignore_cls='0',
                 **kwargs):
        super(FastBERTCLSDistillor, self).__init__()

        if not ignore_cls:
            ignore_cls = []
        if isinstance(ignore_cls, str):
            ignore_cls = ignore_cls.replace(' ', '').split(',')
            ignore_cls = list(map(int, ignore_cls))
        elif isinstance(ignore_cls, list):
            ignore_cls = list(map(int, ignore_cls))
        else:
            raise ValueError(
                '`ignore_cls` should be a list of child-classifier ids or '
                'a string seperated with commas.')

        if not speed:
            raise ValueError(
                '`speed` should be a float number between `0` and `1`.')

        bert_config = copy.deepcopy(bert_config)
        bert_config.hidden_dropout_prob = 0.0
        bert_config.attention_probs_dropout_prob = 0.0

        input_shape = util.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        max_seq_length = input_shape[1]

        with tf.variable_scope(scope):
            with tf.variable_scope('embeddings'):

                (self.embedding_output, self.embedding_table) = \
                    self.embedding_lookup(
                        input_ids=input_ids,
                        vocab_size=bert_config.vocab_size,
                        batch_size=batch_size,
                        max_seq_length=max_seq_length,
                        embedding_size=bert_config.hidden_size,
                        initializer_range=bert_config.initializer_range,
                        word_embedding_name='word_embeddings',
                        dtype=dtype,
                        trainable=False,
                        tilda_embeddings=None)

                # Add positional embeddings and token type embeddings
                # layer normalize and perform dropout.
                self.embedding_output = self.embedding_postprocessor(
                    input_tensor=self.embedding_output,
                    batch_size=batch_size,
                    max_seq_length=max_seq_length,
                    hidden_size=bert_config.hidden_size,
                    use_token_type=True,
                    segment_ids=segment_ids,
                    token_type_vocab_size=bert_config.type_vocab_size,
                    token_type_embedding_name='token_type_embeddings',
                    use_position_embeddings=True,
                    position_embedding_name='position_embeddings',
                    initializer_range=bert_config.initializer_range,
                    max_position_embeddings=\
                        bert_config.max_position_embeddings,
                    dropout_prob=bert_config.hidden_dropout_prob,
                    dtype=dtype,
                    trainable=False)

            with tf.variable_scope('encoder'):
                attention_mask = self.create_attention_mask_from_input_mask(
                    input_mask, batch_size, max_seq_length, dtype=dtype)

                # stacked transformers
                (self.all_encoder_layers, self.all_cls_layers) = \
                    self.dynamic_transformer_model(
                        is_training,
                        input_tensor=self.embedding_output,
                        input_mask=input_mask,
                        batch_size=batch_size,
                        max_seq_length=max_seq_length,
                        label_size=label_size,
                        attention_mask=attention_mask,
                        hidden_size=bert_config.hidden_size,
                        num_hidden_layers=bert_config.num_hidden_layers,
                        num_attention_heads=bert_config.num_attention_heads,
                        intermediate_size=bert_config.intermediate_size,
                        intermediate_act_fn=util.get_activation(
                            bert_config.hidden_act),
                        hidden_dropout_prob=bert_config.hidden_dropout_prob,
                        attention_probs_dropout_prob=\
                            bert_config.attention_probs_dropout_prob,
                        initializer_range=bert_config.initializer_range,
                        dtype=dtype,
                        cls_model=cls_model,
                        speed=speed,
                        ignore_cls=ignore_cls)

            self.sequence_output = self.all_encoder_layers[-1]
            with tf.variable_scope('pooler'):
                first_token_tensor = self.sequence_output[:, 0, :]

                # trick: ignore the fully connected layer
                if drop_pooler:
                    self.pooled_output = first_token_tensor
                else:
                    self.pooled_output = tf.layers.dense(
                        first_token_tensor,
                        bert_config.hidden_size,
                        activation=tf.tanh,
                        kernel_initializer=util.create_initializer(
                            bert_config.initializer_range),
                        trainable=False)

        # teacher classifier
        if bert_config.num_hidden_layers not in ignore_cls:
            with tf.variable_scope('cls/seq_relationship'):
                output_weights = tf.get_variable(
                    'output_weights',
                    shape=[label_size, bert_config.hidden_size],
                    initializer=util.create_initializer(
                        bert_config.initializer_range),
                    trainable=False)
                output_bias = tf.get_variable(
                    'output_bias',
                    shape=[label_size],
                    initializer=tf.zeros_initializer(),
                    trainable=False)

                logits = tf.matmul(self.pooled_output,
                                   output_weights,
                                   transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)
                probs = tf.nn.softmax(logits, axis=-1)

        # distillation
        if is_training:
            losses = []
            for cls_probs in self.all_cls_layers.values():

                # KL-Divergence
                per_example_loss = tf.reduce_sum(
                    cls_probs * (tf.log(cls_probs) - tf.log(probs)), axis=-1)
                if sample_weight is not None:
                    per_example_loss *= tf.cast(sample_weight,
                                                dtype=tf.float32)
                loss = tf.reduce_mean(per_example_loss)
                losses.append(loss)

            distill_loss = tf.add_n(losses)
            self.total_loss = distill_loss
            self.losses['losses'] = distill_loss

        else:
            if bert_config.num_hidden_layers not in ignore_cls:
                self.all_cls_layers[bert_config.num_hidden_layers] = probs
            self.probs['probs'] = tf.concat(list(self.all_cls_layers.values()),
                                            axis=0,
                                            name='probs')
Пример #5
0
    def __init__(self,
                 tiny_bert_config,
                 bert_config,
                 is_training,
                 input_ids,
                 input_mask,
                 segment_ids,
                 label_ids=None,
                 sample_weight=None,
                 scope='bert',
                 name='',
                 dtype=tf.float32,
                 use_tilda_embedding=False,
                 drop_pooler=False,
                 label_size=2,
                 trainable=True,
                 **kwargs):
        super(TinyBERTCLSDistillor, self).__init__()

        def _get_logits(pooled_output, hidden_size, scope, trainable):
            with tf.variable_scope(scope):
                output_weights = tf.get_variable(
                    'output_weights',
                    shape=[label_size, hidden_size],
                    initializer=util.create_initializer(
                        bert_config.initializer_range),
                    trainable=trainable)
                output_bias = tf.get_variable(
                    'output_bias',
                    shape=[label_size],
                    initializer=tf.zeros_initializer(),
                    trainable=trainable)

                logits = tf.matmul(pooled_output,
                                   output_weights,
                                   transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)
                return logits

        student = BERTEncoder(bert_config=tiny_bert_config,
                              is_training=is_training,
                              input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              scope='tiny/bert',
                              use_tilda_embedding=use_tilda_embedding,
                              drop_pooler=drop_pooler,
                              trainable=True,
                              **kwargs)
        student_logits = _get_logits(student.get_pooled_output(),
                                     tiny_bert_config.hidden_size,
                                     'tiny/cls/seq_relationship', True)

        if is_training:
            teacher = BERTEncoder(bert_config=bert_config,
                                  is_training=False,
                                  input_ids=input_ids,
                                  input_mask=input_mask,
                                  segment_ids=segment_ids,
                                  scope=scope,
                                  use_tilda_embedding=False,
                                  drop_pooler=drop_pooler,
                                  trainable=False,
                                  **kwargs)
            teacher_logits = _get_logits(teacher.get_pooled_output(),
                                         bert_config.hidden_size,
                                         'cls/seq_relationship', False)

            weights = 1.0
            if sample_weight is not None:
                weights = tf.cast(sample_weight, dtype=tf.float32)

            # embedding loss
            teacher_embedding = teacher.get_embedding_output()
            student_embedding = student.get_embedding_output()
            with tf.variable_scope('embedding_loss'):
                linear_trans = tf.layers.dense(
                    student_embedding,
                    bert_config.hidden_size,
                    kernel_initializer=util.create_initializer(
                        bert_config.initializer_range))
                embedding_loss = tf.losses.mean_squared_error(
                    linear_trans,
                    teacher_embedding,
                    weights=tf.reshape(weights, [-1, 1, 1]))

            # attention loss
            teacher_attention_scores = teacher.get_attention_scores()
            student_attention_scores = student.get_attention_scores()
            num_teacher_hidden_layers = bert_config.num_hidden_layers
            num_student_hidden_layers = tiny_bert_config.num_hidden_layers
            num_projections = \
                int(num_teacher_hidden_layers / num_student_hidden_layers)
            attention_losses = []
            for i in range(num_student_hidden_layers):
                attention_losses.append(
                    tf.losses.mean_squared_error(
                        teacher_attention_scores[num_projections * i +
                                                 num_projections - 1],
                        student_attention_scores[i],
                        weights=tf.reshape(weights, [-1, 1, 1, 1])), )
            attention_loss = tf.add_n(attention_losses)

            # hidden loss
            teacher_hidden_layers = teacher.all_encoder_layers
            student_hidden_layers = student.all_encoder_layers
            num_teacher_hidden_layers = bert_config.num_hidden_layers
            num_student_hidden_layers = tiny_bert_config.num_hidden_layers
            num_projections = int(num_teacher_hidden_layers /
                                  num_student_hidden_layers)
            with tf.variable_scope('hidden_loss'):
                hidden_losses = []
                for i in range(num_student_hidden_layers):
                    hidden_losses.append(
                        tf.losses.mean_squared_error(
                            teacher_hidden_layers[num_projections * i +
                                                  num_projections - 1],
                            tf.layers.dense(
                                student_hidden_layers[i],
                                bert_config.hidden_size,
                                kernel_initializer=util.create_initializer(
                                    bert_config.initializer_range)),
                            weights=tf.reshape(weights, [-1, 1, 1])))
                hidden_loss = tf.add_n(hidden_losses)

            # prediction loss
            teacher_probs = tf.nn.softmax(teacher_logits, axis=-1)
            student_log_probs = tf.nn.log_softmax(student_logits, axis=-1)
            pred_loss = (
                -tf.reduce_sum(teacher_probs * student_log_probs, axis=-1) *
                tf.reshape(weights, [-1, 1]))
            pred_loss = tf.reduce_mean(pred_loss)

            # sum up
            distill_loss = (embedding_loss + attention_loss + hidden_loss +
                            pred_loss)
            self.total_loss = distill_loss
            self.losses['distill'] = tf.reshape(distill_loss, [1])

        else:
            student_probs = tf.nn.softmax(student_logits,
                                          axis=-1,
                                          name='probs')
            self.probs[name] = student_probs