示例#1
0
def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
                  r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt,
                  is_training, scale):
    '''Core relative positional attention operations.'''

    # content based attention score
    ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)

    # position based attention score
    bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
    bd = rel_shift(bd, klen=tf.shape(ac)[1])

    # segment based attention score
    if seg_mat is None:
        ef = 0
    else:
        ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
        ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)

    # merge attention scores and perform masking
    attn_score = (ac + bd + ef) * scale
    if attn_mask is not None:
        # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
        attn_score = attn_score - 1e30 * attn_mask

    # attention probability
    attn_prob = tf.nn.softmax(attn_score, 1)
    attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)

    # attention output
    attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)

    return attn_vec
示例#2
0
def embedding_lookup(x,
                     n_token,
                     d_embed,
                     initializer,
                     use_tpu=True,
                     scope='embedding',
                     tilda_embeddings=None,
                     reuse=None,
                     dtype=tf.float32):
    '''TPU and GPU embedding_lookup function.'''
    if tilda_embeddings is not None:
        lookup_table = tilda_embeddings
    else:
        with tf.variable_scope(scope, reuse=reuse):
            lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
                                           dtype=dtype,
                                           initializer=initializer)
    if use_tpu:
        one_hot_idx = tf.one_hot(x, n_token, dtype=dtype)
        if one_hot_idx.shape.ndims == 2:
            return (tf.einsum('in,nd->id', one_hot_idx,
                              lookup_table), lookup_table)
        else:
            return (tf.einsum('ibn,nd->ibd', one_hot_idx,
                              lookup_table), lookup_table)
    else:
        return tf.nn.embedding_lookup(lookup_table, x), lookup_table
示例#3
0
def noncausal_denominator(qs, ks):
    '''Computes FAVOR normalizer in noncausal attention.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
  Returns:
    FAVOR normalizer in noncausal attention.
  '''
    all_ones = tf.ones([ks.shape[0]])
    ks_sum = tf.einsum('lbhm,l->bhm', ks, all_ones)
    return tf.einsum('lbhm,bhm->lbh', qs, ks_sum)
示例#4
0
def noncausal_numerator(qs, ks, vs):
    '''Computes not-normalized FAVOR noncausal attention AV.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
    vs: value tensor of the shape [L,B,H,D].
  Returns:
    Not-normalized FAVOR noncausal attention AV.
  '''
    kvs = tf.einsum('lbhm,lbhd->bhmd', ks, vs)
    return tf.einsum('lbhm,bhmd->lbhd', qs, kvs)
示例#5
0
def causal_numerator(qs, ks, vs):
    '''Computes not-normalized FAVOR causal attention A_{masked}V.
  Args:
    qs: query_prime tensor of the shape [L,B,H,M].
    ks: key_prime tensor of the shape [L,B,H,M].
    vs: value tensor of the shape [L,B,H,D].
  Returns:
    Not-normalized FAVOR causal attention A_{masked}V.
  '''

    result = []
    sums = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

    for index in range(qs.shape[0]):
        sums = sums + tf.einsum('ijk,ijl->ijkl', ks[index], vs[index])
        result.append(
            tf.einsum('ijkl,ijk->ijl', sums, qs[index])[None, Ellipsis])

    result = tf.concat(result, axis=0)

    def grad(res_grad):

        grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

        gr_sums = sums

        q_grads = []
        k_grads = []
        v_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None,
                                                                     Ellipsis])
            grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index],
                                      res_grad[index])
            k_grads.append(
                tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis])
            v_grads.append(
                tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis])
            gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index],
                                          vs[index])

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)
        v_grads = tf.concat(v_grads[::-1], axis=0)

        return q_grads, k_grads, v_grads

    return result, grad
示例#6
0
def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None,
            tie_weight=False, bi_data=True, use_tpu=False):
    '''doc.'''

    with tf.variable_scope('lm_loss'):
        if tie_weight:
            assert lookup_table is not None, \
                'lookup_table cannot be None for tie_weight'
            softmax_w = lookup_table
        else:
            softmax_w = tf.get_variable(
                'weight', [n_token, d_model],
                dtype=hidden.dtype, initializer=initializer)

        softmax_b = tf.get_variable(
            'bias', [n_token], dtype=hidden.dtype,
            initializer=tf.zeros_initializer())

        logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b
        preds = tf.argmax(logits, axis=-1)

        if use_tpu:
            one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype)
            loss = -tf.reduce_sum(
                tf.nn.log_softmax(logits) * one_hot_target, -1)
        else:
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=target, logits=logits)

        return loss, preds
示例#7
0
def relu_kernel_transformation(data,
                               is_query,
                               projection_matrix=None,
                               numerical_stabilizer=0.001):
    '''Computes features for the ReLU-kernel.
  Computes random features for the ReLU kernel from
  https://arxiv.org/pdf/2009.14794.pdf.
  Args:
    data: input data tensor of the shape [B, L, H, D], where: B - batch
      dimension, L - attention dimensions, H - heads, D - features.
    is_query: indicates whether input data is a query oor key tensor.
    projection_matrix: random Gaussian matrix of shape [M, D], where M stands
      for the number of random features and each D x D sub-block has pairwise
      orthogonal rows.
    numerical_stabilizer: small positive constant for numerical stability.
  Returns:
    Corresponding kernel feature map.
  '''
    del is_query
    if projection_matrix is None:
        return tf.nn.relu(data) + numerical_stabilizer
    else:
        ratio = tf.math.rsqrt(tf.cast(projection_matrix.shape[0], tf.float32))
        data_dash = ratio * tf.einsum('blhd,md->blhm', data, projection_matrix)
        return tf.nn.relu(data_dash) + numerical_stabilizer
示例#8
0
 def call(self, inputs):
     ret = tf.einsum(self._einsum_string, inputs, self._kernel)
     if self._use_bias:
         ret += self._bias
     if self._activation is not None:
         ret = self._activation(ret)
     return ret
示例#9
0
def abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, is_training,
                  scale):
    '''Core absolute positional attention operations.'''

    attn_score = tf.einsum('ibnd,jbnd->ijbn', q_head, k_head)
    attn_score *= scale
    if attn_mask is not None:
        attn_score = attn_score - 1e30 * attn_mask

    # attention probability
    attn_prob = tf.nn.softmax(attn_score, 1)
    attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)

    # attention output
    attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head)

    return attn_vec
示例#10
0
def head_projection(h, d_model, n_head, d_head, kernel_initializer, name):
    '''Project hidden states to a specific head with a 4D-shape.'''
    proj_weight = tf.get_variable('{}/kernel'.format(name),
                                  [d_model, n_head, d_head],
                                  dtype=h.dtype,
                                  initializer=kernel_initializer)
    head = tf.einsum('ibh,hnd->ibnd', h, proj_weight)

    return head
示例#11
0
def positional_embedding(pos_seq, inv_freq, bsz=None):
    sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq)
    pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
    pos_emb = pos_emb[:, None, :]

    if bsz is not None:
        pos_emb = tf.tile(pos_emb, [1, bsz, 1])

    return pos_emb
示例#12
0
    def grad(res_grad):

        grads = tf.zeros_like(tf.einsum('ijk,ijl->ijkl', ks[0], vs[0]))

        gr_sums = sums

        q_grads = []
        k_grads = []
        v_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijkl,ijl->ijk', gr_sums, res_grad[index])[None,
                                                                     Ellipsis])
            grads = grads + tf.einsum('ijk,ijl->ijkl', qs[index],
                                      res_grad[index])
            k_grads.append(
                tf.einsum('ijkl,ijl->ijk', grads, vs[index])[None, Ellipsis])
            v_grads.append(
                tf.einsum('ijkl,ijk->ijl', grads, ks[index])[None, Ellipsis])
            gr_sums = gr_sums - tf.einsum('ijk,ijl->ijkl', ks[index],
                                          vs[index])

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)
        v_grads = tf.concat(v_grads[::-1], axis=0)

        return q_grads, k_grads, v_grads
示例#13
0
    def grad(res_grad):

        k_grad = tf.zeros_like(ks[0])

        gr_sums = sums

        q_grads = []
        k_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum('ijk,ij->ijk', gr_sums, res_grad[index])[None,
                                                                   Ellipsis])
            k_grad = k_grad + tf.einsum('ijk,ij->ijk', qs[index],
                                        res_grad[index])
            k_grads.append(k_grad[None, Ellipsis])
            gr_sums = gr_sums - ks[index]

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)

        return q_grads, k_grads
示例#14
0
def dense_layer_3d(input_tensor,
                   num_attention_heads,
                   head_size,
                   initializer,
                   activation,
                   use_einsum,
                   name=None,
                   trainable=True):
  """A dense layer with 3D kernel.

  Args:
    input_tensor: float Tensor of shape [batch, seq_length, hidden_size].
    num_attention_heads: Number of attention heads.
    head_size: The size per attention head.
    initializer: Kernel initializer.
    activation: Actication function.
    use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers.
    name: The name scope of this layer.

  Returns:
    float logits Tensor.
  """

  input_shape = util.get_shape_list(input_tensor)
  hidden_size = input_shape[2]

  with tf.variable_scope(name):
    w = tf.get_variable(
        name="kernel",
        shape=[hidden_size, num_attention_heads * head_size],
        initializer=initializer,
        trainable=trainable)
    w = tf.reshape(w, [hidden_size, num_attention_heads, head_size])
    b = tf.get_variable(
        name="bias",
        shape=[num_attention_heads * head_size],
        initializer=tf.zeros_initializer,
        trainable=trainable)
    b = tf.reshape(b, [num_attention_heads, head_size])
    if use_einsum:
      ret = tf.einsum("BFH,HND->BFND", input_tensor, w)
    else:
      ret = einsum_via_matmul(input_tensor, w, 1)
    ret += b
  if activation is not None:
    return activation(ret)
  else:
    return ret
示例#15
0
def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training,
                   kernel_initializer, residual=True):
    '''Post-attention processing.'''
    # post-attention projection (back to `d_model`)
    proj_o = tf.get_variable(
        'o/kernel', [d_model, n_head, d_head],
        dtype=h.dtype, initializer=kernel_initializer)
    attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o)

    attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
    if residual:
        output = util.layer_norm(
            attn_out + h, name='LayerNorm')
    else:
        output = util.layer_norm(
            attn_out, name='LayerNorm')

    return output
示例#16
0
def dense_layer_2d(input_tensor,
                   output_size,
                   initializer,
                   activation,
                   use_einsum,
                   num_attention_heads=1,
                   name=None,
                   trainable=True):
    """A dense layer with 2D kernel.

  Args:
    input_tensor: Float tensor with rank 3.
    output_size: The size of output dimension.
    initializer: Kernel initializer.
    activation: Activation function.
    use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers.
    num_attention_heads: number of attention head in attention layer.
    name: The name scope of this layer.

  Returns:
    float logits Tensor.
  """
    del num_attention_heads  # unused
    input_shape = util.get_shape_list(input_tensor)
    hidden_size = input_shape[2]
    with tf.variable_scope(name):
        w = tf.get_variable(name="kernel",
                            shape=[hidden_size, output_size],
                            initializer=initializer,
                            trainable=trainable)
        b = tf.get_variable(name="bias",
                            shape=[output_size],
                            initializer=tf.zeros_initializer,
                            trainable=trainable)
        if use_einsum:
            ret = tf.einsum("BFH,HO->BFO", input_tensor, w)
        else:
            ret = tf.matmul(input_tensor, w)
        ret += b
    if activation is not None:
        return activation(ret)
    else:
        return ret
示例#17
0
def softmax_kernel_transformation(data,
                                  is_query,
                                  projection_matrix=None,
                                  numerical_stabilizer=0.000001):
    '''Computes random features for the softmax kernel using FAVOR+ mechanism.
  Computes random features for the softmax kernel using FAVOR+ mechanism from
  https://arxiv.org/pdf/2009.14794.pdf.
  Args:
    data: input data tensor of the shape [B, L, H, D], where: B - batch
      dimension, L - attention dimensions, H - heads, D - features.
    is_query: indicates whether input data is a query oor key tensor.
    projection_matrix: random Gaussian matrix of shape [M, D], where M stands
      for the number of random features and each D x D sub-block has pairwise
      orthogonal rows.
    numerical_stabilizer: small positive constant for numerical stability.
  Returns:
    Corresponding kernel feature map.
  '''
    data_normalizer = \
        tf.math.rsqrt(1 / tf.math.rsqrt(tf.cast(data.shape[-1], tf.float32)))
    ratio = tf.math.rsqrt(
        tf.cast(
            projection_matrix.shape[0]
            if projection_matrix is not None else 1.0, tf.float32))
    data_dash = tf.einsum('blhd,md->blhm', data, projection_matrix)
    diag_data = tf.math.square(data)
    diag_data = tf.math.reduce_sum(diag_data,
                                   axis=tf.keras.backend.ndim(data) - 1)
    diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
    diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
    if is_query:
        last_dims_t = (len(data_dash.shape) - 1, )
        data_dash = ratio * (tf.math.exp(
            data_dash - diag_data -
            tf.math.reduce_max(data_dash, axis=last_dims_t, keepdims=True)) +
                             numerical_stabilizer)
    else:
        data_dash = ratio * (tf.math.exp(data_dash - diag_data -
                                         tf.math.reduce_max(data_dash)) +
                             numerical_stabilizer)

    return data_dash
示例#18
0
def dense_layer_3d_proj(input_tensor,
                        hidden_size,
                        head_size,
                        initializer,
                        activation,
                        use_einsum,
                        name=None):
    """A dense layer with 3D kernel for projection.

  Args:
    input_tensor: float Tensor of shape [batch,from_seq_length,
      num_attention_heads, size_per_head].
    hidden_size: The size of hidden layer.
    head_size: The size of head.
    initializer: Kernel initializer.
    activation: Actication function.
    use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers.
    name: The name scope of this layer.

  Returns:
    float logits Tensor.
  """
    input_shape = util.get_shape_list(input_tensor)
    num_attention_heads = input_shape[2]
    with tf.variable_scope(name):
        w = tf.get_variable(
            name="kernel",
            shape=[num_attention_heads * head_size, hidden_size],
            initializer=initializer)
        w = tf.reshape(w, [num_attention_heads, head_size, hidden_size])
        b = tf.get_variable(name="bias",
                            shape=[hidden_size],
                            initializer=tf.zeros_initializer)
        if use_einsum:
            ret = tf.einsum("BFND,NDH->BFH", input_tensor, w)
        else:
            ret = einsum_via_matmul(input_tensor, w, 2)
        ret += b
    if activation is not None:
        return activation(ret)
    else:
        return ret
示例#19
0
def two_stream_rel_attn(h,
                        g,
                        r,
                        mems,
                        r_w_bias,
                        r_r_bias,
                        seg_mat,
                        r_s_bias,
                        seg_embed,
                        attn_mask_h,
                        attn_mask_g,
                        target_mapping,
                        d_model,
                        n_head,
                        d_head,
                        dropout,
                        dropatt,
                        is_training,
                        kernel_initializer,
                        scope='rel_attn'):
    '''Two-stream attention with relative positional encoding.'''

    scale = 1 / (d_head**0.5)
    with tf.variable_scope(scope, reuse=False):

        # content based attention score
        if mems is not None and mems.shape.ndims > 1:
            cat = tf.concat([mems, h], 0)
        else:
            cat = h

        # content-based key head
        k_head_h = head_projection(cat, d_model, n_head, d_head,
                                   kernel_initializer, 'k')

        # content-based value head
        v_head_h = head_projection(cat, d_model, n_head, d_head,
                                   kernel_initializer, 'v')

        # position-based key head
        k_head_r = head_projection(r, d_model, n_head, d_head,
                                   kernel_initializer, 'r')

        ##### h-stream
        # content-stream query head
        q_head_h = head_projection(h, d_model, n_head, d_head,
                                   kernel_initializer, 'q')

        # core attention ops
        attn_vec_h = rel_attn_core(q_head_h, k_head_h, v_head_h, k_head_r,
                                   seg_embed, seg_mat, r_w_bias, r_r_bias,
                                   r_s_bias, attn_mask_h, dropatt, is_training,
                                   scale)

        # post processing
        output_h = post_attention(h, attn_vec_h, d_model, n_head, d_head,
                                  dropout, is_training, kernel_initializer)

    with tf.variable_scope(scope, reuse=True):
        ##### g-stream
        # query-stream query head
        q_head_g = head_projection(g, d_model, n_head, d_head,
                                   kernel_initializer, 'q')

        # core attention ops
        if target_mapping is not None:
            q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
            attn_vec_g = rel_attn_core(q_head_g, k_head_h, v_head_h, k_head_r,
                                       seg_embed, seg_mat, r_w_bias, r_r_bias,
                                       r_s_bias, attn_mask_g, dropatt,
                                       is_training, scale)
            attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g,
                                   target_mapping)
        else:
            attn_vec_g = rel_attn_core(q_head_g, k_head_h, v_head_h, k_head_r,
                                       seg_embed, seg_mat, r_w_bias, r_r_bias,
                                       r_s_bias, attn_mask_g, dropatt,
                                       is_training, scale)

        # post processing
        output_g = post_attention(g, attn_vec_g, d_model, n_head, d_head,
                                  dropout, is_training, kernel_initializer)

        return output_h, output_g