Example #1
0
                def _forward(dilated_ids, dilated_mask):

                    logits = self._bert_forward(
                        bert_config,
                        dilated_ids,
                        dilated_mask,
                        batch_size,
                        dilated_seq_length,
                        tilda_embeddings=tilda_embeddings)
                    output_ids = tf.argmax(logits, axis=-1)
                    output_ids = tf.cast(output_ids, dtype=tf.int32)

                    # special padding (using `spad` token)
                    equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32)
                    equal_zero = tf.reduce_sum(equal_zero, axis=-1)
                    right_pad = spad_id * tf.sequence_mask(
                        equal_zero, dilated_seq_length, dtype=tf.int32)
                    paded = tf.concat([output_ids, right_pad], axis=-1)

                    # extract ids of length `max_seq_length`
                    flattened_padded = tf.reshape(paded, [-1])
                    is_valid = tf.cast(tf.greater(flattened_padded, 0),
                                       dtype=tf.int32)
                    flattened_valid = tf.boolean_mask(flattened_padded,
                                                      is_valid)
                    valid = tf.reshape(flattened_valid,
                                       [batch_size, dilated_seq_length])
                    cutted_valid = valid[:, :max_seq_length]

                    # replace `spad` token with `pad`
                    non_spad_mask = tf.cast(tf.not_equal(
                        cutted_valid, spad_id),
                                            dtype=tf.int32)
                    output_ids = cutted_valid * non_spad_mask
                    output_length = tf.reduce_sum(non_spad_mask, axis=-1)

                    # dilate
                    reshaped_ids = tf.reshape(output_ids,
                                              [batch_size, max_seq_length, 1])
                    reshaped_mask = tf.reshape(
                        tf.sequence_mask(output_length,
                                         max_seq_length,
                                         dtype=tf.int32),
                        [batch_size, max_seq_length, 1])
                    concat_ids = tf.concat(
                        [reshaped_ids,
                         tf.zeros_like(reshaped_ids)], axis=-1)
                    concat_mask = tf.concat([
                        reshaped_mask,
                        tf.zeros_like(reshaped_mask, dtype=tf.int32)
                    ],
                                            axis=-1)
                    dilated_ids = tf.reshape(concat_ids,
                                             [batch_size, max_seq_length * 2])
                    dilated_mask = tf.reshape(concat_mask,
                                              [batch_size, max_seq_length * 2])

                    return dilated_ids, dilated_mask
Example #2
0
    def grad(res_grad):

        grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

        gr_sums = sums

        q_grads = []
        k_grads = []
        v_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None,
                                                                     Ellipsis])
            grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index],
                                      res_grad[index])
            k_grads.append(
                tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis])
            v_grads.append(
                tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis])
            gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index],
                                          vs[index])

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)
        v_grads = tf.concat(v_grads[::-1], axis=0)

        return q_grads, k_grads, v_grads
Example #3
0
    def _forward(self, is_training, split_placeholders, **kwargs):

        if not is_training:
            return super()._forward(is_training, split_placeholders, **kwargs)

        aug_input_ids = tf.boolean_mask(
            split_placeholders['aug_input_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_input_mask = tf.boolean_mask(
            split_placeholders['aug_input_mask'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_segment_ids = tf.boolean_mask(
            split_placeholders['aug_segment_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids],
                              axis=0)
        input_mask = tf.concat(
            [split_placeholders['input_mask'], aug_input_mask], axis=0)
        segment_ids = tf.concat(
            [split_placeholders['segment_ids'], aug_segment_ids], axis=0)
        encoder = BERTEncoder(bert_config=self.bert_config,
                              is_training=is_training,
                              input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              scope='bert',
                              drop_pooler=self._drop_pooler,
                              **kwargs)
        encoder_output = encoder.get_pooled_output()

        label_ids = split_placeholders['label_ids']
        is_expanded = tf.zeros_like(label_ids, dtype=tf.float32)
        batch_size = util.get_shape_list(aug_input_ids)[0]
        aug_is_expanded = tf.ones((batch_size), dtype=tf.float32)
        is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0)
        decoder = UDADecoder(
            is_training=is_training,
            input_tensor=encoder_output,
            is_supervised=split_placeholders['is_supervised'],
            is_expanded=is_expanded,
            label_ids=label_ids,
            label_size=self.label_size,
            sample_weight=split_placeholders.get('sample_weight'),
            scope='cls/seq_relationship',
            global_step=self._global_step,
            num_train_steps=self.total_steps,
            uda_softmax_temp=self._uda_softmax_temp,
            uda_confidence_thresh=self._uda_confidence_thresh,
            tsa_schedule=self._tsa_schedule,
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)
Example #4
0
 def _single_seq_fn():
     batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
     example_inds = tf.reshape(
         tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
     sequence_scores = tf.gather_nd(
         tf.squeeze(inputs, [1]),
         tf.concat([example_inds, tag_indices], axis=1))
     sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0),
                                tf.zeros_like(sequence_scores),
                                sequence_scores)
     return sequence_scores
Example #5
0
def causal_numerator(qs, ks, vs):
    '''Computes not-normalized FAVOR causal attention A_{masked}V.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
    vs: value tensor of the shape [L,B,H,D].
  Returns:
    Not-normalized FAVOR causal attention A_{masked}V.
  '''

    result = []
    sums = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

    for index in range(qs.shape[0]):
        sums = sums + tf.einsum('ijk,ijl->ijkl', ks[index], vs[index])
        result.append(
            tf.einsum('ijkl,ijk->ijl', sums, qs[index])[None, Ellipsis])

    result = tf.concat(result, axis=0)

    def grad(res_grad):

        grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

        gr_sums = sums

        q_grads = []
        k_grads = []
        v_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None,
                                                                     Ellipsis])
            grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index],
                                      res_grad[index])
            k_grads.append(
                tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis])
            v_grads.append(
                tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis])
            gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index],
                                          vs[index])

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)
        v_grads = tf.concat(v_grads[::-1], axis=0)

        return q_grads, k_grads, v_grads

    return result, grad
Example #6
0
                def _forward(dilated_ids, dilated_mask):

                    logits = self._bert_forward(
                        bert_config,
                        dilated_ids,
                        dilated_mask,
                        batch_size,
                        dilated_seq_length,
                        tilda_embeddings=tilda_embeddings)
                    output_ids = tf.argmax(logits, axis=-1)
                    output_ids = tf.cast(output_ids, dtype=tf.int32)

                    equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32)
                    equal_zero = tf.reduce_sum(equal_zero, axis=-1)
                    right_pad = spad_id * tf.sequence_mask(
                        equal_zero, dilated_seq_length, dtype=tf.int32)

                    paded = tf.concat([output_ids, right_pad], axis=-1)
                    flattened_padded = tf.reshape(paded, [-1])
                    is_valid = tf.cast(tf.greater(flattened_padded, 0),
                                       dtype=tf.int32)
                    flattened_valid = tf.boolean_mask(flattened_padded,
                                                      is_valid)
                    valid = tf.reshape(flattened_valid,
                                       [batch_size, dilated_seq_length])
                    cutted_valid = valid[:, :max_seq_length]

                    nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id),
                                          dtype=tf.int32)
                    output_ids = cutted_valid * nonpad_mask

                    reshaped = tf.reshape(output_ids,
                                          [batch_size, max_seq_length, 1])
                    concatenated = tf.concat(
                        [reshaped, tf.zeros_like(reshaped)], axis=-1)
                    dilated_ids = tf.reshape(concatenated,
                                             [batch_size, max_seq_length * 2])

                    input_mask = tf.reduce_sum(nonpad_mask, axis=-1)
                    dilated_mask = tf.sequence_mask(input_mask,
                                                    dilated_seq_length,
                                                    dtype=tf.int32)

                    return dilated_ids, dilated_mask
Example #7
0
def causal_denominator(qs, ks):
    '''Computes FAVOR normalizer in causal attention.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
  Returns:
    FAVOR normalizer in causal attention.
  '''

    result = []
    sums = tf.zeros_like(ks[0])

    for index in range(qs.shape[0]):
        sums = sums + ks[index]
        result.append(tf.reduce_sum(qs[index] * sums, axis=2)[None, Ellipsis])

    result = tf.concat(result, axis=0)

    def grad(res_grad):

        k_grad = tf.zeros_like(ks[0])

        gr_sums = sums

        q_grads = []
        k_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None,
                                                                   Ellipsis])
            k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index],
                                        res_grad[index])
            k_grads.append(k_grad[None, Ellipsis])
            gr_sums = gr_sums - ks[index]

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)

        return q_grads, k_grads

    return result, grad
Example #8
0
    def _multi_seq_fn():
        '''Forward computation of alpha values.'''
        rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])

        # Compute the alpha values in the forward algorithm in order to get the
        # partition function.
        forward_cell = CrfForwardRnnCell(transition_params)
        # Sequence length is not allowed to be less than zero.
        sequence_lengths_less_one = tf.maximum(
            tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1)
        _, alphas = rnn.dynamic_rnn(cell=forward_cell,
                                    inputs=rest_of_input,
                                    sequence_length=sequence_lengths_less_one,
                                    initial_state=first_input,
                                    dtype=tf.float32)
        log_norm = tf.reduce_logsumexp(alphas, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm
Example #9
0
    def grad(res_grad):

        k_grad = tf.zeros_like(ks[0])

        gr_sums = sums

        q_grads = []
        k_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None,
                                                                   Ellipsis])
            k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index],
                                        res_grad[index])
            k_grads.append(k_grad[None, Ellipsis])
            gr_sums = gr_sums - ks[index]

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)

        return q_grads, k_grads
Example #10
0
    def _get_attention_emd(self, teacher, student, teacher_weight,
                           student_weight, sample_weight, emd_temporature):
        teacher_attention_scores = teacher.get_attention_scores()
        teacher_attention_scores = [
            tf.stop_gradient(value) for value in teacher_attention_scores
        ]
        student_attention_scores = student.get_attention_scores()
        M = len(teacher_attention_scores)
        N = len(student_attention_scores)

        with tf.variable_scope('attention_emd'):
            flow = tf.get_variable('flow',
                                   shape=[M, N],
                                   initializer=tf.constant_initializer(1 / M /
                                                                       N),
                                   trainable=False)

            # MSE
            rows = []
            for m in range(M):
                cols = []
                for n in range(N):
                    teacher_matrix = tf.where(
                        teacher_attention_scores[m] < -1e2,
                        tf.zeros_like(teacher_attention_scores[m]),
                        teacher_attention_scores[m])
                    student_matrix = tf.where(
                        student_attention_scores[n] < -1e2,
                        tf.zeros_like(student_attention_scores[n]),
                        student_attention_scores[n])
                    mse = tf.losses.mean_squared_error(teacher_matrix,
                                                       student_matrix,
                                                       weights=tf.reshape(
                                                           sample_weight,
                                                           [-1, 1, 1, 1]))
                    col = tf.reshape(mse, [1, 1])
                    cols.append(col)
                row = tf.concat(cols, axis=1)
                rows.append(row)
            distance = tf.concat(rows, axis=0)

            # cost attention mechanism
            teacher_cost = (tf.reduce_sum(flow, axis=1) *
                            tf.reduce_sum(distance, axis=1) /
                            (teacher_weight + 1e-6))
            student_cost = (tf.reduce_sum(flow, axis=0) *
                            tf.reduce_sum(distance, axis=0) /
                            (student_weight + 1e-6))

            # new weights
            new_teacher_weight = tf.where(
                teacher_cost > 1e-12,
                tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6),
                teacher_weight)
            new_student_weight = tf.where(
                student_cost > 1e-12,
                tf.reduce_sum(student_cost) / (student_cost + 1e-6),
                student_weight)
            new_teacher_weight = tf.nn.softmax(new_teacher_weight /
                                               emd_temporature)
            new_student_weight = tf.nn.softmax(new_student_weight /
                                               emd_temporature)

        self.attention_flow = flow
        self.attention_distance = distance
        attention_emd = tf.reduce_sum(flow * distance)
        return attention_emd, new_teacher_weight, new_student_weight
Example #11
0
 def _single_seq_fn():
     log_norm = tf.reduce_logsumexp(first_input, [1])
     # Mask `log_norm` of the sequences with length <= zero.
     log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                         tf.zeros_like(log_norm), log_norm)
     return log_norm