Example #1
0
def MultiplicativeModularCausalAttention(d_feature,
                                         n_heads=1,
                                         sparsity=None,
                                         dropout=0.0,
                                         max_inference_length=2048,
                                         mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  MultiplicativeSparseDense layer with LocallyConnectedLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    sparsity = n_heads if sparsity is None else sparsity
    return tl.ConfigurableAttention(
        MultiplicativeModularSparseDense(sparsity, d_feature),
        MultiplicativeModularSparseDense(sparsity, d_feature),
        MultiplicativeModularSparseDense(sparsity, d_feature),
        MultiplicativeModularSparseDense(sparsity, d_feature),
        n_heads=n_heads,
        qkv_attention_layer=tl.DotProductCausalAttention(
            dropout=dropout,
            max_inference_length=max_inference_length,
            mode=mode))
Example #2
0
def LowRankCausalAttention(d_feature,
                           n_heads=1,
                           dropout=0.0,
                           max_inference_length=2048,
                           lowrank=64,
                           mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  it uses low-rank approximation of kernel in Dense layer for computing Q/K/V.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    lowrank: The rank of low-rank approximation.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    return tl.ConfigurableAttention(
        LowRankDense(d_feature, lowrank),
        LowRankDense(d_feature, lowrank),
        LowRankDense(d_feature, lowrank),
        LowRankDense(d_feature, lowrank),
        n_heads=n_heads,
        qkv_attention_layer=tl.DotProductCausalAttention(
            dropout=dropout,
            max_inference_length=max_inference_length,
            mode=mode))
Example #3
0
def MultiplicativeCausalAttention(d_feature, n_heads=1, sparsity=None,
                                  dropout=0.0, max_inference_length=2048,
                                  mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it multiplies each embedding
  dimension by a scalar specific to each dimension and each head; then it
  produces Q/K/V by applying the same dense layer to each head. In comparison
  to standard dense layer for computing Q/K/V, this layer uses less parameters
  while still being able to express many functions, like a permutation.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  sparsity = n_heads if sparsity is None else sparsity
  return tl.ConfigurableAttention(
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode))
Example #4
0
def ConvCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0,
                        max_inference_length=2048,
                        kernel_size=1, mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  it uses LocallyConvDense instead of Dense layer for computing Q/K/V.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: Number of modules used in LocallyConvDense.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    kernel_size: Kernel size used in LocallyConnectedDense.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  n_modules = n_heads if sparsity is None else sparsity
  @assert_shape('...a->...b')
  def ProcessingLayer():
    assert d_feature % n_modules == 0
    return LocallyConvDense(n_modules, d_feature // n_modules,
                            kernel_size=kernel_size)

  return tl.ConfigurableAttention(
      ProcessingLayer(), ProcessingLayer(), ProcessingLayer(),
      ProcessingLayer(), n_heads=n_heads,
      qkv_attention_layer=tl.DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode))
Example #5
0
def Favor(d_feature,
          n_heads=1,
          dropout=0.0,
          numerical_stabilizer=0.001,
          mode='train'):
    """Returns a layer that maps (activations, mask) to (new_activations, mask).

  See the FAVOR paper for details: https://arxiv.org/abs/2006.03555

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    numerical_stabilizer: float, small number used for numerical stability.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    del dropout, mode  # not implemented yet but needed in the API

    def bidirectional_numerator(query_prime, key_prime, value):
        kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value)
        return jnp.einsum('lbm,bmd->lbd', query_prime, kvs)

    def bidirectional_denominator(query_prime, key_prime):
        all_ones = jnp.ones([query_prime.shape[0]])
        ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones)
        return jnp.einsum('lbm,bm->lb', query_prime, ks_sum)

    def relu(x):
        return jnp.where(x <= 0, jnp.zeros_like(x), x)

    def favor(query, key, value, mask):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        mask_batch_1_length = jnp.reshape(
            mask,
            [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32)
        mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1))
        key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1])

        w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0),
                                    jnp.moveaxis(key_prime, 1, 0),
                                    jnp.moveaxis(value, 1, 0))
        r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0),
                                      jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention, mask

    return tl.ConfigurableAttention(tl.Dense(d_feature),
                                    tl.Dense(d_feature),
                                    tl.Dense(d_feature),
                                    tl.Dense(d_feature),
                                    tl.Fn('FAVOR', favor, n_out=2),
                                    n_heads=n_heads)
Example #6
0
def MultiplicativeModularCausalAttention(  # pylint: disable=invalid-name
        d_feature,
        sparsity=1,
        n_heads=1,
        dropout=0.0,
        max_inference_length=2048,
        kernel_size=1,
        mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  MultiplicativeSparseDense layer with LocallyConnectedLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    kernel_size: Kernel size used in LocallyConnectedDense.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    assert d_feature % sparsity == 0

    @assert_shape('...a->...a')
    def ProcessingLayer():  # pylint: disable=invalid-name
        return tl.Serial(
            MultiplicativeSparseDense(sparsity, d_feature, d_feature),
            LocallyConnectedDense(sparsity,
                                  d_feature // sparsity,
                                  kernel_size=kernel_size))

    return tl.ConfigurableAttention(
        ProcessingLayer(),
        ProcessingLayer(),
        ProcessingLayer(),
        ProcessingLayer(),
        n_heads=n_heads,
        qkv_attention_layer=tl.DotProductCausalAttention(
            dropout=dropout,
            max_inference_length=max_inference_length,
            mode=mode))
Example #7
0
def CausalFavor(d_feature,
                n_heads=1,
                dropout=0.0,
                numerical_stabilizer=0.001,
                precision=None,
                mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  causal attention, but using FAVOR fast attention as in the following paper:
  https://arxiv.org/abs/2006.03555

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    numerical_stabilizer: float, small number used for numerical stability.
    precision: passed to jnp.einsum to define arithmetic precision.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    del dropout, mode  # not implemented yet but needed in the API

    def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime,
                            key_prime, value):
        def body(p, qkv):
            (q, k, v) = qkv
            p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
            x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
            return p, x_slice

        p, w = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime, value))
        return w, (precision, p, query_prime, key_prime, value)

    def favor_numerator_bwd(pqkv, w_ct):
        precision, p, qs, ks, vs = pqkv

        def body(carry, qkv_xct):
            p, p_ct = carry
            q, k, v, x_ct = qkv_xct
            q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
            p_ct += jnp.einsum('...d,...m->...md',
                               x_ct,
                               q,
                               precision=precision)
            k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
            v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
            p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
            return (p, p_ct), (q_ct, k_ct, v_ct)

        _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)),
                                                 (qs, ks, vs, w_ct),
                                                 reverse=True)
        return (None, None, qs_ct, ks_ct, vs_ct)

    def favor_numerator(init_prefix_sum_value, precision, query_prime,
                        key_prime, value):
        w, _ = favor_numerator_fwd(init_prefix_sum_value, precision,
                                   query_prime, key_prime, value)
        return w

    favor_numerator = fastmath.custom_vjp(favor_numerator, favor_numerator_fwd,
                                          favor_numerator_bwd)

    def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime,
                              key_prime):
        def body(p, qk):
            q, k = qk
            p += k
            x = jnp.einsum('...m,...m->...', q, p, precision=precision)
            return p, x

        p, r = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime))
        return r, (precision, query_prime, key_prime, p)

    def favor_denominator_bwd(qkp, r_ct):
        precision, qs, ks, p = qkp

        def body(carry, qkx):
            p, p_ct = carry
            q, k, x_ct = qkx
            q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
            p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
            k_ct = p_ct
            p -= k
            return (p, p_ct), (q_ct, k_ct)

        _, (qs_ct, ks_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)),
                                          (qs, ks, r_ct),
                                          reverse=True)
        return (None, None, qs_ct, ks_ct)

    def favor_denominator(init_prefix_sum_value, precision, query_prime,
                          key_prime):
        r, _ = favor_denominator_fwd(init_prefix_sum_value, precision,
                                     query_prime, key_prime)
        return r

    favor_denominator = fastmath.custom_vjp(favor_denominator,
                                            favor_denominator_fwd,
                                            favor_denominator_bwd)

    favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd)

    def relu(x):
        return jnp.where(x <= 0, jnp.zeros_like(x), x)

    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            jnp.moveaxis(query_prime, 1, 0),
                            jnp.moveaxis(key_prime, 1, 0),
                            jnp.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              jnp.moveaxis(query_prime, 1, 0),
                              jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention

    return tl.ConfigurableAttention(core.Dense(d_feature),
                                    core.Dense(d_feature),
                                    core.Dense(d_feature),
                                    core.Dense(d_feature),
                                    n_heads=n_heads,
                                    qkv_attention_layer=base.Fn(
                                        'CausalFAVOR', favor))