Exemplo n.º 1
0
def crf_binary_score(tag_indices, sequence_lengths, transition_params):
    ''' Computes the binary scores of tag sequences.

    Args:
      tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
      sequence_lengths: A [batch_size] vector of true sequence lengths.
      transition_params: A [num_tags, num_tags] matrix of binary potentials.

    Returns:
      binary_scores: A [batch_size] vector of binary scores.
    '''
    # Get shape information.
    num_tags = transition_params.get_shape()[0]
    num_transitions = tf.shape(tag_indices)[1] - 1

    # Truncate by one on each side of the sequence to get the start and end
    # indices of each transition.
    start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions])
    end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions])

    # Encode the indices in a flattened representation.
    flattened_transition_indices = \
        start_tag_indices * num_tags + end_tag_indices
    flattened_transition_params = tf.reshape(transition_params, [-1])

    # Get the binary scores based on the flattened representation.
    binary_scores = tf.gather(flattened_transition_params,
                              flattened_transition_indices)

    masks = tf.sequence_mask(sequence_lengths,
                             maxlen=tf.shape(tag_indices)[1],
                             dtype=tf.float32)
    truncated_masks = tf.slice(masks, [0, 1], [-1, -1])
    binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1)
    return binary_scores
Exemplo n.º 2
0
def crf_unary_score(tag_indices, sequence_lengths, inputs):
    ''' Computes the unary scores of tag sequences.

    Args:
      tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
      sequence_lengths: A [batch_size] vector of true sequence lengths.
      inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.

    Returns:
      unary_scores: A [batch_size] vector of unary scores.
    '''
    batch_size = tf.shape(inputs)[0]
    max_seq_len = tf.shape(inputs)[1]
    num_tags = tf.shape(inputs)[2]

    flattened_inputs = tf.reshape(inputs, [-1])

    offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1)
    offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0)
    # Use int32 or int64 based on tag_indices' dtype.
    if tag_indices.dtype == tf.int64:
        offsets = tf.cast(offsets, tf.int64)
    flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1])

    unary_scores = tf.reshape(
        tf.gather(flattened_inputs, flattened_tag_indices),
        [batch_size, max_seq_len])

    masks = tf.sequence_mask(sequence_lengths,
                             maxlen=tf.shape(tag_indices)[1],
                             dtype=tf.float32)

    unary_scores = tf.reduce_sum(unary_scores * masks, 1)
    return unary_scores
Exemplo n.º 3
0
def mask(inputs, key_masks=None, type=None):
    '''Masks paddings on keys or queries to inputs
    inputs: 3d tensor. (h*N, T_q, T_k)
    key_masks: 3d tensor. (N, 1, T_k)
    type: string. 'key' | 'future'

    e.g.,
    >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32)
    >> key_masks = tf.constant([[0., 0., 1.],
                                [0., 1., 1.]])
    >> mask(inputs, key_masks=key_masks, type='key')
    array([[[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]],

       [[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32)
    '''
    padding_num = -2 ** 32 + 1
    if type in ('k', 'key', 'keys'):
        key_masks = tf.to_float(key_masks)
        key_masks = tf.tile(
            key_masks,
            [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.expand_dims(key_masks, 1)  # (h*N, 1, seqlen)
        outputs = inputs + key_masks * padding_num
    # elif type in ('q', 'query', 'queries'):
    #     # Generate masks
    #     masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1))  # (N, T_q)
    #     masks = tf.expand_dims(masks, -1)  # (N, T_q, 1)
    #     masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]])  # (N, T_q, T_k)
    #
    #     # Apply masks to inputs
    #     outputs = inputs*masks
    elif type in ('f', 'future', 'right'):
        diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
        tril = tf.linalg.LinearOperatorLowerTriangular(
            diag_vals).to_dense()  # (T_q, T_k)
        future_masks = tf.tile(
            tf.expand_dims(tril, 0),
            [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)

        paddings = tf.ones_like(future_masks) * padding_num
        outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
    else:
        print('Check if you entered type correctly!')

    return outputs
Exemplo n.º 4
0
def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
                  r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt,
                  is_training, scale):
    '''Core relative positional attention operations.'''

    # content based attention score
    ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)

    # position based attention score
    bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
    bd = rel_shift(bd, klen=tf.shape(ac)[1])

    # segment based attention score
    if seg_mat is None:
        ef = 0
    else:
        ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
        ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)

    # merge attention scores and perform masking
    attn_score = (ac + bd + ef) * scale
    if attn_mask is not None:
        # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
        attn_score = attn_score - 1e30 * attn_mask

    # attention probability
    attn_prob = tf.nn.softmax(attn_score, 1)
    attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)

    # attention output
    attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)

    return attn_vec
Exemplo n.º 5
0
    def __init__(self,
                 is_training,
                 input_tensor,
                 input_mask,
                 label_ids,
                 label_size=2,
                 sample_weight=None,
                 scope='cls/sequence',
                 name='',
                 hidden_dropout_prob=0.1,
                 initializer_range=0.02,
                 trainable=True,
                 **kwargs):
        super().__init__(**kwargs)

        batch_size = tf.shape(input_tensor)[0]
        seq_length = input_tensor.shape.as_list()[-2]
        hidden_size = input_tensor.shape.as_list()[-1]
        with tf.variable_scope(scope):
            output_weights = tf.get_variable(
                'output_weights',
                shape=[label_size, hidden_size],
                initializer=util.create_initializer(initializer_range),
                trainable=trainable)
            output_bias = tf.get_variable('output_bias',
                                          shape=[label_size],
                                          initializer=tf.zeros_initializer(),
                                          trainable=trainable)

            output_layer = util.dropout(
                input_tensor, hidden_dropout_prob if is_training else 0.0)

            output_layer = tf.reshape(output_layer, [-1, hidden_size])
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            logits = tf.reshape(logits, [-1, seq_length, label_size])

            self.preds[name] = tf.argmax(logits, axis=-1)
            self.probs[name] = tf.nn.softmax(logits, axis=-1, name='probs')

            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(label_ids,
                                        depth=label_size,
                                        dtype=tf.float32)
            per_token_losses = -tf.reduce_mean(one_hot_labels * log_probs,
                                               axis=-1)
            input_mask = tf.concat([
                tf.zeros((batch_size, 1), dtype=tf.float32),
                tf.cast(input_mask[:, 2:], dtype=tf.float32),
                tf.zeros((batch_size, 1), dtype=tf.float32)
            ],
                                   axis=-1)
            per_token_losses *= input_mask
            per_example_loss = tf.reduce_mean(per_token_losses, axis=-1)
            if sample_weight is not None:
                per_example_loss *= tf.cast(sample_weight, dtype=tf.float32)

            self.losses[name] = per_example_loss
            self.total_loss = tf.reduce_mean(per_example_loss)
Exemplo n.º 6
0
def rel_shift(x, klen=-1):
    '''perform relative shift to form the relative attention score.'''
    x_size = tf.shape(x)

    x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
    x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
    x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
    x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])

    return x
Exemplo n.º 7
0
 def _single_seq_fn():
     batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
     example_inds = tf.reshape(
         tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
     sequence_scores = tf.gather_nd(
         tf.squeeze(inputs, [1]),
         tf.concat([example_inds, tag_indices], axis=1))
     sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0),
                                tf.zeros_like(sequence_scores),
                                sequence_scores)
     return sequence_scores
Exemplo n.º 8
0
def summarize_sequence(summary_type, hidden, d_model, n_head, d_head, dropout,
                       dropatt, input_mask, is_training, initializer,
                       scope=None, reuse=None, use_proj=True):

    '''
    Different classification tasks may not may not share the same parameters
    to summarize the sequence features.

    If shared, one can keep the `scope` to the default value `None`.
    Otherwise, one should specify a different `scope` for each task.
    '''

    with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse):
        if summary_type == 'last':
            summary = hidden[-1]
        elif summary_type == 'first':
            summary = hidden[0]
        elif summary_type == 'mean':
            summary = tf.reduce_mean(hidden, axis=0)
        elif summary_type == 'attn':
            bsz = tf.shape(hidden)[1]

            summary_bias = tf.get_variable('summary_bias', [d_model],
                                           dtype=hidden.dtype,
                                           initializer=initializer)
            summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1])

            if input_mask is not None:
                input_mask = input_mask[None, :, :, None]

            summary = multihead_attn(
                summary_bias, hidden, hidden, input_mask,
                d_model, n_head, d_head, dropout, dropatt,
                is_training, initializer, residual=False)
            summary = summary[0]
        else:
            raise ValueError('Unsupported summary type %s' % summary_type)

        # use another projection as in BERT
        if use_proj:
            summary = tf.layers.dense(
                summary,
                d_model,
                activation=tf.tanh,
                kernel_initializer=initializer,
                name='summary')

        # dropout
        summary = tf.layers.dropout(
            summary, dropout, training=is_training,
            name='dropout')

    return summary
Exemplo n.º 9
0
def positional_encoding(inputs,
                        maxlen,
                        masking=True,
                        scope='positional_encoding'):
    '''Sinusoidal Positional_Encoding. See 3.5
    inputs: 3d tensor. (N, T, E)
    maxlen: scalar. Must be >= T
    masking: Boolean. If True, padding positions are set to zeros.
    scope: Optional scope for `variable_scope`.

    returns
    3d tensor that has the same shape as inputs.
    '''

    E = inputs.get_shape().as_list()[-1] # static
    N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # position indices
        position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T)

        # First part of the PE function: sin and cos argument
        position_enc = np.array([
            [pos / np.power(10000, (i-i%2)/E) for i in range(E)]
            for pos in range(maxlen)])

        # Second part, apply the cosine to even columns and sin to odds.
        position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  # dim 2i
        position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])  # dim 2i+1
        position_enc = tf.convert_to_tensor(
            position_enc, tf.float32) # (maxlen, E)

        # lookup
        outputs = tf.nn.embedding_lookup(position_enc, position_ind)

        # masks
        if masking:
            outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)

        return tf.to_float(outputs)
Exemplo n.º 10
0
def crf_log_norm(inputs, sequence_lengths, transition_params):
    ''' Computes the normalization for a CRF.

    Args:
        inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
          to use as input to the CRF layer.
        sequence_lengths: A [batch_size] vector of true sequence lengths.
        transition_params: A [num_tags, num_tags] transition matrix.

    Returns:
        log_norm: A [batch_size] vector of normalizers for a CRF.
    '''
    # Split up the first and rest of the inputs in preparation for the forward
    # algorithm.
    first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
    first_input = tf.squeeze(first_input, [1])

    # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp
    # over the 'initial state' (the unary potentials).
    def _single_seq_fn():
        log_norm = tf.reduce_logsumexp(first_input, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm

    def _multi_seq_fn():
        '''Forward computation of alpha values.'''
        rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])

        # Compute the alpha values in the forward algorithm in order to get the
        # partition function.
        forward_cell = CrfForwardRnnCell(transition_params)
        # Sequence length is not allowed to be less than zero.
        sequence_lengths_less_one = tf.maximum(
            tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1)
        _, alphas = rnn.dynamic_rnn(cell=forward_cell,
                                    inputs=rest_of_input,
                                    sequence_length=sequence_lengths_less_one,
                                    initial_state=first_input,
                                    dtype=tf.float32)
        log_norm = tf.reduce_logsumexp(alphas, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm

    return smart.smart_cond(pred=tf.equal(
        util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1),
                            true_fn=_single_seq_fn,
                            false_fn=_multi_seq_fn)
Exemplo n.º 11
0
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
                       transition_params):
    ''' Computes the unnormalized score for a tag sequence.

    Args:
        inputs: A [batch_size, max_seq_len, num_tags] tensor of unary
          potentials to use as input to the CRF layer.
        tag_indices: A [batch_size, max_seq_len] matrix of tag indices for
          which we compute the unnormalized score.
        sequence_lengths: A [batch_size] vector of true sequence lengths.
        transition_params: A [num_tags, num_tags] transition matrix.

    Returns:
        sequence_scores: A [batch_size] vector of unnormalized sequence scores.
    '''

    # If max_seq_len is 1, we skip the score calculation and simply gather the
    # unary potentials of the single tag.
    def _single_seq_fn():
        batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
        example_inds = tf.reshape(
            tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
        sequence_scores = tf.gather_nd(
            tf.squeeze(inputs, [1]),
            tf.concat([example_inds, tag_indices], axis=1))
        sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0),
                                   tf.zeros_like(sequence_scores),
                                   sequence_scores)
        return sequence_scores

    def _multi_seq_fn():
        # Compute the scores of the given tag sequence.
        unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
        binary_scores = crf_binary_score(tag_indices, sequence_lengths,
                                         transition_params)
        sequence_scores = unary_scores + binary_scores
        return sequence_scores

    return smart.smart_cond(pred=tf.equal(
        util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1),
                            true_fn=_single_seq_fn,
                            false_fn=_multi_seq_fn)
Exemplo n.º 12
0
def get_shape_list(tensor, expected_rank=None, name=None):
    '''Returns a list of the shape of tensor, preferring static dimensions.'''
    if name is None:
        name = tensor.name

    if expected_rank is not None:
        assert_rank(tensor, expected_rank, name)

    shape = tensor.shape.as_list()

    non_static_indexes = []
    for (index, dim) in enumerate(shape):
        if dim is None:
            non_static_indexes.append(index)

    if not non_static_indexes:
        return shape

    dyn_shape = tf.shape(tensor)
    for index in non_static_indexes:
        shape[index] = dyn_shape[index]
    return shape
Exemplo n.º 13
0
            def _forward(input_ids, past=None):
                batch, sequence = shape_list(input_ids)

                if tilda_embeddings is None:
                    wte = tf.get_variable(
                        'word_embeddings', [hparams.n_vocab, hparams.n_embed],
                        initializer=tf.random_normal_initializer(stddev=0.02))
                else:
                    wte = tilda_embeddings
                wpe = tf.get_variable(
                    'wpe', [hparams.n_ctx, hparams.n_embed],
                    initializer=tf.random_normal_initializer(stddev=0.01))
                past_length = 0 if past is None else tf.shape(past)[-2]
                h = (tf.gather(wte, input_ids) +
                     tf.gather(wpe, positions_for(input_ids, past_length)))

                # stacked transformer layers
                presents = []
                pasts = tf.unstack(past, axis=1) if past is not None else \
                    [None] * hparams.n_layer
                assert len(pasts) == hparams.n_layer
                for layer, past in enumerate(pasts):
                    h, present = block(h,
                                       'h%d' % layer,
                                       past=past,
                                       hparams=hparams)
                    presents.append(present)
                present = tf.stack(presents, axis=1)
                h = norm(h, 'ln_f')

                # Language model loss.  Do tokens <n predict token n?
                h_flat = tf.reshape(h, [batch * sequence, hparams.n_embed])
                logits = tf.matmul(h_flat, wte, transpose_b=True)
                logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])

                return logits, present
Exemplo n.º 14
0
def positions_for(tokens, past_length):
    batch_size = tf.shape(tokens)[0]
    nsteps = tf.shape(tokens)[1]
    return expand_tile(past_length + tf.range(nsteps), batch_size)
Exemplo n.º 15
0
def shape_list(x):
    '''Deal with dynamic shape in tensorflow cleanly.'''
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
Exemplo n.º 16
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    '''
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  '''
    batch_size = tf.shape(inputs)[0]

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.reshape(index, [-1, perm_size])
    index = tf.transpose(index)
    index = tf.random_shuffle(index)
    index = tf.transpose(index)
    index = tf.reshape(index, [1, -1])
    index = tf.tile(index, [batch_size, 1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked),
                                     non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won't be information leak
    smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, :, None] <= rev_index[:, None, :],
        tf.expand_dims(masked_or_func_tokens, axis=-1))

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Exemplo n.º 17
0
def _expand_features(module, split_placeholders):

    inputs = split_placeholders['input']
    target = split_placeholders['target']
    is_masked = tf.cast(split_placeholders['is_masked'], tf.bool)
    batch_size = tf.shape(inputs)[0]

    non_reuse_len = module.max_seq_length - module.reuse_seq_length
    assert (module.perm_size <= module.reuse_seq_length
            and module.perm_size <= non_reuse_len)

    (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \
        _local_perm(
            inputs[:, :module.reuse_seq_length],
            target[:, :module.reuse_seq_length],
            is_masked[:, :module.reuse_seq_length],
            module.perm_size,
            module.reuse_seq_length)

    (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \
        _local_perm(
            inputs[:, module.reuse_seq_length:],
            target[:, module.reuse_seq_length:],
            is_masked[:, module.reuse_seq_length:],
            module.perm_size,
            non_reuse_len)

    perm_mask_0 = tf.concat([
        tf.cast(perm_mask_0, dtype=tf.float32),
        tf.ones([batch_size, module.reuse_seq_length, non_reuse_len])
    ],
                            axis=2)
    perm_mask_1 = tf.concat([
        tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]),
        tf.cast(perm_mask_1, dtype=tf.float32)
    ],
                            axis=2)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1)
    target = tf.concat([target_0, target_1], axis=1)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=1)
    input_k = tf.concat([input_k_0, input_k_1], axis=1)
    input_q = tf.concat([input_q_0, input_q_1], axis=1)

    if module._num_predict is not None:
        #TODO(geying): convert tensors from 1-D to 2-D

        indices = tf.range(module.max_seq_length, dtype=tf.int64)
        indices = tf.reshape(indices, [-1, module.max_seq_length])
        indices = tf.tile(indices, [batch_size, 1])
        bool_target_mask = tf.cast(target_mask, tf.bool)
        indices = tf.boolean_mask(indices, bool_target_mask)

        ##### extra padding due to CLS/SEP introduced after prepro
        actual_num_predict = tf.shape(indices)[1]
        pad_len = module._num_predict - actual_num_predict

        ##### target_mapping
        target_mapping = tf.one_hot(indices,
                                    module.max_seq_length,
                                    dtype=tf.float32)
        paddings = tf.zeros([pad_len, module.max_seq_length],
                            dtype=target_mapping.dtype)
        target_mapping = tf.concat([target_mapping, paddings], axis=0)
        split_placeholders['target_mapping'] = tf.reshape(
            target_mapping, [-1, module._num_predict, module.max_seq_length])

        ##### target
        target = tf.boolean_mask(target, bool_target_mask)
        paddings = tf.zeros([pad_len], dtype=target.dtype)
        target = tf.concat([target, paddings], axis=0)
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module._num_predict])

        ##### target mask
        target_mask = tf.concat([
            tf.ones([batch_size, actual_num_predict], dtype=tf.float32),
            tf.zeros([batch_size, pad_len], dtype=tf.float32)
        ],
                                axis=1)
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module._num_predict])
    else:
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module.max_seq_length])
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module.max_seq_length])

    # reshape back to fixed shape
    split_placeholders['perm_mask'] = tf.reshape(
        perm_mask, [-1, module.max_seq_length, module.max_seq_length])
    split_placeholders['input_k'] = tf.reshape(input_k,
                                               [-1, module.max_seq_length])
    split_placeholders['input_q'] = tf.reshape(input_q,
                                               [-1, module.max_seq_length])

    return split_placeholders
Exemplo n.º 18
0
def transformer_xl(inp_k,
                   n_token,
                   n_layer,
                   d_model,
                   n_head,
                   d_head,
                   d_inner,
                   dropout,
                   dropatt,
                   attn_type,
                   bi_data,
                   initializer,
                   is_training,
                   mem_len=None,
                   inp_q=None,
                   mems=None,
                   same_length=False,
                   clamp_len=-1,
                   untie_r=False,
                   use_tpu=True,
                   input_mask=None,
                   perm_mask=None,
                   seg_id=None,
                   reuse_len=None,
                   ff_activation='relu',
                   target_mapping=None,
                   use_bfloat16=False,
                   scope='transformer',
                   tilda_embeddings=None,
                   **kwargs):
    '''
    Defines a Transformer-XL computation graph with additional
    support for XLNet.

      Args:

      inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
      seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
      input_mask: float32 Tensor in shape [len, bsz], the input mask.
          0 for real tokens and 1 for padding.
      mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
          from previous batches. The length of the list equals n_layer.
          If None, no memory is used.
      perm_mask: float32 Tensor in shape [len, len, bsz].
          If perm_mask[i, j, k] = 0, i attend to j in batch k;
          if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
          If None, each position attends to all the others.
      target_mapping: float32 Tensor in shape [num_predict, len, bsz].
          If target_mapping[i, j, k] = 1, the i-th predict in batch k is
          on the j-th token.
          Only used during pretraining for partial prediction.
          Set to None during finetuning.
      inp_q: float32 Tensor in shape [len, bsz].
          1 for tokens with losses and 0 for tokens without losses.
          Only used during pretraining for two-stream attention.
          Set to None during finetuning.

      n_layer: int, the number of layers.
      d_model: int, the hidden size.
      n_head: int, the number of attention heads.
      d_head: int, the dimension size of each attention head.
      d_inner: int, the hidden size in feed-forward layers.
      ff_activation: str, 'relu' or 'gelu'.
      untie_r: bool, whether to untie the biases in attention.
      n_token: int, the vocab size.

      is_training: bool, whether in training mode.
      use_tpu: bool, whether TPUs are used.
      use_bfloat16: bool, use bfloat16 instead of float32.
      dropout: float, dropout rate.
      dropatt: float, dropout rate on attention probabilities.
      init: str, the initialization scheme, either 'normal' or 'uniform'.
      init_range: float, initialize the parameters with a uniform distribution
          in [-init_range, init_range]. Only effective when init='uniform'.
      init_std: float, initialize the parameters with a normal distribution
          with mean 0 and stddev init_std. Only effective when init='normal'.
      mem_len: int, the number of tokens to cache.
      reuse_len: int, the number of tokens in the currect batch to be cached
          and reused in the future.
      bi_data: bool, whether to use bidirectional input pipeline.
          Usually set to True during pretraining and False during finetuning.
      clamp_len: int, clamp all relative distances larger than clamp_len.
          -1 means no clamping.
      same_length: bool, whether to use the same attention length for each token.
      summary_type: str, 'last', 'first', 'mean', or 'attn'. The method
          to pool the input to get a vector representation.
      initializer: A tf initializer.
      scope: scope name for the computation graph.
    '''
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32

    new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        bsz = tf.shape(inp_k)[1]
        qlen = tf.shape(inp_k)[0]
        mlen = tf.shape(mems[0])[0] if mems is not None else 0
        klen = mlen + qlen

        ##### Attention mask
        # causal attention mask
        if attn_type == 'uni':
            attn_mask = _create_mask(qlen, mlen, tf_float, same_length)
            attn_mask = attn_mask[:, :, None, None]
        elif attn_type == 'bi':
            attn_mask = None
        else:
            raise ValueError('Unsupported attention type: %s' % attn_type)

        # data mask: input mask & perm mask
        if input_mask is not None and perm_mask is not None:
            data_mask = input_mask[None] + perm_mask
        elif input_mask is not None and perm_mask is None:
            data_mask = input_mask[None]
        elif input_mask is None and perm_mask is not None:
            data_mask = perm_mask
        else:
            data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
                                 dtype=tf_float)
            data_mask = tf.cast(data_mask, dtype=tf.float32)
            data_mask = tf.concat([mems_mask, data_mask], 1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)

        if attn_mask is not None:
            non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
            non_tgt_mask = tf.concat(
                [tf.zeros([qlen, mlen], dtype=tf_float), non_tgt_mask],
                axis=-1)
            non_tgt_mask = tf.cast(
                (attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                dtype=tf_float)
        else:
            non_tgt_mask = None

        ##### Word embedding
        word_emb_k, lookup_table = embedding_lookup(
            x=inp_k,
            n_token=n_token,
            d_embed=d_model,
            initializer=initializer,
            use_tpu=use_tpu,
            dtype=tf_float,
            scope='word_embedding',
            tilda_embeddings=tilda_embeddings)

        if inp_q is not None:
            with tf.variable_scope('mask_emb'):
                mask_emb = tf.get_variable('mask_emb', [1, 1, d_model],
                                           dtype=tf_float)
                if target_mapping is not None:
                    word_emb_q = tf.tile(mask_emb,
                                         [tf.shape(target_mapping)[0], bsz, 1])
                else:
                    inp_q_ext = inp_q[:, :, None]
                    word_emb_q = \
                        inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
        output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
        if inp_q is not None:
            output_g = tf.layers.dropout(word_emb_q,
                                         dropout,
                                         training=is_training)

        ##### Segment embedding
        if seg_id is not None:
            if untie_r:
                r_s_bias = tf.get_variable('r_s_bias',
                                           [n_layer, n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)
            else:
                # default case (tie)
                r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)

            seg_embed = tf.get_variable('seg_embed',
                                        [n_layer, 2, n_head, d_head],
                                        dtype=tf_float,
                                        initializer=initializer)

            # Convert `seg_id` to one-hot `seg_mat`
            mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
            cat_ids = tf.concat([mem_pad, seg_id], 0)

            # `1` indicates not in the same segment [qlen x klen x bsz]
            seg_mat = tf.cast(
                tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
                tf.int32)
            seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
        else:
            seg_mat = None

        ##### Positional encoding
        pos_emb = relative_positional_encoding(qlen,
                                               klen,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bi_data,
                                               bsz=bsz,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)

        ##### Attention layers
        if mems is None:
            mems = [None] * n_layer

        for i in range(n_layer):
            # cache new mems
            new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))

            # segment bias
            if seg_id is None:
                r_s_bias_i = None
                seg_embed_i = None
            else:
                r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
                seg_embed_i = seg_embed[i]

            with tf.variable_scope('layer_{}'.format(i)):
                if inp_q is not None:
                    output_h, output_g = two_stream_rel_attn(
                        h=output_h,
                        g=output_g,
                        r=pos_emb,
                        r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                        r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                        seg_mat=seg_mat,
                        r_s_bias=r_s_bias_i,
                        seg_embed=seg_embed_i,
                        attn_mask_h=non_tgt_mask,
                        attn_mask_g=attn_mask,
                        mems=mems[i],
                        target_mapping=target_mapping,
                        d_model=d_model,
                        n_head=n_head,
                        d_head=d_head,
                        dropout=dropout,
                        dropatt=dropatt,
                        is_training=is_training,
                        kernel_initializer=initializer)
                    reuse = True
                else:
                    reuse = False

                    output_h = rel_multihead_attn(
                        h=output_h,
                        r=pos_emb,
                        r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                        r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                        seg_mat=seg_mat,
                        r_s_bias=r_s_bias_i,
                        seg_embed=seg_embed_i,
                        attn_mask=non_tgt_mask,
                        mems=mems[i],
                        d_model=d_model,
                        n_head=n_head,
                        d_head=d_head,
                        dropout=dropout,
                        dropatt=dropatt,
                        is_training=is_training,
                        kernel_initializer=initializer,
                        reuse=reuse)

                if inp_q is not None:
                    output_g = positionwise_ffn(inp=output_g,
                                                d_model=d_model,
                                                d_inner=d_inner,
                                                dropout=dropout,
                                                kernel_initializer=initializer,
                                                activation_type=ff_activation,
                                                is_training=is_training)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=reuse)

        if inp_q is not None:
            output = tf.layers.dropout(output_g, dropout, training=is_training)
        else:
            output = tf.layers.dropout(output_h, dropout, training=is_training)

        return output, new_mems, lookup_table