Beispiel #1
0
 def _get_pred_loss(self, teacher_logits, student_logits, weights):
     teacher_probs = tf.nn.softmax(teacher_logits, axis=-1)
     teacher_probs = tf.stop_gradient(teacher_probs)
     student_log_probs = tf.nn.log_softmax(student_logits, axis=-1)
     pred_loss = (
         - tf.reduce_sum(teacher_probs * student_log_probs, axis=-1) *
         tf.reshape(weights, [-1, 1]))
     pred_loss = tf.reduce_mean(pred_loss)
     return pred_loss
Beispiel #2
0
def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
    '''cache hidden states into memory.'''
    if mem_len is None or mem_len == 0:
        return None
    else:
        if reuse_len is not None and reuse_len > 0:
            curr_out = curr_out[:reuse_len]

        if prev_mem is None:
            new_mem = curr_out[-mem_len:]
        else:
            new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]

    return tf.stop_gradient(new_mem)
Beispiel #3
0
 def _get_embedding_loss(self, teacher, student, bert_config, weights):
     teacher_embedding = teacher.get_embedding_output()
     teacher_embedding = tf.stop_gradient(teacher_embedding)
     student_embedding = student.get_embedding_output()
     with tf.variable_scope('embedding_loss'):
         linear_trans = tf.layers.dense(
             student_embedding,
             bert_config.hidden_size,
             kernel_initializer=util.create_initializer(
                 bert_config.initializer_range))
         embedding_loss = tf.losses.mean_squared_error(
             linear_trans, teacher_embedding,
             weights=tf.reshape(weights, [-1, 1, 1]))
     return embedding_loss
Beispiel #4
0
 def _get_attention_loss(self, teacher, student,
                         bert_config, student_config, weights):
     teacher_attention_scores = teacher.get_attention_scores()
     teacher_attention_scores = [tf.stop_gradient(value)
                                 for value in teacher_attention_scores]
     student_attention_scores = student.get_attention_scores()
     num_teacher_hidden_layers = bert_config.num_hidden_layers
     num_student_hidden_layers = student_config.num_hidden_layers
     num_projections = \
         int(num_teacher_hidden_layers / num_student_hidden_layers)
     attention_losses = []
     for i in range(num_student_hidden_layers):
         attention_losses.append(tf.losses.mean_squared_error(
             teacher_attention_scores[
                 num_projections * i + num_projections - 1],
             student_attention_scores[i],
             weights=tf.reshape(weights, [-1, 1, 1, 1])),)
     attention_loss = tf.add_n(attention_losses)
     return attention_loss
Beispiel #5
0
 def _get_fake_data(self, inputs, mlm_logits):
     '''Sample from the generator to create corrupted input.'''
     inputs = unmask(inputs)
     disallow = tf.one_hot(
         inputs.masked_lm_ids, depth=self.bert_config.vocab_size,
         dtype=tf.float32) if self.config.disallow_correct else None
     sampled_tokens = tf.stop_gradient(sample_from_softmax(
         mlm_logits / self.config.temperature, disallow=disallow))
     sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
     updated_input_ids, masked = scatter_update(
         inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)
     labels = masked * (1 - tf.cast(
         tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
     updated_inputs = get_updated_inputs(
         inputs, input_ids=updated_input_ids)
     FakedData = collections.namedtuple('FakedData', [
         'inputs', 'is_fake_tokens', 'sampled_tokens'])
     return FakedData(inputs=updated_inputs, is_fake_tokens=labels,
                      sampled_tokens=sampled_tokens)
Beispiel #6
0
 def _get_hidden_loss(self, teacher, student,
                      bert_config, student_config, weights):
     teacher_hidden_layers = teacher.all_encoder_layers
     teacher_hidden_layers = [tf.stop_gradient(value)
                              for value in teacher_hidden_layers]
     student_hidden_layers = student.all_encoder_layers
     num_teacher_hidden_layers = bert_config.num_hidden_layers
     num_student_hidden_layers = student_config.num_hidden_layers
     num_projections = int(
         num_teacher_hidden_layers / num_student_hidden_layers)
     with tf.variable_scope('hidden_loss'):
         hidden_losses = []
         for i in range(num_student_hidden_layers):
             hidden_losses.append(tf.losses.mean_squared_error(
                 teacher_hidden_layers[
                     num_projections * i + num_projections - 1],
                 tf.layers.dense(
                     student_hidden_layers[i], bert_config.hidden_size,
                     kernel_initializer=util.create_initializer(
                         bert_config.initializer_range)),
                 weights=tf.reshape(weights, [-1, 1, 1])))
         hidden_loss = tf.add_n(hidden_losses)
     return hidden_loss
Beispiel #7
0
    def __init__(self,
                 is_training,
                 input_tensor,
                 is_supervised,
                 is_expanded,
                 label_ids,
                 label_size=2,
                 sample_weight=None,
                 scope='cls/seq_relationship',
                 hidden_dropout_prob=0.1,
                 initializer_range=0.02,
                 trainable=True,
                 global_step=None,
                 num_train_steps=None,
                 uda_softmax_temp=-1,
                 uda_confidence_thresh=-1,
                 tsa_schedule='linear',
                 **kwargs):
        super().__init__(**kwargs)

        is_supervised = tf.cast(is_supervised, tf.float32)
        is_expanded = tf.cast(is_expanded, tf.float32)

        hidden_size = input_tensor.shape.as_list()[-1]
        with tf.variable_scope(scope):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[label_size, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[label_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            output_layer = util.dropout(
                input_tensor, hidden_dropout_prob if is_training else 0.0)
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            with tf.variable_scope('sup_loss'):

                # reshape
                sup_ori_log_probs = tf.boolean_mask(log_probs,
                                                    mask=(1.0 - is_expanded),
                                                    axis=0)
                sup_log_probs = tf.boolean_mask(sup_ori_log_probs,
                                                mask=is_supervised,
                                                axis=0)
                sup_label_ids = tf.boolean_mask(label_ids,
                                                mask=is_supervised,
                                                axis=0)

                self.preds['preds'] = tf.argmax(sup_ori_log_probs, axis=-1)

                one_hot_labels = tf.one_hot(sup_label_ids,
                                            depth=label_size,
                                            dtype=tf.float32)
                per_example_loss = -tf.reduce_sum(
                    one_hot_labels * sup_log_probs, axis=-1)

                loss_mask = tf.ones_like(per_example_loss, dtype=tf.float32)
                correct_label_probs = tf.reduce_sum(one_hot_labels *
                                                    tf.exp(sup_log_probs),
                                                    axis=-1)

                if is_training and tsa_schedule:
                    tsa_start = 1.0 / label_size
                    tsa_threshold = get_tsa_threshold(tsa_schedule,
                                                      global_step,
                                                      num_train_steps,
                                                      tsa_start,
                                                      end=1)

                    larger_than_threshold = tf.greater(correct_label_probs,
                                                       tsa_threshold)
                    loss_mask = loss_mask * (
                        1 - tf.cast(larger_than_threshold, tf.float32))

                loss_mask = tf.stop_gradient(loss_mask)
                per_example_loss = per_example_loss * loss_mask
                if sample_weight is not None:
                    sup_sample_weight = tf.boolean_mask(sample_weight,
                                                        mask=is_supervised,
                                                        axis=0)
                    per_example_loss *= tf.cast(sup_sample_weight,
                                                dtype=tf.float32)
                sup_loss = (tf.reduce_sum(per_example_loss) /
                            tf.maximum(tf.reduce_sum(loss_mask), 1))

                self.losses['supervised'] = per_example_loss

            with tf.variable_scope('unsup_loss'):

                # reshape
                ori_log_probs = tf.boolean_mask(sup_ori_log_probs,
                                                mask=(1.0 - is_supervised),
                                                axis=0)
                aug_log_probs = tf.boolean_mask(log_probs,
                                                mask=is_expanded,
                                                axis=0)
                sup_ori_logits = tf.boolean_mask(logits,
                                                 mask=(1.0 - is_expanded),
                                                 axis=0)
                ori_logits = tf.boolean_mask(sup_ori_logits,
                                             mask=(1.0 - is_supervised),
                                             axis=0)

                unsup_loss_mask = 1
                if uda_softmax_temp != -1:
                    tgt_ori_log_probs = tf.nn.log_softmax(ori_logits /
                                                          uda_softmax_temp,
                                                          axis=-1)
                    tgt_ori_log_probs = tf.stop_gradient(tgt_ori_log_probs)
                else:
                    tgt_ori_log_probs = tf.stop_gradient(ori_log_probs)

                if uda_confidence_thresh != -1:
                    largest_prob = tf.reduce_max(tf.exp(ori_log_probs),
                                                 axis=-1)
                    unsup_loss_mask = tf.cast(
                        tf.greater(largest_prob, uda_confidence_thresh),
                        tf.float32)
                    unsup_loss_mask = tf.stop_gradient(unsup_loss_mask)

                per_example_loss = kl_for_log_probs(
                    tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask
                if sample_weight is not None:
                    unsup_sample_weight = tf.boolean_mask(sample_weight,
                                                          mask=(1.0 -
                                                                is_supervised),
                                                          axis=0)
                    per_example_loss *= tf.cast(unsup_sample_weight,
                                                dtype=tf.float32)
                unsup_loss = tf.reduce_mean(per_example_loss)

                self.losses['unsupervised'] = per_example_loss

            self.total_loss = sup_loss + unsup_loss
Beispiel #8
0
    def __init__(self,
                 bert_config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 use_one_hot_embeddings=True,
                 scope=None,
                 embedding_size=None,
                 input_embeddings=None,
                 input_reprs=None,
                 update_embeddings=True,
                 untied_embeddings=False):
        '''Constructor for BertModel.

        Args:
          bert_config: `BertConfig` instance.
          is_training: bool. true for training model, false for eval model.
            Controls whether dropout will be applied.
          input_ids: int32 Tensor of shape [batch_size, seq_length].
          input_mask: (optional) int32 Tensor of shape [batch_size,
            seq_length].
          token_type_ids: (optional) int32 Tensor of shape [batch_size,
            seq_length].
          use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
            embeddings or tf.embedding_lookup() for the word embeddings. On
            the TPU, it is much faster if this is True, on the CPU or GPU,
            it is faster if this is False.
          scope: (optional) variable scope. Defaults to 'electra'.

        Raises:
          ValueError: The config is invalid or one of the input tensor shapes
            is invalid.
        '''
        bert_config = copy.deepcopy(bert_config)
        if not is_training:
            bert_config.hidden_dropout_prob = 0.0
            bert_config.attention_probs_dropout_prob = 0.0

        input_shape = util.get_shape_list(token_type_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if input_mask is None:
            input_mask = tf.ones(shape=[batch_size, seq_length],
                                 dtype=tf.int32)

        assert token_type_ids is not None

        if input_reprs is None:
            with tf.variable_scope(
                ((scope if untied_embeddings else 'electra') + '/embeddings'),
                    reuse=tf.AUTO_REUSE):
                # Perform embedding lookup on the word ids
                if embedding_size is None:
                    embedding_size = bert_config.hidden_size
                (token_embeddings, self.embedding_table) = \
                    embedding_lookup(
                        input_ids=input_ids,
                        vocab_size=bert_config.vocab_size,
                        embedding_size=embedding_size,
                        initializer_range=bert_config.initializer_range,
                        word_embedding_name='word_embeddings',
                        use_one_hot_embeddings=use_one_hot_embeddings)

            with tf.variable_scope(
                ((scope if untied_embeddings else 'electra') + '/embeddings'),
                    reuse=tf.AUTO_REUSE):
                # Add positional embeddings and token type embeddings, then
                # layer normalize and perform dropout.
                self.embedding_output = embedding_postprocessor(
                    input_tensor=token_embeddings,
                    use_token_type=True,
                    token_type_ids=token_type_ids,
                    token_type_vocab_size=bert_config.type_vocab_size,
                    token_type_embedding_name='token_type_embeddings',
                    use_position_embeddings=True,
                    position_embedding_name='position_embeddings',
                    initializer_range=bert_config.initializer_range,
                    max_position_embeddings=\
                        bert_config.max_position_embeddings,
                    dropout_prob=bert_config.hidden_dropout_prob)
        else:
            self.embedding_output = input_reprs
        if not update_embeddings:
            self.embedding_output = tf.stop_gradient(self.embedding_output)

        with tf.variable_scope(scope, default_name='electra'):
            if self.embedding_output.shape[-1] != bert_config.hidden_size:
                self.embedding_output = tf.layers.dense(
                    self.embedding_output,
                    bert_config.hidden_size,
                    name='embeddings_project')

            with tf.variable_scope('encoder'):
                # This converts a 2D mask of shape [batch_size, seq_length]
                # to a 3D mask of shape [batch_size, seq_length, seq_length]
                # which is used for the attention scores.
                attention_mask = create_attention_mask_from_input_mask(
                    token_type_ids, input_mask)

                # Run the stacked transformer. Output shapes
                # attn_maps:
                #   [n_layers, batch_size, n_heads, seq_length, seq_length]
                (self.all_layer_outputs, self.attn_maps) = transformer_model(
                    input_tensor=self.embedding_output,
                    attention_mask=attention_mask,
                    hidden_size=bert_config.hidden_size,
                    num_hidden_layers=bert_config.num_hidden_layers,
                    num_attention_heads=bert_config.num_attention_heads,
                    intermediate_size=bert_config.intermediate_size,
                    intermediate_act_fn=util.get_activation(
                        bert_config.hidden_act),
                    hidden_dropout_prob=bert_config.hidden_dropout_prob,
                    attention_probs_dropout_prob=bert_config.
                    attention_probs_dropout_prob,
                    initializer_range=bert_config.initializer_range,
                    do_return_all_layers=True)
                self.sequence_output = self.all_layer_outputs[-1]
                self.pooled_output = self.sequence_output[:, 0]
Beispiel #9
0
    def _get_hidden_emd(self, teacher, student, teacher_weight, student_weight,
                        bert_config, sample_weight, emd_temporature):
        teacher_hidden_layers = teacher.all_encoder_layers
        teacher_hidden_layers = [
            tf.stop_gradient(value) for value in teacher_hidden_layers
        ]
        student_hidden_layers = student.all_encoder_layers
        M = len(teacher_hidden_layers)
        N = len(student_hidden_layers)

        with tf.variable_scope('hidden_emd'):
            flow = tf.get_variable('flow',
                                   shape=[M, N],
                                   initializer=tf.constant_initializer(1 / M /
                                                                       N),
                                   trainable=False)

            # MSE
            rows = []
            for m in range(M):
                cols = []
                for n in range(N):
                    linear_trans = tf.layers.dense(
                        student_hidden_layers[n],
                        bert_config.hidden_size,
                        kernel_initializer=util.create_initializer(
                            bert_config.initializer_range))
                    mse = tf.losses.mean_squared_error(
                        teacher_hidden_layers[m],
                        linear_trans,
                        weights=tf.reshape(sample_weight, [-1, 1, 1]))
                    col = tf.reshape(mse, [1, 1])
                    cols.append(col)
                row = tf.concat(cols, axis=1)
                rows.append(row)
            distance = tf.concat(rows, axis=0)

            # cost attention mechanism
            teacher_cost = (tf.reduce_sum(flow, axis=1) *
                            tf.reduce_sum(distance, axis=1) /
                            (teacher_weight + 1e-6))
            student_cost = (tf.reduce_sum(flow, axis=0) *
                            tf.reduce_sum(distance, axis=0) /
                            (student_weight + 1e-6))

            # new weights
            new_teacher_weight = tf.where(
                teacher_cost > 1e-12,
                tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6),
                teacher_weight)
            new_student_weight = tf.where(
                student_cost > 1e-12,
                tf.reduce_sum(student_cost) / (student_cost + 1e-6),
                student_weight)
            new_teacher_weight = tf.nn.softmax(new_teacher_weight /
                                               emd_temporature)
            new_student_weight = tf.nn.softmax(new_student_weight /
                                               emd_temporature)

        self.hidden_flow = flow
        self.hidden_distance = distance
        hidden_emd = tf.reduce_sum(flow * distance)
        return hidden_emd, new_teacher_weight, new_student_weight
Beispiel #10
0
    def _get_attention_emd(self, teacher, student, teacher_weight,
                           student_weight, sample_weight, emd_temporature):
        teacher_attention_scores = teacher.get_attention_scores()
        teacher_attention_scores = [
            tf.stop_gradient(value) for value in teacher_attention_scores
        ]
        student_attention_scores = student.get_attention_scores()
        M = len(teacher_attention_scores)
        N = len(student_attention_scores)

        with tf.variable_scope('attention_emd'):
            flow = tf.get_variable('flow',
                                   shape=[M, N],
                                   initializer=tf.constant_initializer(1 / M /
                                                                       N),
                                   trainable=False)

            # MSE
            rows = []
            for m in range(M):
                cols = []
                for n in range(N):
                    teacher_matrix = tf.where(
                        teacher_attention_scores[m] < -1e2,
                        tf.zeros_like(teacher_attention_scores[m]),
                        teacher_attention_scores[m])
                    student_matrix = tf.where(
                        student_attention_scores[n] < -1e2,
                        tf.zeros_like(student_attention_scores[n]),
                        student_attention_scores[n])
                    mse = tf.losses.mean_squared_error(teacher_matrix,
                                                       student_matrix,
                                                       weights=tf.reshape(
                                                           sample_weight,
                                                           [-1, 1, 1, 1]))
                    col = tf.reshape(mse, [1, 1])
                    cols.append(col)
                row = tf.concat(cols, axis=1)
                rows.append(row)
            distance = tf.concat(rows, axis=0)

            # cost attention mechanism
            teacher_cost = (tf.reduce_sum(flow, axis=1) *
                            tf.reduce_sum(distance, axis=1) /
                            (teacher_weight + 1e-6))
            student_cost = (tf.reduce_sum(flow, axis=0) *
                            tf.reduce_sum(distance, axis=0) /
                            (student_weight + 1e-6))

            # new weights
            new_teacher_weight = tf.where(
                teacher_cost > 1e-12,
                tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6),
                teacher_weight)
            new_student_weight = tf.where(
                student_cost > 1e-12,
                tf.reduce_sum(student_cost) / (student_cost + 1e-6),
                student_weight)
            new_teacher_weight = tf.nn.softmax(new_teacher_weight /
                                               emd_temporature)
            new_student_weight = tf.nn.softmax(new_student_weight /
                                               emd_temporature)

        self.attention_flow = flow
        self.attention_distance = distance
        attention_emd = tf.reduce_sum(flow * distance)
        return attention_emd, new_teacher_weight, new_student_weight