Example #1
0
def CausalAttention(d_feature,
                    n_heads=1,
                    dropout=0.0,
                    max_inference_length=2048,
                    mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like ``Attention``, this layer type represents one pass of multi-head
  self-attention, but with causal masking rather than padding-based masking.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
    if d_feature % n_heads != 0:
        raise ValueError(
            f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
            f'of the requested number of attention heads ({n_heads}).')

    return ConfigurableAttention(core.Dense(d_feature),
                                 core.Dense(d_feature),
                                 core.Dense(d_feature),
                                 core.Dense(d_feature),
                                 n_heads=n_heads,
                                 qkv_attention_layer=DotProductCausalAttention(
                                     dropout=dropout,
                                     max_inference_length=max_inference_length,
                                     mode=mode))
Example #2
0
 def QKVLayer():
     """Function returning the Q, K and V layer."""
     if use_dconv:
         return cb.Serial(core.Dense(d_feature),
                          convolution.CausalDepthwiseConv())
     else:
         return core.Dense(d_feature)
Example #3
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Transformer-style multi-headed attention.

  Accepts inputs of the form q, k, v, mask.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
    return cb.Serial(
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
Example #4
0
def CausalAttention(d_feature, n_heads=1, dropout=0.0,
                    max_inference_length=2048, mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like :py:class:`Attention`, this layer type represents one pass of multi-head
  self-attention, but with causal masking rather than padding-based masking.

  Args:
    d_feature: Last/innermost dimension of activations in the input to and
        output from this layer.
    n_heads: Number of attention heads. Attention heads effectively split
        activation vectors into ``n_heads`` subvectors, of size
        ``d_feature / n_heads``.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored. Applies only if layer is
        created in ``'train'`` mode.
    max_inference_length: Maximum sequence length allowed in non-training
        modes.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
  if d_feature % n_heads != 0:
    raise ValueError(
        f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
        f'of the requested number of attention heads ({n_heads}).')

  return ConfigurableAttention(
      core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature),
      core.Dense(d_feature), n_heads=n_heads,
      qkv_attention_layer=DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode))
Example #5
0
def CausalAttention(d_feature, n_heads=1, dropout=0.0,
                    max_inference_length=2048, mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `Attention`, this layer type represents one pass of multi-head
  self-attention, but with causal masking rather than padding-based masking.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  if d_feature % n_heads != 0:
    raise ValueError(
        f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
        f'of the requested number of attention heads ({n_heads}).')

  d_head = d_feature // n_heads

  def _split_into_heads():
    """Returns a layer that reshapes tensors for multi-headed computation."""
    def f(x):
      batch_size = x.shape[0]
      seq_len = x.shape[1]

      # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
      x = x.reshape((batch_size, seq_len, n_heads, d_head))
      x = x.transpose((0, 2, 1, 3))
      x = x.reshape((-1, seq_len, d_head))
      return x
    return Fn('SplitIntoHeads', f)

  def _merge_heads():
    """Returns a layer that undoes splitting, after multi-head computation."""
    def f(x):
      seq_len = x.shape[1]

      # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
      x = x.reshape((-1, n_heads, seq_len, d_head))
      x = x.transpose((0, 2, 1, 3))
      x = x.reshape((-1, seq_len, n_heads * d_head))
      return x
    return Fn('MergeHeads', f)

  return cb.Serial(
      cb.Branch(
          [core.Dense(d_feature), _split_into_heads()],
          [core.Dense(d_feature), _split_into_heads()],
          [core.Dense(d_feature), _split_into_heads()],
      ),
      DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode),
      _merge_heads(),
      core.Dense(d_feature),
  )
Example #6
0
def RelativeAttentionLayer(d_feature,
                           context_bias_layer,
                           location_bias_layer,
                           total_kv_pooling,
                           separate_cls,
                           n_heads=1,
                           dropout=0.0,
                           mode='train'):
    """Returns a layer that maps (q, k, v, masks) to (activations, masks).

  When number of keys is smaller than number of queries layer works in O(q^2*d).
  Otherwise it is O(q*k*d). That is because we need to shift relative distances
  by current_pooling. When we upsample this is current pooling is a fraction < 1
  Visual explanation:
  [01][23][45][67] -> [0][1][2][3][4][5][6][7]
  For token [0] we calculate relative distances as follows:
  * 0 2 4 6
  However for token [1] we need relative distances changed by 1, specifically:
  * -1 1 3 5
  So we not only need to calculate the distances that corresponds to spacing
  between the keys but also for the ones in between because there are more than
  one query tokens (on different positions which means different relative
  distances) for single key token.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    context_bias_layer: Global context bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers
    location_bias_layer: Global location bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer
    separate_cls: True/False if we separate_cls in calculations.

    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    return cb.Serial(
        cb.Branch(
            PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling),
            cb.Select([0]), cb.Select([1])),
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        context_bias_layer,
        location_bias_layer,
        RelativeAttention(  # pylint: disable=no-value-for-parameter
            separate_cls=separate_cls,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
Example #7
0
def FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode,
                     activation):
    # We copy the ff block function because we cannot import it from models
    return [
        core.Dense(d_ff),
        activation(),
        core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        core.Dense(d_model),
    ]
Example #8
0
  def test_dense_param_sharing(self):
    model1 = combinators.Serial(core.Dense(32), core.Dense(32))
    layer = core.Dense(32)
    model2 = combinators.Serial(layer, layer)

    rng1, rng2 = backend.random.split(backend.random.get_prng(0), 2)
    params1, _ = model1.initialize_once((1, 32), onp.float32, rng1)
    params2, _ = model2.initialize_once((1, 32), onp.float32, rng2)
    # The first parameters have 2 kernels of size (32, 32).
    self.assertEqual((32, 32), params1[0][0].shape)
    self.assertEqual((32, 32), params1[1][0].shape)
    # The second parameters have 1 kernel of size (32, 32) and an empty dict.
    self.assertEqual((32, 32), params2[0][0].shape)
    self.assertEqual((), params2[1])
Example #9
0
    def test_dense_weight_sharing(self):
        model1 = combinators.Serial(core.Dense(32), core.Dense(32))
        layer = core.Dense(32)
        model2 = combinators.Serial(layer, layer)

        input_signature = ShapeDtype((1, 32))
        weights1, _ = model1.init(input_signature)
        weights2, _ = model2.init(input_signature)
        # The first weights have 2 kernels of size (32, 32).
        self.assertEqual((32, 32), weights1[0][0].shape)
        self.assertEqual((32, 32), weights1[1][0].shape)
        # The second weights have 1 kernel of size (32, 32) and an empty dict.
        self.assertEqual((32, 32), weights2[0][0].shape)
        self.assertEqual((), weights2[1])
  def test_dense_param_sharing(self):
    model1 = combinators.Serial(core.Dense(32), core.Dense(32))
    layer = core.Dense(32)
    model2 = combinators.Serial(layer, layer)

    input_signature = ShapeDtype((1, 32))
    params1, _ = model1.initialize_once(input_signature)
    params2, _ = model2.initialize_once(input_signature)
    # The first parameters have 2 kernels of size (32, 32).
    self.assertEqual((32, 32), params1[0][0].shape)
    self.assertEqual((32, 32), params1[1][0].shape)
    # The second parameters have 1 kernel of size (32, 32) and an empty dict.
    self.assertEqual((32, 32), params2[0][0].shape)
    self.assertEqual((), params2[1])
Example #11
0
def SRU(n_units, activation=None):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r)))
Example #12
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (q, k, v, mask) to (activations, mask).
    Args:
      d_feature: Depth/dimensionality of feature embedding.
      n_heads: Number of attention heads.
      dropout: Probababilistic rate for internal dropout applied to attention
          activations (based on query-key pairs) before dotting them with values.
      mode: Either 'train' or 'eval'.
    """
    return cb.Serial(
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        PureAttention(n_heads=n_heads, dropout=dropout, mode=mode),
        core.Dense(d_feature),
    )
Example #13
0
def CausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Transformer-style multi-headed causal attention.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: attention dropout
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result.
  """
  assert d_feature % n_heads == 0
  d_head = d_feature // n_heads

  def compute_attention_heads(x):
    batch_size = x.shape[0]
    seqlen = x.shape[1]
    # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
    x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
    # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
    x = jnp.transpose(x, (0, 2, 1, 3))
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    return jnp.reshape(x, (-1, seqlen, d_head))

  ComputeAttentionHeads = Fn('ComputeAttentionHeads', compute_attention_heads)

  def compute_attention_output(x):
    seqlen = x.shape[1]
    x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
    x = jnp.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    return jnp.reshape(x, (-1, seqlen, n_heads * d_head))

  return cb.Serial(
      cb.Branch(
          [core.Dense(d_feature), ComputeAttentionHeads],
          [core.Dense(d_feature), ComputeAttentionHeads],
          [core.Dense(d_feature), ComputeAttentionHeads],
      ),
      DotProductCausalAttention(dropout=dropout, mode=mode),
      Fn('ComputeAttentionOutput', compute_attention_output),
      core.Dense(d_feature)
  )
Example #14
0
 def _SparsifiableDense(layer_sparsity):
   if layer_sparsity is None:
     return core.Dense(d_feature)
   elif layer_sparsity == 'noop':
     return cb.Serial()  # No-op layer.
   else:
     d_module = d_feature // layer_sparsity
     return cb.Serial(
         sparsity.FactoredDense(layer_sparsity, d_feature, d_feature),
         sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode,
                                   kernel_size=3, length_kernel_size=3)
     )
Example #15
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Returns a layer that maps (q, k, v, mask) to (activations, mask).

  See `Attention` above for further context/details.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  return cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(  # pylint: disable=no-value-for-parameter
          n_heads=n_heads, dropout=dropout, mode=mode),
      core.Dense(d_feature),
  )
Example #16
0
def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train',
                     **kwargs):
  del args, kwargs

  return cb.Serial(
      core.Dense(shorten_factor * d_model),
      core.Dropout(rate=dropout, mode=mode),
      core.Fn(
          'ProlongBack',
          lambda x: jnp.reshape(  # pylint: disable=g-long-lambda
              # Prolong back.  # pylint: disable=g-long-lambda
              x, (x.shape[0], x.shape[1] * shorten_factor, -1)),
          n_out=1)
  )
Example #17
0
def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode='train',
                  **kwargs):
  del args, kwargs

  return cb.Serial(
      core.Fn(
          'Shorten',
          lambda x: jnp.reshape(  # pylint: disable=g-long-lambda
              # Shorten -- move to depth.  # pylint: disable=g-long-lambda
              x, (x.shape[0], x.shape[1] // shorten_factor, -1)),
          n_out=1),
      core.Dense(d_model),
      core.Dropout(rate=dropout, mode=mode)
  )
Example #18
0
  def test_set_rng_serial_recurse_two_levels(self):
    dense_00 = core.Dense(2)
    dense_01 = core.Dense(2)
    dense_10 = core.Dense(2)
    dense_11 = core.Dense(2)
    layer = cb.Serial(
        cb.Serial(dense_00, dense_01),
        cb.Serial(dense_10, dense_11),
    )
    input_signature = ShapeDtype((1, 2))

    _, _ = layer.init(input_signature)
    weights = layer.weights
    dense_00_w, dense_00_b = weights[0][0]
    dense_01_w, dense_01_b = weights[0][1]
    dense_10_w, dense_10_b = weights[1][0]
    dense_11_w, dense_11_b = weights[1][1]

    # Setting rng's recursively during init should yield differing weights.
    self.assertFalse(np.array_equal(dense_00_w, dense_01_w))
    self.assertFalse(np.array_equal(dense_00_b, dense_01_b))
    self.assertFalse(np.array_equal(dense_10_w, dense_11_w))
    self.assertFalse(np.array_equal(dense_10_b, dense_11_b))
Example #19
0
def GRUCell(n_units):
    """Builds a traditional GRU cell with dense internal transformations.

  Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555


  Args:
    n_units: Number of hidden units.

  Returns:
    A Stax model representing a traditional GRU RNN cell.
  """
    return GeneralGRUCell(candidate_transform=lambda: core.Dense(n_units),
                          memory_transform_fn=None,
                          gate_nonlinearity=core.Sigmoid,
                          candidate_nonlinearity=core.Tanh)
Example #20
0
def SRU(n_units, activation=None, mode='train'):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    mode: if 'predict' then we save the previous state for one-by-one inference

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        ScanSRUCell(mode=mode),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation if activation is not None else [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)),
        # Set the name to SRU and don't print sublayers.
        name=f'SRU_{n_units}',
        sublayers_to_print=[])
Example #21
0
def SRU(n_units, activation=None):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)))
Example #22
0
def SRU(n_units, activation=None, rescale=False, highway_bias=0):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    rescale: To offset the problem of the gradient vanishing in the h_t as a result
    of light recurrence and highway computation for deeper layers, a scaling correction
    alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755,
    page 4, section 3.2 Initialization.
    highway_bias: intial bias of highway gates
  Returns:
    The SRU layer.
  """
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r) *
                ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
Example #23
0
 def test_weights_serial(self):
     model = cb.Serial(core.Dense(4), core.Dense(5), core.Dense(7))
     self.assertIsInstance(model.weights, tuple)
     self.assertLen(model.weights, 3)
Example #24
0
 def _CacheableDense():
   if cache_KV_in_predict and mode == 'predict':
     return cb.Cache(core.Dense(d_feature))
   else:
     return core.Dense(d_feature)
Example #25
0
def AttentionQKV(d_feature,
                 n_heads=1,
                 dropout=0.0,
                 mode='train',
                 cache_KV_in_predict=False,
                 q_sparsity=None,
                 result_sparsity=None):
    """Returns a layer that maps (q, k, v, mask) to (activations, mask).

  See ``Attention`` above for further context/details.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    cache_KV_in_predict: Whether to cache K/V tensors in predict mode.
    q_sparsity: Sparsity with which to process queries. If None, Dense is
        used. If 'noop' then no processing is used.
    result_sparsity: Sparsity with which to process result of the attention. If
        None, Dense is used. If 'noop' then no processing is used.
  """
    k_processor = core.Dense(d_feature)
    v_processor = core.Dense(d_feature)
    if cache_KV_in_predict and mode == 'predict':
        k_processor = cb.Cache(k_processor)
        v_processor = cb.Cache(v_processor)

    if q_sparsity is None:
        q_processor = core.Dense(d_feature)
    elif q_sparsity == 'noop':
        q_processor = cb.Serial()
    else:
        d_module = d_feature // q_sparsity
        q_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(q_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(q_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    if result_sparsity is None:
        result_processor = core.Dense(d_feature)
    elif result_sparsity == 'noop':
        result_processor = cb.Serial()
    else:
        d_module = d_feature // result_sparsity
        result_processor = cb.Serial(
            sparsity.MultiplicativeSparseDense(result_sparsity, d_feature,
                                               d_feature),
            sparsity.LocallyConvDense(result_sparsity,
                                      d_module,
                                      mode=mode,
                                      kernel_size=3,
                                      length_kernel_size=3))

    return cb.Serial(
        cb.Parallel(
            q_processor,
            k_processor,
            v_processor,
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        result_processor)
Example #26
0
 def test_state_parallel(self):
     model = cb.Parallel(core.Dense(3), core.Dense(5))
     self.assertIsInstance(model.state, tuple)
     self.assertLen(model.state, 2)
Example #27
0
def CausalFavor(d_feature,
                n_heads=1,
                dropout=0.0,
                numerical_stabilizer=0.001,
                precision=None,
                mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  causal attention, but using FAVOR fast attention as in the following paper:
  https://arxiv.org/abs/2006.03555

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    numerical_stabilizer: float, small number used for numerical stability.
    precision: passed to jnp.einsum to define arithmetic precision.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    del dropout, mode  # not implemented yet but needed in the API

    def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime,
                            key_prime, value):
        def body(p, qkv):
            (q, k, v) = qkv
            p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
            x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
            return p, x_slice

        p, w = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime, value))
        return w, (precision, p, query_prime, key_prime, value)

    def favor_numerator_bwd(pqkv, w_ct):
        precision, p, qs, ks, vs = pqkv

        def body(carry, qkv_xct):
            p, p_ct = carry
            q, k, v, x_ct = qkv_xct
            q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
            p_ct += jnp.einsum('...d,...m->...md',
                               x_ct,
                               q,
                               precision=precision)
            k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
            v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
            p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
            return (p, p_ct), (q_ct, k_ct, v_ct)

        _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)),
                                                 (qs, ks, vs, w_ct),
                                                 reverse=True)
        return (None, None, qs_ct, ks_ct, vs_ct)

    def favor_numerator(init_prefix_sum_value, precision, query_prime,
                        key_prime, value):
        w, _ = favor_numerator_fwd(init_prefix_sum_value, precision,
                                   query_prime, key_prime, value)
        return w

    favor_numerator = fastmath.custom_vjp(favor_numerator, favor_numerator_fwd,
                                          favor_numerator_bwd)

    def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime,
                              key_prime):
        def body(p, qk):
            q, k = qk
            p += k
            x = jnp.einsum('...m,...m->...', q, p, precision=precision)
            return p, x

        p, r = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime))
        return r, (precision, query_prime, key_prime, p)

    def favor_denominator_bwd(qkp, r_ct):
        precision, qs, ks, p = qkp

        def body(carry, qkx):
            p, p_ct = carry
            q, k, x_ct = qkx
            q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
            p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
            k_ct = p_ct
            p -= k
            return (p, p_ct), (q_ct, k_ct)

        _, (qs_ct, ks_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)),
                                          (qs, ks, r_ct),
                                          reverse=True)
        return (None, None, qs_ct, ks_ct)

    def favor_denominator(init_prefix_sum_value, precision, query_prime,
                          key_prime):
        r, _ = favor_denominator_fwd(init_prefix_sum_value, precision,
                                     query_prime, key_prime)
        return r

    favor_denominator = fastmath.custom_vjp(favor_denominator,
                                            favor_denominator_fwd,
                                            favor_denominator_bwd)

    favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd)

    def relu(x):
        return jnp.where(x <= 0, jnp.zeros_like(x), x)

    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            jnp.moveaxis(query_prime, 1, 0),
                            jnp.moveaxis(key_prime, 1, 0),
                            jnp.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              jnp.moveaxis(query_prime, 1, 0),
                              jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention

    return tl.ConfigurableAttention(core.Dense(d_feature),
                                    core.Dense(d_feature),
                                    core.Dense(d_feature),
                                    core.Dense(d_feature),
                                    n_heads=n_heads,
                                    qkv_attention_layer=base.Fn(
                                        'CausalFAVOR', favor))
Example #28
0
 def test_state_serial(self):
     model = cb.Serial(core.Dense(4), core.Dense(5), core.Dense(7))
     self.assertIsInstance(model.state, tuple)
     self.assertLen(model.state, 3)
Example #29
0
 def test_weights_parallel(self):
     model = cb.Parallel(core.Dense(3), core.Dense(5))
     self.assertIsInstance(model.weights, tuple)
     self.assertLen(model.weights, 2)
Example #30
0
def CausalFavor(
        d_feature,
        n_heads=1,
        dropout=0.0,  # pylint: disable=invalid-name
        numerical_stabilizer=0.001,
        precision=None,
        mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  causal attention, but using FAVOR fast attention as in the following paper:
  https://arxiv.org/abs/2006.03555

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    numerical_stabilizer: float, small number used for numerical stability.
    precision: passed to np.einsum to define arithmetic precision.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    del dropout, mode  # not implemented yet but needed in the API

    # TODO(lukaszkaiser): make an API for split/merge heads in core layers,
    # and use it here so we don't duplicate these functions.
    if d_feature % n_heads != 0:
        raise ValueError(
            f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
            f'of the requested number of attention heads ({n_heads}).')

    d_head = d_feature // n_heads

    def _split_into_heads():
        """Returns a layer that reshapes tensors for multi-headed computation."""
        def f(x):
            batch_size = x.shape[0]
            seq_len = x.shape[1]

            # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
            x = x.reshape((batch_size, seq_len, n_heads, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, d_head))
            return x

        return base.Fn('SplitIntoHeads', f)

    def _merge_heads():
        """Returns a layer that undoes splitting, after multi-head computation."""
        def f(x):
            seq_len = x.shape[1]

            # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
            x = x.reshape((-1, n_heads, seq_len, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, n_heads * d_head))
            return x

        return base.Fn('MergeHeads', f)

    def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime,
                            key_prime, value):
        def body(p, qkv):
            (q, k, v) = qkv
            p += np.einsum('...m,...d->...md', k, v, precision=precision)
            x_slice = np.einsum('...m,...md->...d', q, p, precision=precision)
            return p, x_slice

        p, w = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime, value))
        return w, (p, query_prime, key_prime, value)

    def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct):
        del init_prefix_sum_value

        def body(carry, qkv_xct):
            p, p_ct = carry
            q, k, v, x_ct = qkv_xct
            q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision)
            k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision)
            v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision)
            p -= np.einsum('...m,...d->...md', k, v, precision=precision)
            return (p, p_ct), (q_ct, k_ct, v_ct)

        p, qs, ks, vs = pqkv
        _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                                 (qs, ks, vs, w_ct),
                                                 reverse=True)
        return qs_ct, ks_ct, vs_ct

    def favor_numerator(init_prefix_sum_value, precision, query_prime,
                        key_prime, value):
        w, _ = favor_numerator_fwd(init_prefix_sum_value, precision,
                                   query_prime, key_prime, value)
        return w

    favor_numerator = fastmath.custom_vjp(favor_numerator,
                                          favor_numerator_fwd,
                                          favor_numerator_bwd,
                                          nondiff_argnums=(0, 1))

    def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime,
                              key_prime):
        def body(p, qk):
            q, k = qk
            p += k
            x = np.einsum('...m,...m->...', q, p, precision=precision)
            return p, x

        p, r = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime))
        return r, (query_prime, key_prime, p)

    def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct):
        del init_prefix_sum_value

        def body(carry, qkx):
            p, p_ct = carry
            q, k, x_ct = qkx
            q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision)
            k_ct = p_ct
            p -= k
            return (p, p_ct), (q_ct, k_ct)

        qs, ks, p = qkp
        _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                          (qs, ks, r_ct),
                                          reverse=True)
        return (qs_ct, ks_ct)

    def favor_denominator(init_prefix_sum_value, precision, query_prime,
                          key_prime):
        r, _ = favor_denominator_fwd(init_prefix_sum_value, precision,
                                     query_prime, key_prime)
        return r

    favor_denominator = fastmath.custom_vjp(favor_denominator,
                                            favor_denominator_fwd,
                                            favor_denominator_bwd,
                                            nondiff_argnums=(0, 1))

    favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd)

    def relu(x):
        return np.where(x <= 0, np.zeros_like(x), x)

    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = np.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = np.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            np.moveaxis(query_prime, 1, 0),
                            np.moveaxis(key_prime, 1, 0),
                            np.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              np.moveaxis(query_prime, 1, 0),
                              np.moveaxis(key_prime, 1, 0))
        w = np.moveaxis(w, 0, 1)
        r = np.moveaxis(r, 0, 1)

        r = r + 2 * numerical_stabilizer * (np.abs(r) <= numerical_stabilizer)
        r = np.reciprocal(r)
        r = np.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention

    return cb.Serial(
        cb.Branch(
            [core.Dense(d_feature), _split_into_heads()],
            [core.Dense(d_feature), _split_into_heads()],
            [core.Dense(d_feature), _split_into_heads()],
        ),
        base.Fn('FAVOR', favor),
        _merge_heads(),
        core.Dense(d_feature),
    )