def rbpf_optimal(current_config, xt, params, nparticles=100): """ Rao-Blackwell Particle Filter using optimal proposal """ key, mu_t, Sigma_t, weights_t, st = current_config key_sample, key_state, key_next, key_reindex = random.split(key, 4) keys = random.split(key_sample, nparticles) st = random.categorical(key_state, logit(params.transition_matrix[st, :])) mu_t, Sigma_t, weights_t, proposal = rbpf_step_optimal_vec( keys, weights_t, st, mu_t, Sigma_t, xt, params) indices = jnp.arange(nparticles) pi = random.choice(key_reindex, indices, shape=(nparticles, ), p=weights_t, replace=True) # Obtain optimal proposal distribution proposal_samp = proposal[pi, :] st = random.categorical(key, logit(proposal_samp)) mu_t = mu_t[pi, st, ...] Sigma_t = Sigma_t[pi, st, ...] weights_t = jnp.ones(nparticles) / nparticles return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp)
def log_likelihood(particles, test_results, groups, log_specificity, log_1msensitivity): """Computes individual (parallel) log_likelihood of k_groups test results. Args: particles: np.ndarray<bool>[n_particles, n_patients]. Each one is a possible scenario of a disease status of n patients. test_results: np.ndarray<bool>[n_groups] the results given by the wet lab for each of the tested groups. groups: np.ndarray<bool>[num_groups, num_patients] the definition of the group that were tested. log_specificity: np.ndarray. Depending on the configuration, it can be an array of size one or more if we have different sensitivities per group size. log_1msensitivity: np.ndarray. Depending on the configuration, it can be an array of size one or more if we have different specificities per group size. Returns: The log likelihood of the particles given the test results. """ positive_in_groups = np.dot(groups, np.transpose(particles)) > 0 group_sizes = np.sum(groups, axis=1) log_specificity = utils.select_from_sizes(log_specificity, group_sizes) log_1msensitivity = utils.select_from_sizes(log_1msensitivity, group_sizes) logit_specificity = special.logit(np.exp(log_specificity)) logit_sensitivity = -special.logit(np.exp(log_1msensitivity)) gamma = log_1msensitivity - log_specificity add_logits = logit_specificity + logit_sensitivity ll = np.sum(positive_in_groups * (gamma + test_results * add_logits)[:, np.newaxis], axis=0) return ll + np.sum(log_specificity - test_results * logit_specificity)
def __init__(self, num_patients: int, num_tests_per_cycle: int, max_group_size: int, prior_infection_rate: np.ndarray, prior_specificity: np.ndarray, prior_sensitivity: np.ndarray): self.num_patients = num_patients self.num_tests_per_cycle = num_tests_per_cycle self.max_group_size = max_group_size self.prior_infection_rate = np.atleast_1d(prior_infection_rate) self.prior_specificity = np.atleast_1d(prior_specificity) self.prior_sensitivity = np.atleast_1d(prior_sensitivity) self.log_prior_specificity = np.log(self.prior_specificity) self.log_prior_1msensitivity = np.log(1 - self.prior_sensitivity) self.logit_prior_sensitivity = special.logit(self.prior_sensitivity) self.logit_prior_specificity = special.logit(self.prior_specificity) self.curr_cycle = 0 self.past_groups = None self.past_test_results = None self.groups_to_test = None self.particle_weights = None self.particles = None self.to_clear_positives = None self.all_cleared = False self.marginals = {} self.reset() # Initializes the attributes above.
def __init__(self, mixing_coeffs, probs, class_priors, C, threshold=1e-10): self.mixing_coeffs = mixing_coeffs self.probs = probs self.class_priors = class_priors self.model = (logit(mixing_coeffs), logit(probs)) self.num_of_classes = C self.threshold = threshold self.log_threshold = jnp.log(threshold)
def train_step(params, i): self.model = params # Expectation gamma, log_likelihood = vmap(self.expectation, in_axes=(0, 0))(X, classes) # Maximization mixing_coeffs, probs = vmap(self.maximization, in_axes=(0, 0))(X, gamma) return (logit(mixing_coeffs), logit(probs)), -jnp.mean(log_likelihood)
def sample(p, key, num_samples=1): """ Generate Binomial Concrete samples :param p: Binary relaxed params (interpreted as Bernoulli probabilities) (jax.numpy array) :param key: PRNG key :param num_samples: number of samples """ tol = 1e-7 p = np.clip(p, tol, 1 - tol) logit_p = logit(p) u = random.uniform(key, shape=(num_samples, *p.shape)) logit_u = logit(np.clip(u, tol, 1 - tol)) return logit_p + logit_u
def inverse_fun(params, inputs, **kwargs): if clip_before_logit: inputs = np.clip(inputs, 1e-5, 1 - 1e-5) outputs = spys.logit(inputs) log_det_jacobian = -np.log(inputs - np.square(inputs)).sum(-1) return outputs, log_det_jacobian
def rbpf(current_config, xt, params, nparticles=100): """ Rao-Blackwell Particle Filter using prior as proposal """ key, mu_t, Sigma_t, weights_t, st = current_config key_sample, key_state, key_next, key_reindex = random.split(key, 4) keys = random.split(key_sample, nparticles) st = random.categorical(key_state, logit(params.transition_matrix[st, :])) mu_t, Sigma_t, weights_t, Ltk = rbpf_step_vec(keys, weights_t, st, mu_t, Sigma_t, xt, params) weights_t = weights_t / weights_t.sum() indices = jnp.arange(nparticles) pi = random.choice(key_reindex, indices, shape=(nparticles, ), p=weights_t, replace=True) st = st[pi] mu_t = mu_t[pi, ...] Sigma_t = Sigma_t[pi, ...] weights_t = jnp.ones(nparticles) / nparticles return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, Ltk)
def rbpf_step(key, weight_t, st, mu_t, Sigma_t, xt, params): log_p_next = logit(params.transition_matrix[st]) k = random.categorical(key, log_p_next) mu_t, Sigma_t, Ltk = kf_update(mu_t, Sigma_t, k, xt, params) weight_t = weight_t * Ltk return mu_t, Sigma_t, weight_t, Ltk
def draw_state(val, key, params): """ Simulate one step of a system that evolves as A z_{t-1} + Bk + eps, where eps ~ N(0, Q). Parameters ---------- val: tuple (int, jnp.array) (latent value of system, state value of system). params: PRBPFParamsDiscrete key: PRNGKey """ latent_old, state_old = val probabilities = params.transition_matrix[latent_old, :] logits = logit(probabilities) latent_new = random.categorical(key, logits) key_latent, key_obs = random.split(key) state_new = params.A @ state_old + params.B[latent_new, :] state_new = random.multivariate_normal(key_latent, state_new, params.Q) obs_new = random.multivariate_normal(key_obs, params.C @ state_new, params.R) return (latent_new, state_new), (latent_new, state_new, obs_new)
def model(data): p = numpyro.sample('p', dist.Beta(1., 1.)) if with_logits: logits = logit(p) numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x']) else: numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])
def log_prior(particles, base_infection_rate): """Computes log of prior probability of state using infection rate.""" # here base_infection can be either a single number per patient or an array if np.size(base_infection_rate) == 1: # only one rate return ( np.sum(particles, axis=-1) * special.logit(base_infection_rate) + particles.shape[0] * np.log(1 - base_infection_rate)) elif base_infection_rate.shape[0] == particles.shape[ -1]: # prior per patient return np.sum( particles * special.logit(base_infection_rate)[np.newaxis, :] + np.log(1 - base_infection_rate)[np.newaxis, :], axis=-1) else: raise ValueError( "Vector of prior probabilities is not of correct size")
def model_3(capture_history, sex): N, T = capture_history.shape phi_mean = numpyro.sample("phi_mean", dist.Uniform(0.0, 1.0)) # mean survival probability phi_logit_mean = logit(phi_mean) # controls temporal variability of survival probability phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0)) rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability def transition_fn(carry, y): first_capture_mask, z = carry with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}): phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma)) phi_t = expit(phi_logit_t) with numpyro.plate("animals", N, dim=-1): with handlers.mask(mask=first_capture_mask): mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask) # NumPyro exactly sums out the discrete states z_t. z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t))) mu_y_t = rho * z numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y) first_capture_mask = first_capture_mask | y.astype(bool) return (first_capture_mask, z), None z = jnp.ones(N, dtype=jnp.int32) # we use this mask to eliminate extraneous log probabilities # that arise for a given individual before its first capture. first_capture_mask = capture_history[:, 0].astype(bool) # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it scan(transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1))
def conditional_sample(p, y, key): """ Generate conditional Binary relaxed samples :param p: Binary relaxed params (interpreted as Bernoulli probabilities) (jax.numpy array) :param y: Conditioning parameters (jax.numpy array) :param key: PRNG key """ tol = 1e-7 p = np.clip(p, tol, 1 - tol) v = random.uniform(key, shape=y.shape) v_prime = (v * p + (1 - p)) * y + (v * (1 - p)) * (1 - y) v_prime = np.clip(v_prime, tol, 1 - tol) logit_v = logit(v_prime) logit_p = logit(p) return logit_p + logit_v
def model(data): p = numpyro.sample("p", dist.Beta(1.0, 1.0)) if with_logits: logits = logit(p) numpyro.sample( "obs", dist.Binomial(data["n"], logits=logits), obs=data["x"] ) else: numpyro.sample("obs", dist.Binomial(data["n"], probs=p), obs=data["x"])
def conditional_sample(p, y, temperature, key): """ Generate conditional Binomial Concrete sample :param p: Binomial Concrete params (interpreted as Bernoulli probabilities) (jax.numpy array) :param y: Conditioning parameters (jax.numpy array) :param temperature: temperature parameter :param key: PRNG key """ tol = 1e-7 p = np.clip(p, tol, 1 - tol) v = random.uniform(key, shape=y.shape) v_prime = (v * p + (1 - p)) * y + (v * (1 - p)) * (1 - y) v_prime = np.clip(v_prime, tol, 1 - tol) logit_v = logit(v_prime) logit_p = logit(p) return nn.sigmoid((logit_p + logit_v) / (temperature + tol))
def fit_em(self, observations, targets, num_of_iters=10): ''' Fits the model using em algorithm. Parameters ---------- observations : array Dataset targets : array Ground-truth class labels num_of_iters : int The number of iterations the training process that takes place Returns ------- * array Log likelihoods found per iteration ''' iterations = jnp.arange(num_of_iters) classes = jnp.arange(self.num_of_classes) X = self._cluster(observations, targets) def train_step(params, i): self.model = params # Expectation gamma, log_likelihood = vmap(self.expectation, in_axes=(0, 0))(X, classes) # Maximization mixing_coeffs, probs = vmap(self.maximization, in_axes=(0, 0))(X, gamma) return (logit(mixing_coeffs), logit(probs)), -jnp.mean(log_likelihood) initial_params = (logit(self.mixing_coeffs), logit(self.probs)) final_params, history = scan(train_step, initial_params, iterations) self.model = final_params return history
def hmm_sample_jax(params, seq_len, rng_key): ''' Samples an observation of given length according to the defined hidden markov model and gives the sequence of the hidden states as well as the observation. Parameters ---------- params : HMMJax Hidden Markov Model seq_len: array(seq_len) The length of the observation sequence rng_key : array Random key of shape (2,) and dtype uint32 Returns ------- * array(seq_len,) Hidden state sequence * array(seq_len,) : Observation sequence ''' trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist n_states, n_obs = obs_mat.shape initial_state = jax.random.categorical(rng_key, logits=logit(init_dist), shape=(1, )) obs_states = jnp.arange(n_obs) def draw_state(prev_state, key): logits = logit(trans_mat[:, prev_state]) state = jax.random.categorical(key, logits=logits.flatten(), shape=(1, )) return state, state rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3) keys = jax.random.split(rng_state, seq_len - 1) final_state, states = jax.lax.scan(draw_state, initial_state, keys) state_seq = jnp.append(jnp.array([initial_state]), states) def draw_obs(z, key): obs = jax.random.choice(key, a=obs_states, p=obs_mat[z]) return obs keys = jax.random.split(rng_obs, seq_len) obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys) return state_seq, obs_seq
def test_discrete_with_logits(jax_dist, dist_args): rng = random.PRNGKey(0) logit_args = dist_args[:-1] + (logit(dist_args[-1]), ) actual_sample = jax_dist.rvs(*dist_args, random_state=rng) expected_sample = jax_dist(*logit_args, is_logits=True).rvs(random_state=rng) assert_allclose(actual_sample, expected_sample) actual_pmf = jax_dist.logpmf(actual_sample, *dist_args) expected_pmf = jax_dist(*logit_args, is_logits=True).logpmf(actual_sample) assert_allclose(actual_pmf, expected_pmf, rtol=1e-6)
def sample(p, temperature, key, num_samples=1): """ Generate Binomial Concrete samples :param p: Binomial Concrete params (interpreted as Bernoulli probabilities) (jax.numpy array) :param temperature: temperature parameter :param key: PRNG key :param num_samples: number of samples """ tol = 1e-7 p = np.clip(p, tol, 1 - tol) logit_p = logit(p) base_randomness = random.logistic(key, shape=(num_samples, *p.shape)) return nn.sigmoid((logit_p + base_randomness) / (temperature + tol))
def unbounded_to_lower_and_upper_bounded(lower, upper): """Construct transform from reals to bounded interval. Args: lower (float): Lower-bound of image of transform. upper (float): Upper-bound of image of transform. """ return ElementwiseMonotonicTransform( forward=lambda u: lower + (upper - lower) * expit(np.asarray(u, np.float64)), backward=lambda x: logit((np.asarray(x, np.float64) - lower) / (upper - lower)), domain=reals, image=RealInterval(lower, upper), )
def logpdf(x, p, temperature): """ Bernoulli log probability mass function :param x: outcome (jax.numpy array) :param p: Binomial Concrete params (interpreted as Bernoulli probabilities) (jax.numpy array) :param temperature: temperature parameter """ assert x.shape == p.shape tol = 1e-7 p = np.clip(p, tol, 1 - tol) x = np.clip(x, tol, 1 - tol) logit_p = logit(p) first_term = np.log(temperature) + logit_p - ( 1 + temperature) * np.log(x) - (1 + temperature) * np.log(1 - x) second_term = 2 * np.log((np.exp(logit_p) * (x**(-temperature))) + (1 - x)**(-temperature)) return first_term - second_term
def inv(self, y): return logit(y)
def __call__(self, x): s = jnp.cumsum(x[..., :-1], axis=-1) y = logit(s) + jnp.expand_dims(self.anchor_point, -1) return y
def draw_state(prev_state, key): logits = logit(trans_mat[:, prev_state]) state = jax.random.categorical(key, logits=logits.flatten(), shape=(1, )) return state, state
def _logistic(key, shape, dtype): _check_shape("logistic", shape) return logit(uniform(key, shape, dtype))
def _inverse(self, y): return logit(y)
def ppf(x): return logit(x)
def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=1): ''' Fits the model using gradient descent algorithm with the given hyperparameters. Parameters ---------- observations : array The observation sequences which Bernoulli Mixture Model is trained on batch_size : int The size of the batch rng_key : array Random key of shape (2,) and dtype uint32 optimizer : jax.experimental.optimizers.Optimizer Optimizer to be used num_epochs : int The number of epoch the training process takes place Returns ------- * array Mean loss values found per epoch * array Mixing coefficients found per epoch * array Probabilities of Bernoulli distribution found per epoch * array Responsibilites found per epoch ''' global opt_init, opt_update, get_params if rng_key is None: rng_key = PRNGKey(0) if optimizer is not None: opt_init, opt_update, get_params = optimizer opt_state = opt_init((softmax(self.mixing_coeffs), logit(self.probs))) itercount = itertools.count() def epoch_step(opt_state, key): def train_step(opt_state, batch): opt_state, loss = self.update(next(itercount), opt_state, batch) return opt_state, loss batches = self._make_minibatches(observations, batch_size, key) opt_state, losses = scan(train_step, opt_state, batches) params = get_params(opt_state) mixing_coeffs, probs_logits = params probs = expit(probs_logits) self.model = (softmax(mixing_coeffs), probs) self._probs = probs return opt_state, (losses.mean(), *params, self.responsibilities(observations)) epochs = split(rng_key, num_epochs) opt_state, history = scan(epoch_step, opt_state, epochs) params = get_params(opt_state) mixing_coeffs, probs_logits = params probs = expit(probs_logits) self.model = (softmax(mixing_coeffs), probs) self._probs = probs return history
def isf(x): return -logit(x)