Esempio n. 1
0
def norm(x, scope, *, axis=-1, epsilon=1e-5):
    '''Normalize to mean = 0, std = 1, then do a diagonal affine transform.'''
    with tf.variable_scope(scope):
        n_state = x.shape[-1].value
        g = tf.get_variable('g', [n_state],
                            initializer=tf.constant_initializer(1))
        b = tf.get_variable('b', [n_state],
                            initializer=tf.constant_initializer(0))
        u = tf.reduce_mean(x, axis=axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x - u), axis=axis, keepdims=True)
        x = (x - u) * tf.rsqrt(s + epsilon)
        x = x * g + b
        return x
Esempio n. 2
0
def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.get_variable(
            'w', [1, nx, nf],
            initializer=tf.random_normal_initializer(stddev=w_init_stdev))
        b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
        c = tf.reshape(
            tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b,
            start + [nf])
        return c
Esempio n. 3
0
    def __init__(self,
                 student_config,
                 bert_config,
                 is_training,
                 input_ids,
                 input_mask,
                 segment_ids,
                 label_ids=None,
                 sample_weight=None,
                 scope='bert',
                 dtype=tf.float32,
                 drop_pooler=False,
                 label_size=2,
                 pred_temporature=1.0,
                 emd_temporature=1.0,
                 beta=0.01,
                 **kwargs):
        super(TinyBERTCLSDistillor, self).__init__()

        def _get_logits(pooled_output, hidden_size, scope, trainable):
            with tf.variable_scope(scope):
                output_weights = tf.get_variable(
                    'output_weights',
                    shape=[label_size, hidden_size],
                    initializer=util.create_initializer(
                        bert_config.initializer_range),
                    trainable=trainable)
                output_bias = tf.get_variable(
                    'output_bias',
                    shape=[label_size],
                    initializer=tf.zeros_initializer(),
                    trainable=trainable)

                logits = tf.matmul(pooled_output,
                                   output_weights,
                                   transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)
                return logits

        use_tilda_embedding = kwargs.get('use_tilda_embedding')
        student = BERTEncoder(bert_config=student_config,
                              is_training=is_training,
                              input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              scope='tiny/bert',
                              use_tilda_embedding=use_tilda_embedding,
                              drop_pooler=drop_pooler,
                              trainable=True,
                              **kwargs)
        student_logits = _get_logits(student.get_pooled_output(),
                                     student_config.hidden_size,
                                     'tiny/cls/seq_relationship', True)

        if is_training:
            teacher = BERTEncoder(bert_config=bert_config,
                                  is_training=False,
                                  input_ids=input_ids,
                                  input_mask=input_mask,
                                  segment_ids=segment_ids,
                                  scope=scope,
                                  use_tilda_embedding=False,
                                  drop_pooler=drop_pooler,
                                  trainable=False,
                                  **kwargs)
            teacher_logits = _get_logits(teacher.get_pooled_output(),
                                         bert_config.hidden_size,
                                         'cls/seq_relationship', False)

            weights = 1.0
            if sample_weight is not None:
                weights = tf.cast(sample_weight, dtype=tf.float32)

            # embedding loss
            embedding_loss = self._get_embedding_loss(teacher, student,
                                                      bert_config, weights)

            # emd
            M = bert_config.num_hidden_layers
            N = student_config.num_hidden_layers
            with tf.variable_scope('emd'):
                teacher_weight = tf.get_variable(
                    'teacher_weight',
                    shape=[M],
                    initializer=tf.constant_initializer(1 / M),
                    trainable=False)
                student_weight = tf.get_variable(
                    'student_weight',
                    shape=[N],
                    initializer=tf.constant_initializer(1 / N),
                    trainable=False)
                self.teacher_weight = teacher_weight
                self.student_weight = student_weight

            # attention emd
            (attention_emd, new_attention_teacher_weight,
             new_attention_student_weight) = \
                self._get_attention_emd(
                    teacher, student, teacher_weight, student_weight,
                    weights, emd_temporature)

            # hidden emd
            (hidden_emd, new_hidden_teacher_weight,
             new_hidden_student_weight) = \
                self._get_hidden_emd(
                    teacher, student, teacher_weight, student_weight,
                    bert_config, weights, emd_temporature)

            # update weights
            new_teacher_weight = \
                (new_attention_teacher_weight + new_hidden_teacher_weight) / 2
            new_student_weight = \
                (new_attention_student_weight + new_hidden_student_weight) / 2
            update_teacher_weight_op = tf.assign(teacher_weight,
                                                 new_teacher_weight)
            update_student_weight_op = tf.assign(student_weight,
                                                 new_student_weight)

            # prediction loss
            pred_loss = self._get_pred_loss(teacher_logits, student_logits,
                                            weights, pred_temporature)

            # sum up
            with tf.control_dependencies(
                [update_teacher_weight_op, update_student_weight_op]):
                distill_loss = \
                    beta * (embedding_loss + attention_emd + hidden_emd) + \
                    pred_loss
            self.total_loss = distill_loss
            self.losses['losses'] = distill_loss

        else:
            student_probs = tf.nn.softmax(student_logits,
                                          axis=-1,
                                          name='probs')
            self.probs['probs'] = student_probs
            self.preds['preds'] = tf.argmax(student_probs, axis=-1)
Esempio n. 4
0
    def _get_hidden_emd(self, teacher, student, teacher_weight, student_weight,
                        bert_config, sample_weight, emd_temporature):
        teacher_hidden_layers = teacher.all_encoder_layers
        teacher_hidden_layers = [
            tf.stop_gradient(value) for value in teacher_hidden_layers
        ]
        student_hidden_layers = student.all_encoder_layers
        M = len(teacher_hidden_layers)
        N = len(student_hidden_layers)

        with tf.variable_scope('hidden_emd'):
            flow = tf.get_variable('flow',
                                   shape=[M, N],
                                   initializer=tf.constant_initializer(1 / M /
                                                                       N),
                                   trainable=False)

            # MSE
            rows = []
            for m in range(M):
                cols = []
                for n in range(N):
                    linear_trans = tf.layers.dense(
                        student_hidden_layers[n],
                        bert_config.hidden_size,
                        kernel_initializer=util.create_initializer(
                            bert_config.initializer_range))
                    mse = tf.losses.mean_squared_error(
                        teacher_hidden_layers[m],
                        linear_trans,
                        weights=tf.reshape(sample_weight, [-1, 1, 1]))
                    col = tf.reshape(mse, [1, 1])
                    cols.append(col)
                row = tf.concat(cols, axis=1)
                rows.append(row)
            distance = tf.concat(rows, axis=0)

            # cost attention mechanism
            teacher_cost = (tf.reduce_sum(flow, axis=1) *
                            tf.reduce_sum(distance, axis=1) /
                            (teacher_weight + 1e-6))
            student_cost = (tf.reduce_sum(flow, axis=0) *
                            tf.reduce_sum(distance, axis=0) /
                            (student_weight + 1e-6))

            # new weights
            new_teacher_weight = tf.where(
                teacher_cost > 1e-12,
                tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6),
                teacher_weight)
            new_student_weight = tf.where(
                student_cost > 1e-12,
                tf.reduce_sum(student_cost) / (student_cost + 1e-6),
                student_weight)
            new_teacher_weight = tf.nn.softmax(new_teacher_weight /
                                               emd_temporature)
            new_student_weight = tf.nn.softmax(new_student_weight /
                                               emd_temporature)

        self.hidden_flow = flow
        self.hidden_distance = distance
        hidden_emd = tf.reduce_sum(flow * distance)
        return hidden_emd, new_teacher_weight, new_student_weight
Esempio n. 5
0
    def _get_attention_emd(self, teacher, student, teacher_weight,
                           student_weight, sample_weight, emd_temporature):
        teacher_attention_scores = teacher.get_attention_scores()
        teacher_attention_scores = [
            tf.stop_gradient(value) for value in teacher_attention_scores
        ]
        student_attention_scores = student.get_attention_scores()
        M = len(teacher_attention_scores)
        N = len(student_attention_scores)

        with tf.variable_scope('attention_emd'):
            flow = tf.get_variable('flow',
                                   shape=[M, N],
                                   initializer=tf.constant_initializer(1 / M /
                                                                       N),
                                   trainable=False)

            # MSE
            rows = []
            for m in range(M):
                cols = []
                for n in range(N):
                    teacher_matrix = tf.where(
                        teacher_attention_scores[m] < -1e2,
                        tf.zeros_like(teacher_attention_scores[m]),
                        teacher_attention_scores[m])
                    student_matrix = tf.where(
                        student_attention_scores[n] < -1e2,
                        tf.zeros_like(student_attention_scores[n]),
                        student_attention_scores[n])
                    mse = tf.losses.mean_squared_error(teacher_matrix,
                                                       student_matrix,
                                                       weights=tf.reshape(
                                                           sample_weight,
                                                           [-1, 1, 1, 1]))
                    col = tf.reshape(mse, [1, 1])
                    cols.append(col)
                row = tf.concat(cols, axis=1)
                rows.append(row)
            distance = tf.concat(rows, axis=0)

            # cost attention mechanism
            teacher_cost = (tf.reduce_sum(flow, axis=1) *
                            tf.reduce_sum(distance, axis=1) /
                            (teacher_weight + 1e-6))
            student_cost = (tf.reduce_sum(flow, axis=0) *
                            tf.reduce_sum(distance, axis=0) /
                            (student_weight + 1e-6))

            # new weights
            new_teacher_weight = tf.where(
                teacher_cost > 1e-12,
                tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6),
                teacher_weight)
            new_student_weight = tf.where(
                student_cost > 1e-12,
                tf.reduce_sum(student_cost) / (student_cost + 1e-6),
                student_weight)
            new_teacher_weight = tf.nn.softmax(new_teacher_weight /
                                               emd_temporature)
            new_student_weight = tf.nn.softmax(new_student_weight /
                                               emd_temporature)

        self.attention_flow = flow
        self.attention_distance = distance
        attention_emd = tf.reduce_sum(flow * distance)
        return attention_emd, new_teacher_weight, new_student_weight
Esempio n. 6
0
    def __init__(self,
                 vocab_size,
                 filter_sizes,
                 num_channels,
                 is_training,
                 input_ids,
                 scope='text_cnn',
                 embedding_size=256,
                 dropout_prob=0.1,
                 trainable=True,
                 **kwargs):

        input_shape = util.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        max_seq_length = input_shape[1]

        if isinstance(filter_sizes, str):
            filter_sizes = filter_sizes.split(',')
        assert isinstance(filter_sizes, list), (
            '`filter_sizes` should be a list of integers or a string '
            'seperated with commas.')

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        use_tilda_embedding=kwargs.get('use_tilda_embedding')
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):
            with tf.variable_scope('embeddings'):

                if tilda_embeddings is not None:
                    embedding_table = tilda_embeddings
                else:
                    embedding_table = tf.get_variable(
                        name='word_embeddings',
                        shape=[vocab_size, embedding_size],
                        initializer=util.create_initializer(0.02),
                        dtype=tf.float32,
                        trainable=trainable)

                flat_input_ids = tf.reshape(input_ids, [-1])
                output = tf.gather(
                    embedding_table, flat_input_ids, name='embedding_look_up')
                output = tf.reshape(
                    output, [batch_size, max_seq_length, embedding_size])

                output_expanded = tf.expand_dims(output, -1)

            # Create a convolution + maxpool layer for each filter size
            pooled_outputs = []
            for i, filter_size in enumerate(filter_sizes):
                with tf.variable_scope('conv_%s' % filter_size):

                    # Convolution Layer
                    filter_shape = [filter_size, embedding_size, 1, num_channels]
                    W = tf.get_variable(
                        name='W',
                        shape=filter_shape,
                        initializer=\
                            tf.truncated_normal_initializer(0.1),
                        dtype=tf.float32,
                        trainable=trainable)
                    b = tf.get_variable(
                        name='b',
                        shape=[num_channels],
                        initializer=\
                            tf.constant_initializer(0.1),
                        dtype=tf.float32,
                        trainable=trainable)
                    conv = tf.nn.conv2d(
                        output_expanded, W,
                        strides=[1, 1, 1, 1],
                        padding='VALID',
                        name='conv')

                    # Apply nonlinearity
                    h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')

                    # Maxpooling over the outputs
                    pooled = tf.nn.max_pool(
                        h,
                        ksize=[1, max_seq_length - int(filter_size) + 1, 1, 1],
                        strides=[1, 1, 1, 1],
                        padding='VALID',
                        name='pool')
                    pooled_outputs.append(pooled)

            num_channels_total = num_channels * len(filter_sizes)
            h_pool = tf.concat(pooled_outputs, 3)
            h_pool_flat = tf.reshape(h_pool, [batch_size, num_channels_total])

            with tf.name_scope('dropout'):
                self.pooled_output = util.dropout(h_pool_flat, dropout_prob)