Beispiel #1
0
def AttentionPosition(vec,
                      pos,
                      positions=None,
                      d_model=None,
                      n_heads=8,
                      dropout=0.0,
                      mode='train'):
    """Transformer-style multi-headed attention."""

    new_posns = list(
        LearnedPosOperations(positions=positions, n_combinations=n_heads)
        @ (vec, pos))

    hq = tl.Serial(tl.Dense(d_model), CopyPosToHeads(n_heads, tile=False)) @ ([
        vec,
    ] + new_posns)
    hk = tl.Serial(tl.Dense(d_model), CopyPosToHeads(n_heads,
                                                     tile=True)) @ (vec, pos)
    hv = tl.ComputeAttentionHeads(n_heads=n_heads,
                                  d_head=d_model // n_heads) @ vec

    x, pos = tl.Serial(
        tl.DotProductCausalAttention(dropout=dropout, mode=mode),
        CombineHeadsPos(n_heads=n_heads), tl.Dense(d_model)) @ (hq, hk, hv)

    return x, pos
Beispiel #2
0
def MultiplicativeModularCausalAttention(d_feature,
                                         n_heads=1,
                                         sparsity=None,
                                         dropout=0.0,
                                         max_inference_length=2048,
                                         mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  MultiplicativeSparseDense layer with LocallyConnectedLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    sparsity = n_heads if sparsity is None else sparsity
    return tl.ConfigurableAttention(
        MultiplicativeModularSparseDense(sparsity, d_feature),
        MultiplicativeModularSparseDense(sparsity, d_feature),
        MultiplicativeModularSparseDense(sparsity, d_feature),
        MultiplicativeModularSparseDense(sparsity, d_feature),
        n_heads=n_heads,
        qkv_attention_layer=tl.DotProductCausalAttention(
            dropout=dropout,
            max_inference_length=max_inference_length,
            mode=mode))
Beispiel #3
0
def MultiplicativeConvCausalAttention(d_feature,
                                      n_heads=1,
                                      sparsity=None,
                                      length_kernel_size=3,
                                      dropout=0.0,
                                      max_inference_length=2048,
                                      mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  MultiplicativeSparseDense layer with LocallyConvLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    length_kernel_size: Size of convolution kernel on the length dimension.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    sparsity = n_heads if sparsity is None else sparsity
    d_module = d_feature // sparsity
    return tl.Serial(
        tl.Select([0, 0]),  # duplicate activations
        MultiplicativeSparseDense(sparsity, d_feature,
                                  d_feature),  # shared q, k
        tl.Select([0, 0, 0]),  # use for q, k, v
        tl.Parallel(
            [
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
            [
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=3,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
            [
                tl.Concatenate(),  # use permuted and original for v
                LocallyConvDense(sparsity,
                                 d_module,
                                 kernel_size=1,
                                 length_kernel_size=length_kernel_size),
                tl.SplitIntoHeads(n_heads)
            ],
        ),
        tl.DotProductCausalAttention(dropout=dropout,
                                     max_inference_length=max_inference_length,
                                     mode=mode),
        tl.MergeHeads(n_heads),
    )
Beispiel #4
0
def LowRankCausalAttention(d_feature,
                           n_heads=1,
                           dropout=0.0,
                           max_inference_length=2048,
                           lowrank=64,
                           mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  it uses low-rank approximation of kernel in Dense layer for computing Q/K/V.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    lowrank: The rank of low-rank approximation.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    return tl.ConfigurableAttention(
        LowRankDense(d_feature, lowrank),
        LowRankDense(d_feature, lowrank),
        LowRankDense(d_feature, lowrank),
        LowRankDense(d_feature, lowrank),
        n_heads=n_heads,
        qkv_attention_layer=tl.DotProductCausalAttention(
            dropout=dropout,
            max_inference_length=max_inference_length,
            mode=mode))
Beispiel #5
0
def MultiplicativeCausalAttention(d_feature, n_heads=1, sparsity=None,
                                  dropout=0.0, max_inference_length=2048,
                                  mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it multiplies each embedding
  dimension by a scalar specific to each dimension and each head; then it
  produces Q/K/V by applying the same dense layer to each head. In comparison
  to standard dense layer for computing Q/K/V, this layer uses less parameters
  while still being able to express many functions, like a permutation.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  sparsity = n_heads if sparsity is None else sparsity
  return tl.ConfigurableAttention(
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      MultiplicativeSparseDense(sparsity, d_feature, d_feature),
      n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode))
Beispiel #6
0
def ConvCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0,
                        max_inference_length=2048,
                        kernel_size=1, mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  it uses LocallyConvDense instead of Dense layer for computing Q/K/V.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    sparsity: Number of modules used in LocallyConvDense.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    kernel_size: Kernel size used in LocallyConnectedDense.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  n_modules = n_heads if sparsity is None else sparsity
  @assert_shape('...a->...b')
  def ProcessingLayer():
    assert d_feature % n_modules == 0
    return LocallyConvDense(n_modules, d_feature // n_modules,
                            kernel_size=kernel_size)

  return tl.ConfigurableAttention(
      ProcessingLayer(), ProcessingLayer(), ProcessingLayer(),
      ProcessingLayer(), n_heads=n_heads,
      qkv_attention_layer=tl.DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode))
Beispiel #7
0
def MultiplicativeModularCausalAttention(  # pylint: disable=invalid-name
        d_feature,
        sparsity=1,
        n_heads=1,
        dropout=0.0,
        max_inference_length=2048,
        kernel_size=1,
        mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  for computing Q/K/V instead of a Dense layer it combines
  MultiplicativeSparseDense layer with LocallyConnectedLayer.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    sparsity: The sparsity of the layer; usually it should be equal to n_heads.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    kernel_size: Kernel size used in LocallyConnectedDense.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    assert d_feature % sparsity == 0

    @assert_shape('...a->...a')
    def ProcessingLayer():  # pylint: disable=invalid-name
        return tl.Serial(
            MultiplicativeSparseDense(sparsity, d_feature, d_feature),
            LocallyConnectedDense(sparsity,
                                  d_feature // sparsity,
                                  kernel_size=kernel_size))

    return tl.ConfigurableAttention(
        ProcessingLayer(),
        ProcessingLayer(),
        ProcessingLayer(),
        ProcessingLayer(),
        n_heads=n_heads,
        qkv_attention_layer=tl.DotProductCausalAttention(
            dropout=dropout,
            max_inference_length=max_inference_length,
            mode=mode))
Beispiel #8
0
def ModularCausalAttention(
        d_feature,
        n_heads=1,
        dropout=0.0,  # pylint: disable=invalid-name
        max_inference_length=2048,
        n_modules=1,
        mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  it uses LocallyConnectedDense instead of Dense layer for computing K/Q/V.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    n_modules: Number of modules used in LocallyConnectedDense.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    if d_feature % n_heads != 0:
        raise ValueError(
            f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
            f'of the requested number of attention heads ({n_heads}).')

    d_head = d_feature // n_heads

    @assert_shape('bld->hlx')
    def _split_into_heads():
        """Returns a layer that reshapes tensors for multi-headed computation."""
        def f(x):
            batch_size = x.shape[0]
            seq_len = x.shape[1]

            # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
            x = x.reshape((batch_size, seq_len, n_heads, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((batch_size * n_heads, seq_len, d_head))
            return x

        return tl.Fn('SplitIntoHeads', f)

    @assert_shape('hlx->bld')
    def _merge_heads():
        """Returns a layer that undoes splitting, after multi-head computation."""
        def f(x):
            seq_len = x.shape[1]

            # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
            x = x.reshape((-1, n_heads, seq_len, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, d_head * n_heads))
            return x

        return tl.Fn('MergeHeads', f)

    @assert_shape('...a->...b')
    def ProcessingLayer():  # pylint: disable=invalid-name
        if n_modules == 1:
            return tl.Dense(d_feature)
        else:
            assert d_feature % n_modules == 0
            return LocallyConnectedDense(n_modules, d_feature // n_modules)

    return cb.Serial(
        cb.Branch(
            [ProcessingLayer(), _split_into_heads()],
            [ProcessingLayer(), _split_into_heads()],
            [ProcessingLayer(), _split_into_heads()],
        ),
        tl.DotProductCausalAttention(dropout=dropout,
                                     max_inference_length=max_inference_length,
                                     mode=mode), _merge_heads(),
        ProcessingLayer())