Ejemplo n.º 1
0
 def inv(self, y):
     z = matrix_to_tril_vec(y, diagonal=-1)
     return jnp.concatenate(
         [z, jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1))], axis=-1)
Ejemplo n.º 2
0
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator,
                              inverse_mass_matrix, position, rng_key,
                              init_step_size):
    """
    Finds a reasonable step size by tuning `init_step_size`. This function is used
    to avoid working with a too large or too small step size in HMC.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman

    :param potential_fn: A callable to compute potential energy.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param momentum_generator: A generator to get a random momentum variable.
    :param inverse_mass_matrix: Inverse of mass matrix.
    :param position: Current position of the particle.
    :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
    :param float init_step_size: Initial step size to be tuned.
    :return: a reasonable value for step size.
    :rtype: float
    """
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = np.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z = position
    potential_energy, z_grad = value_and_grad(potential_fn)(z)
    finfo = np.finfo(get_dtype(init_step_size))

    def _body_fn(state):
        step_size, _, direction, rng_key = state
        rng_key, rng_key_momentum = random.split(rng_key)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0**direction) * step_size
        r = momentum_generator(inverse_mass_matrix, rng_key_momentum)
        _, r_new, potential_energy_new, _ = vv_update(
            step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad))
        energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
        energy_new = kinetic_fn(inverse_mass_matrix,
                                r_new) + potential_energy_new
        delta_energy = energy_new - energy_current
        direction_new = np.where(target_accept_prob < -delta_energy, 1, -1)
        return step_size, direction, direction_new, rng_key

    def _cond_fn(state):
        step_size, last_direction, direction, _ = state
        # condition to run only if step_size is not too small or we are not decreasing step_size
        not_small_step_size_cond = (step_size > finfo.tiny) | (direction >= 0)
        # condition to run only if step_size is not too large or we are not increasing step_size
        not_large_step_size_cond = (step_size < finfo.max) | (direction <= 0)
        not_extreme_cond = not_small_step_size_cond & not_large_step_size_cond
        return not_extreme_cond & ((last_direction == 0) |
                                   (direction == last_direction))

    step_size, _, _, _ = while_loop(_cond_fn, _body_fn,
                                    (init_step_size, 0, 0, rng_key))
    return step_size
Ejemplo n.º 3
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     return sum_rightmost(
         jnp.broadcast_to(jnp.log(jnp.abs(self.scale)), jnp.shape(x)),
         self.event_dim)
Ejemplo n.º 4
0
 def f(x):
     checkify.check(x > 0, "must be positive!")
     return jnp.log(x)
Ejemplo n.º 5
0
def softplus_inverse(x):
    return jnp.log(jnp.exp(x) - 1.)
Ejemplo n.º 6
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     return (np.log(self.rate) * value) - gammaln(value + 1) - self.rate
Ejemplo n.º 7
0
def categorical_kl(p_probs, q_log_probs):
    p_log_probs = jnp.log(p_probs)
    outs = jnp.where(p_probs > 0, p_probs * (p_log_probs - q_log_probs),
                     p_probs)
    return jnp.sum(outs)
Ejemplo n.º 8
0
def div_left(left_val, out_val, ildj_):
    return left_val / out_val, ((np.log(left_val) - 2 * np.log(out_val)) +
                                ildj_)
Ejemplo n.º 9
0
def div_right(right_val, out_val, ildj_):
    return out_val * right_val, np.log(np.abs(right_val)) + ildj_
Ejemplo n.º 10
0
def mul_left(left_val, out_val, ildj_):
    return out_val / left_val, -np.log(np.abs(left_val)) + ildj_
Ejemplo n.º 11
0
def mul_right(right_val, out_val, ildj_):
    return out_val / right_val, -np.log(np.abs(right_val)) + ildj_
Ejemplo n.º 12
0
def logit(x: Array) -> Array:
    """Logit transform, inverse of sigmoid."""
    chex.assert_type(x, float)
    return -jnp.log(1. / x - 1.)
Ejemplo n.º 13
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     return jnp.log(jnp.abs(self.exponent * y / x))
Ejemplo n.º 14
0
 def inv(self, y):
     x = jnp.log(y[..., 1:] - y[..., :-1])
     return jnp.concatenate([y[..., :1], x], axis=-1)
Ejemplo n.º 15
0
 def log_prob(self, x):
     if self._validate_args:
         self._validate_sample(x)
     log_prob = np.log(x == self.value)
     log_prob = sum_rightmost(log_prob, len(self.event_shape))
     return log_prob + self.log_density
Ejemplo n.º 16
0
def pow_left(x, z, ildj_):
    # x ** y = z
    # y = f^-1(z) = log(z) / log(x)
    # grad(f^-1)(z) = 1 / (z log(x))
    # log(grad(f^-1)(z)) = log(1 / (z log(x))) = -log(z) - log(log(x))
    return np.log(z) / np.log(x), ildj_ - np.log(z) - np.log(np.log(x))
Ejemplo n.º 17
0
def _to_logits_bernoulli(probs):
    ps_clamped = clamp_probs(probs)
    return np.log(ps_clamped) - np.log1p(-ps_clamped)
Ejemplo n.º 18
0
## constants
##

# numbers
eps = 1e-7

# link functions
links = {
    'identity': lambda x: x,
    'exponential': lambda x: np.exp(x),
    'logit': lambda x: 1/(1+np.exp(-x))
}

# loss functions
losses = {
    'binary': lambda yh, y: y*np.log(yh) + (1-y)*np.log(1-yh),
    'poisson': lambda yh, y: y*np.log(yh) - yh,
    'negative_binomial': lambda r, yh, y: gammaln(r+y) - gammaln(r) + r*np.log(r) + y*np.log(yh) - (r+y)*np.log(r+yh),
    'least_squares': lambda yh, y: -(y-yh)**2
}

##
## batching it
##

class DataLoader:
    def __init__(self, y, x, batch_size):
        self.y = y
        self.x = x
        self.batch_size = batch_size
        self.data_size = len(y)
Ejemplo n.º 19
0
def _to_logits_multinom(probs):
    minval = np.finfo(get_dtypes(probs)[0]).min
    return np.clip(np.log(probs), a_min=minval)
Ejemplo n.º 20
0
 def loss(par, yh, y):
     pzero = 1/(1+np.exp(-par[0]))
     like = pzero*(y==0) + (1-pzero)*np.exp(loss0(yh, y))
     return np.log(like)
Ejemplo n.º 21
0
 def prior_kl_whiten(self):
     qu = self.inducing_variable.variational_distribution
     log_det = 2 * jnp.sum(jnp.log(jnp.diag(qu.scale)))
     dim = qu.mean.shape[-1]
     return -.5 * (log_det + 0.5 * dim - jnp.sum(qu.mean**2) -
                   jnp.sum(qu.scale**2))
Ejemplo n.º 22
0
    def __call__(self, inputs, is_training):
        """Connects the module to some inputs.

    Args:
      inputs: Tensor, final dimension must be equal to embedding_dim. All other
        leading dimensions will be flattened and treated as a large batch.
      is_training: boolean, whether this connection is to training data. When
        this is set to False, the internal moving average statistics will not be
        updated.

    Returns:
      dict containing the following keys and values:
        quantize: Tensor containing the quantized version of the input.
        loss: Tensor containing the loss to optimize.
        perplexity: Tensor containing the perplexity of the encodings.
        encodings: Tensor containing the discrete encodings, ie which element
        of the quantized space each input element was mapped to.
        encoding_indices: Tensor containing the discrete encoding indices, ie
        which element of the quantized space each input element was mapped to.
    """
        flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
        embeddings = self.embeddings

        distances = (jnp.sum(flat_inputs**2, 1, keepdims=True) -
                     2 * jnp.matmul(flat_inputs, embeddings) +
                     jnp.sum(embeddings**2, 0, keepdims=True))

        encoding_indices = jnp.argmax(-distances, 1)
        encodings = jax.nn.one_hot(encoding_indices,
                                   self.num_embeddings,
                                   dtype=distances.dtype)

        # NB: if your code crashes with a reshape error on the line below about a
        # Tensor containing the wrong number of values, then the most likely cause
        # is that the input passed in does not have a final dimension equal to
        # self.embedding_dim. Ideally we would catch this with an Assert but that
        # creates various other problems related to device placement / TPUs.
        encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
        quantized = self.quantize(encoding_indices)
        e_latent_loss = jnp.mean(
            (jax.lax.stop_gradient(quantized) - inputs)**2)

        if is_training:
            updated_ema_cluster_size = self.ema_cluster_size(
                jnp.sum(encodings, axis=0))

            dw = jnp.matmul(flat_inputs.T, encodings)
            updated_ema_dw = self.ema_dw(dw)

            n = jnp.sum(updated_ema_cluster_size)
            updated_ema_cluster_size = (
                (updated_ema_cluster_size + self.epsilon) /
                (n + self.num_embeddings * self.epsilon) * n)

            normalised_updated_ema_w = (
                updated_ema_dw /
                jnp.reshape(updated_ema_cluster_size, [1, -1]))

            hk.set_state("embeddings", normalised_updated_ema_w)
            loss = self.commitment_cost * e_latent_loss

        else:
            loss = self.commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
        avg_probs = jnp.mean(encodings, 0)
        perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

        return {
            "quantize": quantized,
            "loss": loss,
            "perplexity": perplexity,
            "encodings": encodings,
            "encoding_indices": encoding_indices,
            "distances": distances,
        }
Ejemplo n.º 23
0
 def log_normal(x, mean, cov):
     L = jnp.linalg.cholesky(cov)
     dx = x - mean
     dx = solve_triangular(L, dx, lower=True)
     return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
            - 0.5 * dx @ dx
Ejemplo n.º 24
0
def log_average_softmax_probs(logits: jnp.ndarray) -> jnp.ndarray:
  # TODO(zmariet): dedicated eval loss function.
  ens_size, _, _ = logits.shape
  log_p = jax.nn.log_softmax(logits)  # (ensemble_size, batch_size, num_classes)
  log_p = jax.nn.logsumexp(log_p, axis=0) - jnp.log(ens_size)
  return log_p
Ejemplo n.º 25
0
def binary_cross_entropy_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits +
                    (1. - labels) * jnp.log(-jnp.expm1(logits)))
Ejemplo n.º 26
0
def nll(w): return -(y*np.log(mu(w)) + (1-y)*np.log(1-mu(w)))
#def deriv_nll(w): return -(y*(1-mu(w))*x - (1-y)*mu(w)*x)
def deriv_nll(w): return (mu(w)-y)*x
Ejemplo n.º 27
0
def get_gibbs_loss(loss_train, tlbs, tcut):
    mu, sig = loss_train
    mu = mu.squeeze()
    loss = -0.5 * (np.linalg.norm(tlbs - mu, axis=0)**
                   2) - 0.5 * np.trace(sig) - tcut / 2.0 * np.log(2 * np.pi)
    return np.sum(loss)
Ejemplo n.º 28
0
def loss(weights, inputs, targets):
    preds = predict(weights, inputs)
    logprobs = np.log(preds) * targets + np.log(1 - preds) * (1 - targets)
    return -np.sum(logprobs)
Ejemplo n.º 29
0
 def _inverse(self, y):
     return jnp.log(y)
Ejemplo n.º 30
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     return jnp.broadcast_to(
         jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1),
         jnp.shape(x)[:-1])