Example #1
0
def crf_unary_score(tag_indices, sequence_lengths, inputs):
    ''' Computes the unary scores of tag sequences.

    Args:
      tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
      sequence_lengths: A [batch_size] vector of true sequence lengths.
      inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.

    Returns:
      unary_scores: A [batch_size] vector of unary scores.
    '''
    batch_size = tf.shape(inputs)[0]
    max_seq_len = tf.shape(inputs)[1]
    num_tags = tf.shape(inputs)[2]

    flattened_inputs = tf.reshape(inputs, [-1])

    offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1)
    offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0)
    # Use int32 or int64 based on tag_indices' dtype.
    if tag_indices.dtype == tf.int64:
        offsets = tf.cast(offsets, tf.int64)
    flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1])

    unary_scores = tf.reshape(
        tf.gather(flattened_inputs, flattened_tag_indices),
        [batch_size, max_seq_len])

    masks = tf.sequence_mask(sequence_lengths,
                             maxlen=tf.shape(tag_indices)[1],
                             dtype=tf.float32)

    unary_scores = tf.reduce_sum(unary_scores * masks, 1)
    return unary_scores
Example #2
0
def attention_mask(nd, ns, *, dtype):
    '''1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't
      produce garbage on TPUs.
    '''
    i = tf.range(nd)[:, None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)
Example #3
0
def relative_positional_encoding(qlen,
                                 klen,
                                 d_model,
                                 clamp_len,
                                 attn_type,
                                 bi_data,
                                 bsz=None,
                                 dtype=None):
    '''create relative positional encoding.'''
    freq_seq = tf.range(0, d_model, 2.0)
    if dtype is not None and dtype != tf.float32:
        freq_seq = tf.cast(freq_seq, dtype=dtype)
    inv_freq = 1 / (10000**(freq_seq / d_model))

    if attn_type == 'bi':
        # beg, end = klen - 1, -qlen
        beg, end = klen, -qlen
    elif attn_type == 'uni':
        # beg, end = klen - 1, -1
        beg, end = klen, -1
    else:
        raise ValueError('Unknown `attn_type` {}.'.format(attn_type))

    if bi_data:
        fwd_pos_seq = tf.range(beg, end, -1.0)
        bwd_pos_seq = tf.range(-beg, -end, 1.0)

        if dtype is not None and dtype != tf.float32:
            fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
            bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)

        if clamp_len > 0:
            fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
            bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len)

        if bsz is not None:
            # With bi_data, the batch size should be divisible by 2.
            assert bsz % 2 == 0
            fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
            bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
        else:
            fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
            bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)

        pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
    else:
        fwd_pos_seq = tf.range(beg, end, -1.0)
        if dtype is not None and dtype != tf.float32:
            fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
        if clamp_len > 0:
            fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
        pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)

    return pos_emb
Example #4
0
def gather_positions(sequence, positions):
    '''Gathers the vectors at the specific positions over a minibatch.

    Args:
      sequence: A [batch_size, seq_length] or
          [batch_size, seq_length, depth] tensor of values
      positions: A [batch_size, n_positions] tensor of indices

    Returns: A [batch_size, n_positions] or
      [batch_size, n_positions, depth] tensor of the values at the indices
    '''
    shape = util.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    position_shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + position_shift, [-1])
    flat_sequence = tf.reshape(sequence, [B * L, D])
    gathered = tf.gather(flat_sequence, flat_positions)
    if depth_dimension:
        return tf.reshape(gathered, [B, -1, D])
    else:
        return tf.reshape(gathered, [B, -1])
Example #5
0
def get_timing_signal_1d_given_position(channels,
                                        position,
                                        min_timescale=1.0,
                                        max_timescale=1.0e4):
    """Get sinusoids of diff frequencies, with timing position given.

  Adapted from add_timing_signal_1d_given_position in
  //third_party/py/tensor2tensor/layers/common_attention.py

  Args:
    channels: scalar, size of timing embeddings to create. The number of
        different timescales is equal to channels / 2.
    position: a Tensor with shape [batch, seq_len]
    min_timescale: a float
    max_timescale: a float

  Returns:
    a Tensor of timing signals [batch, seq_len, channels]
  """
    num_timescales = channels // 2
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (tf.to_float(num_timescales) - 1))
    inv_timescales = min_timescale * tf.exp(
        tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
    scaled_time = (tf.expand_dims(tf.to_float(position), 2) *
                   tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0))
    signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
    signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]])
    return signal
Example #6
0
def scatter_update(sequence, updates, positions):
    '''Scatter-update a sequence.

    Args:
      sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
      updates: A tensor of size batch_size*seq_len(*depth)
      positions: A [batch_size, n_positions] tensor

    Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
      [batch_size, seq_len, depth] tensor of 'sequence' with elements at
      'positions' replaced by the values at 'updates.' Updates to index 0 are
      ignored. If there are duplicated positions the update is only applied
      once. Second is a [batch_size, seq_len] mask tensor of which inputs were
      updated.
    '''
    shape = util.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        B, L, D = shape
    else:
        B, L = shape
        D = 1
        sequence = tf.expand_dims(sequence, -1)
    N = util.get_shape_list(positions)[1]

    shift = tf.expand_dims(L * tf.range(B), -1)
    flat_positions = tf.reshape(positions + shift, [-1, 1])
    flat_updates = tf.reshape(updates, [-1, D])
    updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
    updates = tf.reshape(updates, [B, L, D])

    flat_updates_mask = tf.ones([B * N], tf.int32)
    updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
    updates_mask = tf.reshape(updates_mask, [B, L])
    not_first_token = tf.concat(
        [tf.zeros((B, 1), tf.int32),
         tf.ones((B, L - 1), tf.int32)], -1)
    updates_mask *= not_first_token
    updates_mask_3d = tf.expand_dims(updates_mask, -1)

    # account for duplicate positions
    if sequence.dtype == tf.float32:
        updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
        updates /= tf.maximum(1.0, updates_mask_3d)
    else:
        assert sequence.dtype == tf.int32
        updates = tf.divide(updates, tf.maximum(1, updates_mask_3d))
        updates = tf.cast(updates, tf.int32)
    updates_mask = tf.minimum(updates_mask, 1)
    updates_mask_3d = tf.minimum(updates_mask_3d, 1)

    updated_sequence = (((1 - updates_mask_3d) * sequence) +
                        (updates_mask_3d * updates))
    if not depth_dimension:
        updated_sequence = tf.squeeze(updated_sequence, -1)

    return updated_sequence, updates_mask
Example #7
0
 def _single_seq_fn():
     batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
     example_inds = tf.reshape(
         tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
     sequence_scores = tf.gather_nd(
         tf.squeeze(inputs, [1]),
         tf.concat([example_inds, tag_indices], axis=1))
     sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0),
                                tf.zeros_like(sequence_scores),
                                sequence_scores)
     return sequence_scores
Example #8
0
        def gather_indexes(sequence_tensor, positions):
            sequence_shape = util.get_shape_list(sequence_tensor, 3)
            batch_size = sequence_shape[0]
            seq_length = sequence_shape[1]
            width = sequence_shape[2]

            flat_offsets = tf.reshape(
                tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
            flat_positions = tf.reshape(positions + flat_offsets, [-1])
            flat_sequence_tensor = tf.reshape(sequence_tensor,
                                              [batch_size * seq_length, width])
            output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
            return output_tensor
Example #9
0
    def create_attention_mask_from_input_mask(self,
                                              input_mask,
                                              batch_size,
                                              max_seq_length,
                                              dtype=tf.float32):
        if self._mode == 'bi':
            to_mask = tf.cast(tf.reshape(
                input_mask, [batch_size, 1, max_seq_length]), dtype=dtype)
            broadcast_ones = tf.ones(
                shape=[batch_size, max_seq_length, 1], dtype=dtype)
            mask = broadcast_ones * to_mask

        elif self._mode == 'l2r':
            arange = tf.range(max_seq_length) + 1
            to_mask = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype)
            to_mask = tf.reshape(to_mask, [1, max_seq_length, max_seq_length])
            mask = tf.tile(to_mask, [batch_size, 1, 1])

        elif self._mode == 'r2l':
            to_mask = tf.cast(tf.reshape(
                input_mask, [batch_size, 1, max_seq_length]), dtype=dtype)
            broadcast_ones = tf.ones(
                shape=[batch_size, max_seq_length, 1], dtype=dtype)
            cover_mask = broadcast_ones * to_mask

            arange = tf.range(max_seq_length)
            reverse = tf.cast(tf.sequence_mask(arange, max_seq_length), dtype)
            reverse = tf.reshape(reverse, [1, max_seq_length, max_seq_length])
            reverse_mask = tf.tile(reverse, [batch_size, 1, 1])

            mask = (1 - reverse_mask) * cover_mask

        elif self._mode == 's2s':
            mask = tf.cast(
                tf.sequence_mask(input_mask, max_seq_length), dtype)

        return mask
Example #10
0
def positional_encoding(inputs,
                        maxlen,
                        masking=True,
                        scope='positional_encoding'):
    '''Sinusoidal Positional_Encoding. See 3.5
    inputs: 3d tensor. (N, T, E)
    maxlen: scalar. Must be >= T
    masking: Boolean. If True, padding positions are set to zeros.
    scope: Optional scope for `variable_scope`.

    returns
    3d tensor that has the same shape as inputs.
    '''

    E = inputs.get_shape().as_list()[-1] # static
    N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # position indices
        position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T)

        # First part of the PE function: sin and cos argument
        position_enc = np.array([
            [pos / np.power(10000, (i-i%2)/E) for i in range(E)]
            for pos in range(maxlen)])

        # Second part, apply the cosine to even columns and sin to odds.
        position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  # dim 2i
        position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])  # dim 2i+1
        position_enc = tf.convert_to_tensor(
            position_enc, tf.float32) # (maxlen, E)

        # lookup
        outputs = tf.nn.embedding_lookup(position_enc, position_ind)

        # masks
        if masking:
            outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)

        return tf.to_float(outputs)
Example #11
0
def positions_for(tokens, past_length):
    batch_size = tf.shape(tokens)[0]
    nsteps = tf.shape(tokens)[1]
    return expand_tile(past_length + tf.range(nsteps), batch_size)
Example #12
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    '''
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  '''
    batch_size = tf.shape(inputs)[0]

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.reshape(index, [-1, perm_size])
    index = tf.transpose(index)
    index = tf.random_shuffle(index)
    index = tf.transpose(index)
    index = tf.reshape(index, [1, -1])
    index = tf.tile(index, [batch_size, 1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked),
                                     non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won't be information leak
    smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, :, None] <= rev_index[:, None, :],
        tf.expand_dims(masked_or_func_tokens, axis=-1))

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Example #13
0
def _expand_features(module, split_placeholders):

    inputs = split_placeholders['input']
    target = split_placeholders['target']
    is_masked = tf.cast(split_placeholders['is_masked'], tf.bool)
    batch_size = tf.shape(inputs)[0]

    non_reuse_len = module.max_seq_length - module.reuse_seq_length
    assert (module.perm_size <= module.reuse_seq_length
            and module.perm_size <= non_reuse_len)

    (perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0) = \
        _local_perm(
            inputs[:, :module.reuse_seq_length],
            target[:, :module.reuse_seq_length],
            is_masked[:, :module.reuse_seq_length],
            module.perm_size,
            module.reuse_seq_length)

    (perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1) = \
        _local_perm(
            inputs[:, module.reuse_seq_length:],
            target[:, module.reuse_seq_length:],
            is_masked[:, module.reuse_seq_length:],
            module.perm_size,
            non_reuse_len)

    perm_mask_0 = tf.concat([
        tf.cast(perm_mask_0, dtype=tf.float32),
        tf.ones([batch_size, module.reuse_seq_length, non_reuse_len])
    ],
                            axis=2)
    perm_mask_1 = tf.concat([
        tf.zeros([batch_size, non_reuse_len, module.reuse_seq_length]),
        tf.cast(perm_mask_1, dtype=tf.float32)
    ],
                            axis=2)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=1)
    target = tf.concat([target_0, target_1], axis=1)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=1)
    input_k = tf.concat([input_k_0, input_k_1], axis=1)
    input_q = tf.concat([input_q_0, input_q_1], axis=1)

    if module._num_predict is not None:
        #TODO(geying): convert tensors from 1-D to 2-D

        indices = tf.range(module.max_seq_length, dtype=tf.int64)
        indices = tf.reshape(indices, [-1, module.max_seq_length])
        indices = tf.tile(indices, [batch_size, 1])
        bool_target_mask = tf.cast(target_mask, tf.bool)
        indices = tf.boolean_mask(indices, bool_target_mask)

        ##### extra padding due to CLS/SEP introduced after prepro
        actual_num_predict = tf.shape(indices)[1]
        pad_len = module._num_predict - actual_num_predict

        ##### target_mapping
        target_mapping = tf.one_hot(indices,
                                    module.max_seq_length,
                                    dtype=tf.float32)
        paddings = tf.zeros([pad_len, module.max_seq_length],
                            dtype=target_mapping.dtype)
        target_mapping = tf.concat([target_mapping, paddings], axis=0)
        split_placeholders['target_mapping'] = tf.reshape(
            target_mapping, [-1, module._num_predict, module.max_seq_length])

        ##### target
        target = tf.boolean_mask(target, bool_target_mask)
        paddings = tf.zeros([pad_len], dtype=target.dtype)
        target = tf.concat([target, paddings], axis=0)
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module._num_predict])

        ##### target mask
        target_mask = tf.concat([
            tf.ones([batch_size, actual_num_predict], dtype=tf.float32),
            tf.zeros([batch_size, pad_len], dtype=tf.float32)
        ],
                                axis=1)
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module._num_predict])
    else:
        split_placeholders['target'] = tf.reshape(target,
                                                  [-1, module.max_seq_length])
        split_placeholders['target_mask'] = tf.reshape(
            target_mask, [-1, module.max_seq_length])

    # reshape back to fixed shape
    split_placeholders['perm_mask'] = tf.reshape(
        perm_mask, [-1, module.max_seq_length, module.max_seq_length])
    split_placeholders['input_k'] = tf.reshape(input_k,
                                               [-1, module.max_seq_length])
    split_placeholders['input_q'] = tf.reshape(input_q,
                                               [-1, module.max_seq_length])

    return split_placeholders