Example #1
0
def MultiplicativeConvCausalAttention(d_feature,
                                      n_heads=1,
                                      sparsity=None,
                                      length_kernel_size=3,
                                      dropout=0.0,
                                      max_inference_length=2048,
                                      mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  MultiplicativeSparseDense layer with LocallyConvLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    length_kernel_size: Size of convolution kernel on the length dimension.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    sparsity = n_heads if sparsity is None else sparsity
    d_module = d_feature // sparsity
    return tl.Serial(
        tl.Select([0, 0]),  # duplicate activations
        MultiplicativeSparseDense(sparsity, d_feature,
                                  d_feature),  # shared q, k
        tl.Select([0, 0, 0]),  # use for q, k, v
        tl.Parallel(
            [
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
            [
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
            [
                tl.Concatenate(),  # use permuted and original for v
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=1,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
        ),
        tl.DotProductCausalAttention(dropout=dropout,
                                     max_inference_length=max_inference_length,
                                     mode=mode),
        tl.MergeHeads(n_heads),
    )
Example #2
0
def Favor(d_feature,
          n_heads=1,
          dropout=0.0,
          numerical_stabilizer=0.001,
          mode='train'):
    """Returns a layer that maps (activations, mask) to (new_activations, mask).

  See the FAVOR paper for details: https://arxiv.org/abs/2006.03555

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    numerical_stabilizer: float, small number used for numerical stability.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    del dropout, mode  # not implemented yet but needed in the API

    def bidirectional_numerator(query_prime, key_prime, value):
        kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value)
        return jnp.einsum('lbm,bmd->lbd', query_prime, kvs)

    def bidirectional_denominator(query_prime, key_prime):
        all_ones = jnp.ones([query_prime.shape[0]])
        ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones)
        return jnp.einsum('lbm,bm->lb', query_prime, ks_sum)

    def relu(x):
        return jnp.where(x <= 0, jnp.zeros_like(x), x)

    def favor(query, key, value, mask):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        mask_batch_1_length = jnp.reshape(
            mask,
            [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32)
        mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1))
        key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1])

        w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0),
                                    jnp.moveaxis(key_prime, 1, 0),
                                    jnp.moveaxis(value, 1, 0))
        r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0),
                                      jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention, mask

    return tl.Serial(
        tl.Branch(
            [tl.Dense(d_feature),
             tl.SplitIntoHeads(n_heads)],
            [tl.Dense(d_feature),
             tl.SplitIntoHeads(n_heads)],
            [tl.Dense(d_feature),
             tl.SplitIntoHeads(n_heads)],
        ),
        tl.Fn('FAVOR', favor, n_out=2),
        tl.Dense(d_feature),
    )