Exemplo n.º 1
0
 def objective(params, batch) -> float:
     x, targets = batch
     logits = predict_fun(params, x)
     logits = logits - logsumexp(logits, axis=1, keepdims=True)
     loss = -np.mean(np.sum(logits * targets, axis=1))
     loss += 5e-6 * (l2_norm(params))  #+ l2_norm(bparam))
     return loss
Exemplo n.º 2
0
def matching_log_probas(embeddings,
                        targets,
                        test_embeddings,
                        num_classes,
                        eps=1e-8):
    num_samples = test_embeddings.shape[0]
    similarities = pairwise_cosine_similarity(embeddings,
                                              test_embeddings,
                                              eps=eps)
    logsumexp = nn.logsumexp(similarities, axis=0, keepdims=True)

    max_similarities = jnp.max(similarities, axis=0, keepdims=True)
    exp_similarities = jnp.exp(similarities - max_similarities)

    sum_exp = jnp.zeros((num_classes, num_samples),
                        dtype=exp_similarities.dtype)
    indices = jnp.expand_dims(targets, axis=-1)
    dimension_numbers = ScatterDimensionNumbers(
        update_window_dims=(1, ),
        inserted_window_dims=(0, ),
        scatter_dims_to_operand_dims=(0, ))
    sum_exp = scatter_add(sum_exp, indices, exp_similarities,
                          dimension_numbers)

    return jnp.log(sum_exp) + max_similarities - logsumexp
Exemplo n.º 3
0
 def scan_fn(beta_prev, t):
     beta_t = jnp.where(
         t > length, -jnp.inf + jnp.zeros_like(beta_prev),
         log_normalize(
             logsumexp(beta_prev + obs_dist.log_prob(obs_seq[-t + 1]) +
                       trans_dist.logits,
                       axis=1))[0])
     return beta_t, beta_t
Exemplo n.º 4
0
def log_pred_density(model, samples, *args, **kwargs):
    # waic score of posterior samples
    log_lk = log_likelihood(model, samples, *args, **kwargs)['y']
    ll = log_lk.sum(-1)

    S = ll.shape[0]
    lppd = nn.logsumexp(ll, 0) - jnp.log(S)
    p_waic = jnp.var(ll, axis=0, ddof=1)
    return lppd - p_waic, log_lk
Exemplo n.º 5
0
    def scan_fn(carry, t):
        (alpha_prev, log_ll_prev) = carry
        alpha_n = jnp.where(t < length,
                            obs_dist.log_prob(obs_seq[t]) + logsumexp(
                                logdotexp(alpha_prev[:, None], trans_dist.logits), axis=0),
                            -jnp.inf + jnp.zeros_like(alpha_prev))

        alpha_n, cn = log_normalize(alpha_n)
        carry = (alpha_n, cn + log_ll_prev)

        return carry, alpha_n
Exemplo n.º 6
0
def log_normalize(u, axis=-1):
    '''
    Normalizes the values within the axis in a way that the exponential of each values within the axis
    sums up to 1.
    Parameters
    ----------
    u : array
    axis : int
    Returns
    -------
    * array
        The Log of normalized version of the given matrix
    * array(seq_len, n_hidden) :
        The values of the normalizer
    '''
    c = logsumexp(u, axis=axis)
    return jnp.where(u == -jnp.inf, -jnp.inf, u - c), c
Exemplo n.º 7
0
def loss(W, b):
    logits = predict(W, b, inputs)
    preds = logits - logsumexp(logits, axis=1, keepdims=True)
    loss = -jnp.mean(jnp.sum(preds * targets, axis=1))
    loss += 0.001 * (l2_norm(W) + l2_norm(b))
    return loss