Esempio n. 1
0
def comp(num_features, q, k, key, num_samples=128, sample=True):
    T, qk_dim = q.shape

    # compare all attention implementations
    vals = jnp.exp(q @ k.T)
    true_attn = vals / vals.sum(-1, keepdims=True)

    samples = []
    gsamples = []
    nsamples = []
    for i in range(num_samples):
        if sample:
            key, sample_key, norm_key = jax.random.split(key, 3)
        else:
            sample_key, norm_key = jax.random.split(key)

        gaussian_sample = fat.random_projection(num_features, qk_dim,
                                                sample_key)
        projection_matrix = fat.get_2d_array(gaussian_sample, norm_key)

        ra, _ = fat.rff_attn(q, k, projection_matrix)
        samples.append(ra)

        gra, _ = fat.rff_attn(q, k, gaussian_sample)
        gsamples.append(gra)

        nprojection_matrix = fat.get_2d_array(gaussian_sample,
                                              norm_key,
                                              scaling=1)
        nra, _ = fat.rff_attn(q, k, nprojection_matrix)
        nsamples.append(nra)

    return (true_attn, ) + tuple(
        jnp.stack(x) for x in [samples, gsamples, nsamples])
Esempio n. 2
0
 def proj_fn_reg_small(shape, key):
     sample_key, norm_key = jax.random.split(key)
     gaussian_sample = fat.random_projection(num_features, qk_dim,
                                             sample_key)
     projection_matrix = fat.get_2d_array(gaussian_sample,
                                          norm_key,
                                          scaling=2)
     return projection_matrix
Esempio n. 3
0
 def proj_fn_anti(shape, key):
     sample_key, norm_key = jax.random.split(key)
     gaussian_sample = fat.random_projection(num_features // 2, qk_dim,
                                             sample_key)
     projection_matrix = fat.get_2d_array(gaussian_sample,
                                          norm_key,
                                          scaling=0)
     return jnp.concatenate([projection_matrix, -projection_matrix], axis=0)
Esempio n. 4
0
 def proj_fn_gaus(shape, key):
     sample_key, norm_key = jax.random.split(key)
     gaussian_sample = fat.random_projection(num_features, qk_dim,
                                             sample_key)
     return gaussian_sample
Esempio n. 5
0
import streamlit as st

num_features = 512
qk_dim = 64
T = 8

key = jax.random.PRNGKey(0)

key, key1, key2 = jax.random.split(key, 3)
q = jax.random.normal(key1, (T, qk_dim))
k = jax.random.normal(key2, (T, qk_dim))

q0, k0 = q.copy(), k.copy()

key, sample_key, norm_key = jax.random.split(key, 3)
gaussian_sample = fat.random_projection(num_features, qk_dim, sample_key)
projection_matrix = fat.get_2d_array(gaussian_sample, norm_key)

# compare all attention implementations

## mean
vals = jnp.exp(q @ k.T)
true_attn = vals / vals.sum(-1, keepdims=True)

ra, _ = fat.rff_attn(q, k, projection_matrix)
ra0 = fat.rff_attn0(q, k, projection_matrix)

qt = torch.tensor(onp.asarray(q))
kt = torch.tensor(onp.asarray(k))
pt = torch.tensor(onp.asarray(projection_matrix)).transpose(-1, -2)