def rbpf_optimal(current_config, xt, params, nparticles=100):
    """
    Rao-Blackwell Particle Filter using optimal proposal
    """
    key, mu_t, Sigma_t, weights_t, st = current_config

    key_sample, key_state, key_next, key_reindex = random.split(key, 4)
    keys = random.split(key_sample, nparticles)

    st = random.categorical(key_state, logit(params.transition_matrix[st, :]))
    mu_t, Sigma_t, weights_t, proposal = rbpf_step_optimal_vec(
        keys, weights_t, st, mu_t, Sigma_t, xt, params)

    indices = jnp.arange(nparticles)
    pi = random.choice(key_reindex,
                       indices,
                       shape=(nparticles, ),
                       p=weights_t,
                       replace=True)

    # Obtain optimal proposal distribution
    proposal_samp = proposal[pi, :]
    st = random.categorical(key, logit(proposal_samp))

    mu_t = mu_t[pi, st, ...]
    Sigma_t = Sigma_t[pi, st, ...]

    weights_t = jnp.ones(nparticles) / nparticles

    return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t,
                                                      st, proposal_samp)
Example #2
0
def log_likelihood(particles, test_results, groups, log_specificity,
                   log_1msensitivity):
    """Computes individual (parallel) log_likelihood of k_groups test results.

  Args:
   particles: np.ndarray<bool>[n_particles, n_patients]. Each one is a possible
    scenario of a disease status of n patients.
   test_results: np.ndarray<bool>[n_groups] the results given by the wet lab for
    each of the tested groups.
   groups: np.ndarray<bool>[num_groups, num_patients] the definition of the
    group that were tested.
   log_specificity: np.ndarray. Depending on the configuration, it can be an
    array of size one or more if we have different sensitivities per group size.
   log_1msensitivity: np.ndarray. Depending on the configuration, it can be an
    array of size one or more if we have different specificities per group size.

  Returns:
   The log likelihood of the particles given the test results.
  """
    positive_in_groups = np.dot(groups, np.transpose(particles)) > 0
    group_sizes = np.sum(groups, axis=1)
    log_specificity = utils.select_from_sizes(log_specificity, group_sizes)
    log_1msensitivity = utils.select_from_sizes(log_1msensitivity, group_sizes)
    logit_specificity = special.logit(np.exp(log_specificity))
    logit_sensitivity = -special.logit(np.exp(log_1msensitivity))
    gamma = log_1msensitivity - log_specificity
    add_logits = logit_specificity + logit_sensitivity
    ll = np.sum(positive_in_groups *
                (gamma + test_results * add_logits)[:, np.newaxis],
                axis=0)
    return ll + np.sum(log_specificity - test_results * logit_specificity)
Example #3
0
    def __init__(self, num_patients: int, num_tests_per_cycle: int,
                 max_group_size: int, prior_infection_rate: np.ndarray,
                 prior_specificity: np.ndarray, prior_sensitivity: np.ndarray):
        self.num_patients = num_patients
        self.num_tests_per_cycle = num_tests_per_cycle
        self.max_group_size = max_group_size

        self.prior_infection_rate = np.atleast_1d(prior_infection_rate)
        self.prior_specificity = np.atleast_1d(prior_specificity)
        self.prior_sensitivity = np.atleast_1d(prior_sensitivity)
        self.log_prior_specificity = np.log(self.prior_specificity)
        self.log_prior_1msensitivity = np.log(1 - self.prior_sensitivity)
        self.logit_prior_sensitivity = special.logit(self.prior_sensitivity)
        self.logit_prior_specificity = special.logit(self.prior_specificity)

        self.curr_cycle = 0
        self.past_groups = None
        self.past_test_results = None
        self.groups_to_test = None
        self.particle_weights = None
        self.particles = None
        self.to_clear_positives = None
        self.all_cleared = False
        self.marginals = {}
        self.reset()  # Initializes the attributes above.
 def __init__(self, mixing_coeffs, probs, class_priors, C, threshold=1e-10):
     self.mixing_coeffs = mixing_coeffs
     self.probs = probs
     self.class_priors = class_priors
     self.model = (logit(mixing_coeffs), logit(probs))
     self.num_of_classes = C
     self.threshold = threshold
     self.log_threshold = jnp.log(threshold)
Example #5
0
        def train_step(params, i):
            self.model = params

            # Expectation
            gamma, log_likelihood = vmap(self.expectation,
                                         in_axes=(0, 0))(X, classes)

            # Maximization
            mixing_coeffs, probs = vmap(self.maximization,
                                        in_axes=(0, 0))(X, gamma)
            return (logit(mixing_coeffs),
                    logit(probs)), -jnp.mean(log_likelihood)
Example #6
0
def sample(p, key, num_samples=1):
    """
    Generate Binomial Concrete samples
    :param p: Binary relaxed params (interpreted as Bernoulli probabilities) (jax.numpy array)
    :param key: PRNG key
    :param num_samples: number of samples
    """
    tol = 1e-7
    p = np.clip(p, tol, 1 - tol)
    logit_p = logit(p)
    u = random.uniform(key, shape=(num_samples, *p.shape))
    logit_u = logit(np.clip(u, tol, 1 - tol))
    return logit_p + logit_u
Example #7
0
        def inverse_fun(params, inputs, **kwargs):
            if clip_before_logit:
                inputs = np.clip(inputs, 1e-5, 1 - 1e-5)

            outputs = spys.logit(inputs)
            log_det_jacobian = -np.log(inputs - np.square(inputs)).sum(-1)
            return outputs, log_det_jacobian
def rbpf(current_config, xt, params, nparticles=100):
    """
    Rao-Blackwell Particle Filter using prior as proposal
    """
    key, mu_t, Sigma_t, weights_t, st = current_config

    key_sample, key_state, key_next, key_reindex = random.split(key, 4)
    keys = random.split(key_sample, nparticles)

    st = random.categorical(key_state, logit(params.transition_matrix[st, :]))
    mu_t, Sigma_t, weights_t, Ltk = rbpf_step_vec(keys, weights_t, st, mu_t,
                                                  Sigma_t, xt, params)
    weights_t = weights_t / weights_t.sum()

    indices = jnp.arange(nparticles)
    pi = random.choice(key_reindex,
                       indices,
                       shape=(nparticles, ),
                       p=weights_t,
                       replace=True)
    st = st[pi]
    mu_t = mu_t[pi, ...]
    Sigma_t = Sigma_t[pi, ...]
    weights_t = jnp.ones(nparticles) / nparticles

    return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t,
                                                      st, Ltk)
def rbpf_step(key, weight_t, st, mu_t, Sigma_t, xt, params):
    log_p_next = logit(params.transition_matrix[st])
    k = random.categorical(key, log_p_next)
    mu_t, Sigma_t, Ltk = kf_update(mu_t, Sigma_t, k, xt, params)
    weight_t = weight_t * Ltk

    return mu_t, Sigma_t, weight_t, Ltk
def draw_state(val, key, params):
    """
    Simulate one step of a system that evolves as
                A z_{t-1} + Bk + eps,
    where eps ~ N(0, Q).
    
    Parameters
    ----------
    val: tuple (int, jnp.array)
        (latent value of system, state value of system).
    params: PRBPFParamsDiscrete
    key: PRNGKey
    """
    latent_old, state_old = val
    probabilities = params.transition_matrix[latent_old, :]
    logits = logit(probabilities)
    latent_new = random.categorical(key, logits)

    key_latent, key_obs = random.split(key)
    state_new = params.A @ state_old + params.B[latent_new, :]
    state_new = random.multivariate_normal(key_latent, state_new, params.Q)
    obs_new = random.multivariate_normal(key_obs, params.C @ state_new,
                                         params.R)

    return (latent_new, state_new), (latent_new, state_new, obs_new)
Example #11
0
 def model(data):
     p = numpyro.sample('p', dist.Beta(1., 1.))
     if with_logits:
         logits = logit(p)
         numpyro.sample('obs', dist.Binomial(data['n'], logits=logits), obs=data['x'])
     else:
         numpyro.sample('obs', dist.Binomial(data['n'], probs=p), obs=data['x'])
Example #12
0
def log_prior(particles, base_infection_rate):
    """Computes log of prior probability of state using infection rate."""
    # here base_infection can be either a single number per patient or an array
    if np.size(base_infection_rate) == 1:  # only one rate
        return (
            np.sum(particles, axis=-1) * special.logit(base_infection_rate) +
            particles.shape[0] * np.log(1 - base_infection_rate))
    elif base_infection_rate.shape[0] == particles.shape[
            -1]:  # prior per patient
        return np.sum(
            particles * special.logit(base_infection_rate)[np.newaxis, :] +
            np.log(1 - base_infection_rate)[np.newaxis, :],
            axis=-1)
    else:
        raise ValueError(
            "Vector of prior probabilities is not of correct size")
Example #13
0
def model_3(capture_history, sex):
    N, T = capture_history.shape
    phi_mean = numpyro.sample("phi_mean", dist.Uniform(0.0, 1.0))  # mean survival probability
    phi_logit_mean = logit(phi_mean)
    # controls temporal variability of survival probability
    phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0))
    rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    def transition_fn(carry, y):
        first_capture_mask, z = carry
        with handlers.reparam(config={"phi_logit": LocScaleReparam(0)}):
            phi_logit_t = numpyro.sample("phi_logit", dist.Normal(phi_logit_mean, phi_sigma))
        phi_t = expit(phi_logit_t)
        with numpyro.plate("animals", N, dim=-1):
            with handlers.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
                # NumPyro exactly sums out the discrete states z_t.
                z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
                mu_y_t = rho * z
                numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y)

        first_capture_mask = first_capture_mask | y.astype(bool)
        return (first_capture_mask, z), None

    z = jnp.ones(N, dtype=jnp.int32)
    # we use this mask to eliminate extraneous log probabilities
    # that arise for a given individual before its first capture.
    first_capture_mask = capture_history[:, 0].astype(bool)
    # NB swapaxes: we move time dimension of `capture_history` to the front to scan over it
    scan(transition_fn, (first_capture_mask, z), jnp.swapaxes(capture_history[:, 1:], 0, 1))
Example #14
0
def conditional_sample(p, y, key):
    """
    Generate conditional Binary relaxed samples
    :param p: Binary relaxed params (interpreted as Bernoulli probabilities) (jax.numpy array)
    :param y: Conditioning parameters (jax.numpy array)
    :param key: PRNG key
    """
    tol = 1e-7
    p = np.clip(p, tol, 1 - tol)

    v = random.uniform(key, shape=y.shape)
    v_prime = (v * p + (1 - p)) * y + (v * (1 - p)) * (1 - y)
    v_prime = np.clip(v_prime, tol, 1 - tol)

    logit_v = logit(v_prime)
    logit_p = logit(p)
    return logit_p + logit_v
Example #15
0
 def model(data):
     p = numpyro.sample("p", dist.Beta(1.0, 1.0))
     if with_logits:
         logits = logit(p)
         numpyro.sample(
             "obs", dist.Binomial(data["n"], logits=logits), obs=data["x"]
         )
     else:
         numpyro.sample("obs", dist.Binomial(data["n"], probs=p), obs=data["x"])
Example #16
0
def conditional_sample(p, y, temperature, key):
    """
    Generate conditional Binomial Concrete sample
    :param p: Binomial Concrete params (interpreted as Bernoulli probabilities) (jax.numpy array)
    :param y: Conditioning parameters (jax.numpy array)
    :param temperature: temperature parameter
    :param key: PRNG key
    """
    tol = 1e-7
    p = np.clip(p, tol, 1 - tol)

    v = random.uniform(key, shape=y.shape)
    v_prime = (v * p + (1 - p)) * y + (v * (1 - p)) * (1 - y)
    v_prime = np.clip(v_prime, tol, 1 - tol)

    logit_v = logit(v_prime)
    logit_p = logit(p)
    return nn.sigmoid((logit_p + logit_v) / (temperature + tol))
Example #17
0
    def fit_em(self, observations, targets, num_of_iters=10):
        '''
        Fits the model using em algorithm.

        Parameters
        ----------
        observations : array
            Dataset

        targets : array
            Ground-truth class labels

        num_of_iters : int
            The number of iterations the training process that takes place

        Returns
        -------
        * array
            Log likelihoods found per iteration
        '''
        iterations = jnp.arange(num_of_iters)
        classes = jnp.arange(self.num_of_classes)
        X = self._cluster(observations, targets)

        def train_step(params, i):
            self.model = params

            # Expectation
            gamma, log_likelihood = vmap(self.expectation,
                                         in_axes=(0, 0))(X, classes)

            # Maximization
            mixing_coeffs, probs = vmap(self.maximization,
                                        in_axes=(0, 0))(X, gamma)
            return (logit(mixing_coeffs),
                    logit(probs)), -jnp.mean(log_likelihood)

        initial_params = (logit(self.mixing_coeffs), logit(self.probs))

        final_params, history = scan(train_step, initial_params, iterations)
        self.model = final_params
        return history
Example #18
0
def hmm_sample_jax(params, seq_len, rng_key):
    '''
    Samples an observation of given length according to the defined
    hidden markov model and gives the sequence of the hidden states
    as well as the observation.

    Parameters
    ----------
    params : HMMJax
        Hidden Markov Model

    seq_len: array(seq_len)
        The length of the observation sequence

    rng_key : array
        Random key of shape (2,) and dtype uint32

    Returns
    -------
    * array(seq_len,)
        Hidden state sequence

    * array(seq_len,) :
        Observation sequence
    '''
    trans_mat, obs_mat, init_dist = params.trans_mat, params.obs_mat, params.init_dist
    n_states, n_obs = obs_mat.shape

    initial_state = jax.random.categorical(rng_key,
                                           logits=logit(init_dist),
                                           shape=(1, ))
    obs_states = jnp.arange(n_obs)

    def draw_state(prev_state, key):
        logits = logit(trans_mat[:, prev_state])
        state = jax.random.categorical(key,
                                       logits=logits.flatten(),
                                       shape=(1, ))
        return state, state

    rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)
    keys = jax.random.split(rng_state, seq_len - 1)

    final_state, states = jax.lax.scan(draw_state, initial_state, keys)
    state_seq = jnp.append(jnp.array([initial_state]), states)

    def draw_obs(z, key):
        obs = jax.random.choice(key, a=obs_states, p=obs_mat[z])
        return obs

    keys = jax.random.split(rng_obs, seq_len)
    obs_seq = jax.vmap(draw_obs, in_axes=(0, 0))(state_seq, keys)

    return state_seq, obs_seq
Example #19
0
def test_discrete_with_logits(jax_dist, dist_args):
    rng = random.PRNGKey(0)
    logit_args = dist_args[:-1] + (logit(dist_args[-1]), )

    actual_sample = jax_dist.rvs(*dist_args, random_state=rng)
    expected_sample = jax_dist(*logit_args,
                               is_logits=True).rvs(random_state=rng)
    assert_allclose(actual_sample, expected_sample)

    actual_pmf = jax_dist.logpmf(actual_sample, *dist_args)
    expected_pmf = jax_dist(*logit_args, is_logits=True).logpmf(actual_sample)
    assert_allclose(actual_pmf, expected_pmf, rtol=1e-6)
Example #20
0
def sample(p, temperature, key, num_samples=1):
    """
    Generate Binomial Concrete samples
    :param p: Binomial Concrete params (interpreted as Bernoulli probabilities) (jax.numpy array)
    :param temperature: temperature parameter
    :param key: PRNG key
    :param num_samples: number of samples
    """
    tol = 1e-7
    p = np.clip(p, tol, 1 - tol)
    logit_p = logit(p)
    base_randomness = random.logistic(key, shape=(num_samples, *p.shape))
    return nn.sigmoid((logit_p + base_randomness) / (temperature + tol))
Example #21
0
def unbounded_to_lower_and_upper_bounded(lower, upper):
    """Construct transform from reals to bounded interval.

    Args:
        lower (float): Lower-bound of image of transform.
        upper (float): Upper-bound of image of transform.
    """
    return ElementwiseMonotonicTransform(
        forward=lambda u: lower +
        (upper - lower) * expit(np.asarray(u, np.float64)),
        backward=lambda x: logit((np.asarray(x, np.float64) - lower) /
                                 (upper - lower)),
        domain=reals,
        image=RealInterval(lower, upper),
    )
Example #22
0
def logpdf(x, p, temperature):
    """
    Bernoulli log probability mass function
    :param x: outcome (jax.numpy array)
    :param p: Binomial Concrete params (interpreted as Bernoulli probabilities) (jax.numpy array)
    :param temperature: temperature parameter
    """
    assert x.shape == p.shape
    tol = 1e-7
    p = np.clip(p, tol, 1 - tol)
    x = np.clip(x, tol, 1 - tol)
    logit_p = logit(p)
    first_term = np.log(temperature) + logit_p - (
        1 + temperature) * np.log(x) - (1 + temperature) * np.log(1 - x)
    second_term = 2 * np.log((np.exp(logit_p) * (x**(-temperature))) +
                             (1 - x)**(-temperature))
    return first_term - second_term
Example #23
0
 def inv(self, y):
     return logit(y)
Example #24
0
 def __call__(self, x):
     s = jnp.cumsum(x[..., :-1], axis=-1)
     y = logit(s) + jnp.expand_dims(self.anchor_point, -1)
     return y
Example #25
0
 def draw_state(prev_state, key):
     logits = logit(trans_mat[:, prev_state])
     state = jax.random.categorical(key,
                                    logits=logits.flatten(),
                                    shape=(1, ))
     return state, state
Example #26
0
def _logistic(key, shape, dtype):
  _check_shape("logistic", shape)
  return logit(uniform(key, shape, dtype))
Example #27
0
 def _inverse(self, y):
     return logit(y)
Example #28
0
def ppf(x):
    return logit(x)
    def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=1):
        '''
        Fits the model using gradient descent algorithm with the given hyperparameters.

        Parameters
        ----------
        observations : array
            The observation sequences which Bernoulli Mixture Model is trained on

        batch_size : int
            The size of the batch

        rng_key : array
            Random key of shape (2,) and dtype uint32

        optimizer : jax.experimental.optimizers.Optimizer
            Optimizer to be used

        num_epochs : int
            The number of epoch the training process takes place

        Returns
        -------
        * array
            Mean loss values found per epoch

        * array
            Mixing coefficients found per epoch

        * array
            Probabilities of Bernoulli distribution found per epoch

        * array
            Responsibilites found per epoch
        '''
        global opt_init, opt_update, get_params

        if rng_key is None:
            rng_key = PRNGKey(0)

        if optimizer is not None:
            opt_init, opt_update, get_params = optimizer

        opt_state = opt_init((softmax(self.mixing_coeffs), logit(self.probs)))
        itercount = itertools.count()

        def epoch_step(opt_state, key):

            def train_step(opt_state, batch):
                opt_state, loss = self.update(next(itercount), opt_state, batch)
                return opt_state, loss

            batches = self._make_minibatches(observations, batch_size, key)
            opt_state, losses = scan(train_step, opt_state, batches)

            params = get_params(opt_state)
            mixing_coeffs, probs_logits = params
            probs = expit(probs_logits)
            self.model = (softmax(mixing_coeffs), probs)
            self._probs = probs

            return opt_state, (losses.mean(), *params, self.responsibilities(observations))

        epochs = split(rng_key, num_epochs)
        opt_state, history = scan(epoch_step, opt_state, epochs)
        params = get_params(opt_state)
        mixing_coeffs, probs_logits = params
        probs = expit(probs_logits)
        self.model = (softmax(mixing_coeffs), probs)
        self._probs = probs
        return history
Example #30
0
def isf(x):
    return -logit(x)