def log_prob(self, value): log_prob_minus_log_gate = -self.gate_logits + self.base_dist.log_prob( value) log_gate = -softplus(-self.gate_logits) log_prob = log_prob_minus_log_gate + log_gate zero_log_prob = softplus(log_prob_minus_log_gate) + log_gate return jnp.where(value == 0, zero_log_prob, log_prob)
def latent_diffusion_function(self, x, t, param_dict, gp_matrices): if self.config["constant_diffusion"]: concentration, rate = self.gamma_params.build( param_dict["gamma_params"]) self.key, subkey = random.split(self.key) inverse_lambdas = numpyro.sample("inverse_lambdas", InverseGamma( nn.softplus(concentration), nn.softplus(rate)), rng_key=subkey) return np.tile( np.expand_dims(aux_math.diag(inverse_lambdas), axis=0), [x.shape[0], 1, 1]) else: if self.config["time_dependent_gp"]: time = np.ones(shape=(x.shape[0], 1)) * t y = np.concatenate((x, time), axis=1) return aux_math.diag( np.transpose( self.sde_gp_diffusion(y, param_dict["sde_gp"], gp_matrices))) return aux_math.diag( np.transpose( self.sde_gp_diffusion(x, param_dict["sde_gp"], gp_matrices)))
def sample(self, state, model_args, model_kwargs): i, x, x_pe, x_grad, _, mean_accept_prob, adapt_state, rng_key = state x_flat, unravel_fn = ravel_pytree(x) x_grad_flat, _ = ravel_pytree(x_grad) shape = jnp.shape(x_flat) rng_key, key_normal, key_bernoulli, key_accept = random.split( rng_key, 4) mass_sqrt_inv = adapt_state.mass_matrix_sqrt_inv x_grad_flat_scaled = (mass_sqrt_inv @ x_grad_flat if self._dense_mass else mass_sqrt_inv * x_grad_flat) # Generate proposal y. z = adapt_state.step_size * random.normal(key_normal, shape) p = expit(-z * x_grad_flat_scaled) b = jnp.where(random.uniform(key_bernoulli, shape) < p, 1.0, -1.0) dx_flat = b * z dx_flat_scaled = (mass_sqrt_inv.T @ dx_flat if self._dense_mass else mass_sqrt_inv * dx_flat) y_flat = x_flat + dx_flat_scaled y = unravel_fn(y_flat) y_pe, y_grad = jax.value_and_grad(self._potential_fn)(y) y_grad_flat, _ = ravel_pytree(y_grad) y_grad_flat_scaled = (mass_sqrt_inv @ y_grad_flat if self._dense_mass else mass_sqrt_inv * y_grad_flat) log_accept_ratio = (x_pe - y_pe + jnp.sum( softplus(dx_flat * x_grad_flat_scaled) - softplus(-dx_flat * y_grad_flat_scaled))) accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.0) x, x_flat, pe, x_grad = jax.lax.cond( random.bernoulli(key_accept, accept_prob), (y, y_flat, y_pe, y_grad), identity, (x, x_flat, x_pe, x_grad), identity, ) # do not update adapt_state after warmup phase adapt_state = jax.lax.cond( i < self._num_warmup, (i, accept_prob, (x, ), adapt_state), lambda args: self._wa_update(*args), adapt_state, identity, ) itr = i + 1 n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup) mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n return BarkerMHState(itr, x, pe, x_grad, accept_prob, mean_accept_prob, adapt_state, rng_key)
def apply_fun(params, inputs, **kwargs): # a * f(x) + (1 - a) * x (fx, logdet), (x, _) = inputs gate = sigmoid(params) out = gate * fx + (1 - gate) * x logdet = softplus(logdet + params) - softplus(params) return out, logdet
def __call__(self, y, sc, multiplicative_factor=None): net = self.predict(sc["encoder_params"], y) if multiplicative_factor is None: scale_tril = aux_math.diag(nn.softplus(net[..., self.output_dims:])) else: scale_tril = np.einsum( "ab,cbd->cad", multiplicative_factor, aux_math.diag(nn.softplus(net[..., self.output_dims:]))) return net[..., :self.output_dims], scale_tril
def gradient_step(i, state, mod): params = get_params(state) mod.prior.hyp = params[0] mod.likelihood.hyp = params[1] neg_log_marg_lik, gradients = mod.run() # neg_log_marg_lik, gradients = mod.run_two_stage() # <-- less elegant but reduces compile time print( 'iter %2d: var_f=%1.2f len_f=%1.2f, nlml=%2.2f' % (i, softplus(params[0][0]), softplus(params[0][1]), neg_log_marg_lik)) return opt_update(i, gradients, state)
def gradient_step(i, state, model): params = get_params(state) model.prior.hyp = params[0] model.likelihood.hyp = params[1] # neg_log_marg_lik, gradients = model.run_model() neg_log_marg_lik, gradients = model.neg_log_marg_lik() print( 'iter %2d: var_f=%1.2f len_f=%1.2f, nlml=%2.2f' % (i, softplus(params[0][0]), softplus(params[0][1]), neg_log_marg_lik)) return opt_update(i, gradients, state)
def __call__(self, inputs): inputs = np.clip(inputs, self.eps.value, 1 - self.eps.value) outputs = (1 / self.temperature.value) * (np.log(inputs) - np.log1p(-inputs)) logabsdet = -(np.log(self.temperature.value) - softplus(-self.temperature.value * outputs) - softplus(self.temperature.value * outputs)) return outputs, logabsdet.sum(axis=1)
def inverse_and_log_det(self, inputs: Array, **kwargs) -> Tuple[Array, Array]: # transformation outputs = mixture_gaussian_invcdf_vectorized( inputs, self.prior_logits, self.means, softplus(self.log_scales), ) # log abs det, all zeros logabsdet = mixture_gaussian_log_pdf( outputs, self.prior_logits, self.means, softplus(self.log_scales), ) return outputs, logabsdet # .sum(axis=1)
def forward_and_log_det(self, inputs: Array, **kwargs) -> Tuple[Array, Array]: # forward transformation with batch dimension outputs = mixture_gaussian_cdf( inputs, self.prior_logits, self.means, softplus(self.log_scales), ) # log abs det, all zeros logabsdet = mixture_gaussian_log_pdf( inputs, self.prior_logits, self.means, softplus(self.log_scales), ) return outputs, logabsdet # .sum(axis=1)
def conditional_params_from_outputs(image, theta): """ Maps image and model output theta to conditional parameters for a mixture of nr_mix logistics. If the input shapes are image.shape == (h, w, c) theta.shape == (h, w, 10 * nr_mix) the output shapes will be means.shape == inv_scales.shape == (nr_mix, h, w, c) logit_probs.shape == (nr_mix, h, w) """ nr_mix = 10 logit_probs, theta = np.split(theta, [nr_mix], axis=-1) logit_probs = np.moveaxis(logit_probs, -1, 0) theta = np.moveaxis(np.reshape(theta, image.shape + (-1, )), -1, 0) unconditioned_means, log_scales, coeffs = np.split(theta, 3) coeffs = np.tanh(coeffs) # now condition the means for the last 2 channels mean_red = unconditioned_means[..., 0] mean_green = unconditioned_means[..., 1] + coeffs[..., 0] * image[..., 0] mean_blue = (unconditioned_means[..., 2] + coeffs[..., 1] * image[..., 0] + coeffs[..., 2] * image[..., 1]) means = np.stack((mean_red, mean_green, mean_blue), axis=-1) inv_scales = softplus(log_scales) return means, inv_scales, logit_probs
def transition_fn(carry, t): lam_prev, x_prev = carry gamma = npyro.deterministic('gamma', nn.softplus(mu + x_prev)) U = jnp.log(lam_prev) - jnp.log(lam_prev.sum(-1, keepdims=True)) logs = logits((beliefs[0][:, t], beliefs[1][:, t]), jnp.expand_dims(gamma, -1), jnp.expand_dims(U, -2)) lam_next = npyro.deterministic( 'lams', lam_prev + nn.one_hot(beliefs[2][t], 4) * jnp.expand_dims(mask[t] * eta, -1)) mixing_dist = dist.CategoricalProbs(weights) component_dist = dist.CategoricalLogits(logs.swapaxes(0, 1)).mask( mask[t][..., None]) with npyro.plate('subjects', N): y = npyro.sample( 'y', dist.MixtureSameFamily(mixing_dist, component_dist)) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return (lam_next, x_next), None
def logistic_log_pdf(x: Array, mean: Array, scale: Array) -> Array: """Element-wise log PDF of the logistic distribution Args: x (Array): feature vector to be transformed mean (Array) : mean for the features scale (Array) : scale for features Returns: log_prob (Array): log probability of the distribution """ # change of variables z = (x - mean) / scale # log probability # log_prob = z -jnp.log(scale) - 2 * jax.nn.softplus(z) # log_prob = jax.scipy.stats.logistic.logpdf(z) # Original Author Implementation log_prob = z - jnp.log(scale) - 2 * softplus(z) # # distrax implementation # log_prob = -z - 2.0 * softplus(-z) - jnp.log(scale) return log_prob
def loss(self, y_input, t_indices: list, param_dict, num_steps): if type(t_indices) is not list: raise TypeError("Time indices object must be a list") # if self.config["mapping"] == "neural_ode_with_softplus" and 0 not in t_indices: # print(f"For mapping {self.config['mapping']}, the initial point (index 0) is required.") y_0 = y_input[0].reshape(y_input[0].shape[0], -1) gp_matrices, latent_drift_function, latent_diffusion_function = self.build( param_dict) self.sde_var.drift_function = latent_drift_function self.sde_var.diffusion_function = latent_diffusion_function self.y_t, self.paths_y = self.sde_var(y_0, num_steps) y_t_to_compare = self.paths_y[ops.index[t_indices]] metrics = dict() metrics["reco"] = \ np.mean( aux_math.log_prob_multivariate_normal( y_t_to_compare, aux_math.diag(np.sqrt(nn.softplus(self.signal_variance.build(param_dict["likelihood"])))), y_input[t_indices])) self.get_metrics(metrics, gp_matrices, param_dict) return -metrics["elbo"], metrics
def dequantize(rng: jnp.ndarray, deq_params: Sequence[jnp.ndarray], deq_fn: Callable, xsph: jnp.ndarray, num_samples: int) -> Tuple[jnp.ndarray]: """Dequantize observations on the sphere into the ambient space. Args: rng: Pseudo-random number generator seed. deq_params: Parameters of the mean and scale functions used in the log-normal dequantizer. deq_fn: Function that computes the mean and scale of the dequantization distribution. xsph: Observations on the sphere. num_samples: Number of dequantization samples. Returns: out: A tuple containing the dequantized samples and the log-density of the dequantized samples. """ # Dequantization parameters. mu, sigma = deq_fn(deq_params, xsph) mu = nn.softplus(mu) # Random samples for dequantization. rng, rng_rad = random.split(rng, 2) mu, sigma = mu[..., 0], sigma[..., 0] rad = pd.lognormal.rvs(rng_rad, mu, sigma, [num_samples] + list(xsph.shape[:-1])) xdeq = rad[..., jnp.newaxis] * xsph # Dequantization density calculation. ldj = -(num_dims - 1) * jnp.log(rad) logdens = pd.lognormal.logpdf(rad, mu, sigma) + ldj return xdeq, logdens
def sample(self, key, sample_shape=()): assert is_prng_key(key) logits = self.logits dtype = jnp.result_type(logits) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / -softplus(logits))
def deep_linear_model_loss(w, x, y): """Returns the loss given an input and its ground-truth label.""" z = x for i in range(len(w)): z = w[i].T @ z loss = nn.softplus(-y * z).mean() return loss
def forward_log_det_jacobian(self, inputs: Array, **kwargs) -> Tuple[Array, Array]: # log abs det, all zeros logabsdet = mixture_gaussian_log_pdf( inputs, self.prior_logits, self.means, softplus(self.log_scales), ) return logabsdet # .sum(axis=1)
def inverse(self, inputs: Array, **kwargs) -> Tuple[Array, Array]: # transformation outputs = mixture_gaussian_invcdf_vectorized( inputs, self.prior_logits, self.means, softplus(self.log_scales), ) return outputs # .sum(axis=1)
def conv_linear_model_loss(w, x, y): """Returns the loss given an input and its ground-truth label.""" z = x for i in range(len(w) - 1): z = circ_1d_conv(w[i], z) z = w[-1].T @ z loss = nn.softplus(-y * z).mean() return loss
def spline_params(params, upper): outputs = network_apply_fun(params, upper) outputs = np.reshape(outputs, [-1, lower_dim, 3 * K - 1]) W, H, D = np.split(outputs, [K, 2 * K], axis=2) W = 2 * B * softmax(W) H = 2 * B * softmax(H) D = softplus(D) return W, H, D
def forward(self, inputs: Array, **kwargs) -> Array: # forward transformation with batch dimension outputs = mixture_logistic_cdf( inputs, self.prior_logits, self.means, softplus(self.log_scales), ) return outputs
def discretized_mix_logistic_loss(theta, y, num_class=256, log_scale_min=-7.): """ Discretized mixture of logistic distributions loss :param theta: B x T x 3 * nr_mix :param y: B x T x 1 """ theta_shape = theta.shape nr_mix = theta_shape[2] // 3 # unpack parameters means = theta[:, :, :nr_mix] log_scales = np.maximum(theta[:, :, nr_mix:2 * nr_mix], log_scale_min) logit_probs = theta[:, :, nr_mix * 2:nr_mix * 3] # B x T x 1 => B x T x nr_mix y = np.broadcast_to(y, y.shape[:-1] + (nr_mix, )) centered_y = y - means inv_stdv = np.exp(-log_scales) plus_in = inv_stdv * (centered_y + 1. / (num_class - 1)) cdf_plus = sigmoid(plus_in) min_in = inv_stdv * (centered_y - 1. / (num_class - 1)) cdf_min = sigmoid(min_in) # log probability for edge case of 0 (before scaling): log_cdf_plus = plus_in - softplus(plus_in) # log probability for edge case of 255 (before scaling): log_one_minus_cdf_min = -softplus(min_in) cdf_delta = cdf_plus - cdf_min # probability for all other cases mid_in = inv_stdv * centered_y log_pdf_mid = mid_in - log_scales - 2. * softplus(mid_in) log_probs = np.where( y < -0.999, log_cdf_plus, np.where( y > 0.999, log_one_minus_cdf_min, np.where(cdf_delta > 1e-5, np.log(np.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log((num_class - 1) / 2)))) log_probs = log_probs + log_softmax(logit_probs) return -np.sum(logsumexp(log_probs, axis=-1), axis=-1)
def inverse_fun(params, z): log_det = np.zeros(z.shape[0]) idx = dim // 2 lower, upper = z[:, :idx], z[:, idx:] out = f2_apply_fun(f2_params, upper).reshape(-1, dim // 2, 3 * K - 1) W, H, D = onp.array_split(out, 3, axis=2) W, H = nn.softmax(W, axis=2), nn.softmax(H, axis=2) W, H = 2 * B * W, 2 * B * H D = nn.softplus(D) lower, ld = unconstrained_RQS(lower, W, H, D, inverse=True, tail_bound=B) log_det += np.sum(ld, axis=1) out = f1_apply_fun(f1_params, lower).reshape(-1, dim // 2, 3 * K - 1) W, H, D = onp.array_split(out, 3, axis=2) W, H = nn.softmax(W, axis=2), nn.softmax(H, axis=2) W, H = 2 * B * W, 2 * B * H D = nn.softplus(D) upper, ld = unconstrained_RQS(upper, W, H, D, inverse=True, tail_bound=B) log_det += np.sum(ld, axis=1) return np.concatenate([lower, upper], axis=1), log_det.reshape((z.shape[0],))
def apply_fun(params, inputs, **kwargs): x, logdet = inputs out = jnp.tanh(x) tanh_logdet = -2 * (x + softplus(-2 * x) - jnp.log(2.)) # logdet.shape = batch_shape + (num_blocks, in_factor, out_factor) # tanh_logdet.shape = batch_shape + (num_blocks x out_factor,) # so we need to reshape tanh_logdet to: batch_shape + (num_blocks, 1, out_factor) tanh_logdet = tanh_logdet.reshape(logdet.shape[:-2] + (1, logdet.shape[-1])) return out, logdet + tanh_logdet
def __call__(self, x): x = MLP( output_dim=2 * self.output_dim, hidden_units=self.hidden_units, hidden_activation=partial(nn.leaky_relu, negative_slope=self.negative_slope), )(x) mean, log_std = jnp.split(x, 2, axis=1) std = nn.softplus(log_std) + 1e-5 return mean, std
def __init__(self, config): super(DataGeneratorToyDatasetLinearCombinationWithSoftplus, self).__init__(config) np.random.seed(0) intermediate_state = 2 * np.prod(self.config["state_size"]) self.matrix_a = np.random.normal(size=(intermediate_state, 2)) self.matrix_b = np.random.normal(size=(np.prod(self.config["state_size"]), intermediate_state)) t_steps = self.config["num_steps"] delta_t = self.config["delta_t"] frequency = 2 / (delta_t * 30) # The 30 is to have a decent rotation at each time step trig_arg = frequency * delta_t rotation_matrix = np.array([[np.cos(trig_arg), -np.sin(trig_arg)], [np.sin(trig_arg), np.cos(trig_arg)]]) x_steps_train = np.zeros((t_steps+1, self.n_points, 2)) x_steps_test = np.zeros((t_steps+1, self.n_points_test, 2)) x_0_train, _ = make_moons(self.n_points, noise=.05) x_0_test, _ = make_moons(self.n_points_test, noise=.05) x_steps_train[0] = x_0_train x_steps_test[0] = x_0_test for i in range(t_steps): noise_train = np.random.normal(scale=0.03, size=[self.n_points, 2]) noise_test = np.random.normal(scale=0.03, size=[self.n_points_test, 2]) x_steps_train[i+1] = np.einsum("ab,cb->ca", rotation_matrix, x_steps_train[i]) + noise_train x_steps_test[i+1] = np.einsum("ab,cb->ca", rotation_matrix, x_steps_test[i]) + noise_test self.input_train_pca_np = x_steps_train self.input_test_pca_np = x_steps_test self.input_train_np = \ np.einsum('ab,cdb->cda', self.matrix_b, softplus(np.einsum("ab,cdb->cda", self.matrix_a, self.input_train_pca_np))) self.input_test_np = \ np.einsum('ab,cdb->cda', self.matrix_b, softplus(np.einsum("ab,cdb->cda", self.matrix_a, self.input_test_pca_np))) self.input_t = list(range(self.config["num_steps"]+1))
def gaussian_process(params, predictors, target, test_predictors=None, compute_marginal_likelihood=False): noise = softplus(params['noise']) ampl = softplus(params['amplitude']) scale = softplus(params['lengthscale']) target = target - np.mean(target) predictors = predictors / scale @jax.jit def exp_squared(x, y): return np.exp(-np.sum((x - y)**2)) base_cov = build_covariance_matrix(exp_squared, predictors) train_cov = ampl * base_cov + np.eye(x.shape[0]) * (noise + 1.0e-6) LDLT = scipy.linalg.cholesky(train_cov, lower=True) kinvy = scipy.linalg.solve_triangular( LDLT.T, scipy.linalg.solve_triangular(LDLT, target, lower=True)) if compute_marginal_likelihood: lbda = np.sum(-0.5 * np.dot(target.T, kinvy) - np.sum(np.log(np.diag(LDLT))) - x.shape[0] * 0.5 * np.log(2 * pi)) return -(lbda - np.sum(-0.5 * np.log(2 * pi) - np.log(ampl)**2)) if test_predictors is not None: test_predictors = test_predictors / scale cross_cov = ampl * build_covariance_matrix(exp_squared, predictors, test_predictors) else: cross_cov = base_cov LDLT = scipy.linalg.cholesky(train_cov, lower=True) mu = np.dot(cross_cov.T, kinvy) + np.mean(target) v = scipy.linalg.solve_triangular(LDLT, cross_cov, lower=True) cov = ampl * build_covariance_matrix(exp_squared, test_predictors) - np.dot(v.T, v) return mu, cov
def transition_fn(carry, t): x_prev = carry dyn_gamma = npyro.deterministic('dyn_gamma', nn.softplus(mu + x_prev)) logs = logits((beliefs[0][t], beliefs[1][t]), jnp.expand_dims(dyn_gamma, -1), jnp.expand_dims(U, -2)) npyro.sample('y', dist.CategoricalLogits(logs).mask(mask[t])) noise = npyro.sample('dw', dist.Normal(0., 1.)) x_next = rho * x_prev + sigma * noise return x_next, None
def log_abs_det_jacobian(self, x, y, intermediates=None): # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. # stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted) z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1) # by taking diagonal=-2, we don't need to shift z_cumprod to the right # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2) stick_breaking_logdet = 0.5 * jnp.sum(jnp.log(z1m_cumprod_tril), axis=-1) tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.), axis=-1) return stick_breaking_logdet + tanh_logdet