Пример #1
0
  def PositionsVectors(queries, keys):
    is_funnel_layer = queries.shape != keys.shape
    keys_len, queries_len = keys.shape[1], queries.shape[1]
    current_pooling_ratio = keys_len / queries_len

    # Special case of upsampling
    if is_funnel_layer and current_pooling_ratio < 1:
      # We should not be doing standard upsampling when we use separate_cls
      # Cls token is being used for classification
      assert not separate_cls
      assert (total_kv_pooling * keys_len) % queries_len == 0
      multiplier = ((total_kv_pooling * keys_len) // queries_len)
      positions = jnp.arange(-queries_len + 1, queries_len, 1.0) * multiplier
    else:
      positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling

    if is_funnel_layer and separate_cls:
      # For pool_size 2 without separating cls we have got
      # [0][1][2][3][4][5][6][7] -> [01][23][45][67]
      # With separating cls we have got
      # [0][1][2][3][4][5][6][7] -> [0][12][34][56]

      # First group always will always consist of one token after pooling
      # instead of (pool_size) tokens. We need to add proper offset so
      # that our shift later on in calculating attention works properly
      cls_offset = (current_pooling_ratio - 1) * total_kv_pooling
      positions = positions + cls_offset

    return positions
Пример #2
0
 def _sincos(self, start, length, d_feature):
     """Create the sin-cos tensor of shape [1, length, d_feature]."""
     position = jnp.arange(0, length)[:, None] + start
     div_term = jnp.exp(
         jnp.arange(0, d_feature, 2) * -(jnp.log(10000.0) / d_feature))
     sin = jnp.sin(position * div_term)
     cos = jnp.cos(position * div_term)
     pe = jnp.concatenate([sin, cos], axis=1)
     return pe[None, :, :]  # [1, length, d_feature]
Пример #3
0
  def PositionsVectors(self, n_tokens):
    if self._mode == 'predict':
      current_token, sequence_length = calc_predict_next_token_index(
          self.state, self._total_kv_pooling, self._max_len, self._chunk_len,
          self._chunk_offset)
      positions = jnp.arange(0, sequence_length, 1.0) - current_token
      self.state = self.state + self._n_raw_tokens_generated
      return positions

    sequence_length = self._chunk_len if self._chunk_len is not None else n_tokens
    offset = sequence_length - 1  # offset to be compatible with predict mode
    positions = jnp.arange(sequence_length) - offset

    return positions
Пример #4
0
def one_hot(x, n_categories, dtype=jnp.float32):  # pylint: disable=invalid-name
    """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
    indices_less_than_n = jnp.arange(n_categories)
    if fastmath.is_backend(fastmath.Backend.JAX):
        # Work around a jax broadcasting issue.
        indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n)
    return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
Пример #5
0
 def forward(self, x):
   rng = self.rng
   batch_size, length = x.shape[0], x.shape[1]
   max_pos = min(self._bases)**self._n_digits
   rng1, rng2, rng3 = fastmath.random.split(rng, 3)
   assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos)
   positions = jnp.arange(0, length)[None, :]
   if self._mode == 'train':
     # In 1% of training cases still start from 0 to be exactly as in eval.
     start_from_nonzero = fastmath.random.randint(
         rng1, (batch_size,), 0, self._start_from_zero_one_in)
     start_from_nonzero = jnp.minimum(1, start_from_nonzero)
     random_start = fastmath.random.randint(
         rng2, (batch_size,), 0, max_pos-length)
     random_start *= start_from_nonzero
     positions += random_start[:, None]
   res = []
   for bn, base in enumerate(self._bases):
     pos_embeddings = []
     cur_positions = positions
     for i in range(self._n_digits):
       cur_indices = jnp.mod(cur_positions, base)
       cur_positions = cur_positions // base
       s = self.weights[bn][i]
       pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s)
     embeddings = jnp.concatenate(pos_embeddings, axis=-1)
     if self._mode == 'train':
       base_dropout = fastmath.random.randint(
           rng3, (batch_size,), 0, self._base_dropout_one_in)
       base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
       embeddings *= base_dropout[:, None, None]
     res.append(embeddings)
   res = sum(res) + jnp.zeros_like(x)
   return x + res
Пример #6
0
 def Sinusoidal_Embeddings(positions):
     inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
     sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
     pos_emb = jnp.concatenate(
         [jnp.sin(sinusoid_freq),
          jnp.cos(sinusoid_freq)], axis=1)
     return pos_emb
Пример #7
0
    def forward(self, inputs):
        state = self.state
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, jnp.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(
                t=jnp.arange(input_len, dtype=jnp.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[jnp.newaxis, :, :]
            emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        self.state = state
        return inputs
Пример #8
0
    def PositionsVectors(queries, keys):
        assert not separate_cls

        keys_len, queries_len = keys.shape[-2], queries.shape[-2]
        funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

        if funnel_factor == 1:
            offset = keys_len - 1
            positions = (jnp.arange(keys_len) - offset) * total_kv_pooling
        else:
            if is_upsampling:
                positions = jnp.arange(-queries_len + 1, queries_len, 1.0)
            else:
                positions = jnp.arange(-keys_len + 1, keys_len,
                                       1.0) * total_kv_pooling

        return positions
Пример #9
0
def rotate(x):
    """Rotate function."""
    _, l, d = x.shape
    inv_freq = jnp.exp(jnp.arange(0, d, 2) * -(jnp.log(10000.0) / d))
    positions = jnp.arange(l)
    freqs = jnp.einsum('i,j->ij', positions, inv_freq)
    emb = jnp.concatenate((freqs, freqs), axis=-1)
    cos = jnp.cos(emb)
    sin = jnp.sin(emb)

    def mul(vecs, pos_emb):
        return jnp.einsum('bld,ld->bld', vecs, pos_emb)

    def rotate_half(x):
        x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
        return jnp.concatenate((-x2, x1), axis=x1.ndim - 1)

    return mul(x, cos) + mul(rotate_half(x), sin)
Пример #10
0
    def _funnel_mask(self, batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Creates a funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: upsampling if set to True.

    Returns:
      Funnel mask.
    """

        if self._mode == 'predict':
            # We cannot generate more than one token because it contradicts
            # all autoregressive properties
            assert queries_len == 1
            mask = jnp.arange(
                self._max_len) <= (self.state // self._total_kv_pooling)
            mask = jnp.reshape(mask, (1, 1, 1, self._max_len))
            mask = jnp.repeat(mask, batch_size, axis=0)
            self.state += self._n_raw_tokens_generated
            return mask

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
Пример #11
0
def _fast_inference_update_state(inputs, state):
  """Updates state of a causal attention layer for fast inference.

  The layer state stores arrays with cached values of keys and values,
  as well as an index. To make shapes static, keys and values in the state are
  long, and the index indicates where the new keys and values from inputs need
  to be appended.

  During update, we append new_keys and new_values to keys and values at
  position given by index. And we increment index by length of new keys.
  We also create a mask to be 1 at appropriate positions (causal mask).

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, index)

  Returns:
    Updated state and mask to be used.
  """
  # Fast inference: run step-by-step, storing the sequence
  # of keys and values calculated so far in state.
  (_, new_k, new_v) = inputs
  length = new_k.shape[1]
  (ks, vs, idx) = state
  # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
  # with index_update when length == 1 is worth it.
  # Keys and values are of shape [batch_size, length, d_kv].
  ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
  vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
  k_length = ks.shape[1]

  # Mask is of shape [1, q_length, k_length].
  # Mask should be true for every pair of (query_token, key_token) such that
  # index of query_token is equal or larger to index of key_token.
  mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length))
          <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1)))

  return (ks, vs, idx + length), mask
Пример #12
0
 def forward(self, x):
     rng = self.rng
     base_weights, start_vec = self.weights
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = fastmath.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     # In training we'll randomize starts for better generalization.
     # We use the trainable start_vec to compensate and give model a way
     # to learn what is the starting position in a sequence.
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = fastmath.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = fastmath.random.randint(rng2, (batch_size, ), 0,
                                                max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     if self._mode == 'predict':
         positions += self.state
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = base_weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = fastmath.random.randint(
                 rng3, (batch_size, ), 0, self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res)  # Sum embeddings from all bases.
     # Add start_vec to the first position only to mark it as starting.
     res0 = res[:, 0, :][:, None, :]
     start_pos = res0 + start_vec
     if self._mode == 'predict':
         start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0)
         self.state += length  # Add input length to state.
     res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1)
     return x + res
Пример #13
0
def Sinusoidal_Embeddings(positions, d_feature):
  """Sinusoidal Embeddings.

  Computes out of 1-D integer absolute position vector the sinusoidal
  embeddings defined like in paper Attention is all you need (2017).
  Embeddings are shaped (positions, d_feature).

  Args:
    positions: a one-dimensional array of positions.
    d_feature: the number of sin-cos features.

  Returns:
    Positional embeddings.
  """
  inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
  sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
  pos_emb = jnp.concatenate(
      [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1)
  return pos_emb
Пример #14
0
def _fast_inference_update_state(inputs, state):
  """Updates state of a causal attention layer for fast inference."""
  if fastmath.backend_name() != 'jax':
    raise ValueError(f'JAX backend is required in predict mode, but found '
                     f'backend ({fastmath.backend_nameO()}).')
  for x in inputs:
    if x.shape[1] != 1:
      raise ValueError(f'In predict mode, input sequence must have length 1, '
                       f'instead has length {x.shape[1]}.')
  # Fast inference: run with only 1 query in each step, storing the sequence
  # of keys and values calculated so far in state.
  (_, new_k, new_v) = inputs
  (ks, vs, mask, seq_indices) = state
  batch_indices = jnp.arange(ks.shape[0])
  ks = jax.ops.index_update(
      ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :])
  vs = jax.ops.index_update(
      vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :])
  mask = jax.ops.index_update(
      mask, jax.ops.index[batch_indices, :, seq_indices], 1)
  return (ks, vs, mask, seq_indices + 1)
Пример #15
0
  def forward(self, inputs):
    inputs_len = inputs.shape[1]

    if self._mode == 'predict':
      # We cannot generate more than one token because it contradicts
      # all autoregressive properties
      assert inputs_len == 1

      current_token, sequence_length = calc_predict_next_token_index(
          self.state, self._total_kv_pooling, self._max_len, self._chunk_len,
          self._chunk_offset)

      mask = jnp.arange(sequence_length) <= current_token
      mask = jnp.reshape(mask, (1, sequence_length))
      self.state += self._n_raw_tokens_generated
      return mask

    if self._chunk_len is not None:
      return jnp.tril(
          jnp.ones((self._chunk_len, self._chunk_len), dtype=jnp.bool_))

    return jnp.tril(jnp.ones((inputs_len, inputs_len), dtype=jnp.bool_))
Пример #16
0
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2):
    """Splits a key into a stream of random keys.

  This uses the little-endian counter mode.

  Args:
    key: uint32[2] the key to split
    lo: the range to start extracting from
    hi: the range to stop extracting from

  Returns:
    keys: uint32[hi - lo, 2] the split keys
  """
    if not (key.shape == (2, ) and key.dtype == jnp.uint32):
        raise ValueError('key must be uint32[2]')
    if not hi < 2**32:
        # You shouldn't really be using more than half the key size anyways.
        raise NotImplementedError('only 32-bit sizes are supported')
    # Create a 64-bit counter:
    i_lo = jnp.arange(lo, hi, dtype=jnp.uint32)
    i_hi = jnp.zeros_like(i_lo)
    i = jnp.stack([i_lo, i_hi], axis=-1)
    return threefry_2x32_prf(key, i)
Пример #17
0
def one_hot(x, n_categories, dtype=jnp.float32):  # pylint: disable=invalid-name
    """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims)."""
    indices_less_than_n = jnp.arange(n_categories)
    return jnp.array(x[..., jnp.newaxis] == indices_less_than_n, dtype)
Пример #18
0
def top_k(x, k):
  """Select the top k slices from the last dimension."""
  bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape)
  sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs)
  # TODO(levskaya): use lax.slice here instead to benefit from XLA optimization
  return sorted_vals[..., -k:], sorted_idxs[..., -k:]
Пример #19
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, m2, mb, w1, w2, b2 = self.weights
        if self._mode != 'predict':
            w1 = jnp.reshape(w1.T, (-1, self._d_ff))
            w2 = jnp.reshape(w2, (self._d_ff, -1))
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: should we add bias and/or put relu after the low-rank m1 dot?
        mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
        mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (quant_prob of the batches) use the soft-mask instead
            # of the quantized mask to improve training stability (see paper above).
            select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
            quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

        if self._mode == 'train':
            # In training, run full matmul to get benefits from the above tricks.
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2
        elif self._mode == 'predict':
            # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
            # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
            # This implementation mimicks inference. It's not efficient for large
            # size of joint_batch, but at inference that will be 1 most of the time.
            # Shapes:
            # quant_mask is [joint_batch, self._d1]
            # w1 is [d_model, self._d1, self._d2]
            # we'll index w1 with advanced numpy indexing, first range over
            # self._d1 times the batch size, second range being quant_mask
            batch_size = quant_mask.shape[0]
            idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
            # flatten indices and select from w1
            idx1 = jnp.reshape(idx1, [-1])
            idx2 = jnp.reshape(quant_mask, [-1])
            w = w1[idx1,
                   idx2, :]  # now we have per-element weights with batch dim
            w = jnp.reshape(w, [batch_size, self._d1, -1])
            mid = jnp.einsum('ai,aji->aj', x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [self._d1, self._d2, d_model]
            v = w2[idx1, idx2, :]
            v = jnp.reshape(v, [batch_size, self._d1, -1])
            res = jnp.einsum('ai,aij->aj', relu, v) + b2
        else:
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Пример #20
0
Файл: rse.py Проект: yliu45/trax
 def bit_sequence(inputs):
   seq_length = inputs.shape[1]
   n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1
   return jnp.arange(0, n_bits)