Ejemplo n.º 1
0
    def __init__(self,
                 is_training,
                 input_tensor,
                 label_ids,
                 sample_weight=None,
                 scope='mrc',
                 name='',
                 hidden_dropout_prob=0.1,
                 initializer_range=0.02,
                 trainable=True,
                 **kwargs):
        super().__init__(**kwargs)

        seq_length = input_tensor.shape.as_list()[-2]
        hidden_size = input_tensor.shape.as_list()[-1]
        with tf.variable_scope(scope):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[2, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[2],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            output_layer = util.dropout(
                input_tensor, hidden_dropout_prob if is_training else 0.0)

            output_layer = tf.reshape(output_layer, [-1, hidden_size])
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            logits = tf.reshape(logits, [-1, seq_length, 2])
            logits = tf.transpose(logits, [0, 2, 1])
            probs = tf.nn.softmax(logits, axis=-1, name='probs')
            self.probs[name] = probs

            start_one_hot_labels = tf.one_hot(label_ids[:, 0],
                                              depth=seq_length,
                                              dtype=tf.float32)
            end_one_hot_labels = tf.one_hot(label_ids[:, 1],
                                            depth=seq_length,
                                            dtype=tf.float32)
            start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1)
            end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1)
            per_example_loss = (
                -0.5 * tf.reduce_sum(start_one_hot_labels * start_log_probs,
                                     axis=-1) - 0.5 *
                tf.reduce_sum(end_one_hot_labels * end_log_probs, axis=-1))
            if sample_weight is not None:
                per_example_loss *= sample_weight

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses[name] = per_example_loss

            start_preds = tf.expand_dims(tf.argmax(logits[:, 0, :], axis=-1),
                                         axis=-1)
            end_preds = tf.expand_dims(tf.argmax(logits[:, 1, :], axis=-1),
                                       axis=-1)
            self.preds[name] = tf.concat([start_preds, end_preds], axis=-1)
Ejemplo n.º 2
0
def get_timing_signal_1d_given_position(channels,
                                        position,
                                        min_timescale=1.0,
                                        max_timescale=1.0e4):
    """Get sinusoids of diff frequencies, with timing position given.

  Adapted from add_timing_signal_1d_given_position in
  //third_party/py/tensor2tensor/layers/common_attention.py

  Args:
    channels: scalar, size of timing embeddings to create. The number of
        different timescales is equal to channels / 2.
    position: a Tensor with shape [batch, seq_len]
    min_timescale: a float
    max_timescale: a float

  Returns:
    a Tensor of timing signals [batch, seq_len, channels]
  """
    num_timescales = channels // 2
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (tf.to_float(num_timescales) - 1))
    inv_timescales = min_timescale * tf.exp(
        tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
    scaled_time = (tf.expand_dims(tf.to_float(position), 2) *
                   tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0))
    signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
    signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]])
    return signal
Ejemplo n.º 3
0
def gather_positions(sequence, positions):
    '''Gathers the vectors at the specific positions over a minibatch.

    Args:
      sequence: A [batch_size, seq_length] or
          [batch_size, seq_length, depth] tensor of values
      positions: A [batch_size, n_positions] tensor of indices

    Returns: A [batch_size, n_positions] or
      [batch_size, n_positions, depth] tensor of the values at the indices
    '''
    shape = util.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    position_shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + position_shift, [-1])
    flat_sequence = tf.reshape(sequence, [B * L, D])
    gathered = tf.gather(flat_sequence, flat_positions)
    if depth_dimension:
        return tf.reshape(gathered, [B, -1, D])
    else:
        return tf.reshape(gathered, [B, -1])
Ejemplo n.º 4
0
def crf_unary_score(tag_indices, sequence_lengths, inputs):
    ''' Computes the unary scores of tag sequences.

    Args:
      tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
      sequence_lengths: A [batch_size] vector of true sequence lengths.
      inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.

    Returns:
      unary_scores: A [batch_size] vector of unary scores.
    '''
    batch_size = tf.shape(inputs)[0]
    max_seq_len = tf.shape(inputs)[1]
    num_tags = tf.shape(inputs)[2]

    flattened_inputs = tf.reshape(inputs, [-1])

    offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1)
    offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0)
    # Use int32 or int64 based on tag_indices' dtype.
    if tag_indices.dtype == tf.int64:
        offsets = tf.cast(offsets, tf.int64)
    flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1])

    unary_scores = tf.reshape(
        tf.gather(flattened_inputs, flattened_tag_indices),
        [batch_size, max_seq_len])

    masks = tf.sequence_mask(sequence_lengths,
                             maxlen=tf.shape(tag_indices)[1],
                             dtype=tf.float32)

    unary_scores = tf.reduce_sum(unary_scores * masks, 1)
    return unary_scores
Ejemplo n.º 5
0
def scatter_update(sequence, updates, positions):
    '''Scatter-update a sequence.

    Args:
      sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
      updates: A tensor of size batch_size*seq_len(*depth)
      positions: A [batch_size, n_positions] tensor

    Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
      [batch_size, seq_len, depth] tensor of 'sequence' with elements at
      'positions' replaced by the values at 'updates.' Updates to index 0 are
      ignored. If there are duplicated positions the update is only applied
      once. Second is a [batch_size, seq_len] mask tensor of which inputs were
      updated.
    '''
    shape = util.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    N = util.get_shape_list(positions)[1]

    shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + shift, [-1, 1])
    flat_updates = tf.reshape(updates, [-1, D])
    updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
    updates = tf.reshape(updates, [B, L, D])

    flat_updates_mask = tf.ones([B * N], tf.int32)
    updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
    updates_mask = tf.reshape(updates_mask, [B, L])
    not_first_token = tf.concat(
        [tf.zeros((B, 1), tf.int32),
         tf.ones((B, L - 1), tf.int32)], -1)
    updates_mask *= not_first_token
    updates_mask_3d = tf.expand_dims(updates_mask, -1)

    # account for duplicate positions
    if sequence.dtype == tf.float32:
        updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
        updates /= tf.maximum(1.0, updates_mask_3d)
    else:
        assert sequence.dtype == tf.int32
        updates = tf.divide(updates, tf.maximum(1, updates_mask_3d))
        updates = tf.cast(updates, tf.int32)
    updates_mask = tf.minimum(updates_mask, 1)
    updates_mask_3d = tf.minimum(updates_mask_3d, 1)

    updated_sequence = (((1 - updates_mask_3d) * sequence) +
                        (updates_mask_3d * updates))
    if not depth_dimension:
        updated_sequence = tf.squeeze(updated_sequence, -1)

    return updated_sequence, updates_mask
Ejemplo n.º 6
0
def mask(inputs, key_masks=None, type=None):
    '''Masks paddings on keys or queries to inputs
    inputs: 3d tensor. (h*N, T_q, T_k)
    key_masks: 3d tensor. (N, 1, T_k)
    type: string. 'key' | 'future'

    e.g.,
    >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32)
    >> key_masks = tf.constant([[0., 0., 1.],
                                [0., 1., 1.]])
    >> mask(inputs, key_masks=key_masks, type='key')
    array([[[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]],

       [[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32)
    '''
    padding_num = -2 ** 32 + 1
    if type in ('k', 'key', 'keys'):
        key_masks = tf.to_float(key_masks)
        key_masks = tf.tile(
            key_masks,
            [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.expand_dims(key_masks, 1)  # (h*N, 1, seqlen)
        outputs = inputs + key_masks * padding_num
    # elif type in ('q', 'query', 'queries'):
    #     # Generate masks
    #     masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1))  # (N, T_q)
    #     masks = tf.expand_dims(masks, -1)  # (N, T_q, 1)
    #     masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]])  # (N, T_q, T_k)
    #
    #     # Apply masks to inputs
    #     outputs = inputs*masks
    elif type in ('f', 'future', 'right'):
        diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
        tril = tf.linalg.LinearOperatorLowerTriangular(
            diag_vals).to_dense()  # (T_q, T_k)
        future_masks = tf.tile(
            tf.expand_dims(tril, 0),
            [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)

        paddings = tf.ones_like(future_masks) * padding_num
        outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
    else:
        print('Check if you entered type correctly!')

    return outputs
Ejemplo n.º 7
0
    def __init__(self,
                 xlnet_config,
                 is_training,
                 input_ids,
                 seg_ids,
                 input_mask,
                 mems,
                 perm_mask,
                 target,
                 target_mask,
                 target_mapping,
                 inp_q,
                 sample_weight=None,
                 **kwargs):
        super().__init__()

        run_config = XLNetRunConfig(is_training=is_training,
                                    bi_data=True,
                                    use_tpu=False,
                                    use_bfloat16=False,
                                    dropout=(0.1 if is_training else 0.0),
                                    dropatt=(0.1 if is_training else 0.0),
                                    init='normal',
                                    init_range=0.1,
                                    init_std=0.02,
                                    clamp_len=-1)

        model = XLNetEncoder(xlnet_config=xlnet_config,
                             is_training=is_training,
                             input_ids=input_ids,
                             seg_ids=seg_ids,
                             input_mask=input_mask,
                             mems=mems,
                             perm_mask=perm_mask,
                             target_mapping=target_mapping,
                             inp_q=inp_q,
                             **kwargs)

        with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
            per_example_loss, preds = lm_loss(
                hidden=model.get_sequence_output(),
                target=target,
                n_token=xlnet_config.n_token,
                d_model=xlnet_config.d_model,
                initializer=model.get_initializer(),
                lookup_table=model.get_embedding_table(),
                tie_weight=True,
                bi_data=run_config.bi_data,
                use_tpu=run_config.use_tpu)
            if sample_weight is not None:
                sample_weight = tf.expand_dims(tf.cast(sample_weight,
                                                       dtype=tf.float32),
                                               axis=-1)
                per_example_loss *= sample_weight

        self.total_loss = tf.reduce_sum(
            per_example_loss * target_mask) / tf.reduce_sum(target_mask)
        self.losses['losses'] = per_example_loss * target_mask
        self.preds['preds'] = preds
        self.preds['mask'] = target_mask
Ejemplo n.º 8
0
    def embedding_lookup(self,
                         input_ids,
                         vocab_size,
                         batch_size,
                         max_seq_length,
                         embedding_size=128,
                         initializer_range=0.02,
                         word_embedding_name='word_embeddings',
                         dtype=tf.float32,
                         trainable=True,
                         tilda_embeddings=None):
        if input_ids.shape.ndims == 2:
            input_ids = tf.expand_dims(input_ids, axis=[-1])

        if tilda_embeddings is not None:
            embedding_table = tilda_embeddings
        else:
            embedding_table = tf.get_variable(
                name=word_embedding_name,
                shape=[vocab_size, embedding_size],
                initializer=util.create_initializer(initializer_range),
                dtype=dtype,
                trainable=trainable)

        flat_input_ids = tf.reshape(input_ids, [-1])
        output = tf.gather(embedding_table,
                           flat_input_ids,
                           name='embedding_look_up')
        output = tf.reshape(output,
                            [batch_size, max_seq_length, embedding_size])

        return (output, embedding_table)
Ejemplo n.º 9
0
    def __call__(self, inputs, state, scope=None):
        ''' Build the CrfForwardRnnCell.

      Args:
          inputs: A [batch_size, num_tags] matrix of unary potentials.
          state: A [batch_size, num_tags] matrix containing the previous alpha
            values.
          scope: Unused variable scope of this cell.

      Returns:
          new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
            values containing the new alpha values.
      '''
        state = tf.expand_dims(state, 2)

        # This addition op broadcasts self._transitions_params along the zeroth
        # dimension and state along the second dimension. This performs the
        # multiplication of previous alpha values and the current binary
        # potentials in log space.
        transition_scores = state + self._transition_params
        new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1])

        # Both the state and the output of this RNN cell contain the alphas
        # values. The output value is currently unused and simply satisfies the
        # RNN API. This could be useful in the future if we need to compute
        # marginal probabilities, which would require the accumulated alpha
        # values at every time step.
        return new_alphas, new_alphas
Ejemplo n.º 10
0
    def __init__(self, transition_params):
        '''Initialize the CrfForwardRnnCell.

    Args:
      transition_params: A [num_tags, num_tags] matrix of binary potentials.
          This matrix is expanded into a [1, num_tags, num_tags] in preparation
          for the broadcast summation occurring within the cell.
    '''
        self._transition_params = tf.expand_dims(transition_params, 0)
        self._num_tags = util.get_shape_list(transition_params)[0]
Ejemplo n.º 11
0
def embedding_lookup(input_ids,
                     vocab_size,
                     embedding_size=128,
                     initializer_range=0.02,
                     word_embedding_name='word_embeddings',
                     use_one_hot_embeddings=False):
    '''Looks up words embeddings for id tensor.

  Args:
    input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
      ids.
    vocab_size: int. Size of the embedding vocabulary.
    embedding_size: int. Width of the word embeddings.
    initializer_range: float. Embedding initialization range.
    word_embedding_name: string. Name of the embedding table.
    use_one_hot_embeddings: bool. If True, use one-hot method for word
      embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
      for TPUs.

  Returns:
    float Tensor of shape [batch_size, seq_length, embedding_size].
  '''
    # This function assumes that the input is of shape [batch_size, seq_length,
    # num_inputs].
    #
    # If the input is a 2D tensor of shape [batch_size, seq_length], we
    # reshape to [batch_size, seq_length, 1].
    original_dims = input_ids.shape.ndims
    if original_dims == 2:
        input_ids = tf.expand_dims(input_ids, axis=[-1])

    embedding_table = tf.get_variable(
        name=word_embedding_name,
        shape=[vocab_size, embedding_size],
        initializer=util.create_initializer(initializer_range))

    if original_dims == 3:
        input_shape = util.get_shape_list(input_ids)
        tf.reshape(input_ids, [-1, input_shape[-1]])
        output = tf.matmul(input_ids, embedding_table)
        output = tf.reshape(output,
                            [input_shape[0], input_shape[1], embedding_size])
    else:
        if use_one_hot_embeddings:
            flat_input_ids = tf.reshape(input_ids, [-1])
            one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
            output = tf.matmul(one_hot_input_ids, embedding_table)
        else:
            output = tf.nn.embedding_lookup(embedding_table, input_ids)

        input_shape = util.get_shape_list(input_ids)

        output = tf.reshape(
            output, input_shape[0:-1] + [input_shape[-1] * embedding_size])
    return output, embedding_table
Ejemplo n.º 12
0
def scaled_dot_product_attention(Q, K, V, key_masks,
                                 causality=False, dropout_rate=0.,
                                 training=True,
                                 scope='scaled_dot_product_attention'):
    '''See 3.2.1.
    Q: Packed queries. 3d tensor. [N, T_q, d_k].
    K: Packed keys. 3d tensor. [N, T_k, d_k].
    V: Packed values. 3d tensor. [N, T_k, d_v].
    key_masks: A 2d tensor with shape of [N, key_seqlen]
    causality: If True, applies masking for future blinding
    dropout_rate: A floating point number of [0, 1].
    training: boolean for controlling droput
    scope: Optional scope for `variable_scope`.
    '''
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        d_k = Q.get_shape().as_list()[-1]

        # dot product
        outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1]))  # (N, T_q, T_k)

        # scale
        outputs /= d_k ** 0.5

        # key masking
        outputs = mask(outputs, key_masks=key_masks, type='key')

        # causality or future blinding masking
        if causality:
            outputs = mask(outputs, type='future')

        # softmax
        outputs = tf.nn.softmax(outputs)
        attention = tf.transpose(outputs, [0, 2, 1])
        tf.summary.image('attention', tf.expand_dims(attention[:1], -1))

        # # query masking
        # outputs = mask(outputs, Q, K, type='query')

        # dropout
        outputs = tf.layers.dropout(
            outputs, rate=dropout_rate, training=training)

        # weighted sum (context vectors)
        outputs = tf.matmul(outputs, V)  # (N, T_q, d_v)

    return outputs
Ejemplo n.º 13
0
def softmax_kernel_transformation(data,
                                  is_query,
                                  projection_matrix=None,
                                  numerical_stabilizer=0.000001):
    '''Computes random features for the softmax kernel using FAVOR+ mechanism.
  Computes random features for the softmax kernel using FAVOR+ mechanism from
  https://arxiv.org/pdf/2009.14794.pdf.
  Args:
    data: input data tensor of the shape [B, L, H, D], where: B - batch
      dimension, L - attention dimensions, H - heads, D - features.
    is_query: indicates whether input data is a query oor key tensor.
    projection_matrix: random Gaussian matrix of shape [M, D], where M stands
      for the number of random features and each D x D sub-block has pairwise
      orthogonal rows.
    numerical_stabilizer: small positive constant for numerical stability.
  Returns:
    Corresponding kernel feature map.
  '''
    data_normalizer = \
        tf.math.rsqrt(1 / tf.math.rsqrt(tf.cast(data.shape[-1], tf.float32)))
    ratio = tf.math.rsqrt(
        tf.cast(
            projection_matrix.shape[0]
            if projection_matrix is not None else 1.0, tf.float32))
    data_dash = tf.einsum('blhd,md->blhm', data, projection_matrix)
    diag_data = tf.math.square(data)
    diag_data = tf.math.reduce_sum(diag_data,
                                   axis=tf.keras.backend.ndim(data) - 1)
    diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
    diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
    if is_query:
        last_dims_t = (len(data_dash.shape) - 1, )
        data_dash = ratio * (tf.math.exp(
            data_dash - diag_data -
            tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)) +
                             numerical_stabilizer)
    else:
        data_dash = ratio * (tf.math.exp(data_dash - diag_data -
                                         tf.math.reduce_max(data_dash)) +
                             numerical_stabilizer)

    return data_dash
Ejemplo n.º 14
0
    def _cls_forward(self,
                     is_training,
                     input_tensor,
                     input_mask,
                     label_ids,
                     bert_config,
                     batch_size,
                     max_seq_length,
                     prob,
                     scope,
                     name,
                     sample_weight=None,
                     hidden_dropout_prob=0.1,
                     initializer_range=0.02):

        with tf.variable_scope(scope):
            logits = tf.layers.dense(
                input_tensor,
                2,
                kernel_initializer=util.create_initializer(
                    bert_config.initializer_range),
                trainable=True)

            # loss
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids, depth=2)
            per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                            axis=-1)

            input_mask = tf.cast(input_mask, tf.float32)
            per_token_loss *= input_mask / tf.reduce_sum(
                input_mask, keepdims=True, axis=-1)
            per_example_loss = tf.reduce_sum(per_token_loss, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            if prob != 0:
                self.total_loss += tf.reduce_mean(per_example_loss)
            self.losses[name + '_loss'] = per_example_loss
            self.preds[name + '_preds'] = tf.argmax(logits, axis=-1)
Ejemplo n.º 15
0
def positional_encoding(inputs,
                        maxlen,
                        masking=True,
                        scope='positional_encoding'):
    '''Sinusoidal Positional_Encoding. See 3.5
    inputs: 3d tensor. (N, T, E)
    maxlen: scalar. Must be >= T
    masking: Boolean. If True, padding positions are set to zeros.
    scope: Optional scope for `variable_scope`.

    returns
    3d tensor that has the same shape as inputs.
    '''

    E = inputs.get_shape().as_list()[-1] # static
    N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # position indices
        position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T)

        # First part of the PE function: sin and cos argument
        position_enc = np.array([
            [pos / np.power(10000, (i-i%2)/E) for i in range(E)]
            for pos in range(maxlen)])

        # Second part, apply the cosine to even columns and sin to odds.
        position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  # dim 2i
        position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])  # dim 2i+1
        position_enc = tf.convert_to_tensor(
            position_enc, tf.float32) # (maxlen, E)

        # lookup
        outputs = tf.nn.embedding_lookup(position_enc, position_ind)

        # masks
        if masking:
            outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)

        return tf.to_float(outputs)
Ejemplo n.º 16
0
def favor_attention(query,
                    key,
                    value,
                    kernel_transformation,
                    causal,
                    projection_matrix=None):
    '''Computes FAVOR normalized attention.
  Args:
    query: query tensor.
    key: key tensor.
    value: value tensor.
    kernel_transformation: transformation used to get finite kernel features.
    causal: whether attention is causal or not.
    projection_matrix: projection matrix to be used.
  Returns:
    FAVOR normalized attention.
  '''
    query_prime = kernel_transformation(query, True,
                                        projection_matrix)  # [B,L,H,M]
    key_prime = kernel_transformation(key, False,
                                      projection_matrix)  # [B,L,H,M]

    query_prime = tf.transpose(query_prime, [1, 0, 2, 3])  # [L,B,H,M]
    key_prime = tf.transpose(key_prime, [1, 0, 2, 3])  # [L,B,H,M]
    value = tf.transpose(value, [1, 0, 2, 3])  # [L,B,H,D]

    if causal:
        av_attention = causal_numerator(query_prime, key_prime, value)
        attention_normalizer = causal_denominator(query_prime, key_prime)
    else:
        av_attention = noncausal_numerator(query_prime, key_prime, value)
        attention_normalizer = noncausal_denominator(query_prime, key_prime)

    av_attention = tf.transpose(av_attention, [1, 0, 2, 3])
    attention_normalizer = tf.transpose(attention_normalizer, [1, 0, 2])
    attention_normalizer = tf.expand_dims(attention_normalizer,
                                          len(attention_normalizer.shape))
    return av_attention / attention_normalizer
Ejemplo n.º 17
0
def expand_tile(value, size):
    '''Add a new axis of given size.'''
    value = tf.convert_to_tensor(value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1] * ndims)
Ejemplo n.º 18
0
    def __init__(self,
                 hparams,
                 is_training,
                 input_ids,
                 sample_weight=None,
                 scope='model',
                 given=1,
                 use_tilda_embedding=False,
                 **kwargs):
        super().__init__()

        batch_size = util.get_shape_list(input_ids, expected_rank=2)[0]
        max_seq_length = hparams.n_predict

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):

            def _forward(input_ids, past=None):
                batch, sequence = shape_list(input_ids)

                if tilda_embeddings is None:
                    wte = tf.get_variable(
                        'word_embeddings', [hparams.n_vocab, hparams.n_embed],
                        initializer=tf.random_normal_initializer(stddev=0.02))
                else:
                    wte = tilda_embeddings
                wpe = tf.get_variable(
                    'wpe', [hparams.n_ctx, hparams.n_embed],
                    initializer=tf.random_normal_initializer(stddev=0.01))
                past_length = 0 if past is None else tf.shape(past)[-2]
                h = (tf.gather(wte, input_ids) +
                     tf.gather(wpe, positions_for(input_ids, past_length)))

                # stacked transformer layers
                presents = []
                pasts = tf.unstack(past, axis=1) if past is not None else \
                    [None] * hparams.n_layer
                assert len(pasts) == hparams.n_layer
                for layer, past in enumerate(pasts):
                    h, present = block(h,
                                       'h%d' % layer,
                                       past=past,
                                       hparams=hparams)
                    presents.append(present)
                present = tf.stack(presents, axis=1)
                h = norm(h, 'ln_f')

                # Language model loss.  Do tokens <n predict token n?
                h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed])
                logits = tf.matmul(h_flat, wte, transpose_b=True)
                logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])

                return logits, present

            # convert to labels
            label_ids = tf.concat(
                [input_ids[:, 1:],
                 tf.zeros([batch_size, 1], dtype=tf.int32)],
                axis=-1)

            # forward once
            if is_training:
                (logits, _) = _forward(input_ids)

                self.preds['LM'] = tf.argmax(logits, axis=-1)

            # forward loop
            else:
                input_ids = input_ids[:, 0:given]

                for cur_length in range(given, max_seq_length + 1):
                    (logits, _) = _forward(input_ids)

                    pred_ids = tf.argmax(logits[:,
                                                cur_length - 1:cur_length, :],
                                         axis=-1)
                    pred_ids = tf.cast(pred_ids, tf.int32)
                    input_ids = tf.concat([input_ids, pred_ids], axis=-1)

                self.preds['LM'] = input_ids

            # loss
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids, depth=hparams.n_vocab)
            per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                            axis=-1)
            label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32)
            per_example_loss = \
                tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \
                tf.reduce_sum(label_mask, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses['LM'] = per_example_loss
Ejemplo n.º 19
0
    def __init__(self,
                 vocab_size,
                 is_training,
                 input_ids,
                 input_mask,
                 segment_ids,
                 sample_weight=None,
                 reduced_size=64,
                 topic_size=1024,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 bias=0,
                 scope='vae',
                 trainable=True,
                 **kwargs):
        super().__init__()

        # freeze parameters
        config = Config(vocab_size,
                        hidden_size=hidden_size,
                        num_hidden_layers=num_hidden_layers,
                        num_attention_heads=num_attention_heads)
        if not is_training:
            config.hidden_dropout_prob = 0.0
            config.attention_probs_dropout_prob = 0.0

        input_shape = util.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        use_tilda_embedding = kwargs.get('use_tilda_embedding')
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):
            with tf.variable_scope('embeddings'):

                (self.embedding_output, self.embedding_table) = \
                    self.embedding_lookup(
                        input_ids=input_ids,
                        vocab_size=config.vocab_size,
                        batch_size=batch_size,
                        max_seq_length=seq_length,
                        embedding_size=config.hidden_size,
                        initializer_range=config.initializer_range,
                        word_embedding_name='word_embeddings',
                        tilda_embeddings=tilda_embeddings,
                        trainable=trainable)
                self.embedding_output = self.embedding_postprocessor(
                    input_tensor=self.embedding_output,
                    batch_size=batch_size,
                    max_seq_length=seq_length,
                    hidden_size=config.hidden_size,
                    use_token_type=True,
                    segment_ids=segment_ids,
                    token_type_vocab_size=config.type_vocab_size,
                    token_type_embedding_name='token_type_embeddings',
                    use_position_embeddings=True,
                    position_embedding_name='position_embeddings',
                    initializer_range=config.initializer_range,
                    max_position_embeddings=config.max_position_embeddings,
                    dropout_prob=config.hidden_dropout_prob,
                    trainable=trainable)

            with tf.variable_scope('encoder'):

                # stacked transformer
                attention_mask = self.create_attention_mask_from_input_mask(
                    input_mask, batch_size, seq_length)
                self.all_encoder_layers = self.transformer_model(
                    input_tensor=self.embedding_output,
                    batch_size=batch_size,
                    max_seq_length=seq_length,
                    attention_mask=attention_mask,
                    hidden_size=config.hidden_size,
                    num_hidden_layers=config.num_hidden_layers,
                    num_attention_heads=config.num_attention_heads,
                    intermediate_size=config.intermediate_size,
                    intermediate_act_fn=util.get_activation(config.hidden_act),
                    hidden_dropout_prob=config.hidden_dropout_prob,
                    attention_probs_dropout_prob=\
                        config.attention_probs_dropout_prob,
                    initializer_range=config.initializer_range,
                    trainable=trainable)

                # projection
                with tf.variable_scope('projection'):
                    transformer_output = tf.layers.dense(
                        self.all_encoder_layers[-1],
                        reduced_size,
                        activation=util.gelu,
                        kernel_initializer=tf.truncated_normal_initializer(
                            stddev=config.initializer_range),
                        trainable=trainable)
                    transformer_output = tf.reshape(transformer_output,
                                                    [batch_size, -1])
                    input_length = tf.reduce_sum(input_mask, axis=-1)
                    input_length = tf.cast(input_length, tf.float32)
                    input_length_1d = tf.reshape(input_length, [batch_size])
                    input_length_2d = tf.reshape(input_length, [batch_size, 1])

                    broadcast_mask = tf.sequence_mask(
                        tf.multiply(input_length_1d, reduced_size),
                        seq_length * reduced_size,
                        dtype=tf.float32)
                    broadcast_mask = tf.multiply(broadcast_mask,
                                                 seq_length / input_length_2d)
                    transformer_output *= broadcast_mask

                    # latent space
                    miu = tf.layers.dense(
                        transformer_output,
                        topic_size,
                        activation='tanh',
                        kernel_initializer=tf.truncated_normal_initializer(
                            stddev=config.initializer_range),
                        name='miu',
                        trainable=trainable)
                    sigma = tf.layers.dense(
                        transformer_output,
                        topic_size,
                        kernel_initializer=tf.truncated_normal_initializer(
                            stddev=config.initializer_range),
                        name='sigma',
                        trainable=trainable)
                    self.probs['miu'] = miu
                    self.probs['sigma'] = sigma

            with tf.variable_scope('decoder'):
                with tf.variable_scope('projection'):

                    # reparametarization
                    if is_training:
                        noise = tf.random_normal([batch_size, topic_size])
                    else:
                        noise = tf.random_uniform([batch_size, topic_size],
                                                  minval=-bias,
                                                  maxval=bias)
                    decoder_input = miu + tf.exp(sigma) * noise

                    # projection
                    decoder_input = tf.layers.dense(
                        decoder_input,
                        seq_length * reduced_size,
                        activation=util.gelu,
                        kernel_initializer=tf.truncated_normal_initializer(
                            stddev=config.initializer_range),
                        trainable=trainable)
                    intermediate_input = tf.reshape(
                        decoder_input, [-1, seq_length, reduced_size])
                    intermediate_input = util.layer_norm(intermediate_input,
                                                         trainable=trainable)
                    intermediate_input = util.dropout(
                        intermediate_input, config.hidden_dropout_prob)

                # MLP
                with tf.variable_scope('intermediate'):
                    intermediate_output = tf.layers.dense(
                        intermediate_input,
                        4 * reduced_size,
                        activation=util.gelu,
                        kernel_initializer=util.create_initializer(
                            config.initializer_range),
                        trainable=trainable)
                with tf.variable_scope('output'):
                    decoder_output = tf.layers.dense(
                        intermediate_output,
                        config.hidden_size,
                        kernel_initializer=util.create_initializer(
                            config.initializer_range),
                        trainable=trainable)
                    decoder_output = util.layer_norm(decoder_output,
                                                     trainable=trainable)
                    decoder_output = util.dropout(decoder_output,
                                                  config.hidden_dropout_prob)
                self.all_decoder_layers = [intermediate_output, decoder_output]
                self.all_decoder_layers = [decoder_output]

        # reconstruction
        with tf.variable_scope('cls/predictions'):
            with tf.variable_scope('transform'):
                input_tensor = tf.layers.dense(
                    decoder_output,
                    units=config.hidden_size,
                    activation=util.get_activation(config.hidden_act),
                    kernel_initializer=util.create_initializer(
                        config.initializer_range),
                    trainable=trainable)
                input_tensor = util.layer_norm(input_tensor,
                                               trainable=trainable)
            output_weights = self.embedding_table
            output_bias = tf.get_variable('output_bias',
                                          shape=[config.vocab_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)
            flatten_input_tensor = tf.reshape(input_tensor,
                                              [-1, config.hidden_size])

            logits = tf.matmul(flatten_input_tensor,
                               output_weights,
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            logits = tf.reshape(logits,
                                [batch_size, seq_length, config.vocab_size])
            probs = tf.nn.softmax(logits, axis=-1, name='probs')
            lm_log_probs = tf.nn.log_softmax(logits, axis=-1)

            self.preds['preds'] = tf.argmax(probs, axis=-1)
            one_hot_labels = tf.one_hot(input_ids,
                                        depth=config.vocab_size,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(lm_log_probs * one_hot_labels,
                                              axis=[-1])
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            self.total_loss = (tf.reduce_mean(per_example_loss) +
                               tf.reduce_mean(tf.square(miu)) +
                               tf.reduce_mean(tf.exp(sigma) - sigma - 1))
            self.losses['losses'] = per_example_loss
Ejemplo n.º 20
0
    def attention_layer(self,
                        from_tensor,
                        to_tensor,
                        attention_mask=None,
                        num_attention_heads=12,
                        size_per_head=512,
                        query_act=None,
                        key_act=None,
                        value_act=None,
                        attention_probs_dropout_prob=0.0,
                        initializer_range=0.02,
                        do_return_2d_tensor=False,
                        batch_size=None,
                        from_max_seq_length=None,
                        to_max_seq_length=None,
                        dtype=tf.float32,
                        trainable=True):
        def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                                 max_seq_length, width):
            output_tensor = tf.reshape(
                input_tensor,
                [batch_size, max_seq_length, num_attention_heads, width])
            output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
            return output_tensor

        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   F = from_tensor sequence length
        #   T = to_tensor sequence length
        #   N = num_attention_heads
        #   H = size_per_head

        from_tensor_2d = util.reshape_to_matrix(from_tensor)
        to_tensor_2d = util.reshape_to_matrix(to_tensor)

        # query_layer = [B*F, N*H]
        query_layer = tf.layers.dense(
            from_tensor_2d,
            num_attention_heads * size_per_head,
            activation=query_act,
            name='query',
            kernel_initializer=util.create_initializer(initializer_range),
            trainable=trainable)

        # key_layer = [B*T, N*H]
        key_layer = tf.layers.dense(
            to_tensor_2d,
            num_attention_heads * size_per_head,
            activation=key_act,
            name='key',
            kernel_initializer=util.create_initializer(initializer_range),
            trainable=trainable)

        # value_layer = [B*T, N*H]
        value_layer = tf.layers.dense(
            to_tensor_2d,
            num_attention_heads * size_per_head,
            activation=value_act,
            name='value',
            kernel_initializer=util.create_initializer(initializer_range),
            trainable=trainable)

        # query_layer = [B, N, F, H]
        query_layer = transpose_for_scores(query_layer, batch_size,
                                           num_attention_heads,
                                           from_max_seq_length, size_per_head)

        # key_layer = [B, N, T, H]
        key_layer = transpose_for_scores(key_layer, batch_size,
                                         num_attention_heads,
                                         to_max_seq_length, size_per_head)

        # Take the dot product between 'query' and 'key' to get the raw
        # attention scores.
        # attention_scores = [B, N, F, T]
        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
        attention_scores = tf.multiply(attention_scores,
                                       1.0 / math.sqrt(float(size_per_head)))

        if attention_mask is not None:

            # attention_mask = [B, 1, F, T]
            attention_mask = tf.expand_dims(attention_mask, axis=[1])
            adder = (1.0 - tf.cast(attention_mask, dtype)) * -10000.0
            attention_scores += adder

        # Normalize the attention scores to probabilities.
        # attention_probs = [B, N, F, T]
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

        # This is actually dropping out entire tokens to attend to,
        # which might seem a bit unusual, but is taken from the original
        # Transformer paper.
        attention_probs = util.dropout(attention_probs,
                                       attention_probs_dropout_prob)

        # value_layer = [B, T, N, H]
        value_layer = tf.reshape(value_layer, [
            batch_size, to_max_seq_length, num_attention_heads, size_per_head
        ])

        # value_layer = [B, N, T, H]
        value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

        # context_layer = [B, N, F, H]
        context_layer = tf.matmul(attention_probs, value_layer)

        # context_layer = [B, F, N, H]
        context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

        if do_return_2d_tensor:
            # context_layer = [B*F, N*H]
            context_layer = tf.reshape(context_layer, [
                batch_size * from_max_seq_length,
                num_attention_heads * size_per_head
            ])
        else:
            # context_layer = [B, F, N*H]
            context_layer = tf.reshape(context_layer, [
                batch_size, from_max_seq_length,
                num_attention_heads * size_per_head
            ])

        return (context_layer, attention_scores)
Ejemplo n.º 21
0
    def __init__(self,
                 bert_config,
                 is_training,
                 sketchy_encoder,
                 intensive_encoder,
                 query_mask,
                 label_ids,
                 has_answer,
                 sample_weight=None,
                 scope='retro_reader',
                 matching_mechanism='cross-attention',
                 beta_1=0.5,
                 beta_2=0.5,
                 threshold=1.0,
                 trainable=True,
                 **kwargs):
        super().__init__(**kwargs)

        # verifier
        with tf.variable_scope(scope):

            # sketchy reading module
            with tf.variable_scope('sketchy/prediction'):
                sketchy_output = sketchy_encoder.get_pooled_output()
                hidden_size = sketchy_output.shape.as_list()[-1]

                output_weights = tf.get_variable(
                    'output_weights',
                    shape=[2, hidden_size],
                    initializer=util.create_initializer(
                        bert_config.initializer_range),
                    trainable=trainable)
                output_bias = tf.get_variable(
                    'output_bias',
                    shape=[2],
                    initializer=tf.zeros_initializer(),
                    trainable=trainable)

                output_layer = util.dropout(
                    sketchy_output, bert_config.hidden_dropout_prob \
                        if is_training else 0.0)
                logits = tf.matmul(
                    output_layer, output_weights, transpose_b=True)
                logits = tf.nn.bias_add(logits, output_bias)

                log_probs = tf.nn.log_softmax(logits, axis=-1)
                one_hot_labels = tf.one_hot(
                    has_answer, depth=2, dtype=tf.float32)
                per_example_loss = - tf.reduce_sum(
                    one_hot_labels * log_probs, axis=-1)
                if sample_weight is not None:
                    per_example_loss = tf.cast(
                        sample_weight, dtype=tf.float32) * per_example_loss

                self.losses['sketchy_losses'] = per_example_loss
                sketchy_loss = tf.reduce_mean(per_example_loss)

                score_ext = logits[:, 1] - logits[:, 0]

            # intensive reading module
            with tf.variable_scope('intensive'):
                H = intensive_encoder.get_sequence_output()
                H_Q = H * tf.cast(
                    tf.expand_dims(query_mask, axis=-1), tf.float32)
                (batch_size, max_seq_length, hidden_size) = \
                    util.get_shape_list(H)

                # cross-attention
                if matching_mechanism == 'cross-attention':
                    with tf.variable_scope('cross_attention'):
                        attention_mask = \
                            self.create_attention_mask_from_input_mask(
                                query_mask, batch_size, max_seq_length)
                        (H_prime, _) = self.attention_layer(
                            from_tensor=H,
                            to_tensor=H_Q,
                            attention_mask=attention_mask,
                            num_attention_heads=\
                                bert_config.num_attention_heads,
                            size_per_head=\
                                hidden_size // bert_config.num_attention_heads,
                            attention_probs_dropout_prob=\
                                bert_config.hidden_dropout_prob,
                            initializer_range=bert_config.initializer_range,
                            do_return_2d_tensor=False,
                            batch_size=batch_size,
                            from_max_seq_length=max_seq_length,
                            to_max_seq_length=max_seq_length,
                            trainable=trainable)

                # matching-attention
                elif matching_mechanism == 'matching-attention':
                    with tf.variable_scope('matching_attention'):
                        output_weights = tf.get_variable(
                            'output_weights',
                            shape=[hidden_size, hidden_size],
                            initializer=util.create_initializer(
                                bert_config.initializer_range),
                            trainable=trainable)
                        output_bias = tf.get_variable(
                            'output_bias',
                            shape=[hidden_size],
                            initializer=tf.zeros_initializer(),
                            trainable=trainable)
                        trans = tf.matmul(
                            H_Q, tf.tile(
                                tf.expand_dims(output_weights, axis=0),
                                [batch_size, 1, 1]),
                            transpose_b=True)
                        trans = tf.nn.bias_add(trans, output_bias)
                        M = tf.nn.softmax(
                            tf.matmul(H, trans, transpose_b=True), axis=-1)
                        H_prime = tf.matmul(M, H_Q)

                with tf.variable_scope('prediction'):
                    output_weights = tf.get_variable(
                        'output_weights',
                        shape=[2, hidden_size],
                        initializer=util.create_initializer(
                            bert_config.initializer_range),
                        trainable=trainable)
                    output_bias = tf.get_variable(
                        'output_bias',
                        shape=[2],
                        initializer=tf.zeros_initializer(),
                        trainable=trainable)

                    output_layer = util.dropout(
                        H_prime, bert_config.hidden_dropout_prob \
                            if is_training else 0.0)
                    output_layer = tf.reshape(
                        output_layer,
                        [batch_size * max_seq_length, hidden_size])
                    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
                    logits = tf.nn.bias_add(logits, output_bias)
                    logits = tf.reshape(
                        logits, [batch_size, max_seq_length, 2])
                    logits = tf.transpose(logits, [0, 2, 1])
                    probs = tf.nn.softmax(logits, axis=-1, name='probs')

                    self.probs['mrc_probs'] = probs
                    self.preds['mrc_preds'] = tf.argmax(logits, axis=-1)

                    start_one_hot_labels = tf.one_hot(
                        label_ids[:, 0], depth=max_seq_length,
                        dtype=tf.float32)
                    end_one_hot_labels = tf.one_hot(
                        label_ids[:, 1], depth=max_seq_length,
                        dtype=tf.float32)
                    start_log_probs = tf.nn.log_softmax(logits[:, 0, :], axis=-1)
                    end_log_probs = tf.nn.log_softmax(logits[:, 1, :], axis=-1)
                    per_example_loss = (
                        - 0.5 * tf.reduce_sum(
                            start_one_hot_labels * start_log_probs, axis=-1)
                        - 0.5 * tf.reduce_sum(
                            end_one_hot_labels * end_log_probs, axis=-1))
                    if sample_weight is not None:
                        per_example_loss *= sample_weight

                    intensive_loss = tf.reduce_mean(per_example_loss)
                    self.losses['intensive_losses'] = per_example_loss

                    score_has = tf.norm(
                        probs[:, 0, 1:] + probs[:, 1, 1:], np.inf, axis=-1)
                    score_null = probs[:, 0, 0] + probs[:, 1, 0]
                    score_diff = score_has - score_null

            # rear verification
            v = beta_1 * score_diff + beta_2 * score_ext
            self.preds['verifier_preds'] = \
                tf.cast(tf.greater(v, threshold), tf.int32)
            self.probs['verifier_probs'] = v

            self.total_loss = sketchy_loss + intensive_loss
Ejemplo n.º 22
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    '''
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  '''
    batch_size = tf.shape(inputs)[0]

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.reshape(index, [-1, perm_size])
    index = tf.transpose(index)
    index = tf.random_shuffle(index)
    index = tf.transpose(index)
    index = tf.reshape(index, [1, -1])
    index = tf.tile(index, [batch_size, 1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked),
                                     non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won't be information leak
    smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, :, None] <= rev_index[:, None, :],
        tf.expand_dims(masked_or_func_tokens, axis=-1))

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Ejemplo n.º 23
0
    def __init__(self,
                 bert_config,
                 is_training,
                 encoder,
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 next_sentence_labels,
                 sample_weight=None,
                 scope_lm='cls/predictions',
                 scope_cls='cls/seq_relationship',
                 trainable=True,
                 use_nsp_loss=True,
                 **kwargs):
        super(BERTDecoder, self).__init__(**kwargs)

        def gather_indexes(sequence_tensor, positions):
            sequence_shape = util.get_shape_list(sequence_tensor, 3)
            batch_size = sequence_shape[0]
            seq_length = sequence_shape[1]
            width = sequence_shape[2]

            flat_offsets = tf.reshape(
                tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
            flat_positions = tf.reshape(positions + flat_offsets, [-1])
            flat_sequence_tensor = tf.reshape(sequence_tensor,
                                              [batch_size * seq_length, width])
            output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
            return output_tensor

        scalar_losses = []

        # masked language modeling
        input_tensor = gather_indexes(encoder.get_sequence_output(),
                                      masked_lm_positions)
        with tf.variable_scope(scope_lm):
            with tf.variable_scope('transform'):
                input_tensor = tf.layers.dense(
                    input_tensor,
                    units=bert_config.hidden_size,
                    activation=util.get_activation(bert_config.hidden_act),
                    kernel_initializer=util.create_initializer(
                        bert_config.initializer_range))
                input_tensor = util.layer_norm(input_tensor)
            output_bias = tf.get_variable('output_bias',
                                          shape=[bert_config.vocab_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            logits = tf.matmul(input_tensor,
                               encoder.get_embedding_table(),
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs')
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            label_ids = tf.reshape(masked_lm_ids, [-1])
            if sample_weight is not None:
                sample_weight = tf.expand_dims(tf.cast(sample_weight,
                                                       dtype=tf.float32),
                                               axis=-1)
                masked_lm_weights *= sample_weight
            label_weights = tf.reshape(masked_lm_weights, [-1])
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=bert_config.vocab_size,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels,
                                              axis=[-1])
            per_example_loss = label_weights * per_example_loss

            numerator = tf.reduce_sum(per_example_loss)
            denominator = tf.reduce_sum(label_weights) + 1e-5
            loss = numerator / denominator

            scalar_losses.append(loss)
            self.losses['MLM_losses'] = per_example_loss
            self.preds['MLM_preds'] = tf.argmax(probs, axis=-1)

        # next sentence prediction
        with tf.variable_scope(scope_cls):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[2, bert_config.hidden_size],
                initializer=util.create_initializer(
                    bert_config.initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[2],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            logits = tf.matmul(encoder.get_pooled_output(),
                               output_weights,
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            probs = tf.nn.softmax(logits, axis=-1, name='probs')
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            labels = tf.reshape(next_sentence_labels, [-1])
            one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            if sample_weight is not None:
                per_example_loss = (tf.cast(sample_weight, dtype=tf.float32) *
                                    per_example_loss)
            loss = tf.reduce_mean(per_example_loss)

            if use_nsp_loss:
                scalar_losses.append(loss)
            self.losses['NSP_losses'] = per_example_loss
            self.probs['NSP_probs'] = probs
            self.preds['NSP_preds'] = tf.argmax(probs, axis=-1)

        self.total_loss = tf.add_n(scalar_losses)
Ejemplo n.º 24
0
    def __init__(self,
                 vocab_size,
                 is_training,
                 source_ids,
                 target_ids,
                 sos_id,
                 sample_weight=None,
                 hidden_size=768,
                 num_blocks=6,
                 num_attention_heads=12,
                 scope='transformer',
                 use_label_smoothing=False,
                 use_tilda_embedding=False,
                 trainable=True,
                 **kwargs):
        super().__init__()

        dropout_rate = 0.0
        if is_training:
            dropout_rate = 0.1

        source_shape = util.get_shape_list(source_ids, expected_rank=2)
        target_shape = util.get_shape_list(target_ids, expected_rank=2)
        batch_size = source_shape[0]
        source_max_seq_length = source_shape[1]
        target_max_seq_length = target_shape[1]

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):
            source_mask = tf.math.equal(source_ids, 0)

            # embedding
            with tf.variable_scope('embeddings'):
                (enc, embedding_table) = embedding_lookup(
                    input_ids=source_ids,
                    vocab_size=vocab_size,
                    batch_size=batch_size,
                    max_seq_length=source_max_seq_length,
                    embedding_size=hidden_size,
                    word_embedding_name='word_embeddings',
                    tilda_embeddings=tilda_embeddings)
                enc *= hidden_size ** 0.5  # scale
                enc += positional_encoding(enc, source_max_seq_length)
                enc = util.dropout(enc, dropout_rate)

            with tf.variable_scope('encoder'):

                # stacked multi-attention layers
                for i in range(num_blocks):
                    with tf.variable_scope('block_%s' % i):

                        # self-attention
                        enc = multihead_attention(
                            queries=enc,
                            keys=enc,
                            values=enc,
                            key_masks=source_mask,
                            num_heads=num_attention_heads,
                            dropout_rate=dropout_rate,
                            training=is_training,
                            causality=False,
                            scope='self_attention')

                        # feed forward
                        enc = ff(enc, num_units=[hidden_size * 4, hidden_size])
                memory = enc

            def _forward(target_ids, target_mask, target_max_seq_length):

                with tf.variable_scope('decoder'):

                    # shared embedding
                    dec = tf.nn.embedding_lookup(embedding_table, target_ids)
                    dec *= hidden_size ** 0.5  # scale
                    dec += positional_encoding(dec, target_max_seq_length)
                    dec = util.dropout(dec, dropout_rate)

                    # blocks
                    for i in range(num_blocks):
                        with tf.variable_scope('block_%s' % i):

                            # masked self-attention
                            dec = multihead_attention(
                                queries=dec,
                                keys=dec,
                                values=dec,
                                key_masks=target_mask,
                                num_heads=num_attention_heads,
                                dropout_rate=dropout_rate,
                                training=is_training,
                                causality=True,
                                scope='masked_self_attention')

                            # vanilla attention
                            dec = multihead_attention(
                                queries=dec,
                                keys=memory,
                                values=memory,
                                key_masks=source_mask,
                                num_heads=num_attention_heads,
                                dropout_rate=dropout_rate,
                                training=is_training,
                                causality=False,
                                scope='vanilla_attention')

                            # feed forward
                            dec = ff(
                                dec, num_units=[4 * hidden_size, hidden_size])

                # final linear projection (embedding weights are shared)
                with tf.variable_scope('cls'):
                    output_bias = tf.get_variable(
                        'output_bias', shape=[vocab_size],
                        initializer=tf.zeros_initializer())
                    dec = tf.reshape(dec, [-1, hidden_size])
                    logits = tf.matmul(dec, embedding_table, transpose_b=True)
                    logits = tf.reshape(
                        logits, [-1, target_max_seq_length, vocab_size])
                    logits = tf.nn.bias_add(logits, output_bias)

                return logits

            # convert to labels
            label_ids = tf.concat(
                [target_ids[:, 1:],
                 tf.zeros([batch_size, 1], dtype=tf.int32)], axis=-1)

            # forward once
            if is_training:
                target_mask = tf.math.equal(target_ids, 0)  # (N, T2)
                logits = _forward(
                    target_ids, target_mask, target_max_seq_length)

                self.preds['MT'] = tf.argmax(logits, axis=-1)

            # forward loop
            else:
                target_mask_base = tf.zeros([batch_size, 1], dtype=tf.int32)
                target_ids = tf.ones([batch_size, 1], dtype=tf.int32) * sos_id

                for cur_length in range(1, target_max_seq_length + 1):
                    target_mask = tf.tile(target_mask_base, [1, cur_length])
                    logits = _forward(target_ids, target_mask, cur_length)

                    pred_ids = tf.argmax(
                        logits[:, cur_length-1:cur_length, :],
                        axis=-1)
                    pred_ids = tf.cast(pred_ids, tf.int32)
                    target_ids = tf.concat([target_ids, pred_ids], axis=-1)

                self.preds['MT'] = target_ids[:, 1:]

            # loss
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids, depth=vocab_size)
            if use_label_smoothing:
                one_hot_labels = label_smoothing(one_hot_labels)
            per_token_loss = -tf.reduce_sum(
                one_hot_labels * log_probs, axis=-1)
            label_mask = tf.cast(tf.not_equal(label_ids, 0), tf.float32)
            per_example_loss = \
                tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \
                tf.reduce_sum(label_mask, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

            self.total_loss = tf.reduce_mean(per_example_loss)
            self.losses['MT'] = per_example_loss
Ejemplo n.º 25
0
    def __init__(self,
                 is_training,
                 input_tensor,
                 n_wide_features,
                 wide_features,
                 label_ids,
                 label_size=2,
                 sample_weight=None,
                 scope='cls/seq_relationship',
                 hidden_dropout_prob=0.1,
                 initializer_range=0.02,
                 trainable=True,
                 **kwargs):
        super().__init__(**kwargs)

        hidden_size = input_tensor.shape.as_list()[-1]
        feature_size = wide_features.shape.as_list()[-1]
        with tf.variable_scope('wide'):
            feature_embeddings = tf.get_variable(
                name='feature_embeddings',
                shape=[feature_size + 1, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            wide_output = tf.gather(feature_embeddings,
                                    wide_features)  # [B, N, H]

        with tf.variable_scope('wide_and_deep'):
            deep_output = tf.expand_dims(input_tensor, -1)  # [B, H, 1]
            attention_scores = tf.matmul(wide_output, deep_output)  # [B, N, 1]
            attention_scores = tf.transpose(attention_scores,
                                            [0, 2, 1])  # [B, 1, N]
            attention_scores = tf.multiply(attention_scores,
                                           1.0 / math.sqrt(hidden_size))
            feature_mask = tf.cast(
                tf.sequence_mask(n_wide_features, feature_size),
                tf.float32)  # [B, N]
            feature_mask = tf.expand_dims(feature_mask, 1)  # [B, 1, N]
            attention_scores += (1.0 - feature_mask) * -10000.0
            attention_matrix = tf.nn.softmax(attention_scores, axis=-1)
            attention_output = tf.matmul(attention_matrix,
                                         wide_output)  # [B, 1, H]
            attention_output = attention_output[:, 0, :]  # [B, H]
            # attention_output = util.dropout(
            #     attention_output, hidden_dropout_prob)
            input_tensor = util.layer_norm(attention_output + input_tensor,
                                           trainable=trainable)

        with tf.variable_scope(scope):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[label_size, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[label_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            output_layer = util.dropout(
                input_tensor, hidden_dropout_prob if is_training else 0.0)
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            self.preds['preds'] = tf.argmax(logits, axis=-1)
            self.probs['probs'] = tf.nn.softmax(logits, axis=-1, name='probs')

            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=label_size,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            if sample_weight is not None:
                per_example_loss = tf.cast(sample_weight,
                                           dtype=tf.float32) * per_example_loss
            thresh = kwargs.get('tsa_thresh')
            if thresh is not None:
                assert isinstance(
                    thresh,
                    float), ('`tsa_thresh` must be a float between 0 and 1.')
                uncertainty = tf.reduce_sum(self.probs['probs'] *
                                            tf.log(self.probs['probs']),
                                            axis=-1)
                uncertainty /= tf.log(1 / label_size)
                per_example_loss = tf.cast(
                    tf.greater(uncertainty, thresh), dtype=tf.float32) * \
                    per_example_loss

            self.losses['losses'] = per_example_loss
            self.total_loss = tf.reduce_mean(per_example_loss)
Ejemplo n.º 26
0
    def _get_generator_output(self, inputs, sample_weight, generator):
        '''Masked language modeling softmax layer.'''
        def gather_indexes(sequence_tensor, positions):
            sequence_shape = util.get_shape_list(sequence_tensor, 3)
            batch_size = sequence_shape[0]
            seq_length = sequence_shape[1]
            width = sequence_shape[2]

            flat_offsets = tf.reshape(
                tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
            flat_positions = tf.reshape(positions + flat_offsets, [-1])
            flat_sequence_tensor = tf.reshape(sequence_tensor,
                                              [batch_size * seq_length, width])
            output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
            return output_tensor

        input_tensor = gather_indexes(generator.get_sequence_output(),
                                      inputs.masked_lm_positions)
        with tf.variable_scope('generator_predictions'):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=self.config.embedding_size,
                activation=util.get_activation(self.bert_config.hidden_act),
                kernel_initializer=util.create_initializer(
                    self.bert_config.initializer_range))
            input_tensor = util.layer_norm(input_tensor)
            output_bias = tf.get_variable('output_bias',
                                          shape=[self.bert_config.vocab_size],
                                          initializer=tf.zeros_initializer())

            logits = tf.matmul(input_tensor,
                               generator.get_embedding_table(),
                               transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            probs = tf.nn.softmax(logits, axis=-1, name='MLM_probs')
            preds = tf.argmax(logits, axis=-1)
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            label_ids = tf.reshape(inputs.masked_lm_ids, [-1])
            masked_lm_weights = inputs.masked_lm_weights
            if sample_weight is not None:
                sample_weight = tf.expand_dims(tf.cast(sample_weight,
                                                       dtype=tf.float32),
                                               axis=-1)
                masked_lm_weights *= sample_weight
            label_weights = tf.reshape(masked_lm_weights, [-1])
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=self.bert_config.vocab_size,
                                        dtype=tf.float32)
            per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels,
                                              axis=[-1])
            per_example_loss = label_weights * per_example_loss

            numerator = tf.reduce_sum(per_example_loss)
            denominator = tf.reduce_sum(label_weights) + 1e-6
            loss = numerator / denominator

            MLMOutput = collections.namedtuple(
                'MLMOutput',
                ['logits', 'probs', 'loss', 'per_example_loss', 'preds'])
            return MLMOutput(logits=logits,
                             probs=probs,
                             per_example_loss=per_example_loss,
                             loss=loss,
                             preds=preds)
Ejemplo n.º 27
0
def attention_layer(from_tensor,
                    to_tensor,
                    attention_mask=None,
                    num_attention_heads=1,
                    size_per_head=512,
                    query_act=None,
                    key_act=None,
                    value_act=None,
                    attention_probs_dropout_prob=0.0,
                    initializer_range=0.02,
                    do_return_2d_tensor=False,
                    batch_size=None,
                    from_seq_length=None,
                    to_seq_length=None):
    '''Performs multi-headed attention from `from_tensor` to `to_tensor`.

  This is an implementation of multi-headed attention based on 'Attention
  is all you Need'. If `from_tensor` and `to_tensor` are the same, then
  this is self-attention. Each timestep in `from_tensor` attends to the
  corresponding sequence in `to_tensor`, and returns a fixed-with vector.

  This function first projects `from_tensor` into a 'query' tensor and
  `to_tensor` into 'key' and 'value' tensors. These are (effectively) a list
  of tensors of length `num_attention_heads`, where each tensor is of shape
  [batch_size, seq_length, size_per_head].

  Then, the query and key tensors are dot-producted and scaled. These are
  softmaxed to obtain attention probabilities. The value tensors are then
  interpolated by these probabilities, then concatenated back to a single
  tensor and returned.

  In practice, the multi-headed attention are done with transposes and
  reshapes rather than actual separate tensors.

  Args:
    from_tensor: float Tensor of shape [batch_size, from_seq_length,
      from_width].
    to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
    attention_mask: (optional) int32 Tensor of shape [batch_size,
      from_seq_length, to_seq_length]. The values should be 1 or 0. The
      attention scores will effectively be set to -infinity for any positions
      in the mask that are 0, and will be unchanged for positions that are 1.
    num_attention_heads: int. Number of attention heads.
    size_per_head: int. Size of each attention head.
    query_act: (optional) Activation function for the query transform.
    key_act: (optional) Activation function for the key transform.
    value_act: (optional) Activation function for the value transform.
    attention_probs_dropout_prob: (optional) float. Dropout probability of the
      attention probabilities.
    initializer_range: float. Range of the weight initializer.
    do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
      * from_seq_length, num_attention_heads * size_per_head]. If False, the
      output will be of shape [batch_size, from_seq_length, num_attention_heads
      * size_per_head].
    batch_size: (Optional) int. If the input is 2D, this might be the batch
      size of the 3D version of the `from_tensor` and `to_tensor`.
    from_seq_length: (Optional) If the input is 2D, this might be the seq
      length of the 3D version of the `from_tensor`.
    to_seq_length: (Optional) If the input is 2D, this might be the seq length
      of the 3D version of the `to_tensor`.

  Returns:
    float Tensor of shape [batch_size, from_seq_length,
      num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
      true, this will be of shape [batch_size * from_seq_length,
      num_attention_heads * size_per_head]).

  Raises:
    ValueError: Any of the arguments or tensor shapes are invalid.
  '''
    def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                             seq_length, width):
        output_tensor = tf.reshape(
            input_tensor, [batch_size, seq_length, num_attention_heads, width])

        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor

    from_shape = util.get_shape_list(from_tensor, expected_rank=[2, 3])
    to_shape = util.get_shape_list(to_tensor, expected_rank=[2, 3])

    if len(from_shape) != len(to_shape):
        raise ValueError(
            'The rank of `from_tensor` must match the rank of `to_tensor`.')

    if len(from_shape) == 3:
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]
        to_seq_length = to_shape[1]
    elif len(from_shape) == 2:
        if batch_size is None or from_seq_length is None or to_seq_length is None:
            raise ValueError(
                'When passing in rank 2 tensors to attention_layer, the values '
                'for `batch_size`, `from_seq_length`, and `to_seq_length` '
                'must all be specified.')

    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`

    from_tensor_2d = util.reshape_to_matrix(from_tensor)
    to_tensor_2d = util.reshape_to_matrix(to_tensor)

    # `query_layer` = [B*F, N*H]
    query_layer = tf.layers.dense(
        from_tensor_2d,
        num_attention_heads * size_per_head,
        activation=query_act,
        name='query',
        kernel_initializer=util.create_initializer(initializer_range))

    # `key_layer` = [B*T, N*H]
    key_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation=key_act,
        name='key',
        kernel_initializer=util.create_initializer(initializer_range))

    # `value_layer` = [B*T, N*H]
    value_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation=value_act,
        name='value',
        kernel_initializer=util.create_initializer(initializer_range))

    # `query_layer` = [B, N, F, H]
    query_layer = transpose_for_scores(query_layer, batch_size,
                                       num_attention_heads, from_seq_length,
                                       size_per_head)

    # `key_layer` = [B, N, T, H]
    key_layer = transpose_for_scores(key_layer, batch_size,
                                     num_attention_heads, to_seq_length,
                                     size_per_head)

    # Take the dot product between 'query' and 'key' to get the raw
    # attention scores.
    # `attention_scores` = [B, N, F, T]
    attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
    attention_scores = tf.multiply(attention_scores,
                                   1.0 / math.sqrt(float(size_per_head)))

    if attention_mask is not None:
        # `attention_mask` = [B, 1, F, T]
        attention_mask = tf.expand_dims(attention_mask, axis=[1])

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        attention_scores += adder

    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = tf.nn.softmax(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = util.dropout(attention_probs,
                                   attention_probs_dropout_prob)

    # `value_layer` = [B, T, N, H]
    value_layer = tf.reshape(
        value_layer,
        [batch_size, to_seq_length, num_attention_heads, size_per_head])

    # `value_layer` = [B, N, T, H]
    value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

    # `context_layer` = [B, N, F, H]
    context_layer = tf.matmul(attention_probs, value_layer)

    # `context_layer` = [B, F, N, H]
    context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

    if do_return_2d_tensor:
        # `context_layer` = [B*F, N*H]
        context_layer = tf.reshape(context_layer, [
            batch_size * from_seq_length, num_attention_heads * size_per_head
        ])
    else:
        # `context_layer` = [B, F, N*H]
        context_layer = tf.reshape(
            context_layer,
            [batch_size, from_seq_length, num_attention_heads * size_per_head])

    return context_layer, attention_probs
Ejemplo n.º 28
0
                def _build_forward(layer_input):
                    with tf.variable_scope('attention'):
                        with tf.variable_scope('self'):

                            layer_input *= tf.cast(tf.expand_dims(input_mask,
                                                                  axis=-1),
                                                   dtype=tf.float32)
                            attention_layer = Attention(
                                hidden_size=hidden_size,
                                num_heads=num_attention_heads,
                                attention_dropout=attention_probs_dropout_prob,
                                kernel_transformation=\
                                    self.kernel_transformation,
                                numerical_stabilizer=0.001,
                                causal=False,
                                projection_matrix_type=True \
                                    if bool(self.nb_random_features) else None,
                                nb_random_features=self.nb_random_features)
                            attention_layer.build(layer_input.shape)
                            attention_output = attention_layer.call(
                                layer_input,
                                layer_input,
                                bias=None,
                                training=is_training,
                                cache=None,
                                decode_loop_step=None)

                        with tf.variable_scope('output'):
                            attention_output = tf.layers.dense(
                                attention_output,
                                hidden_size,
                                kernel_initializer=util.create_initializer(
                                    initializer_range),
                                trainable=trainable)
                            attention_output = util.dropout(
                                attention_output, hidden_dropout_prob)
                            attention_output = util.layer_norm(
                                attention_output + layer_input,
                                trainable=trainable)

                    # The activation is only applied to the `intermediate`
                    # hidden layer.
                    with tf.variable_scope('intermediate'):
                        intermediate_output = tf.layers.dense(
                            attention_output,
                            intermediate_size,
                            activation=intermediate_act_fn,
                            kernel_initializer=util.create_initializer(
                                initializer_range),
                            trainable=trainable)

                    # Down-project back to hidden_size then add the residual.
                    with tf.variable_scope('output'):
                        layer_output = tf.layers.dense(
                            intermediate_output,
                            hidden_size,
                            kernel_initializer=util.create_initializer(
                                initializer_range),
                            trainable=trainable)
                        layer_output = util.dropout(layer_output,
                                                    hidden_dropout_prob)
                        layer_output = util.layer_norm(layer_output +
                                                       attention_output,
                                                       trainable=trainable)

                    return layer_output
Ejemplo n.º 29
0
    def __init__(self,
                 bert_config,
                 is_training,
                 dilated_ids,
                 label_ids,
                 max_seq_length,
                 spad_id=1,
                 loop=3,
                 sample_weight=None,
                 scope='dilated',
                 use_tilda_embedding=False,
                 **kwargs):
        super().__init__()

        dilated_mask = tf.cast(tf.not_equal(dilated_ids, 0), tf.float32)

        shape = util.get_shape_list(dilated_ids, expected_rank=2)
        batch_size = shape[0]
        dilated_seq_length = shape[1]

        # Tilda embeddings for SMART algorithm
        tilda_embeddings = None
        if use_tilda_embedding:
            with tf.variable_scope('', reuse=True):
                tilda_embeddings = tf.get_variable('tilda_embeddings')

        with tf.variable_scope(scope):

            # forward once
            if is_training:
                logits = self._bert_forward(bert_config,
                                            dilated_ids,
                                            dilated_mask,
                                            batch_size,
                                            dilated_seq_length,
                                            tilda_embeddings=tilda_embeddings)

                self.preds['LM'] = tf.argmax(logits, axis=-1)

                # LM loss
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                one_hot_labels = tf.one_hot(label_ids,
                                            depth=bert_config.vocab_size)
                per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                                axis=-1)

                input_length = tf.reduce_sum(dilated_mask, axis=-1) * 2
                label_mask = tf.sequence_mask(input_length,
                                              max_seq_length * 2,
                                              dtype=tf.float32)
                per_example_loss = \
                    tf.reduce_sum(per_token_loss * label_mask, axis=-1) / \
                    tf.reduce_sum(label_mask, axis=-1)
                if sample_weight is not None:
                    per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

                self.total_loss = tf.reduce_mean(per_example_loss)
                self.losses['LM'] = per_example_loss

            # forward loop
            else:

                def _forward(dilated_ids, dilated_mask):

                    logits = self._bert_forward(
                        bert_config,
                        dilated_ids,
                        dilated_mask,
                        batch_size,
                        dilated_seq_length,
                        tilda_embeddings=tilda_embeddings)
                    output_ids = tf.argmax(logits, axis=-1)
                    output_ids = tf.cast(output_ids, dtype=tf.int32)

                    # special padding (using `spad` token)
                    equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32)
                    equal_zero = tf.reduce_sum(equal_zero, axis=-1)
                    right_pad = spad_id * tf.sequence_mask(
                        equal_zero, dilated_seq_length, dtype=tf.int32)
                    paded = tf.concat([output_ids, right_pad], axis=-1)

                    # extract ids of length `max_seq_length`
                    flattened_padded = tf.reshape(paded, [-1])
                    is_valid = tf.cast(tf.greater(flattened_padded, 0),
                                       dtype=tf.int32)
                    flattened_valid = tf.boolean_mask(flattened_padded,
                                                      is_valid)
                    valid = tf.reshape(flattened_valid,
                                       [batch_size, dilated_seq_length])
                    cutted_valid = valid[:, :max_seq_length]

                    # replace `spad` token with `pad`
                    non_spad_mask = tf.cast(tf.not_equal(
                        cutted_valid, spad_id),
                                            dtype=tf.int32)
                    output_ids = cutted_valid * non_spad_mask
                    output_length = tf.reduce_sum(non_spad_mask, axis=-1)

                    # dilate
                    reshaped_ids = tf.reshape(output_ids,
                                              [batch_size, max_seq_length, 1])
                    reshaped_mask = tf.reshape(
                        tf.sequence_mask(output_length,
                                         max_seq_length,
                                         dtype=tf.int32),
                        [batch_size, max_seq_length, 1])
                    concat_ids = tf.concat(
                        [reshaped_ids,
                         tf.zeros_like(reshaped_ids)], axis=-1)
                    concat_mask = tf.concat([
                        reshaped_mask,
                        tf.zeros_like(reshaped_mask, dtype=tf.int32)
                    ],
                                            axis=-1)
                    dilated_ids = tf.reshape(concat_ids,
                                             [batch_size, max_seq_length * 2])
                    dilated_mask = tf.reshape(concat_mask,
                                              [batch_size, max_seq_length * 2])

                    return dilated_ids, dilated_mask

                for _ in range(loop):
                    dilated_ids, dilated_mask = _forward(
                        dilated_ids, dilated_mask)

                self.preds['LM'] = dilated_ids
Ejemplo n.º 30
0
    def _lm_forward(self,
                    is_training,
                    input_tensor,
                    input_mask,
                    label_ids,
                    bert_config,
                    batch_size,
                    max_seq_length,
                    prob,
                    scope,
                    name,
                    sample_weight=None,
                    hidden_dropout_prob=0.1,
                    initializer_range=0.02):

        with tf.variable_scope(scope):

            with tf.variable_scope('verifier'):
                logits = tf.layers.dense(
                    input_tensor,
                    2,
                    kernel_initializer=util.create_initializer(
                        bert_config.initializer_range),
                    trainable=True)
                verifier_label_ids = tf.cast(tf.greater(label_ids, 0),
                                             tf.int32)

                # loss
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                one_hot_labels = tf.one_hot(verifier_label_ids, depth=2)
                per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                                axis=-1)

                input_mask = tf.cast(input_mask, tf.float32)
                per_token_loss *= input_mask / tf.reduce_sum(
                    input_mask, keepdims=True, axis=-1)
                per_example_loss = tf.reduce_sum(per_token_loss, axis=-1)
                if sample_weight is not None:
                    per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

                if prob != 0:
                    self.total_loss += tf.reduce_mean(per_example_loss)
                verifier_loss = per_example_loss
                verifier_preds = tf.argmax(logits, axis=-1)

            with tf.variable_scope('prediction'):

                with tf.variable_scope('intermediate'):
                    logits = tf.layers.dense(
                        input_tensor,
                        bert_config.hidden_size * 4,
                        kernel_initializer=util.create_initializer(
                            bert_config.initializer_range),
                        activation=util.gelu,
                        trainable=True)
                with tf.variable_scope('output'):
                    logits = tf.layers.dense(
                        logits,
                        bert_config.hidden_size,
                        kernel_initializer=util.create_initializer(
                            bert_config.initializer_range),
                        trainable=True)

                flattened = tf.reshape(
                    logits,
                    [batch_size * max_seq_length, bert_config.hidden_size])
                logits = tf.matmul(flattened,
                                   self.embedding_table,
                                   transpose_b=True)
                logits = tf.reshape(
                    logits, [-1, max_seq_length, bert_config.vocab_size])

                # loss
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                one_hot_labels = tf.one_hot(label_ids,
                                            depth=bert_config.vocab_size)
                per_token_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                                axis=-1)

                input_mask *= tf.cast(verifier_preds, tf.float32)
                per_token_loss *= input_mask / (
                    tf.reduce_sum(input_mask, keepdims=True, axis=-1) + 1e-6)
                per_example_loss = tf.reduce_sum(per_token_loss, axis=-1)
                if sample_weight is not None:
                    per_example_loss *= tf.expand_dims(sample_weight, axis=-1)

                if prob != 0:
                    self.total_loss += tf.reduce_mean(per_example_loss)
                self.losses[name + '_loss'] = verifier_loss
                self.preds[name + '_preds'] = \
                    tf.argmax(logits, axis=-1) * verifier_preds