Exemplo n.º 1
0
def SRU(n_units, activation=None):
    """SRU layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    activation = activation or []
    return cb.Serial(
        cb.Dup(),  # x, x
        core.Dense(3 * n_units),
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        cb.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], [cb.Dup(), MakeZeroState()]),  # pylint: disable=no-value-for-parameter
        cb.Scan(InnerSRUCell(), axis=1),  # pylint: disable=no-value-for-parameter
        cb.Parallel(activation, cb.Drop()),  # act(c), r, x
        cb.Fn(lambda c, r, x: c * r + x * (1 - r)))
Exemplo n.º 2
0
 def test_fn_layer_varargs_n_in(self):
     with self.assertRaisesRegexp(ValueError, 'variable arg'):
         cb.Fn(lambda *args: args[0])
     # Check that varargs work when n_in is set.
     id_layer = cb.Fn(lambda *args: args[0], n_in=1)
     input_signature = ShapeDtype((2, 7))
     expected_shape = (2, 7)
     output_shape = base.check_shape_agreement(id_layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Exemplo n.º 3
0
def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling):
  """Positional embedding for relative attention.

  Returns a layer that based on queries, keys and accumulated pool size of
  keys/values until this layer calculates sinusoidal positional embeddings
  for relative attention calculations.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    separate_cls: True/False if we separate_cls in calculations.
    total_kv_pooling: Accumulated pool size of keys/values until this layer.

  Returns:
    Positional embedding.
  """

  def PositionsVectors(queries, keys):
    is_funnel_layer = queries.shape != keys.shape
    keys_len, queries_len = keys.shape[1], queries.shape[1]
    current_pooling_ratio = keys_len / queries_len

    # Special case of upsampling
    if is_funnel_layer and current_pooling_ratio < 1:
      # We should not be doing standard upsampling when we use separate_cls
      # Cls token is being used for classification
      assert not separate_cls
      assert (total_kv_pooling * keys_len) % queries_len == 0
      multiplier = ((total_kv_pooling * keys_len) // queries_len)
      positions = jnp.arange(-queries_len + 1, queries_len, 1.0) * multiplier
    else:
      positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling

    if is_funnel_layer and separate_cls:
      # For pool_size 2 without separating cls we have got
      # [0][1][2][3][4][5][6][7] -> [01][23][45][67]
      # With separating cls we have got
      # [0][1][2][3][4][5][6][7] -> [0][12][34][56]

      # First group always will always consist of one token after pooling
      # instead of (pool_size) tokens. We need to add proper offset so
      # that our shift later on in calculating attention works properly
      cls_offset = (current_pooling_ratio - 1) * total_kv_pooling
      positions = positions + cls_offset

    return positions

  def Sinusoidal_Embeddings(positions):
    inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
    sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
    pos_emb = jnp.concatenate(
        [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1)
    return pos_emb

  return cb.Serial(
      cb.Fn('Generate positions vectors', PositionsVectors, n_out=1),
      cb.Fn(
          'Transform to sinusoidal encodings', Sinusoidal_Embeddings, n_out=1))
Exemplo n.º 4
0
 def test_fn_layer_difficult_n_out(self):
     with self.assertRaisesRegexp(ValueError, 'n_out'):
         # Determining the output of this layer is hard with dummies.
         cb.Fn(lambda x: np.concatencate([x, x], axis=4))
     # Check that this layer works when n_out is set.
     layer = cb.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1)
     input_signature = ShapeDtype((2, 1, 2, 2, 3))
     expected_shape = (2, 1, 2, 2, 6)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Exemplo n.º 5
0
def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling):
    """Positional embeddings.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    separate_cls: True/False if we separate_cls in calculations.
    total_kv_pooling: Accumulated pool size of keys/values until this layer.

  Returns:
    a layer that based on queries, keys and accumulated pool size of
    keys/values until this layer calculates sinusoidal positional embeddings
    for relative attention calculations.
  """
    def PositionsVectors(queries, keys):
        assert not separate_cls

        keys_len, queries_len = keys.shape[-2], queries.shape[-2]
        funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

        if funnel_factor == 1:
            offset = keys_len - 1
            positions = (jnp.arange(keys_len) - offset) * total_kv_pooling
        else:
            if is_upsampling:
                positions = jnp.arange(-queries_len + 1, queries_len, 1.0)
            else:
                positions = jnp.arange(-keys_len + 1, keys_len,
                                       1.0) * total_kv_pooling

        return positions

    def Sinusoidal_Embeddings(positions):
        inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
        sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
        pos_emb = jnp.concatenate(
            [jnp.sin(sinusoid_freq),
             jnp.cos(sinusoid_freq)], axis=1)
        return pos_emb

    return cb.Serial(
        cb.Fn('Generate positions vectors', PositionsVectors, n_out=1),
        cb.Fn('Transform to sinusoidal encodings',
              Sinusoidal_Embeddings,
              n_out=1))
Exemplo n.º 6
0
 def test_fn_layer_example(self):
     layer = cb.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0)))
     input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7)))
     expected_shape = ((2, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
     inp = (np.array([2]), np.array([3]))
     x, xs = layer(inp)
     self.assertEqual(int(x), 5)
     self.assertEqual([int(y) for y in xs], [2, 3])
Exemplo n.º 7
0
def ShiftRightCls(cls_id):
  """Shifts right.

  Returns a layer that shifts input tokens to the right by one
  and inserts an cls token to the beginning like in BERT paper.

  Args:
    cls_id: id of the cls token in embedding dictionary.
  Returns:
    shift_right layer.
  """

  def shift_right(x):
    pad_widths = [(0, 0)] * len(x.shape)
    pad_widths[1] = (1, 0)
    padded = jnp.pad(
        x, pad_widths, mode='constant', constant_values=x.dtype.type(cls_id))
    return padded[:, :-1]

  return cb.Fn('ShiftRightCls()', shift_right)
Exemplo n.º 8
0
def CreateAttentionMaskLayer():
    """Creates attention mask layer.

  Returns a layer that based on queries, keys and accumulated pool size of
  keys/values until this layer calculates positional embeddings for
  causal relative attention calculations.

  Takes as input q, k, v and appends proper mask in the end.
  Causal attention uses masking to prevent a given sequence position from
  attending to positions greater than / following it. This is used, for
  example, when training autoregressive sequence models, or when decoding a
  sequence symbol by symbol.

  Returns:
    an attention mask layer.
  """
    def calculate_mask(queries, keys):
        batch_size = queries.shape[0]
        keys_len, queries_len = keys.shape[-2], queries.shape[-2]
        funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

        return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                            is_upsampling)

    def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Funnel mask.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: True or False.

    Returns:
      funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.
    """

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)

    return cb.Branch(
        cb.Select([0]), cb.Select([1]), cb.Select([2]),
        cb.Fn('create attention mask layer', calculate_mask, n_out=1))
Exemplo n.º 9
0
 def test_fn_layer_fails_wrong_f(self):
     with self.assertRaisesRegexp(ValueError, 'default arg'):
         cb.Fn(lambda x, sth=None: x)
     with self.assertRaisesRegexp(ValueError, 'keyword arg'):
         cb.Fn(lambda x, **kwargs: x)