Ejemplo n.º 1
0
 def body(carry, qkx):
     p, p_ct = carry
     q, k, x_ct = qkx
     q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
     p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
     k_ct = p_ct
     p -= k
     return (p, p_ct), (q_ct, k_ct)
Ejemplo n.º 2
0
 def body(carry, qkv_xct):
   p, p_ct = carry
   q, k, v, x_ct = qkv_xct
   q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
   p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision)
   k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
   v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
   p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
   return (p, p_ct), (q_ct, k_ct, v_ct)
Ejemplo n.º 3
0
def DotProductAttention(queries, keys, values, pos_emb, context_bias,
                        location_bias, mask, separate_cls, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    pos_emb: Per-head activations representing positional embeddings.
    context_bias: Global context bias from Transformer XL's attention.
    location_bias: Global location bias from Transformer XL's attention.
    mask: Mask that distinguishes positions with real content vs. padding.
    separate_cls: True/False if we separate_cls in calculations.
    dropout: Probabilistic rate for dropout applied to attention strengths
      (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    keys_len, queries_len = keys.shape[-2], queries.shape[-2]
    funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

    ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys)
    bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb)

    if mode != 'predict':
        bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling)

    if separate_cls:
        # Masking out location part of attention for cls token
        bd = bd.at[:, :, :, 0].set(0)
        bd = bd.at[:, :, 0, :].set(0)

    dots = (ac + bd) / jnp.sqrt(d_feature)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    out = out.astype(jnp.float32)
    dots = dots.astype(jnp.float32)
    return out, dots
Ejemplo n.º 4
0
def EinsumDense(d_input, d_output, use_bias):
    """Returns a reimplementation of Dense layer, using einsum.

  While this is an equivalent of a Dense layer, it seems to be faster when used
  in decoding if used with bias (see decoding_timing_test.py ).
  This layer can be removed when we understand better the reason for the
  difference in decoding speed.

  Args:
    d_input: Dimensionality of the input tensor.
    d_output: Dimensionality of the output tensor.
    use_bias: Whether to use bias.
  """
    layers = [
        tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]),
        tl.Fn(
            'EinsumDense',
            (
                lambda kernel, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('xd,...d->...x', kernel, embeds)))
    ]
    if use_bias:
        layers.extend([
            tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]),
            tl.Add()
        ])
    return tl.Serial(layers)
Ejemplo n.º 5
0
 def Sinusoidal_Embeddings(positions):
     inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
     sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
     pos_emb = jnp.concatenate(
         [jnp.sin(sinusoid_freq),
          jnp.cos(sinusoid_freq)], axis=1)
     return pos_emb
Ejemplo n.º 6
0
  def _calc_attn_scores(q, k):
    ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k)
    bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb)

    if mode != 'predict':
      bd = _fast_matrix_shift(bd)

    dots = (ac + bd) / jnp.sqrt(d_feature)
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
      raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
      keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
      dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))

    return dots
Ejemplo n.º 7
0
def MultiplicativeSparseDense(sparsity,
                              d_input,
                              d_output=None,
                              use_bias=True,
                              use_bfloat16=False):
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It multiplies each
  dimension of the input tensor by a scalar specific to each dimension and each
  module separately; then it applies Dense(d_output/sparsity) to each module.
  Compared to standard dense layer, MultiplicativeSparseDense uses less
  parameters while still being able to express many interesting functions (for
  example a permutation).

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_input: Dimensionality of input tensor.
    d_output: Dimensionality of output tensor; by default equal to d_input.
    use_bias: Whether to use bias.
    use_bfloat16: Whether to use bfloat16 for weights.
  """

    assert d_output % sparsity == 0
    d_module = d_output // sparsity

    layers = [
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_input],
                   use_bfloat16=use_bfloat16),
        # Weight below is dense kernel, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module],
                   use_bfloat16=use_bfloat16),
        # To save memory the per-head preprocessing and multiplying by the
        # kernel is done in the same einsum.
        tl.Fn(
            'AttentionEinsum',
            (
                lambda kernel, multiplier, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('dx,hd,...d->...hx', kernel, multiplier, embeds))),
        MergeLastTwoAxes(),
    ]
    if use_bias:
        layers.extend([
            # Weight below is bias after dense, per-head.
            tl.Weights(init.RandomNormalInitializer(1e-6), [d_output],
                       use_bfloat16=use_bfloat16),
            tl.Add(),
        ])
    return tl.Serial(layers)
Ejemplo n.º 8
0
def rotate(x):
    """Rotate function."""
    _, l, d = x.shape
    inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d))
    positions = jnp.arange(l)
    freqs = jnp.einsum('i,j->ij', positions, inv_freq)
    emb = jnp.concatenate((freqs, freqs), axis=-1)
    cos = jnp.cos(emb)
    sin = jnp.sin(emb)

    def mul(vecs, pos_emb):
        return jnp.einsum('bld,ld->bld', vecs, pos_emb)

    def rotate_half(x):
        x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
        return jnp.concatenate((-x2, x1), axis=x1.ndim - 1)

    return mul(x, cos) + mul(rotate_half(x), sin)
Ejemplo n.º 9
0
def Sinusoidal_Embeddings(positions, d_feature):
  """Sinusoidal Embeddings.

  Computes out of 1-D integer absolute position vector the sinusoidal
  embeddings defined like in paper Attention is all you need (2017).
  Embeddings are shaped (positions, d_feature).

  Args:
    positions: a one-dimensional array of positions.
    d_feature: the number of sin-cos features.

  Returns:
    Positional embeddings.
  """
  inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
  sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
  pos_emb = jnp.concatenate(
      [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1)
  return pos_emb
Ejemplo n.º 10
0
def MultiplicativeModularSparseDense(sparsity, d_feature):
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It is a combination of
  multiplicative dense and locally connected dense layers.

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_feature: Dimensionality of input and output tensor.
  """

    assert d_feature % sparsity == 0
    d_module = d_feature // sparsity

    return tl.Serial(
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_feature]),
        # Weight below is a kernel of multiplicative dense, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]),
        # Weight below is a kernel of modular dense.
        tl.Weights(
            functools.partial(init.GlorotUniformInitializer(),
                              nonreceptive_dims=[0]),
            [sparsity, d_module, d_module]),
        # To save memory the per-head preprocessing and multiplying by
        # kernels is done in a single einsum.
        tl.Fn(
            'SparseDenseEinsum',
            (
                lambda kmod, kmult, multiplier, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier,
                           embeds))),
        MergeLastTwoAxes(),
        # Weight below is bias after dense, per-head.
        tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]),
        tl.Add(),
    )
Ejemplo n.º 11
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `filters` value, and the second to last dimension is
      shrinked if 'VALID' padding is used with kernel_size bigger than one.
    """
        if self._use_bias:
            if not isinstance(self.weights, (tuple, list)):
                raise ValueError(f'Weights should be a (w, b) tuple or list; '
                                 f'instead got: {self.weights}')
            w, b = self.weights
        else:
            w = self.weights

        linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w)
        # TODO(jaszczur): this could be run after padding for better efficiency

        if self._kernel_size == 1:
            # With kernel size 1 we don't have to split or shift anything.
            linear_result = jnp.squeeze(linear_results_before_shifting,
                                        axis=-2)
        else:
            # We computed a result for every "pixel", but each direction from the
            # receptive field (there are 'self._kernel_size' such directions) must be
            # shifted by a different amount. The easiest way to do it is to split
            # the tensor to 'self._kernel_size' smaller tensors, shift each one
            # appropriately, and then sum them together.
            split_shifting_linear_results = jnp.split(
                linear_results_before_shifting, self._kernel_size, axis=-2)

            for i in range(self._kernel_size):
                # Each tensor has to be shifted a different amount.
                if self._padding == 'WRAP':
                    # We can shift by padding and cutting. With 'wrap' padding we
                    # essentially have a torus.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding, mode='wrap')
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                elif self._padding == 'SAME':
                    # We can shift by padding and cutting.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding)
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                    # TODO(jaszczur): improve efficiency by not padding things to cut
                elif self._padding == 'VALID':
                    # We don't need to shift - just cut the leftmost and rightmost values.
                    cut_left = (self._kernel_size - 1) - i
                    cut_right = split_shifting_linear_results[i].shape[-3] - i
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., cut_left:cut_right, :, :]
                else:
                    raise ValueError(f'Invalid padding {self._padding}')
            # After shifting.
            shifted_linear_results = jnp.concatenate(
                split_shifting_linear_results, axis=-2)
            linear_result = jnp.sum(shifted_linear_results, axis=-2)

        if self._use_bias:
            return linear_result + b
        else:
            return linear_result
Ejemplo n.º 12
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, m2, mb, w1, w2, b2 = self.weights
        if self._mode != 'predict':
            w1 = jnp.reshape(w1.T, (-1, self._d_ff))
            w2 = jnp.reshape(w2, (self._d_ff, -1))
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: should we add bias and/or put relu after the low-rank m1 dot?
        mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
        mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (quant_prob of the batches) use the soft-mask instead
            # of the quantized mask to improve training stability (see paper above).
            select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
            quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

        if self._mode == 'train':
            # In training, run full matmul to get benefits from the above tricks.
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2
        elif self._mode == 'predict':
            # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
            # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
            # This implementation mimicks inference. It's not efficient for large
            # size of joint_batch, but at inference that will be 1 most of the time.
            # Shapes:
            # quant_mask is [joint_batch, self._d1]
            # w1 is [d_model, self._d1, self._d2]
            # we'll index w1 with advanced numpy indexing, first range over
            # self._d1 times the batch size, second range being quant_mask
            batch_size = quant_mask.shape[0]
            idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
            # flatten indices and select from w1
            idx1 = jnp.reshape(idx1, [-1])
            idx2 = jnp.reshape(quant_mask, [-1])
            w = w1[idx1,
                   idx2, :]  # now we have per-element weights with batch dim
            w = jnp.reshape(w, [batch_size, self._d1, -1])
            mid = jnp.einsum('ai,aji->aj', x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [self._d1, self._d2, d_model]
            v = w2[idx1, idx2, :]
            v = jnp.reshape(v, [batch_size, self._d1, -1])
            res = jnp.einsum('ai,aij->aj', relu, v) + b2
        else:
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Ejemplo n.º 13
0
 def mul(vecs, pos_emb):
     return jnp.einsum('bld,ld->bld', vecs, pos_emb)
Ejemplo n.º 14
0
 def body(p, qk):
     q, k = qk
     p += k
     x = jnp.einsum('...m,...m->...', q, p, precision=precision)
     return p, x
Ejemplo n.º 15
0
    def test_lsh_and_pure_lsh_self_attention_equivalence(self):
        # Given the same weight matrices and random numbers, do these produce the
        # same output.
        with fastmath.use_backend(fastmath.Backend.JAX):
            n_heads = 4
            d_head = 4
            d_model = n_heads * d_head
            pure_lsh_layer = efficient_attention.PureLSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                bias=False,
                mode='train')
            lsh_layer = efficient_attention.LSHSelfAttention(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=True,
                masked=False,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=4,
                n_buckets=8,
                use_reference_code=False,
                attention_dropout=0.0,
                use_python_loop=True,
                mode='train')

            batch, seqlen = 3, 32
            input_shape = (batch, seqlen, d_model)

            x = jax.random.uniform(jax.random.PRNGKey(0),
                                   input_shape,
                                   dtype=jnp.float32)
            lsh_layer_input = x

            call_rng = jax.random.PRNGKey(42)

            lsh_layer_weights, lsh_layer_state = lsh_layer.init(
                shapes.signature(lsh_layer_input))
            lsh_layer.rng = call_rng
            lsh_layer_output = lsh_layer(lsh_layer_input)

            # Shapes are: (n_heads, d_model, d_head), (n_heads, d_model, d_head),
            # (n_heads, d_head, d_model)
            # Abbreviated as - hmn, hmn, hnm
            w_qk, w_v, w_o = lsh_layer_weights

            qk = jnp.einsum('blm,hmn->bhln', x, w_qk)
            qk = qk.reshape((-1, qk.shape[2], qk.shape[3]))

            v = jnp.einsum('blm,hmn->bhln', x, w_v)
            v = v.reshape((-1, v.shape[2], v.shape[3]))

            pure_lsh_layer_input = (qk, v)
            _, _ = pure_lsh_layer.init(shapes.signature(pure_lsh_layer_input))
            pure_lsh_layer.rng = call_rng
            pure_lsh_layer.state = lsh_layer_state
            pure_lsh_layer_output = pure_lsh_layer(pure_lsh_layer_input)

            # b*h,l,n
            pure_lsh_layer_output = pure_lsh_layer_output.reshape(
                (batch, -1) + pure_lsh_layer_output.shape[1:])
            pure_lsh_layer_output_projected = (jnp.einsum(
                'bhld,hdm->blm', pure_lsh_layer_output, w_o))

            diff = pure_lsh_layer_output_projected - lsh_layer_output
            avg_diff = jnp.sum(jnp.abs(diff)) / jnp.sum(jnp.ones_like(diff))

            self.assertLess(avg_diff, 1e-5)
Ejemplo n.º 16
0
 def body(p, qkv):
     (q, k, v) = qkv
     p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
     x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
     return p, x_slice
Ejemplo n.º 17
0
 def bidirectional_denominator(query_prime, key_prime):
     all_ones = jnp.ones([query_prime.shape[0]])
     ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones)
     return jnp.einsum('lbm,bm->lb', query_prime, ks_sum)
Ejemplo n.º 18
0
 def bidirectional_numerator(query_prime, key_prime, value):
     kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value)
     return jnp.einsum('lbm,bmd->lbd', query_prime, kvs)
Ejemplo n.º 19
0
 def reverse(self, x, weights=(), state=(), new_state=(), rng=None):
     del state, new_state, rng
     shape = x.shape
     x = x.reshape(shape[:-1] + (self._get_multiplier(x), -1))
     t_x = jnp.einsum('...ab->...ba', x)  # transpose
     return t_x.reshape(shape)
Ejemplo n.º 20
0
 def forward(self, x):
     shape = x.shape
     x = x.reshape(shape[:-1] + (-1, self._get_multiplier(x)))
     t_x = jnp.einsum('...ab->...ba', x)  # transpose
     return t_x.reshape(shape)