Ejemplo n.º 1
0
 def _single_seq_fn():
     batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
     example_inds = tf.reshape(
         tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
     sequence_scores = tf.gather_nd(
         tf.squeeze(inputs, [1]),
         tf.concat([example_inds, tag_indices], axis=1))
     sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0),
                                tf.zeros_like(sequence_scores),
                                sequence_scores)
     return sequence_scores
Ejemplo n.º 2
0
def mask(inputs, key_masks=None, type=None):
    '''Masks paddings on keys or queries to inputs
    inputs: 3d tensor. (h*N, T_q, T_k)
    key_masks: 3d tensor. (N, 1, T_k)
    type: string. 'key' | 'future'

    e.g.,
    >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32)
    >> key_masks = tf.constant([[0., 0., 1.],
                                [0., 1., 1.]])
    >> mask(inputs, key_masks=key_masks, type='key')
    array([[[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]],

       [[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32)
    '''
    padding_num = -2 ** 32 + 1
    if type in ('k', 'key', 'keys'):
        key_masks = tf.to_float(key_masks)
        key_masks = tf.tile(
            key_masks,
            [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.expand_dims(key_masks, 1)  # (h*N, 1, seqlen)
        outputs = inputs + key_masks * padding_num
    # elif type in ('q', 'query', 'queries'):
    #     # Generate masks
    #     masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1))  # (N, T_q)
    #     masks = tf.expand_dims(masks, -1)  # (N, T_q, 1)
    #     masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]])  # (N, T_q, T_k)
    #
    #     # Apply masks to inputs
    #     outputs = inputs*masks
    elif type in ('f', 'future', 'right'):
        diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
        tril = tf.linalg.LinearOperatorLowerTriangular(
            diag_vals).to_dense()  # (T_q, T_k)
        future_masks = tf.tile(
            tf.expand_dims(tril, 0),
            [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)

        paddings = tf.ones_like(future_masks) * padding_num
        outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
    else:
        print('Check if you entered type correctly!')

    return outputs
Ejemplo n.º 3
0
    def _multi_seq_fn():
        '''Forward computation of alpha values.'''
        rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])

        # Compute the alpha values in the forward algorithm in order to get the
        # partition function.
        forward_cell = CrfForwardRnnCell(transition_params)
        # Sequence length is not allowed to be less than zero.
        sequence_lengths_less_one = tf.maximum(
            tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1)
        _, alphas = rnn.dynamic_rnn(cell=forward_cell,
                                    inputs=rest_of_input,
                                    sequence_length=sequence_lengths_less_one,
                                    initial_state=first_input,
                                    dtype=tf.float32)
        log_norm = tf.reduce_logsumexp(alphas, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm
Ejemplo n.º 4
0
def positional_encoding(inputs,
                        maxlen,
                        masking=True,
                        scope='positional_encoding'):
    '''Sinusoidal Positional_Encoding. See 3.5
    inputs: 3d tensor. (N, T, E)
    maxlen: scalar. Must be >= T
    masking: Boolean. If True, padding positions are set to zeros.
    scope: Optional scope for `variable_scope`.

    returns
    3d tensor that has the same shape as inputs.
    '''

    E = inputs.get_shape().as_list()[-1] # static
    N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # position indices
        position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T)

        # First part of the PE function: sin and cos argument
        position_enc = np.array([
            [pos / np.power(10000, (i-i%2)/E) for i in range(E)]
            for pos in range(maxlen)])

        # Second part, apply the cosine to even columns and sin to odds.
        position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  # dim 2i
        position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])  # dim 2i+1
        position_enc = tf.convert_to_tensor(
            position_enc, tf.float32) # (maxlen, E)

        # lookup
        outputs = tf.nn.embedding_lookup(position_enc, position_ind)

        # masks
        if masking:
            outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)

        return tf.to_float(outputs)
Ejemplo n.º 5
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    '''
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  '''
    batch_size = tf.shape(inputs)[0]

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.reshape(index, [-1, perm_size])
    index = tf.transpose(index)
    index = tf.random_shuffle(index)
    index = tf.transpose(index)
    index = tf.reshape(index, [1, -1])
    index = tf.tile(index, [batch_size, 1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked),
                                     non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won't be information leak
    smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, :, None] <= rev_index[:, None, :],
        tf.expand_dims(masked_or_func_tokens, axis=-1))

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Ejemplo n.º 6
0
    def _get_hidden_emd(self, teacher, student, teacher_weight, student_weight,
                        bert_config, sample_weight, emd_temporature):
        teacher_hidden_layers = teacher.all_encoder_layers
        teacher_hidden_layers = [
            tf.stop_gradient(value) for value in teacher_hidden_layers
        ]
        student_hidden_layers = student.all_encoder_layers
        M = len(teacher_hidden_layers)
        N = len(student_hidden_layers)

        with tf.variable_scope('hidden_emd'):
            flow = tf.get_variable('flow',
                                   shape=[M, N],
                                   initializer=tf.constant_initializer(1 / M /
                                                                       N),
                                   trainable=False)

            # MSE
            rows = []
            for m in range(M):
                cols = []
                for n in range(N):
                    linear_trans = tf.layers.dense(
                        student_hidden_layers[n],
                        bert_config.hidden_size,
                        kernel_initializer=util.create_initializer(
                            bert_config.initializer_range))
                    mse = tf.losses.mean_squared_error(
                        teacher_hidden_layers[m],
                        linear_trans,
                        weights=tf.reshape(sample_weight, [-1, 1, 1]))
                    col = tf.reshape(mse, [1, 1])
                    cols.append(col)
                row = tf.concat(cols, axis=1)
                rows.append(row)
            distance = tf.concat(rows, axis=0)

            # cost attention mechanism
            teacher_cost = (tf.reduce_sum(flow, axis=1) *
                            tf.reduce_sum(distance, axis=1) /
                            (teacher_weight + 1e-6))
            student_cost = (tf.reduce_sum(flow, axis=0) *
                            tf.reduce_sum(distance, axis=0) /
                            (student_weight + 1e-6))

            # new weights
            new_teacher_weight = tf.where(
                teacher_cost > 1e-12,
                tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6),
                teacher_weight)
            new_student_weight = tf.where(
                student_cost > 1e-12,
                tf.reduce_sum(student_cost) / (student_cost + 1e-6),
                student_weight)
            new_teacher_weight = tf.nn.softmax(new_teacher_weight /
                                               emd_temporature)
            new_student_weight = tf.nn.softmax(new_student_weight /
                                               emd_temporature)

        self.hidden_flow = flow
        self.hidden_distance = distance
        hidden_emd = tf.reduce_sum(flow * distance)
        return hidden_emd, new_teacher_weight, new_student_weight
Ejemplo n.º 7
0
    def _get_attention_emd(self, teacher, student, teacher_weight,
                           student_weight, sample_weight, emd_temporature):
        teacher_attention_scores = teacher.get_attention_scores()
        teacher_attention_scores = [
            tf.stop_gradient(value) for value in teacher_attention_scores
        ]
        student_attention_scores = student.get_attention_scores()
        M = len(teacher_attention_scores)
        N = len(student_attention_scores)

        with tf.variable_scope('attention_emd'):
            flow = tf.get_variable('flow',
                                   shape=[M, N],
                                   initializer=tf.constant_initializer(1 / M /
                                                                       N),
                                   trainable=False)

            # MSE
            rows = []
            for m in range(M):
                cols = []
                for n in range(N):
                    teacher_matrix = tf.where(
                        teacher_attention_scores[m] < -1e2,
                        tf.zeros_like(teacher_attention_scores[m]),
                        teacher_attention_scores[m])
                    student_matrix = tf.where(
                        student_attention_scores[n] < -1e2,
                        tf.zeros_like(student_attention_scores[n]),
                        student_attention_scores[n])
                    mse = tf.losses.mean_squared_error(teacher_matrix,
                                                       student_matrix,
                                                       weights=tf.reshape(
                                                           sample_weight,
                                                           [-1, 1, 1, 1]))
                    col = tf.reshape(mse, [1, 1])
                    cols.append(col)
                row = tf.concat(cols, axis=1)
                rows.append(row)
            distance = tf.concat(rows, axis=0)

            # cost attention mechanism
            teacher_cost = (tf.reduce_sum(flow, axis=1) *
                            tf.reduce_sum(distance, axis=1) /
                            (teacher_weight + 1e-6))
            student_cost = (tf.reduce_sum(flow, axis=0) *
                            tf.reduce_sum(distance, axis=0) /
                            (student_weight + 1e-6))

            # new weights
            new_teacher_weight = tf.where(
                teacher_cost > 1e-12,
                tf.reduce_sum(teacher_cost) / (teacher_cost + 1e-6),
                teacher_weight)
            new_student_weight = tf.where(
                student_cost > 1e-12,
                tf.reduce_sum(student_cost) / (student_cost + 1e-6),
                student_weight)
            new_teacher_weight = tf.nn.softmax(new_teacher_weight /
                                               emd_temporature)
            new_student_weight = tf.nn.softmax(new_student_weight /
                                               emd_temporature)

        self.attention_flow = flow
        self.attention_distance = distance
        attention_emd = tf.reduce_sum(flow * distance)
        return attention_emd, new_teacher_weight, new_student_weight
Ejemplo n.º 8
0
 def _single_seq_fn():
     log_norm = tf.reduce_logsumexp(first_input, [1])
     # Mask `log_norm` of the sequences with length <= zero.
     log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                         tf.zeros_like(log_norm), log_norm)
     return log_norm