def calculate_likelihood(x, mu, a, theta, y): margin, was_retirement, bo5 = y sigma_obs = (1 - bo5) * theta["sigma_obs"] + bo5 * theta["sigma_obs_bo5"] # If it wasn't a retirement: margin_prob = norm.logpdf(margin, theta["a1"] * (a @ x) + theta["a2"], sigma_obs) win_prob = jnp.log(expit((1 + theta["bo5_factor"] * bo5) * b * a @ x)) # Otherwise: # Only the loser's skill matters: n_skills = x.shape[0] loser_x = x[n_skills // 2:] loser_mu = mu[n_skills // 2:] loser_a = -a[n_skills // 2:] loser_expected_skill = loser_a @ loser_mu loser_actual_skill = loser_a @ loser_x full_ret_factor = theta["ret_factor"] * ( 1 - theta["skill_factor"] * expit(b * loser_expected_skill)) ret_prob = expit(full_ret_factor * (loser_expected_skill - loser_actual_skill) + theta["ret_intercept"]) prob_retirement = jnp.log(ret_prob) prob_not_retirement = jnp.log(1 - ret_prob) return (1 - was_retirement) * (margin_prob + win_prob + prob_not_retirement ) + was_retirement * prob_retirement
def print_results(coef: jnp.ndarray, interval_size: float = 0.95) -> None: """ Print the confidence interval for the effect size with interval_size probability mass. """ baseline_response = expit(coef[:, 0]) response_with_calls = expit(coef[:, 0] + coef[:, 1]) impact_on_probability = hpdi(response_with_calls - baseline_response, prob=interval_size) effect_of_gender = hpdi(coef[:, 2], prob=interval_size) print( f"There is a {interval_size * 100}% probability that calling customers " "increases the chance they'll make a purchase by " f"{(100 * impact_on_probability[0]):.2} to {(100 * impact_on_probability[1]):.2} percentage points." ) print( f"There is a {interval_size * 100}% probability the effect of gender on the log odds of conversion " f"lies in the interval ({effect_of_gender[0]:.2}, {effect_of_gender[1]:.2f})." " Since this interval contains 0, we can conclude gender does not impact the conversion rate." )
def rvs(self, *args, **kwargs): if self.is_logits: # convert logits to probs if args: args = list(args) args[0] = expit(args[0]) else: kwargs['p'] = expit(kwargs['p']) return super(_bernoulli_gen, self).rvs(*args, **kwargs)
def args_maker(): k, n, logit, loc = map(rng, shapes, dtypes) k = np.floor(np.abs(k)) n = np.ceil(np.abs(n)) p = expit(logit) loc = np.floor(loc) return [k, n, p, loc]
def sample(self, state, model_args, model_kwargs): i, x, x_pe, x_grad, _, mean_accept_prob, adapt_state, rng_key = state x_flat, unravel_fn = ravel_pytree(x) x_grad_flat, _ = ravel_pytree(x_grad) shape = jnp.shape(x_flat) rng_key, key_normal, key_bernoulli, key_accept = random.split( rng_key, 4) mass_sqrt_inv = adapt_state.mass_matrix_sqrt_inv x_grad_flat_scaled = (mass_sqrt_inv @ x_grad_flat if self._dense_mass else mass_sqrt_inv * x_grad_flat) # Generate proposal y. z = adapt_state.step_size * random.normal(key_normal, shape) p = expit(-z * x_grad_flat_scaled) b = jnp.where(random.uniform(key_bernoulli, shape) < p, 1.0, -1.0) dx_flat = b * z dx_flat_scaled = (mass_sqrt_inv.T @ dx_flat if self._dense_mass else mass_sqrt_inv * dx_flat) y_flat = x_flat + dx_flat_scaled y = unravel_fn(y_flat) y_pe, y_grad = jax.value_and_grad(self._potential_fn)(y) y_grad_flat, _ = ravel_pytree(y_grad) y_grad_flat_scaled = (mass_sqrt_inv @ y_grad_flat if self._dense_mass else mass_sqrt_inv * y_grad_flat) log_accept_ratio = (x_pe - y_pe + jnp.sum( softplus(dx_flat * x_grad_flat_scaled) - softplus(-dx_flat * y_grad_flat_scaled))) accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.0) x, x_flat, pe, x_grad = jax.lax.cond( random.bernoulli(key_accept, accept_prob), (y, y_flat, y_pe, y_grad), identity, (x, x_flat, x_pe, x_grad), identity, ) # do not update adapt_state after warmup phase adapt_state = jax.lax.cond( i < self._num_warmup, (i, accept_prob, (x, ), adapt_state), lambda args: self._wa_update(*args), adapt_state, identity, ) itr = i + 1 n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup) mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n return BarkerMHState(itr, x, pe, x_grad, accept_prob, mean_accept_prob, adapt_state, rng_key)
def _rvs(self, n, p): if self.is_logits: p = expit(p) # use scipy samplers directly and put the samples on device later. # TODO: use util.binomial instead random_state = onp.random.RandomState(self._random_state) sample = random_state.binomial(n, p, self._size) return device_put(sample)
def log_abs_det_jacobian(self, x, y, intermediates=None): # Ref: https://mc-stan.org/docs/2_19/reference-manual/simplex-transform-section.html # |det|(J) = Product(y * (1 - z)) x = x - jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) z = jnp.clip(expit(x), a_min=jnp.finfo(x.dtype).tiny) # XXX we use the identity 1 - z = z * exp(-x) to not worry about # the case z ~ 1 return jnp.sum(jnp.log(y[..., :-1] * z) - x, axis=-1)
def __init__(self, predictor, cutpoints, validate_args=None): predictor, self.cutpoints = promote_shapes(jnp.expand_dims(predictor, -1), cutpoints) self.predictor = predictor[..., 0] cumulative_probs = expit(cutpoints - predictor) # add two boundary points 0 and 1 pad_width = [(0, 0)] * (jnp.ndim(cumulative_probs) - 1) + [(1, 1)] cumulative_probs = jnp.pad(cumulative_probs, pad_width, constant_values=(0, 1)) probs = cumulative_probs[..., 1:] - cumulative_probs[..., :-1] super(OrderedLogistic, self).__init__(probs, validate_args=validate_args)
def _inverse(self, y): y = y - jnp.expand_dims(self.anchor_point, -1) s = expit(y) # x0 = s0, x1 = s1 - s0, x2 = s2 - s1,..., xn = 1 - s[n-1] # add two boundary points 0 and 1 pad_width = [(0, 0)] * (jnp.ndim(s) - 1) + [(1, 1)] s = jnp.pad(s, pad_width, constant_values=(0, 1)) x = s[..., 1:] - s[..., :-1] return x
def sigmoid(x): r"""Sigmoid activation function. Computes the element-wise function: .. math:: \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} """ return expit(x)
def logistic_normal_integral_approx(mu, var): """ Approximates the logistic normal integral, E[logit^{-1}(X)], where X ~ N(mu, var). """ gamma = np.sqrt(1 + (np.pi * (var / 8))) return expit(mu / gamma)
def sigmoid(x): """Sigmoid activation function. :param x: Input value :type x: :obj:`float` or :obj:`numpy.array` :return: Sigmoid activated value: :math:`sigmoid(x) = \\dfrac{1}{1 + e^{-x}}` :rtype: :obj:`float` """ return expit(x)
def calculate_likelihood(x, mu, a, theta, y): margin = y[0] margin_prob = norm.logpdf(margin, theta["a1"] * (a @ x) + theta["a2"], theta["sigma_obs"]) win_prob = jnp.log(expit(b * a @ x)) return margin_prob + win_prob
def compute_probs(y, x, w): """ returns P(y_generated==y | x, w) y and x are data batches. w is a single parameter array of shape (num_features,)""" y = ((y - 1 / 2) * 2).astype(np.int32) logits = x @ w prob_y = special.expit(logits * y) return prob_y
def sigmoid(x: Array) -> Array: r"""Sigmoid activation function. Computes the element-wise function: .. math:: \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} Args: x : input array """ return expit(x)
def loglikelihood(y, x, w): """ compute log p(y | x, w) for a single parameter w of shape (num_features,) and a batch of data (y, x) of shape (m,) and (m, num_features) log p(y | x, w) = sum_i(logp(yi| xi, w)) """ y = ((y - 1 / 2) * 2).astype(np.int32) logits = x @ w prob_y = special.expit(logits * y) return np.sum(np.log(prob_y))
def transition_fn(carry, y): first_capture_mask, z = carry phi_gamma_t = numpyro.sample("phi_gamma", dist.Normal(0.0, 10.0)) phi_t = expit(phi_beta + phi_gamma_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
def transition_fn(carry, y): first_capture_mask, z = carry with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}): phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma)) phi_t = expit(phi_logit_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None
def unbounded_to_lower_and_upper_bounded(lower, upper): """Construct transform from reals to bounded interval. Args: lower (float): Lower-bound of image of transform. upper (float): Upper-bound of image of transform. """ return ElementwiseMonotonicTransform( forward=lambda u: lower + (upper - lower) * expit(np.asarray(u, np.float64)), backward=lambda x: logit((np.asarray(x, np.float64) - lower) / (upper - lower)), domain=reals, image=RealInterval(lower, upper), )
def calculate_marginal_lik(x, mu, a, cov_mat, theta, y): margin, was_retirement, bo5 = y sigma_obs = (1 - bo5) * theta["sigma_obs"] + bo5 * theta["sigma_obs_bo5"] latent_mean, latent_var = weighted_sum(x, cov_mat, a) # If it wasn't a retirement: margin_prob = norm.logpdf( margin, theta["a1"] * (latent_mean) + theta["a2"], jnp.sqrt(sigma_obs**2 + theta["a1"]**2 * latent_var), ) win_prob = jnp.log( logistic_normal_integral_approx( (1 + theta["bo5_factor"] * bo5) * b * latent_mean, (1 + theta["bo5_factor"] * bo5)**2 * b**2 * latent_var, )) n_skills = x.shape[0] # Otherwise: # Only the loser's skill matters: loser_x = x[n_skills // 2:] loser_mu = mu[n_skills // 2:] loser_a = -a[n_skills // 2:] loser_cov_mat = cov_mat[n_skills // 2:, n_skills // 2:] loser_actual_mean, loser_actual_var = weighted_sum(loser_x, loser_cov_mat, loser_a) loser_expected_skill = loser_a @ loser_mu full_ret_factor = theta["ret_factor"] * ( 1 - theta["skill_factor"] * expit(b * loser_expected_skill)) ret_prob = logistic_normal_integral_approx( full_ret_factor * (loser_expected_skill - loser_actual_mean) + theta["ret_intercept"], full_ret_factor**2 * loser_actual_var, ) prob_retirement = jnp.log(ret_prob) prob_not_retirement = jnp.log(1 - ret_prob) return (1 - was_retirement) * (win_prob + margin_prob + prob_not_retirement ) + (was_retirement * prob_retirement)
def epoch_step(opt_state, key): def train_step(opt_state, batch): opt_state, loss = self.update(next(itercount), opt_state, batch) return opt_state, loss batches = self._make_minibatches(observations, batch_size, key) opt_state, losses = scan(train_step, opt_state, batches) params = get_params(opt_state) mixing_coeffs, probs_logits = params probs = expit(probs_logits) self.model = (softmax(mixing_coeffs), probs) self._probs = probs return opt_state, (losses.mean(), *params, self.responsibilities(observations))
def glmm(dept, male, applications, admit=None): v_mu = numpyro.sample('v_mu', dist.Normal(0, jnp.array([4., 1.]))) sigma = numpyro.sample('sigma', dist.HalfNormal(jnp.ones(2))) L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(2, concentration=2)) scale_tril = sigma[..., jnp.newaxis] * L_Rho # non-centered parameterization num_dept = len(np.unique(dept)) z = numpyro.sample('z', dist.Normal(jnp.zeros((num_dept, 2)), 1)) v = jnp.dot(scale_tril, z.T).T logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male if admit is None: # we use a Delta site to record probs for predictive distribution probs = expit(logits) numpyro.sample('probs', dist.Delta(probs), obs=probs) numpyro.sample('admit', dist.Binomial(applications, logits=logits), obs=admit)
def _discrete_gibbs_proposal_body_fn(z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val): rng_key, z, pe, log_weight_sum = val rng_key, rng_transition = random.split(rng_key) proposal = jnp.where(i >= z_init_flat[idx], i + 1, i) z_new_flat = ops.index_update(z_init_flat, idx, proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_weight_new = pe_init - pe_new # Handles the NaN case... log_weight_new = jnp.where(jnp.isfinite(log_weight_new), log_weight_new, -jnp.inf) # transition_prob = e^weight_new / (e^weight_logsumexp + e^weight_new) transition_prob = expit(log_weight_new - log_weight_sum) z, pe = cond(random.bernoulli(rng_transition, transition_prob), (z_new, pe_new), identity, (z, pe), identity) log_weight_sum = jnp.logaddexp(log_weight_new, log_weight_sum) return rng_key, z, pe, log_weight_sum
def g(params): """differentiable piece of objective in η problem.""" y = params[1:] if log_transform: M = utils.M(self.n, t, np.exp(y)) else: M = utils.M(self.n, t, y) L = self.C @ M ξ = self.mu0 * L.sum(1) if folded: ξ = utils.fold(ξ) else: r = expit(params[0]) ξ = (1 - r) * ξ + r * self.AM_freq @ ξ loss_term = loss(np.squeeze(ξ), x) y_delta = y - y_ref ridge_term = (ridge_penalty / 2) * (y_delta.T @ Γ @ y_delta) return loss_term + ridge_term
def loss_fn(self, params, batch): ''' Calculates expected mean negative loglikelihood. Parameters ---------- params : tuple Consists of mixing coefficients and probabilities of the Bernoulli distribution respectively. batch : array The subset of observations Returns ------- * int Negative log likelihood ''' mixing_coeffs, probs = params self.model = (softmax(mixing_coeffs), expit(probs)) return -self.expected_log_likelihood(batch) / len(batch)
def _newton_iteration(y_train, K, f): pi = expit(f) W = pi * (1 - pi) # Line 5 W_sr = np.sqrt(W) W_sr_K = W_sr[:, np.newaxis] * K B = np.eye(W.shape[0]) + W_sr_K * W_sr L = cholesky(B, lower=True) # Line 6 b = W * f + (y_train - pi) # Line 7 a = b - W_sr * cho_solve((L, True), W_sr_K.dot(b)) # Line 8 f = K.dot(a) # Line 10: Compute log marginal likelihood in loop and use as # convergence criterion lml = -0.5 * a.T.dot(f) \ - np.log1p(np.exp(-(y_train * 2 - 1) * f)).sum() \ - np.log(np.diag(L)).sum() return lml, f, (pi, W_sr, L, b, a)
def _uniform_transition_kernel(current_tree, new_tree): # This function computes transition prob for subtrees (ref [2], section A.3.1). # e^new_weight / (e^new_weight + e^current_weight) transition_prob = expit(new_tree.weight - current_tree.weight) return transition_prob
def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = onp.floor(x) p = expit(logit) loc = onp.floor(loc) return [x, p, loc]
def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=1): ''' Fits the model using gradient descent algorithm with the given hyperparameters. Parameters ---------- observations : array The observation sequences which Bernoulli Mixture Model is trained on batch_size : int The size of the batch rng_key : array Random key of shape (2,) and dtype uint32 optimizer : jax.experimental.optimizers.Optimizer Optimizer to be used num_epochs : int The number of epoch the training process takes place Returns ------- * array Mean loss values found per epoch * array Mixing coefficients found per epoch * array Probabilities of Bernoulli distribution found per epoch * array Responsibilites found per epoch ''' global opt_init, opt_update, get_params if rng_key is None: rng_key = PRNGKey(0) if optimizer is not None: opt_init, opt_update, get_params = optimizer opt_state = opt_init((softmax(self.mixing_coeffs), logit(self.probs))) itercount = itertools.count() def epoch_step(opt_state, key): def train_step(opt_state, batch): opt_state, loss = self.update(next(itercount), opt_state, batch) return opt_state, loss batches = self._make_minibatches(observations, batch_size, key) opt_state, losses = scan(train_step, opt_state, batches) params = get_params(opt_state) mixing_coeffs, probs_logits = params probs = expit(probs_logits) self.model = (softmax(mixing_coeffs), probs) self._probs = probs return opt_state, (losses.mean(), *params, self.responsibilities(observations)) epochs = split(rng_key, num_epochs) opt_state, history = scan(epoch_step, opt_state, epochs) params = get_params(opt_state) mixing_coeffs, probs_logits = params probs = expit(probs_logits) self.model = (softmax(mixing_coeffs), probs) self._probs = probs return history
def chance_constraint(hu, epsilon, gamma): # Pr(h(u) >= 0 ) >= (1-epsilon) return jnp.mean(expit(hu / gamma), axis=1) - (1 - epsilon) # take the sum over the samples (M)