Exemple #1
0
 def log_prob(self, value):
     log_prob_minus_log_gate = -self.gate_logits + self.base_dist.log_prob(
         value)
     log_gate = -softplus(-self.gate_logits)
     log_prob = log_prob_minus_log_gate + log_gate
     zero_log_prob = softplus(log_prob_minus_log_gate) + log_gate
     return jnp.where(value == 0, zero_log_prob, log_prob)
Exemple #2
0
 def latent_diffusion_function(self, x, t, param_dict, gp_matrices):
     if self.config["constant_diffusion"]:
         concentration, rate = self.gamma_params.build(
             param_dict["gamma_params"])
         self.key, subkey = random.split(self.key)
         inverse_lambdas = numpyro.sample("inverse_lambdas",
                                          InverseGamma(
                                              nn.softplus(concentration),
                                              nn.softplus(rate)),
                                          rng_key=subkey)
         return np.tile(
             np.expand_dims(aux_math.diag(inverse_lambdas), axis=0),
             [x.shape[0], 1, 1])
     else:
         if self.config["time_dependent_gp"]:
             time = np.ones(shape=(x.shape[0], 1)) * t
             y = np.concatenate((x, time), axis=1)
             return aux_math.diag(
                 np.transpose(
                     self.sde_gp_diffusion(y, param_dict["sde_gp"],
                                           gp_matrices)))
         return aux_math.diag(
             np.transpose(
                 self.sde_gp_diffusion(x, param_dict["sde_gp"],
                                       gp_matrices)))
Exemple #3
0
    def sample(self, state, model_args, model_kwargs):
        i, x, x_pe, x_grad, _, mean_accept_prob, adapt_state, rng_key = state
        x_flat, unravel_fn = ravel_pytree(x)
        x_grad_flat, _ = ravel_pytree(x_grad)
        shape = jnp.shape(x_flat)
        rng_key, key_normal, key_bernoulli, key_accept = random.split(
            rng_key, 4)

        mass_sqrt_inv = adapt_state.mass_matrix_sqrt_inv

        x_grad_flat_scaled = (mass_sqrt_inv @ x_grad_flat if self._dense_mass
                              else mass_sqrt_inv * x_grad_flat)

        # Generate proposal y.
        z = adapt_state.step_size * random.normal(key_normal, shape)

        p = expit(-z * x_grad_flat_scaled)
        b = jnp.where(random.uniform(key_bernoulli, shape) < p, 1.0, -1.0)

        dx_flat = b * z
        dx_flat_scaled = (mass_sqrt_inv.T @ dx_flat
                          if self._dense_mass else mass_sqrt_inv * dx_flat)

        y_flat = x_flat + dx_flat_scaled

        y = unravel_fn(y_flat)
        y_pe, y_grad = jax.value_and_grad(self._potential_fn)(y)
        y_grad_flat, _ = ravel_pytree(y_grad)
        y_grad_flat_scaled = (mass_sqrt_inv @ y_grad_flat if self._dense_mass
                              else mass_sqrt_inv * y_grad_flat)

        log_accept_ratio = (x_pe - y_pe + jnp.sum(
            softplus(dx_flat * x_grad_flat_scaled) -
            softplus(-dx_flat * y_grad_flat_scaled)))
        accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.0)

        x, x_flat, pe, x_grad = jax.lax.cond(
            random.bernoulli(key_accept, accept_prob),
            (y, y_flat, y_pe, y_grad),
            identity,
            (x, x_flat, x_pe, x_grad),
            identity,
        )

        # do not update adapt_state after warmup phase
        adapt_state = jax.lax.cond(
            i < self._num_warmup,
            (i, accept_prob, (x, ), adapt_state),
            lambda args: self._wa_update(*args),
            adapt_state,
            identity,
        )

        itr = i + 1
        n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup)
        mean_accept_prob = mean_accept_prob + (accept_prob -
                                               mean_accept_prob) / n

        return BarkerMHState(itr, x, pe, x_grad, accept_prob, mean_accept_prob,
                             adapt_state, rng_key)
Exemple #4
0
 def apply_fun(params, inputs, **kwargs):
     # a * f(x) + (1 - a) * x
     (fx, logdet), (x, _) = inputs
     gate = sigmoid(params)
     out = gate * fx + (1 - gate) * x
     logdet = softplus(logdet + params) - softplus(params)
     return out, logdet
Exemple #5
0
 def __call__(self, y, sc, multiplicative_factor=None):
     net = self.predict(sc["encoder_params"], y)
     if multiplicative_factor is None:
         scale_tril = aux_math.diag(nn.softplus(net[...,
                                                    self.output_dims:]))
     else:
         scale_tril = np.einsum(
             "ab,cbd->cad", multiplicative_factor,
             aux_math.diag(nn.softplus(net[..., self.output_dims:])))
     return net[..., :self.output_dims], scale_tril
def gradient_step(i, state, mod):
    params = get_params(state)
    mod.prior.hyp = params[0]
    mod.likelihood.hyp = params[1]
    neg_log_marg_lik, gradients = mod.run()
    # neg_log_marg_lik, gradients = mod.run_two_stage()  # <-- less elegant but reduces compile time
    print(
        'iter %2d: var_f=%1.2f len_f=%1.2f, nlml=%2.2f' %
        (i, softplus(params[0][0]), softplus(params[0][1]), neg_log_marg_lik))
    return opt_update(i, gradients, state)
def gradient_step(i, state, model):
    params = get_params(state)
    model.prior.hyp = params[0]
    model.likelihood.hyp = params[1]
    # neg_log_marg_lik, gradients = model.run_model()
    neg_log_marg_lik, gradients = model.neg_log_marg_lik()
    print(
        'iter %2d: var_f=%1.2f len_f=%1.2f, nlml=%2.2f' %
        (i, softplus(params[0][0]), softplus(params[0][1]), neg_log_marg_lik))
    return opt_update(i, gradients, state)
Exemple #8
0
    def __call__(self, inputs):

        inputs = np.clip(inputs, self.eps.value, 1 - self.eps.value)

        outputs = (1 / self.temperature.value) * (np.log(inputs) -
                                                  np.log1p(-inputs))
        logabsdet = -(np.log(self.temperature.value) -
                      softplus(-self.temperature.value * outputs) -
                      softplus(self.temperature.value * outputs))
        return outputs, logabsdet.sum(axis=1)
Exemple #9
0
    def inverse_and_log_det(self, inputs: Array, **kwargs) -> Tuple[Array, Array]:

        # transformation
        outputs = mixture_gaussian_invcdf_vectorized(
            inputs, self.prior_logits, self.means, softplus(self.log_scales),
        )
        # log abs det, all zeros
        logabsdet = mixture_gaussian_log_pdf(
            outputs, self.prior_logits, self.means, softplus(self.log_scales),
        )

        return outputs, logabsdet  # .sum(axis=1)
Exemple #10
0
    def forward_and_log_det(self, inputs: Array, **kwargs) -> Tuple[Array, Array]:

        # forward transformation with batch dimension
        outputs = mixture_gaussian_cdf(
            inputs, self.prior_logits, self.means, softplus(self.log_scales),
        )

        # log abs det, all zeros
        logabsdet = mixture_gaussian_log_pdf(
            inputs, self.prior_logits, self.means, softplus(self.log_scales),
        )

        return outputs, logabsdet  # .sum(axis=1)
Exemple #11
0
def conditional_params_from_outputs(image, theta):
    """
    Maps image and model output theta to conditional parameters for a mixture
    of nr_mix logistics. If the input shapes are

    image.shape == (h, w, c)
    theta.shape == (h, w, 10 * nr_mix)

    the output shapes will be

    means.shape == inv_scales.shape == (nr_mix, h, w, c)
    logit_probs.shape == (nr_mix, h, w)
    """
    nr_mix = 10
    logit_probs, theta = np.split(theta, [nr_mix], axis=-1)
    logit_probs = np.moveaxis(logit_probs, -1, 0)
    theta = np.moveaxis(np.reshape(theta, image.shape + (-1, )), -1, 0)
    unconditioned_means, log_scales, coeffs = np.split(theta, 3)
    coeffs = np.tanh(coeffs)

    # now condition the means for the last 2 channels
    mean_red = unconditioned_means[..., 0]
    mean_green = unconditioned_means[..., 1] + coeffs[..., 0] * image[..., 0]
    mean_blue = (unconditioned_means[..., 2] + coeffs[..., 1] * image[..., 0] +
                 coeffs[..., 2] * image[..., 1])
    means = np.stack((mean_red, mean_green, mean_blue), axis=-1)
    inv_scales = softplus(log_scales)
    return means, inv_scales, logit_probs
Exemple #12
0
    def transition_fn(carry, t):
        lam_prev, x_prev = carry

        gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev))

        U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True))

        logs = logits((beliefs[0][:, t], beliefs[1][:, t]),
                      jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2))

        lam_next = npyro.deterministic(
            'lams', lam_prev +
            nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1))

        mixing_dist = dist.CategoricalProbs(weights)
        component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask(
            mask[t][..., None])
        with npyro.plate('subjects', N):
            y = npyro.sample(
                'y', dist.MixtureSameFamily(mixing_dist, component_dist))
            noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return (lam_next, x_next), None
Exemple #13
0
def logistic_log_pdf(x: Array, mean: Array, scale: Array) -> Array:
    """Element-wise log PDF of the logistic distribution

    Args:
        x (Array): feature vector to be transformed
        mean (Array) : mean for the features
        scale (Array) : scale for features

    Returns:
        log_prob (Array): log probability of the distribution
    """

    # change of variables
    z = (x - mean) / scale

    # log probability
    # log_prob = z -jnp.log(scale) - 2 * jax.nn.softplus(z)
    # log_prob = jax.scipy.stats.logistic.logpdf(z)

    # Original Author Implementation
    log_prob = z - jnp.log(scale) - 2 * softplus(z)
    # # distrax implementation
    # log_prob = -z - 2.0 * softplus(-z) - jnp.log(scale)

    return log_prob
Exemple #14
0
    def loss(self, y_input, t_indices: list, param_dict, num_steps):

        if type(t_indices) is not list:
            raise TypeError("Time indices object must be a list")
        # if self.config["mapping"] == "neural_ode_with_softplus" and 0 not in t_indices:
        #     print(f"For mapping {self.config['mapping']}, the initial point (index 0) is required.")

        y_0 = y_input[0].reshape(y_input[0].shape[0], -1)

        gp_matrices, latent_drift_function, latent_diffusion_function = self.build(
            param_dict)

        self.sde_var.drift_function = latent_drift_function
        self.sde_var.diffusion_function = latent_diffusion_function

        self.y_t, self.paths_y = self.sde_var(y_0, num_steps)

        y_t_to_compare = self.paths_y[ops.index[t_indices]]
        metrics = dict()
        metrics["reco"] = \
            np.mean(
                aux_math.log_prob_multivariate_normal(
                    y_t_to_compare,
                    aux_math.diag(np.sqrt(nn.softplus(self.signal_variance.build(param_dict["likelihood"])))),
                    y_input[t_indices]))

        self.get_metrics(metrics, gp_matrices, param_dict)
        return -metrics["elbo"], metrics
def dequantize(rng: jnp.ndarray, deq_params: Sequence[jnp.ndarray],
               deq_fn: Callable, xsph: jnp.ndarray,
               num_samples: int) -> Tuple[jnp.ndarray]:
    """Dequantize observations on the sphere into the ambient space.

    Args:
        rng: Pseudo-random number generator seed.
        deq_params: Parameters of the mean and scale functions used in
            the log-normal dequantizer.
        deq_fn: Function that computes the mean and scale of the dequantization
            distribution.
        xsph: Observations on the sphere.
        num_samples: Number of dequantization samples.

    Returns:
        out: A tuple containing the dequantized samples and the log-density of
            the dequantized samples.

    """
    # Dequantization parameters.
    mu, sigma = deq_fn(deq_params, xsph)
    mu = nn.softplus(mu)
    # Random samples for dequantization.
    rng, rng_rad = random.split(rng, 2)
    mu, sigma = mu[..., 0], sigma[..., 0]
    rad = pd.lognormal.rvs(rng_rad, mu, sigma,
                           [num_samples] + list(xsph.shape[:-1]))
    xdeq = rad[..., jnp.newaxis] * xsph
    # Dequantization density calculation.
    ldj = -(num_dims - 1) * jnp.log(rad)
    logdens = pd.lognormal.logpdf(rad, mu, sigma) + ldj
    return xdeq, logdens
Exemple #16
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     logits = self.logits
     dtype = jnp.result_type(logits)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / -softplus(logits))
def deep_linear_model_loss(w, x, y):
    """Returns the loss given an input and its ground-truth label."""
    z = x
    for i in range(len(w)):
        z = w[i].T @ z
    loss = nn.softplus(-y * z).mean()
    return loss
Exemple #18
0
    def forward_log_det_jacobian(self, inputs: Array, **kwargs) -> Tuple[Array, Array]:

        # log abs det, all zeros
        logabsdet = mixture_gaussian_log_pdf(
            inputs, self.prior_logits, self.means, softplus(self.log_scales),
        )

        return logabsdet  # .sum(axis=1)
Exemple #19
0
    def inverse(self, inputs: Array, **kwargs) -> Tuple[Array, Array]:

        # transformation
        outputs = mixture_gaussian_invcdf_vectorized(
            inputs, self.prior_logits, self.means, softplus(self.log_scales),
        )

        return outputs  # .sum(axis=1)
def conv_linear_model_loss(w, x, y):
    """Returns the loss given an input and its ground-truth label."""
    z = x
    for i in range(len(w) - 1):
        z = circ_1d_conv(w[i], z)
    z = w[-1].T @ z
    loss = nn.softplus(-y * z).mean()
    return loss
Exemple #21
0
 def spline_params(params, upper):
     outputs = network_apply_fun(params, upper)
     outputs = np.reshape(outputs, [-1, lower_dim, 3 * K - 1])
     W, H, D = np.split(outputs, [K, 2 * K], axis=2)
     W = 2 * B * softmax(W)
     H = 2 * B * softmax(H)
     D = softplus(D)
     return W, H, D
Exemple #22
0
 def forward(self, inputs: Array, **kwargs) -> Array:
     # forward transformation with batch dimension
     outputs = mixture_logistic_cdf(
         inputs,
         self.prior_logits,
         self.means,
         softplus(self.log_scales),
     )
     return outputs
Exemple #23
0
def discretized_mix_logistic_loss(theta, y, num_class=256, log_scale_min=-7.):
    """
    Discretized mixture of logistic distributions loss
    :param theta: B x T x 3 * nr_mix
    :param y:  B x T x 1
    """
    theta_shape = theta.shape

    nr_mix = theta_shape[2] // 3

    # unpack parameters
    means = theta[:, :, :nr_mix]
    log_scales = np.maximum(theta[:, :, nr_mix:2 * nr_mix], log_scale_min)
    logit_probs = theta[:, :, nr_mix * 2:nr_mix * 3]

    # B x T x 1 => B x T x nr_mix
    y = np.broadcast_to(y, y.shape[:-1] + (nr_mix, ))

    centered_y = y - means
    inv_stdv = np.exp(-log_scales)
    plus_in = inv_stdv * (centered_y + 1. / (num_class - 1))
    cdf_plus = sigmoid(plus_in)
    min_in = inv_stdv * (centered_y - 1. / (num_class - 1))
    cdf_min = sigmoid(min_in)

    # log probability for edge case of 0 (before scaling):
    log_cdf_plus = plus_in - softplus(plus_in)
    # log probability for edge case of 255 (before scaling):
    log_one_minus_cdf_min = -softplus(min_in)

    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_y

    log_pdf_mid = mid_in - log_scales - 2. * softplus(mid_in)

    log_probs = np.where(
        y < -0.999, log_cdf_plus,
        np.where(
            y > 0.999, log_one_minus_cdf_min,
            np.where(cdf_delta > 1e-5, np.log(np.maximum(cdf_delta, 1e-12)),
                     log_pdf_mid - np.log((num_class - 1) / 2))))

    log_probs = log_probs + log_softmax(logit_probs)
    return -np.sum(logsumexp(log_probs, axis=-1), axis=-1)
 def inverse_fun(params, z):
     log_det = np.zeros(z.shape[0])
     idx = dim // 2
     lower, upper = z[:, :idx], z[:, idx:]
     out = f2_apply_fun(f2_params, upper).reshape(-1, dim // 2, 3 * K - 1)
     W, H, D = onp.array_split(out, 3, axis=2)
     W, H = nn.softmax(W, axis=2), nn.softmax(H, axis=2)
     W, H = 2 * B * W, 2 * B * H
     D = nn.softplus(D)
     lower, ld = unconstrained_RQS(lower, W, H, D, inverse=True, tail_bound=B)
     log_det += np.sum(ld, axis=1)
     out = f1_apply_fun(f1_params, lower).reshape(-1, dim // 2, 3 * K - 1)
     W, H, D = onp.array_split(out, 3, axis=2)
     W, H = nn.softmax(W, axis=2), nn.softmax(H, axis=2)
     W, H = 2 * B * W, 2 * B * H
     D = nn.softplus(D)
     upper, ld = unconstrained_RQS(upper, W, H, D, inverse=True, tail_bound=B)
     log_det += np.sum(ld, axis=1)
     return np.concatenate([lower, upper], axis=1), log_det.reshape((z.shape[0],))
Exemple #25
0
 def apply_fun(params, inputs, **kwargs):
     x, logdet = inputs
     out = jnp.tanh(x)
     tanh_logdet = -2 * (x + softplus(-2 * x) - jnp.log(2.))
     # logdet.shape = batch_shape + (num_blocks, in_factor, out_factor)
     # tanh_logdet.shape = batch_shape + (num_blocks x out_factor,)
     # so we need to reshape tanh_logdet to: batch_shape + (num_blocks, 1, out_factor)
     tanh_logdet = tanh_logdet.reshape(logdet.shape[:-2] +
                                       (1, logdet.shape[-1]))
     return out, logdet + tanh_logdet
Exemple #26
0
 def __call__(self, x):
     x = MLP(
         output_dim=2 * self.output_dim,
         hidden_units=self.hidden_units,
         hidden_activation=partial(nn.leaky_relu,
                                   negative_slope=self.negative_slope),
     )(x)
     mean, log_std = jnp.split(x, 2, axis=1)
     std = nn.softplus(log_std) + 1e-5
     return mean, std
    def __init__(self, config):
        super(DataGeneratorToyDatasetLinearCombinationWithSoftplus, self).__init__(config)

        np.random.seed(0)
        intermediate_state = 2 * np.prod(self.config["state_size"])
        self.matrix_a = np.random.normal(size=(intermediate_state, 2))
        self.matrix_b = np.random.normal(size=(np.prod(self.config["state_size"]), intermediate_state))

        t_steps = self.config["num_steps"]
        delta_t = self.config["delta_t"]
        frequency = 2 / (delta_t * 30)  # The 30 is to have a decent rotation at each time step
        trig_arg = frequency * delta_t

        rotation_matrix = np.array([[np.cos(trig_arg), -np.sin(trig_arg)], [np.sin(trig_arg), np.cos(trig_arg)]])

        x_steps_train = np.zeros((t_steps+1, self.n_points, 2))
        x_steps_test = np.zeros((t_steps+1, self.n_points_test, 2))

        x_0_train, _ = make_moons(self.n_points, noise=.05)
        x_0_test, _ = make_moons(self.n_points_test, noise=.05)

        x_steps_train[0] = x_0_train
        x_steps_test[0] = x_0_test

        for i in range(t_steps):
            noise_train = np.random.normal(scale=0.03, size=[self.n_points, 2])
            noise_test = np.random.normal(scale=0.03, size=[self.n_points_test, 2])
            x_steps_train[i+1] = np.einsum("ab,cb->ca", rotation_matrix, x_steps_train[i]) + noise_train
            x_steps_test[i+1] = np.einsum("ab,cb->ca", rotation_matrix, x_steps_test[i]) + noise_test

        self.input_train_pca_np = x_steps_train
        self.input_test_pca_np = x_steps_test

        self.input_train_np = \
            np.einsum('ab,cdb->cda',
                      self.matrix_b,
                      softplus(np.einsum("ab,cdb->cda", self.matrix_a, self.input_train_pca_np)))
        self.input_test_np = \
            np.einsum('ab,cdb->cda',
                      self.matrix_b,
                      softplus(np.einsum("ab,cdb->cda", self.matrix_a, self.input_test_pca_np)))
        self.input_t = list(range(self.config["num_steps"]+1))
    def gaussian_process(params,
                         predictors,
                         target,
                         test_predictors=None,
                         compute_marginal_likelihood=False):
        noise = softplus(params['noise'])
        ampl = softplus(params['amplitude'])
        scale = softplus(params['lengthscale'])
        target = target - np.mean(target)
        predictors = predictors / scale

        @jax.jit
        def exp_squared(x, y):
            return np.exp(-np.sum((x - y)**2))

        base_cov = build_covariance_matrix(exp_squared, predictors)
        train_cov = ampl * base_cov + np.eye(x.shape[0]) * (noise + 1.0e-6)
        LDLT = scipy.linalg.cholesky(train_cov, lower=True)
        kinvy = scipy.linalg.solve_triangular(
            LDLT.T, scipy.linalg.solve_triangular(LDLT, target, lower=True))
        if compute_marginal_likelihood:
            lbda = np.sum(-0.5 * np.dot(target.T, kinvy) -
                          np.sum(np.log(np.diag(LDLT))) -
                          x.shape[0] * 0.5 * np.log(2 * pi))
            return -(lbda - np.sum(-0.5 * np.log(2 * pi) - np.log(ampl)**2))

        if test_predictors is not None:
            test_predictors = test_predictors / scale
            cross_cov = ampl * build_covariance_matrix(exp_squared, predictors,
                                                       test_predictors)
        else:
            cross_cov = base_cov

        LDLT = scipy.linalg.cholesky(train_cov, lower=True)
        mu = np.dot(cross_cov.T, kinvy) + np.mean(target)
        v = scipy.linalg.solve_triangular(LDLT, cross_cov, lower=True)
        cov = ampl * build_covariance_matrix(exp_squared,
                                             test_predictors) - np.dot(v.T, v)
        return mu, cov
Exemple #29
0
    def transition_fn(carry, t):
        x_prev = carry

        dyn_gamma = npyro.deterministic('dyn_gamma', nn.softplus(mu + x_prev))

        logs = logits((beliefs[0][t], beliefs[1][t]),
                      jnp.expand_dims(dyn_gamma, -1), jnp.expand_dims(U, -2))

        npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t]))
        noise = npyro.sample('dw', dist.Normal(0., 1.))

        x_next = rho * x_prev + sigma * noise

        return x_next, None
Exemple #30
0
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # NB: because domain and codomain are two spaces with different dimensions, determinant of
        # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the
        # flatten lower triangular part of `y`.

        # stick_breaking_logdet = log(y / r) = log(z_cumprod)  (modulo right shifted)
        z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
        # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array
        z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2)
        stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1)

        tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.), axis=-1)
        return stick_breaking_logdet + tanh_logdet