def objective(params, batch) -> float: x, targets = batch logits = predict_fun(params, x) logits = logits - logsumexp(logits, axis=1, keepdims=True) loss = -np.mean(np.sum(logits * targets, axis=1)) loss += 5e-6 * (l2_norm(params)) #+ l2_norm(bparam)) return loss
def matching_log_probas(embeddings, targets, test_embeddings, num_classes, eps=1e-8): num_samples = test_embeddings.shape[0] similarities = pairwise_cosine_similarity(embeddings, test_embeddings, eps=eps) logsumexp = nn.logsumexp(similarities, axis=0, keepdims=True) max_similarities = jnp.max(similarities, axis=0, keepdims=True) exp_similarities = jnp.exp(similarities - max_similarities) sum_exp = jnp.zeros((num_classes, num_samples), dtype=exp_similarities.dtype) indices = jnp.expand_dims(targets, axis=-1) dimension_numbers = ScatterDimensionNumbers( update_window_dims=(1, ), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=(0, )) sum_exp = scatter_add(sum_exp, indices, exp_similarities, dimension_numbers) return jnp.log(sum_exp) + max_similarities - logsumexp
def scan_fn(beta_prev, t): beta_t = jnp.where( t > length, -jnp.inf + jnp.zeros_like(beta_prev), log_normalize( logsumexp(beta_prev + obs_dist.log_prob(obs_seq[-t + 1]) + trans_dist.logits, axis=1))[0]) return beta_t, beta_t
def log_pred_density(model, samples, *args, **kwargs): # waic score of posterior samples log_lk = log_likelihood(model, samples, *args, **kwargs)['y'] ll = log_lk.sum(-1) S = ll.shape[0] lppd = nn.logsumexp(ll, 0) - jnp.log(S) p_waic = jnp.var(ll, axis=0, ddof=1) return lppd - p_waic, log_lk
def scan_fn(carry, t): (alpha_prev, log_ll_prev) = carry alpha_n = jnp.where(t < length, obs_dist.log_prob(obs_seq[t]) + logsumexp( logdotexp(alpha_prev[:, None], trans_dist.logits), axis=0), -jnp.inf + jnp.zeros_like(alpha_prev)) alpha_n, cn = log_normalize(alpha_n) carry = (alpha_n, cn + log_ll_prev) return carry, alpha_n
def log_normalize(u, axis=-1): ''' Normalizes the values within the axis in a way that the exponential of each values within the axis sums up to 1. Parameters ---------- u : array axis : int Returns ------- * array The Log of normalized version of the given matrix * array(seq_len, n_hidden) : The values of the normalizer ''' c = logsumexp(u, axis=axis) return jnp.where(u == -jnp.inf, -jnp.inf, u - c), c
def loss(W, b): logits = predict(W, b, inputs) preds = logits - logsumexp(logits, axis=1, keepdims=True) loss = -jnp.mean(jnp.sum(preds * targets, axis=1)) loss += 0.001 * (l2_norm(W) + l2_norm(b)) return loss