def threefry_2x32_prf(key, x: jnp.ndarray) -> jnp.ndarray: """Apply the threefry PRF to an array of inputs. This function is vectorized over x. For threefry_2x32: K = X = uint32[2] Args: key: uint32[2] the key of the PRF x: uint32[..., 2] the inputs Returns: y: uint32[..., 2] the outputs """ if not (key.shape == (2,) and key.dtype == jnp.uint32): raise TypeError('key must be uint32[2]', key) if not (x.shape[-1:] == (2,) and x.dtype == jnp.uint32): raise TypeError('x must be uint32[..., 2]', x) # Threefry-2x32 expects this weird format: x_3f = jnp.moveaxis(x, source=-1, destination=0).flatten() y_3f = jax.random.threefry_2x32(key, x_3f) y = jnp.moveaxis( jnp.reshape(y_3f, (2,) + x.shape[:-1]), source=0, destination=-1) return y
def favor(query, key, value, mask): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer mask_batch_1_length = jnp.reshape( mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32) mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1)) key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention, mask
def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention