def favor(query, key, value, mask): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer mask_batch_1_length = jnp.reshape( mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32) mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1)) key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention, mask
def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention