def sample_scan(params, tup, x): """ Perform single step update of the network """ _, (update_W, update_U, update_b), (reset_W, reset_U, reset_b), (out_W, out_U, out_b), (sm_W, sm_b) = params hidden = tup[3] logP = tup[2] key = tup[0] inp = tup[1] update_gate = sigmoid( np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b) reset_gate = sigmoid( np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b) output_gate = np.tanh( np.dot(inp, out_W) + np.dot(np.multiply(reset_gate, hidden), out_U) + out_b) output = np.multiply(update_gate, hidden) + np.multiply( 1 - update_gate, output_gate) hidden = output logits = np.dot(hidden, sm_W) + sm_b key, subkey = random.split(key) samples = random.categorical( subkey, logits, axis=1, shape=None) # sampling the conditional samples = one_hot( samples, sm_b.shape[0]) # convert to one hot encoded vector log_P_new = np.sum(samples * log_softmax(logits), axis=1) log_P_new = log_P_new + logP # update the value of the logP of the sample return (key, samples, log_P_new, output), samples
def mixture_logistic_log_pdf(x: Array, prior_logits: Array, means: Array, scales: Array) -> Array: """ Args: x (Array): input vector (D,) prior_logits (Array): prior logits to weight the components (D, K) means (Array): means per component per feature (D, K) scales (Array): scales per component per feature (D, K) Returns: log_pdf (Array) : log PDF for the mixture distribution """ x = jnp.expand_dims(x, axis=-1) base_dist = Logistic(loc=means, scale=scales) # x = (x - means) / scales # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=-1) # calculate the log pdf, (D,K)->(D,K) log_pdfs = prior_logits + base_dist.log_prob(x) # normalize distribution for components, (D,K)->(D,) log_pdf = logsumexp(log_pdfs, axis=-1) return log_pdf
def apply_fun(params, inputs): serial_params, vhead_params, phead_params = params serial_out = serial_apply(serial_params, inputs) vhead_out = vhead_apply(vhead_params, serial_out) phead_out = log_softmax(phead_apply(phead_params, serial_out)) out = (vhead_out, phead_out) return out
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0): """Compute cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length] label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1) loss = loss - normalizing_constant if weights is not None: loss = loss * weights normalizing_factor = weights.sum() else: normalizing_factor = np.prod(targets.shape) return loss.sum(), normalizing_factor
def mixture_logistic_log_pdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_pdf (JaxArray) : log PDF for the mixture distribution """ # n_components = prior_logits.shape[1] # # add component dimension, (D,)->(D,1) # will allow for broadcasting x_r = x.reshape(-1, 1) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) # print(x.shape, prior_logits.shape, ) log_pdfs = prior_logits + logistic_log_pdf(x_r, means, scales) # print("Log PDFS:", log_pdfs.shape) # normalize distribution for components, (D,K)->(D,) log_pdf = logsumexp(log_pdfs, axis=1) return log_pdf
def mixture_logistic_cdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_cdf (JaxArray) : log CDF for the mixture distribution """ # print(prior_logits.shape) # n_features, n_components = prior_logits x_r = x.reshape(-1, 1) # # x_r = np.tile(x, (n_features, n_components)) # print(x.shape, x_r.shape) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) log_cdfs = prior_logits + logistic_log_cdf(x_r, means, scales) # normalize distribution for components, (D,K)->(D,) log_cdf = logsumexp(log_cdfs, axis=1) return np.exp(log_cdf)
def structure_mixture_params(components) -> LogisticMixtureParams: unnormalized_weights = components[:, 2] probs = list(np.exp(nn.log_softmax(unnormalized_weights))) component_params = [ LogisticParams(component[0], component[1]) for component in components ] return LogisticMixtureParams(components=component_params, probs=probs)
def mixture_gaussian_log_pdf( x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray ) -> JaxArray: """ Args: x (JaxArray): input vector (D,) prior_logits (JaxArray): prior logits to weight the components (D, K) means (JaxArray): means per component per feature (D, K) scales (JaxArray): scales per component per feature (D, K) Returns: log_pdf (JaxArray) : log PDF for the mixture distribution """ # n_components = prior_logits.shape[1] # # x_r = np.tile(x, (n_components)) x_r = x.reshape(-1, 1) # normalize logit weights to 1, (D,K)->(D,K) prior_logits = log_softmax(prior_logits, axis=1) # calculate the log pdf, (D,K)->(D,K) log_pdfs = prior_logits + jax.scipy.stats.norm.logpdf(x_r, means, scales) # normalize distribution for components, (D,K)->(D,) log_pdf = logsumexp(log_pdfs, axis=1) return log_pdf
def main(unused_argv): # Load data and preprocess it. print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(_LEARNING_RATE, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // _BATCH_SIZE for i, (x, y) in enumerate(datasets.minibatch( x_train, y_train, _BATCH_SIZE, _TRAIN_EPOCHS)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format( epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary( 'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def poisson_categorical_log_prior(length, rate): """Categorical prior populated with log probabilities of Poisson dist.""" rate = jnp.array(rate, dtype=jnp.float32) values = jnp.expand_dims(jnp.arange(1, length + 1, dtype=jnp.float32), 0) log_prob_unnormalized = jax.lax.lgamma( jnp.log(rate) * values - rate - (values + 1)) # TODO(tkipf): Length-sensitive normalization. return nn.log_softmax(log_prob_unnormalized, axis=1) # Normalize.
def apply_fun(params, x, adj, is_training=False, **kwargs): rng = kwargs.pop('rng', None) rngs = random.split(rng, len(attn_funs)) for i, layer_fun in enumerate(attn_funs): x = layer_fun(params[i], x, adj, rng=rngs[i], is_training=is_training) return nn.log_softmax(x)
def mixture_logpdf_single(datum, components): component_scores = [] unnormalized_weights = np.array([component[2] for component in components]) weights = nn.log_softmax(unnormalized_weights) for component, weight in zip(components, weights): loc = component[0] scale = np.max([component[1], 0.01]) # Find a better solution? component_scores.append(logistic_logpdf(datum, loc, scale) + weight) return scipy.special.logsumexp(np.array(component_scores))
def predict(params, x): # per-example predictions activations = x for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = nn.softmax(outputs) final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return nn.log_softmax(logits)
def apply_fun(params, x, adj, is_training=False, **kwargs): rng = kwargs.pop('rng', None) k1, k2, k3, k4 = random.split(rng, 4) x = drop_fun(None, x, is_training=is_training, rng=k1) x = gc1_fun(params[0], x, adj, rng=k2) x = nn.relu(x) x = drop_fun(None, x, is_training=is_training, rng=k3) x = gc2_fun(params[1], x, adj, rng=k4) x = nn.log_softmax(x) return x
def apply_fun(params, first_x, adj, is_training=False, **kwargs): rng = kwargs.pop('rng', None) k1, k2, k3, k4 = random.split(rng, 4) adj_1, adj_5 = adj x = gc1_fun(params[0], (first_x, first_x), adj_1, rng=k2, is_training=is_training) x = gc2_fun(params[1], (first_x, x), adj_5, activation=lambda x: x, rng=k4, is_training=is_training) x = nn.log_softmax(x) return x
def kl_sample_gmm(key, q_mean_u, q_logvar_u, gmm_resps_c, gmm_p_mean_cxu, gmm_p_logvar_cxu, varmin): """Sample the KL divergence between a gaussian and mixture of gaussians. KL(q||p) = E_q[log q(z) - log p(z))] In this case q is gaussian, p is gmm, which requires sampling. So we sample z ~ q(z) and then compute log q(z) - log p(z). Watch the numerics here, varmin needs to be tiny. Arguments: key: random.PRNGKey for random bits q_mean_u: mean from posterior gaussian distribution with dim u q_logvar_u: log variance from posterior gaussian distribution with dim u gmm_resps_c: np.array with shape c, responsibilities in the GMM, \pi in the above formula is softmax(gmm_resps_c). gmm_p_mean_cxu: np.array 2D array with shape mixture by dist dim, means in GMM gmm_p_logvar_cxu: "", log variances in GMM varmin: Minimum variance allowed (numerially useful). Returns: Single estimate of the KL divergence. """ # Handle case of one gaussian in the mixture with closed form equations. if gmm_resps_c.shape[0] == 1: return np.sum( kl_gauss_gauss(q_mean_u, q_logvar_u, gmm_p_mean_cxu[0, :], gmm_p_logvar_cxu[0, :], varmin)) # Otherwise sample the KL ll = diag_gaussian_log_likelihood gmm_ll = gmm_diag_gaussian_log_likelihood sample = diag_gaussian_sample keys = random.split(key, 2) z_u = sample(keys[0], q_mean_u, q_logvar_u, varmin) logq_u = ll(z_u, q_mean_u, q_logvar_u, varmin) # over multigauss dim assert varmin <= 1e-15, "Very small or you need to know what you are doing." llp_each_gaussian_cxu = gmm_ll(z_u, gmm_p_mean_cxu, gmm_p_logvar_cxu, varmin) log_normed_resps_cx1 = np.expand_dims(log_softmax(gmm_resps_c), axis=1) logp_u = logsumexp(llp_each_gaussian_cxu + log_normed_resps_cx1, axis=0) kl_estimate = np.sum(logq_u - logp_u, axis=0) return kl_estimate
def discretized_mix_logistic_loss(theta, y, num_class=256, log_scale_min=-7.): """ Discretized mixture of logistic distributions loss :param theta: B x T x 3 * nr_mix :param y: B x T x 1 """ theta_shape = theta.shape nr_mix = theta_shape[2] // 3 # unpack parameters means = theta[:, :, :nr_mix] log_scales = np.maximum(theta[:, :, nr_mix:2 * nr_mix], log_scale_min) logit_probs = theta[:, :, nr_mix * 2:nr_mix * 3] # B x T x 1 => B x T x nr_mix y = np.broadcast_to(y, y.shape[:-1] + (nr_mix, )) centered_y = y - means inv_stdv = np.exp(-log_scales) plus_in = inv_stdv * (centered_y + 1. / (num_class - 1)) cdf_plus = sigmoid(plus_in) min_in = inv_stdv * (centered_y - 1. / (num_class - 1)) cdf_min = sigmoid(min_in) # log probability for edge case of 0 (before scaling): log_cdf_plus = plus_in - softplus(plus_in) # log probability for edge case of 255 (before scaling): log_one_minus_cdf_min = -softplus(min_in) cdf_delta = cdf_plus - cdf_min # probability for all other cases mid_in = inv_stdv * centered_y log_pdf_mid = mid_in - log_scales - 2. * softplus(mid_in) log_probs = np.where( y < -0.999, log_cdf_plus, np.where( y > 0.999, log_one_minus_cdf_min, np.where(cdf_delta > 1e-5, np.log(np.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log((num_class - 1) / 2)))) log_probs = log_probs + log_softmax(logit_probs) return -np.sum(logsumexp(log_probs, axis=-1), axis=-1)
def apply_fun_scan(params, tup, inp): """ Perform single step update of the network """ _, (update_W, update_U, update_b), (reset_W, reset_U, reset_b), (out_W, out_U, out_b), (sm_W, sm_b) = params hidden = tup[0] logP = tup[1] update_gate = sigmoid( np.dot(inp, update_W) + np.dot(hidden, update_U) + update_b) reset_gate = sigmoid( np.dot(inp, reset_W) + np.dot(hidden, reset_U) + reset_b) output_gate = np.tanh( np.dot(inp, out_W) + np.dot(np.multiply(reset_gate, hidden), out_U) + out_b) output = np.multiply(update_gate, hidden) + np.multiply( 1 - update_gate, output_gate) hidden = output logP = log_softmax(np.dot(hidden, sm_W) + sm_b) return (hidden, logP), (hidden, logP)
def loss(observations, actions, rewards_to_go): logprobs = log_softmax(policy(observations)) action_logprobs = logprobs[np.arange(logprobs.shape[0]), actions] return -np.mean(action_logprobs * rewards_to_go, axis=0)
def cross_entropy(model: Model, inputs: Tensor, targets: Tensor) -> float: # log softmax predicted = nn.log_softmax(model(inputs)) # negative log likelihood return -np.mean(np.sum(targets * predicted, axis=1))
def onnx_log_softmax(x, axis=-1): return log_softmax(x, axis)
def cross_entropy(outputs: DeviceArray, targets: DeviceArray) -> DeviceArray: probs = log_softmax(outputs) labels = _one_hot(targets, 10) loss = -jnp.mean(probs * labels) return loss
def from_params(cls, fixed_params, opt_params, traceable=False): logps = nn.log_softmax(opt_params) return cls(logps, traceable=traceable)
def criterion(logits, targets): return -jnp.mean(jnp.sum(log_softmax(logits) * targets, axis=1), axis=0)
def accuracy(params, batch, predict_fn): logits = predict_fn(params, batch["X"]) logits = log_softmax(logits) return jnp.mean(jnp.argmax(logits, -1) == batch["y"])
def softmax_cross_entropy(logits, labels): one_hot_labels = one_hot(labels, logits.shape[-1]) return -jnp.sum(log_softmax(logits) * one_hot_labels, axis=-1)
def hmm_viterbi_log(params, obs_seq, length=None): ''' Computes, for each time step, the marginal conditional probability that the Hidden Markov Model was in each possible state given the observations that were made at each time step, i.e. P(z[i] | x[0], ..., x[num_steps - 1]) for all i from 0 to num_steps - 1 It is based on https://github.com/deepmind/distrax/blob/master/distrax/_src/utils/hmm.py Parameters ---------- params : HMM Hidden Markov Model obs_seq: array(seq_len) History of observed states Returns ------- * array(seq_len, n_states) Alpha values * array(seq_len, n_states) Beta values * array(seq_len, n_states) Marginal conditional probability * float The loglikelihood giving log(p(x|model)) ''' seq_len = len(obs_seq) if length is None: length = seq_len trans_dist, obs_dist, init_dist = params.trans_dist, params.obs_dist, params.init_dist trans_log_probs = log_softmax(trans_dist.logits) init_log_probs = log_softmax(init_dist.logits) n_states = obs_dist.batch_shape[0] first_log_prob = init_log_probs + obs_dist.log_prob(obs_seq[0]) if seq_len == 1: return jnp.expand_dims(jnp.argmax(first_log_prob), axis=0) def viterbi_forward(prev_logp, t): obs_logp = obs_dist.log_prob(obs_seq[t]) logp = jnp.where( t <= length, prev_logp[..., None] + trans_log_probs + obs_logp[..., None, :], -jnp.inf + jnp.zeros_like(trans_log_probs)) max_logp_given_successor = jnp.where(t <= length, jnp.max(logp, axis=-2), prev_logp) most_likely_given_successor = jnp.where(t <= length, jnp.argmax(logp, axis=-2), -1) return max_logp_given_successor, most_likely_given_successor ts = jnp.arange(1, seq_len) final_log_prob, most_likely_sources = lax.scan(viterbi_forward, first_log_prob, ts) most_likely_initial_given_successor = jnp.argmax(trans_log_probs + first_log_prob, axis=-2) most_likely_sources = jnp.concatenate([ jnp.expand_dims(most_likely_initial_given_successor, axis=0), most_likely_sources ], axis=0) def viterbi_backward(state, t): state = jnp.where( t <= length, jnp.sum(most_likely_sources[t] * one_hot(state, n_states)).astype( jnp.int64), state) most_likely = jnp.where(t <= length, state, -1) return state, most_likely final_state = jnp.argmax(final_log_prob) _, most_likely_path = lax.scan(viterbi_backward, final_state, ts, reverse=True) final_state = jnp.where(length == seq_len, final_state, -1) return jnp.append(most_likely_path, final_state)
def from_params(cls, params, traceable=False): logps = nn.log_softmax(params) return cls(logps, traceable=traceable)
def inner(num_features, qk_dim, S, T, temp_sqrt): qk_key = jax.random.PRNGKey(111) key1, key2 = jax.random.split(qk_key) q = jax.random.normal(key1, (S, qk_dim)) / temp_sqrt #q = jax.random.normal(key1, (T, qk_dim)) / temp_sqrt k = jax.random.normal(key2, (T, qk_dim)) / temp_sqrt log_probs = log_softmax(q @ k.T) print( f"Total entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum()}" ) print( f"Mean entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum(-1).mean()}" ) st.write( f"Total entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum()}" ) st.write( f"Mean entropy of true dist: {-(jnp.exp(log_probs) * log_probs).sum(-1).mean()}" ) #""" # fit exp kernel def loss(q, k, dummy_proj, attn): logits = q @ k.T probs = softmax(logits) return fat.kl(attn, probs).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1))) key, key1 = jax.random.split(key_train_init) print(f"Softmax fit") title = f"KL softmax fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})" kl_ = report_train( q, k, lambda x: None, L_dL, num_features, key1, train_fn=fat.train, sample=False, title=title, ) print(f"kl {kl_}") #""" # """ # fit exp kernel fwd renorm def loss(q, k, scale, dummy_proj, attn): logits = renorm(q) @ renorm(k).T probs = softmax(logits) return fat.kl(attn, probs).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Softmax fwd renorm fit") title = f"KL softmax fwd renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})" kl_ = report_train( q, k, lambda x: None, L_dL, num_features, key1, train_fn = fat.train_proj, sample=False, title=title, post_renorm=False, ) print(f"kl {kl_}") """ """ # fit exp kernel post renorm def loss(q, k, scale, dummy_proj, attn): logits = q @ k.T probs = softmax(logits) return fat.kl(attn, probs).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Softmax post renorm fit") title = f"KL softmax post renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})" kl_ = report_train( q, k, lambda x: None, L_dL, num_features, key1, train_fn = fat.train_proj, sample=False, title=title, post_renorm=True, ) print(f"kl {kl_}") """ # """ def loss(q, k, scale, dummy_proj, attn): logits = renorm(q) @ renorm(k).T probs = softmax(3. * logits) return fat.kl(attn, probs).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Softmax temp fwd renorm fit") title = f"KL softmax temp fwd renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})" kl_ = report_train( q, k, lambda x: None, L_dL, num_features, key1, train_fn = fat.train_proj, sample=False, title=title, post_renorm=False, ) print(f"kl {kl_}") """ """ def loss(q, k, scale, dummy_proj, attn): logits = q @ k.T probs = softmax(3. * logits) return fat.kl(attn, probs).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Softmax temp post renorm fit") title = f"KL softmax temp post renorm fit (S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt})" kl_ = report_train( q, k, lambda x: None, L_dL, num_features, key1, train_fn = fat.train_proj, sample=False, title=title, post_renorm=True, ) print(f"kl {kl_}") """ def proj_fn(shape, key): sample_key, norm_key = jax.random.split(key) gaussian_sample = fat.random_projection(num_features, qk_dim, sample_key) print(gaussian_sample.shape) projection_matrix = fat.get_2d_array(gaussian_sample, norm_key, scaling=0) return projection_matrix proj_fn = functools.partial(proj_fn, (num_features, qk_dim)) def proj_fn_gaus(shape, key): sample_key, norm_key = jax.random.split(key) gaussian_sample = fat.random_projection(num_features, qk_dim, sample_key) return gaussian_sample proj_fn_gaus = functools.partial(proj_fn_gaus, (num_features, qk_dim)) def proj_fn_anti(shape, key): sample_key, norm_key = jax.random.split(key) gaussian_sample = fat.random_projection(num_features // 2, qk_dim, sample_key) projection_matrix = fat.get_2d_array(gaussian_sample, norm_key, scaling=0) return jnp.concatenate([projection_matrix, -projection_matrix], axis=0) proj_fn_anti = functools.partial(proj_fn_anti, (num_features, qk_dim)) def proj_fn_reg(shape, key): sample_key, norm_key = jax.random.split(key) gaussian_sample = fat.random_projection(num_features, qk_dim, sample_key) projection_matrix = fat.get_2d_array(gaussian_sample, norm_key, scaling=1) return projection_matrix proj_fn_reg = functools.partial(proj_fn_reg, (num_features, qk_dim)) def proj_fn_reg_small(shape, key): sample_key, norm_key = jax.random.split(key) gaussian_sample = fat.random_projection(num_features, qk_dim, sample_key) projection_matrix = fat.get_2d_array(gaussian_sample, norm_key, scaling=2) return projection_matrix proj_fn_reg_small = functools.partial(proj_fn_reg_small, (num_features, qk_dim)) #for sample_key in [True, False]: for sample_key in [False]: #for proj_fn in [proj_fn, proj_fn_reg]: #for this_proj_fn in [proj_fn, proj_fn_anti]: for this_proj_fn in [proj_fn]: print(this_proj_fn) """ def loss(q, k, scale, proj, attn_dist): ra, _ = fat.rff_attn(q, k, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Orth fit (Sample: {sample_key})") title = f"KL Orth fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train(q, k, this_proj_fn, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") """ #""" def loss(q, k, scale, proj, attn_dist): qp = renorm(q, axis=-1) kp = renorm(k, axis=-1) ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Orth projected L2 fit (Sample: {sample_key})") title = f"KL Orth projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, proj_fn, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" """ def loss(q, k, scale, proj, attn_dist): qp = renorm_stopgrad(q, axis=-1) kp = renorm_stopgrad(k, axis=-1) ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Orth projected L2 detach fit (Sample: {sample_key})") title = f"KL Orth projected L2 detach fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train(q, k, proj_fn, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") """ """ def loss(q, k, scale, proj, attn_dist): ra, _ = fat.rff_attn(q, k, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Orth post projected L2 fit (Sample: {sample_key})") title = f"KL orth post projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train(q, k, proj_fn, L_dL, num_features, key1, fat.train_proj, sample_key, title, post_renorm=True, ) print(f"kl {kl_}") """ #""" # learn scale def loss(q, k, scale, proj, attn_dist): qp = renorm(q, axis=-1) kp = renorm(k, axis=-1) ra, _ = fat.rff_attn(qp, kp, proj) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Orth projected L2 fit proj (Sample: {sample_key})") title = f"KL Orth Projected L2 fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, proj_fn, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" """ def loss(q, k, scale, proj, attn_dist): qp = renorm(q, axis=-1) kp = renorm(k, axis=-1) ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Orth Projected L2 fit scale projfnreg (Sample: {sample_key})") title = f"KL Orth Projected L2 fit scale projfnreg (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, proj_fn_reg, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") """ """ def loss(q, k, scale, proj, attn_dist): qp = renorm(q, axis=-1) kp = renorm(k, axis=-1) ra, _ = fat.rff_attn(qp, kp, jax.lax.stop_gradient(scale) * proj) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) key, key1 = jax.random.split(key_train_init) print(f"Orth Projected L2 fit proj (Sample: {sample_key})") title = f"KL Orth Projected L2 fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, proj_fn, #proj_fn_norm, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") """ """ def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.relu_rff_attn0(qp, kp, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) #print(f"Relu0 Projected L2 fit (Sample: {sample_key})") #title = f"KL Relu0 Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" print(f"Relu0 fit (Sample: {sample_key})") title = f"KL Relu0 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") """ #""" def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.exp_rff_attn(qp, kp, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) print(f"Exp fit (Sample: {sample_key})") title = f"KL Exp fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" #""" def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.exp_rff_attn(qp, kp, proj) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) print(f"Exp fit proj (Sample: {sample_key})") title = f"KL Exp fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" # #""" def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.relu_rff_attn(qp, kp, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) #print(f"Relu Projected L2 fit (Sample: {sample_key})") #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" print(f"Relu fit (Sample: {sample_key})") title = f"KL Relu fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" # # #""" def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.relu_rff_attn(qp, kp, proj) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) print(f"Relu fit proj (Sample: {sample_key})") title = f"KL Relu fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" # # #""" def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.relu_rff_attn(qp, kp, jax.lax.stop_gradient(proj)) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) #print(f"Relu Projected L2 fit (Sample: {sample_key})") #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" print(f"Relu fit small (Sample: {sample_key})") title = f"KL Relu fit small (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg_small, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" # # #""" def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.relu_rff_attn(qp, kp, proj) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) #print(f"Relu Projected L2 fit (Sample: {sample_key})") #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" print(f"Relu fit small proj (Sample: {sample_key})") title = f"KL Relu fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg_small, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" # #""" def loss(q, k, scale, proj, attn_dist): qp = q kp = k ra, _ = fat.relu_rff_attn(qp, kp, scale * proj) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) print(f"Relu fit small proj (Sample: {sample_key})") title = f"KL Relu fit small proj scale (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg_small, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") #""" """ # bad def loss(q, k, scale, proj, attn_dist): qp = renorm(q) kp = renorm(k) ra, _ = fat.relu_rff_attn(qp, kp, proj) return fat.kl(attn_dist, ra).mean() L_dL = jax.jit(jax.value_and_grad(loss, argnums=(0, 1, 2, 3))) #L_dL = jax.value_and_grad(loss, argnums=(0, 1, 2, 3) key, key1 = jax.random.split(key_train_init) #print(f"Relu Projected L2 fit (Sample: {sample_key})") #title = f"KL Relu Projected L2 fit (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" print(f"Relu projected fit proj (Sample: {sample_key})") title = f"KL Relu projected fit proj (Sample: {sample_key} S: {S} T: {T} dim: {qk_dim} temp: {temp_sqrt} numfeat: {num_features})" kl_ = report_train( q, k, #proj_fn, #proj_fn_gaus, proj_fn_reg, L_dL, num_features, key1, fat.train_proj, sample_key, title, ) print(f"kl {kl_}") """ """
def logp(self, params, obs, act): logits = self._net_apply(params, obs) all_logps = nn.log_softmax(logits) return (hk.one_hot(act, self.act_dim) * all_logps).sum(-1)