Beispiel #1
0
def logp_unordered_subset(theta, zs):
    # last dimension of the zs indicates the selected elements
    # sparse index representation
    #
    # Wouter et al use the Gumbel representation to compute p(Sk) in
    # exponential time rather than factorial.
    # We do it in factorial time.
    Sz, Sx, N, K = zs.shape
    # Is there a better syntax for gather
    logp_z = theta[
        np.arange(Sz)[:,None,None,None],
        np.arange(Sx)[:,None,None],
        np.arange(N)[:,None],
        zs,
    ]

    # get denominator orderings
    perms = all_perms(K)
    logp = logp_z[..., perms]

    # cumlogsumexp would be more stable? but there are only two elements here...
    # sum_i p(b_i)
    #a = logp.max(-1, keepdims=True)
    #p = np.exp(logp - a)
    #sbi0 = a + np.log(p.cumsum(-1) - p)

    # slow implementation, the above seems wrong
    sbis = [np.log(np.zeros(logp[..., 0].shape))]
    for i in range(K-1):
        sbis.append(np.logaddexp(sbis[-1], logp[..., i]))
    sbi = np.stack(sbis, -1)

    logp_bs = logp.sum(-1) - log1mexp(sbi).sum(-1)
    logp_b = lse(logp_bs, -1)
    return logp_b
Beispiel #2
0
def sample_relaxed_part(logits, g0, K=1, tau=1):
    Sz, Sx, N, Z = g0.shape
    C = Z // K
    g = logits + g0
    # logits sorted in ascending order
    # sort by logits not perturbed logits
    # so operation is deterministic
    I_sorted = logits.argsort(-1)
    g_sorted = g[
        np.arange(Sz)[:,None,None,None],
        np.arange(Sx)[:,None,None],
        np.arange(N)[:,None],
        I_sorted,
    ]
    Ik = I_sorted.reshape(Sz, Sx, N, C, K)
    gk = g_sorted.reshape(Sz, Sx, N, C, K)
    gk_agg = lse(gk, -1)
    Ik_agg = gk_agg.argmax(-1)

    g_out = gk[
        np.arange(Sz)[:,None,None,None,None],
        np.arange(Sx)[:,None,None,None],
        np.arange(N)[:,None,None],
        Ik_agg[:,:,:,None,None],
        :,
    ][:,:,:,0,0]
    I_out = Ik[
        np.arange(Sz)[:,None,None,None,None],
        np.arange(Sx)[:,None,None,None],
        np.arange(N)[:,None,None],
        Ik_agg[:,:,:,None,None],
        :,
    ][:,:,:,0,0]
    # works without topk
    return np.exp(normalize(g_out)), I_out
Beispiel #3
0
def exp_rff_attn(q, k, projection_matrix):
    kernel_cons = exp_nonnegative_softmax_kernel_feature_creator
    log_phi_q = kernel_cons(
        q,
        projection_matrix,
        (0, ),
        None,
        is_query=True,
        eps=.0001,
        #normalize_data=True,
        normalize_data=False,
    )
    log_phi_k = kernel_cons(
        k,
        projection_matrix,
        (0, ),
        None,
        is_query=False,
        eps=.0001,
        #normalize_data=True,
        normalize_data=False,
    )
    log_pots_hat = log_phi_q[:, None, :] + log_phi_k[None, :, :]
    # average
    log_pots = lmm(log_phi_q, log_phi_k)
    return jnp.exp(log_pots - lse(log_pots, -1, keepdims=True)), log_pots_hat
Beispiel #4
0
def logp_x_z_relaxed_part(theta, params, x, g, K=1, tau=1):
    theta = normalize(theta)
    rzs, zs = sample_relaxed_part(theta, g, K, tau)
    Sz, Sx, N, Z = g.shape
    fz = f_relaxed_subset(params, rzs, zs)
    fxz = fz[:, np.arange(Sx)[:,None], np.arange(N), x]
    logp_b = lse(theta[
        np.arange(Sz)[:,None,None,None],
        np.arange(Sx)[:,None,None],
        np.arange(N)[:,None],
        zs,
    ], -1)
    return (lax.stop_gradient(fxz) * logp_b + fxz).sum()
Beispiel #5
0
def rff_attn(q, k, projection_matrix, eps=0):
    kernel_cons = nonnegative_softmax_kernel_feature_creator
    log_phi_q = kernel_cons(
        q,
        projection_matrix,
        is_query=True,
        eps=eps,
        #normalize_data=True,
        normalize_data=False,
    )
    log_phi_k = kernel_cons(
        k,
        projection_matrix,
        is_query=False,
        eps=eps,
        #normalize_data=True,
        normalize_data=False,
    )
    log_pots_hat = log_phi_q[:, None, :] + log_phi_k[None, :, :]
    # average
    log_pots = lmm(log_phi_q, log_phi_k) - math.log(k.shape[0])
    return jnp.exp(log_pots - lse(log_pots, -1, keepdims=True)), log_pots_hat
def normalize(x):
    return x - lse(x, -1, keepdims=True)
Beispiel #7
0
get_2d_arrays = jax.jit(jax.vmap(functools.partial(get_2d_array, scaling=0), ))


def random_projection(num_features, original_dim, key, bsz=None):
    shape = ((num_features, original_dim) if bsz is None else
             (bsz, num_features, original_dim))
    return random.normal(key, shape)


# tests
def attn(q, k):
    log_pots = q @ k.T
    return jax.nn.softmax(log_pots), log_pots


lmm = jax.jit(lambda x, y: lse(x[:, None, :] + y[None, :, :], -1))


def rff_attn(q, k, projection_matrix, eps=0):
    kernel_cons = nonnegative_softmax_kernel_feature_creator
    log_phi_q = kernel_cons(
        q,
        projection_matrix,
        is_query=True,
        eps=eps,
        #normalize_data=True,
        normalize_data=False,
    )
    log_phi_k = kernel_cons(
        k,
        projection_matrix,