Example #1
0
def threefry_2x32_prf(key, x: jnp.ndarray) -> jnp.ndarray:
  """Apply the threefry PRF to an array of inputs.

  This function is vectorized over x.
  For threefry_2x32: K = X = uint32[2]

  Args:
    key: uint32[2] the key of the PRF
    x: uint32[..., 2] the inputs

  Returns:
    y: uint32[..., 2] the outputs
  """
  if not (key.shape == (2,) and key.dtype == jnp.uint32):
    raise TypeError('key must be uint32[2]', key)
  if not (x.shape[-1:] == (2,) and x.dtype == jnp.uint32):
    raise TypeError('x must be uint32[..., 2]', x)
  # Threefry-2x32 expects this weird format:
  x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten()
  y_3f = jax.random.threefry_2x32(key, x_3f)
  y = jnp.moveaxis(
      jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1)
  return y
Example #2
0
  def favor(query, key, value, mask):
    query_prime = relu(query) + numerical_stabilizer
    key_prime = relu(key) + numerical_stabilizer
    mask_batch_1_length = jnp.reshape(
        mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32)
    mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1))
    key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1])

    w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0),
                                jnp.moveaxis(key_prime, 1, 0),
                                jnp.moveaxis(value, 1, 0))
    r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0),
                                  jnp.moveaxis(key_prime, 1, 0))
    w = jnp.moveaxis(w, 0, 1)
    r = jnp.moveaxis(r, 0, 1)
    r = jnp.reciprocal(r)
    r = jnp.expand_dims(r, len(r.shape))
    renormalized_attention = w * r
    return renormalized_attention, mask
Example #3
0
    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            jnp.moveaxis(query_prime, 1, 0),
                            jnp.moveaxis(key_prime, 1, 0),
                            jnp.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              jnp.moveaxis(query_prime, 1, 0),
                              jnp.moveaxis(key_prime, 1, 0))
        w = jnp.moveaxis(w, 0, 1)
        r = jnp.moveaxis(r, 0, 1)
        r = jnp.reciprocal(r)
        r = jnp.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention