def logp_unordered_subset(theta, zs): # last dimension of the zs indicates the selected elements # sparse index representation # # Wouter et al use the Gumbel representation to compute p(Sk) in # exponential time rather than factorial. # We do it in factorial time. Sz, Sx, N, K = zs.shape # Is there a better syntax for gather logp_z = theta[ np.arange(Sz)[:,None,None,None], np.arange(Sx)[:,None,None], np.arange(N)[:,None], zs, ] # get denominator orderings perms = all_perms(K) logp = logp_z[..., perms] # cumlogsumexp would be more stable? but there are only two elements here... # sum_i p(b_i) #a = logp.max(-1, keepdims=True) #p = np.exp(logp - a) #sbi0 = a + np.log(p.cumsum(-1) - p) # slow implementation, the above seems wrong sbis = [np.log(np.zeros(logp[..., 0].shape))] for i in range(K-1): sbis.append(np.logaddexp(sbis[-1], logp[..., i])) sbi = np.stack(sbis, -1) logp_bs = logp.sum(-1) - log1mexp(sbi).sum(-1) logp_b = lse(logp_bs, -1) return logp_b
def sample_relaxed_part(logits, g0, K=1, tau=1): Sz, Sx, N, Z = g0.shape C = Z // K g = logits + g0 # logits sorted in ascending order # sort by logits not perturbed logits # so operation is deterministic I_sorted = logits.argsort(-1) g_sorted = g[ np.arange(Sz)[:,None,None,None], np.arange(Sx)[:,None,None], np.arange(N)[:,None], I_sorted, ] Ik = I_sorted.reshape(Sz, Sx, N, C, K) gk = g_sorted.reshape(Sz, Sx, N, C, K) gk_agg = lse(gk, -1) Ik_agg = gk_agg.argmax(-1) g_out = gk[ np.arange(Sz)[:,None,None,None,None], np.arange(Sx)[:,None,None,None], np.arange(N)[:,None,None], Ik_agg[:,:,:,None,None], :, ][:,:,:,0,0] I_out = Ik[ np.arange(Sz)[:,None,None,None,None], np.arange(Sx)[:,None,None,None], np.arange(N)[:,None,None], Ik_agg[:,:,:,None,None], :, ][:,:,:,0,0] # works without topk return np.exp(normalize(g_out)), I_out
def exp_rff_attn(q, k, projection_matrix): kernel_cons = exp_nonnegative_softmax_kernel_feature_creator log_phi_q = kernel_cons( q, projection_matrix, (0, ), None, is_query=True, eps=.0001, #normalize_data=True, normalize_data=False, ) log_phi_k = kernel_cons( k, projection_matrix, (0, ), None, is_query=False, eps=.0001, #normalize_data=True, normalize_data=False, ) log_pots_hat = log_phi_q[:, None, :] + log_phi_k[None, :, :] # average log_pots = lmm(log_phi_q, log_phi_k) return jnp.exp(log_pots - lse(log_pots, -1, keepdims=True)), log_pots_hat
def logp_x_z_relaxed_part(theta, params, x, g, K=1, tau=1): theta = normalize(theta) rzs, zs = sample_relaxed_part(theta, g, K, tau) Sz, Sx, N, Z = g.shape fz = f_relaxed_subset(params, rzs, zs) fxz = fz[:, np.arange(Sx)[:,None], np.arange(N), x] logp_b = lse(theta[ np.arange(Sz)[:,None,None,None], np.arange(Sx)[:,None,None], np.arange(N)[:,None], zs, ], -1) return (lax.stop_gradient(fxz) * logp_b + fxz).sum()
def rff_attn(q, k, projection_matrix, eps=0): kernel_cons = nonnegative_softmax_kernel_feature_creator log_phi_q = kernel_cons( q, projection_matrix, is_query=True, eps=eps, #normalize_data=True, normalize_data=False, ) log_phi_k = kernel_cons( k, projection_matrix, is_query=False, eps=eps, #normalize_data=True, normalize_data=False, ) log_pots_hat = log_phi_q[:, None, :] + log_phi_k[None, :, :] # average log_pots = lmm(log_phi_q, log_phi_k) - math.log(k.shape[0]) return jnp.exp(log_pots - lse(log_pots, -1, keepdims=True)), log_pots_hat
def normalize(x): return x - lse(x, -1, keepdims=True)
get_2d_arrays = jax.jit(jax.vmap(functools.partial(get_2d_array, scaling=0), )) def random_projection(num_features, original_dim, key, bsz=None): shape = ((num_features, original_dim) if bsz is None else (bsz, num_features, original_dim)) return random.normal(key, shape) # tests def attn(q, k): log_pots = q @ k.T return jax.nn.softmax(log_pots), log_pots lmm = jax.jit(lambda x, y: lse(x[:, None, :] + y[None, :, :], -1)) def rff_attn(q, k, projection_matrix, eps=0): kernel_cons = nonnegative_softmax_kernel_feature_creator log_phi_q = kernel_cons( q, projection_matrix, is_query=True, eps=eps, #normalize_data=True, normalize_data=False, ) log_phi_k = kernel_cons( k, projection_matrix,