Exemplo n.º 1
0
def log_likelihood(particles, test_results, groups, log_specificity,
                   log_1msensitivity):
    """Computes individual (parallel) log_likelihood of k_groups test results.

  Args:
   particles: np.ndarray<bool>[n_particles, n_patients]. Each one is a possible
    scenario of a disease status of n patients.
   test_results: np.ndarray<bool>[n_groups] the results given by the wet lab for
    each of the tested groups.
   groups: np.ndarray<bool>[num_groups, num_patients] the definition of the
    group that were tested.
   log_specificity: np.ndarray. Depending on the configuration, it can be an
    array of size one or more if we have different sensitivities per group size.
   log_1msensitivity: np.ndarray. Depending on the configuration, it can be an
    array of size one or more if we have different specificities per group size.

  Returns:
   The log likelihood of the particles given the test results.
  """
    positive_in_groups = np.dot(groups, np.transpose(particles)) > 0
    group_sizes = np.sum(groups, axis=1)
    log_specificity = utils.select_from_sizes(log_specificity, group_sizes)
    log_1msensitivity = utils.select_from_sizes(log_1msensitivity, group_sizes)
    logit_specificity = special.logit(np.exp(log_specificity))
    logit_sensitivity = -special.logit(np.exp(log_1msensitivity))
    gamma = log_1msensitivity - log_specificity
    add_logits = logit_specificity + logit_sensitivity
    ll = np.sum(positive_in_groups *
                (gamma + test_results * add_logits)[:, np.newaxis],
                axis=0)
    return ll + np.sum(log_specificity - test_results * logit_specificity)
Exemplo n.º 2
0
def next_best_group(particle_weights,
                    particles,
                    previous_groups,
                    cur_group,
                    sensitivity,
                    specificity,
                    utility_fun,
                    backtracking):
  """Performs greedy utility optimization to compute the next best group.

  Given a set of groups previous_groups, and a current candidate group
  cur_group, this function computes the utility of the combination of
  previous_groups and cur_group modified by adding (if backtracking = True) or
  adding (if backtracking = False) on element to cur_group, and returns the
  combination with largest utility.

  Args:
   particle_weights: weights of particles
   particles: particles summarizing belief about infection status
   previous_groups: groups already chosen
   cur_group: group that we wish to optimize
   sensitivity: value (vector) of sensitivity(-ies depending on group size).
   specificity: value (vector) of specificity(-ies depending on group size).
   utility_fun: function to compute the utility of a set of groups
   backtracking: (bool), True if removing rather than adding individuals.

  Returns:
   best_group : cur_group updated with best choice
   utility: utility of best_group
  """
  if backtracking:
    # Backward mode: test groups obtained by removing an item to cur_group
    candidate_groups = np.logical_not(
        mutual_information.add_ones_to_line(np.logical_not(cur_group)))
  else:
    # Forward mode: test groups obtained by adding an item to cur_group
    candidate_groups = mutual_information.add_ones_to_line(cur_group)
  n_candidates = candidate_groups.shape[0]

  # Combine past groups with candidate groups
  candidate_sets = np.concatenate(
      (np.repeat(previous_groups[:, :, np.newaxis], n_candidates, axis=2),
       np.expand_dims(np.transpose(candidate_groups), axis=0)),
      axis=0)

  # Compute utility of each candidate group
  group_sizes = np.sum(candidate_sets[:, :, 0], axis=1)
  group_sensitivities = utils.select_from_sizes(sensitivity, group_sizes)
  group_specificities = utils.select_from_sizes(specificity, group_sizes)
  group_util_fun = lambda x: group_utility(particle_weights, particles, x,
                                           group_sensitivities,
                                           group_specificities, utility_fun)
  mgroup_util_fun = jax.vmap(group_util_fun, in_axes=2)
  objectives = mgroup_util_fun(candidate_sets)

  # Greedy selection of largest value
  index = np.argmax(objectives)
  return (candidate_groups[index, :], objectives[index])
  def __call__(self, rng, state):
    """Produces new groups and adds them to state's stack."""
    p_weights, particles = state.particle_weights, state.particles
    marginal = onp.array(np.sum(p_weights[:, np.newaxis] * particles, axis=0))
    marginal = onp.squeeze(marginal)
    not_cut_ids, = onp.where(np.logical_and(
        marginal < self.cut_off_high, marginal > self.cut_off_low))
    marginal = marginal[not_cut_ids]
    sorted_ids = onp.argsort(marginal)
    sorted_marginal = onp.array(marginal[sorted_ids])
    n_p = 0
    n_r = marginal.size
    if n_r == 0:  # no one left to test in between thresholds
      state.all_cleared = True
      return state

    all_new_groups = np.empty((0, state.num_patients), dtype=bool)
    while n_p < marginal.size:
      index_max = onp.amin((n_r, state.max_group_size))
      group_sizes = onp.arange(1, index_max + 1)
      cum_prod_prob = onp.cumprod(1 - sorted_marginal[n_p:(n_p + index_max)])
      # formula below is only valid for group_size > 1,
      # corrected below for a group of size 1.
      sensitivity = onp.array(
          utils.select_from_sizes(state.prior_sensitivity, group_sizes))
      specificity = onp.array(
          utils.select_from_sizes(state.prior_specificity, group_sizes))

      exp_div_size = (
          1 + group_sizes *
          (sensitivity + (1 - sensitivity - specificity) * cum_prod_prob)
          ) / group_sizes
      exp_div_size[0] = 1  # adjusted cost for one patient is one.
      opt_size_group = onp.argmin(exp_div_size) + 1
      new_group = onp.zeros((1, state.num_patients))
      new_group[0, not_cut_ids[sorted_ids[n_p:n_p + opt_size_group]]] = True
      all_new_groups = np.concatenate((all_new_groups, new_group), axis=0)
      n_p = n_p + opt_size_group
      n_r = n_r - opt_size_group
    # sample randomly extra_tests_needed groups in modified case, all in
    # regular ID.
    # Because ID is a Dorfman type approach, it might be followed
    # by exhaustive splitting, which requires to keep track of groups
    # that tested positives to retest them.
    all_new_groups = jax.random.permutation(rng, all_new_groups)
    if self.modified:
      # in the case where we use modified ID, we only subsample a few groups.
      # one needs to take care of requesting to keep track of positives.
      new_groups = all_new_groups[0:state.extra_tests_needed].astype(bool)
      state.add_groups_to_test(new_groups,
                               results_need_clearing=True)
    else:
      # with regular ID we add all groups at once.
      state.add_groups_to_test(all_new_groups.astype(bool),
                               results_need_clearing=True)
    return state
Exemplo n.º 4
0
    def test_select_from_sizes(self):
        sizes = np.array([1, 4, 8, 2])
        prior = np.array([0.1])
        self.assertArraysAllClose(utils.select_from_sizes(prior, sizes),
                                  prior[0] * np.ones_like(sizes),
                                  check_dtypes=True)

        prior = np.array([0.4, 0.2, 0.1])
        expected = np.array([0.4, 0.1, 0.1, 0.2])
        self.assertArraysAllClose(utils.select_from_sizes(prior, sizes),
                                  expected,
                                  check_dtypes=True)
Exemplo n.º 5
0
 def group_tests_outputs(self, rng, groups):
     """Produces test outputs taking into account test errors."""
     n_groups = groups.shape[0]
     group_disease_indicator = np.dot(groups, self.diseased) > 0
     group_sizes = np.sum(groups, axis=1)
     specificity = utils.select_from_sizes(self._specificity, group_sizes)
     sensitivity = utils.select_from_sizes(self._sensitivity, group_sizes)
     draw_u = jax.random.uniform(rng, shape=(n_groups, ))
     delta = sensitivity - specificity
     not_flip_proba = group_disease_indicator * delta + specificity
     test_flipped = draw_u > not_flip_proba
     return np.logical_xor(group_disease_indicator, test_flipped)
    def __call__(self, rng, state):
        """Produces new groups and adds them to state's stack."""
        p_weights, particles = state.particle_weights, state.particles
        marginal = onp.array(
            np.sum(p_weights[:, np.newaxis] * particles, axis=0))
        marginal = onp.squeeze(marginal)
        not_cut_ids, = onp.where(
            np.logical_and(marginal < self.cut_off_high,
                           marginal > self.cut_off_low))
        marginal = marginal[not_cut_ids]
        sorted_ids = onp.argsort(marginal)
        sorted_marginal = onp.array(marginal[sorted_ids])
        n_p = 0
        n_r = marginal.size
        if n_r == 0:  # no one left to test in between thresholds
            state.all_cleared = True
            return state

        all_new_groups = np.empty((0, state.num_patients), dtype=bool)
        while n_p < marginal.size:
            index_max = onp.amin((n_r, state.max_group_size))
            group_sizes = onp.arange(1, index_max + 1)
            cum_prod_prob = onp.cumprod(1 -
                                        sorted_marginal[n_p:(n_p + index_max)])
            # formula below is only valid for group_size > 1,
            # corrected below for a group of size 1.
            sensitivity = onp.array(
                utils.select_from_sizes(state.prior_sensitivity, group_sizes))
            specificity = onp.array(
                utils.select_from_sizes(state.prior_specificity, group_sizes))

            exp_div_size = (1 + group_sizes *
                            (sensitivity +
                             (1 - sensitivity - specificity) * cum_prod_prob)
                            ) / group_sizes
            exp_div_size[0] = 1  # adjusted cost for one patient is one.
            opt_size_group = onp.argmin(exp_div_size) + 1
            new_group = onp.zeros((1, state.num_patients))
            new_group[0,
                      not_cut_ids[sorted_ids[n_p:n_p + opt_size_group]]] = True
            all_new_groups = np.concatenate((all_new_groups, new_group),
                                            axis=0)
            n_p = n_p + opt_size_group
            n_r = n_r - opt_size_group
        # sample randomly extra_tests_needed groups
        all_new_groups = jax.random.permutation(rng, all_new_groups)
        new_groups = np.array(all_new_groups[0:state.extra_tests_needed],
                              dtype=bool)
        state.add_groups_to_test(new_groups)
        return state
Exemplo n.º 7
0
    def get_groups(self, rng, state):
        """Produces random design matrix fixed number of 1s per line.

    Args:
     rng: np.ndarray<int>[2]: the random key.
     state: the current state.State of the system.

    Returns:
     A np.array<bool>[num_groups, patients].
    """
        if self.group_size is None:
            if np.size(state.prior_infection_rate) == 1:
                # candidate group sizes
                group_sizes = np.arange(state.max_group_size) + 1
                sensitivity = utils.select_from_sizes(state.prior_sensitivity,
                                                      group_sizes)
                specificity = utils.select_from_sizes(state.prior_specificity,
                                                      group_sizes)
                rho = specificity + sensitivity - 1
                utility_size_groups = (
                    sensitivity - rho *
                    (1 - state.prior_infection_rate)**group_sizes - 0.5)**2
                group_size = group_sizes[np.argmin(utility_size_groups)]
            else:
                group_size = state.max_group_size
        else:
            group_size = self.group_size

        group_size = int(np.squeeze(group_size))
        new_groups = np.empty((0, state.num_patients), dtype=bool)
        for _ in range(state.extra_tests_needed):
            rng, rng_shuffle = jax.random.split(rng, 2)
            vec = np.zeros((1, state.num_patients), dtype=bool)
            idx = jax.random.permutation(rng_shuffle,
                                         np.arange(state.num_patients))
            vec = jax.ops.index_update(vec, [0, idx[0:group_size]], True)
            new_groups = np.concatenate((new_groups, vec), axis=0)
        return new_groups
def loopy_belief_propagation(tests, groups,
                             base_infection_rate,
                             sensitivity, specificity,
                             min_iterations, max_iterations,
                             atol):
  """LBP approach to compute approximate marginal of posterior distribution.

  Outputs marginal approximation of posterior distribution using all tests'
  history and test setup parameters.

  Args:
    tests : np.ndarray<bool>[n_groups] results stored as a vector of booleans
    groups : np.ndarray<bool>[n_groups, n_patients] matrix of groups
    base_infection_rate : np.ndarray<float> [1,] or [n_patients,] infection rate
    sensitivity : np.ndarray<float> [?,] of sensitivity per group size
    specificity : np.ndarray<float> [?,] of specificity per group size
    min_iterations: int, min number of belief propagation iterations
    max_iterations: int, max number of belief propagation iterations
    atol: float, elementwise tolerance for the difference between two
      consecutive iterations.

  Returns:
    two vectors of marginal probabilities for all n_patients, obtained
    as consecutive evaluations of the LBP algorithm after n_iter and n_iter+1
    iterations.
  """
  n_groups, n_patients = groups.shape
  if np.size(groups) == 0:
    if np.size(base_infection_rate) == 1:  # only one rate
      marginal = base_infection_rate * np.ones(n_patients)
      return marginal, 0
    elif np.size(base_infection_rate) == n_patients:
      return base_infection_rate, 0
    else:
      raise ValueError("Improper size for vector of base infection rates")

  mu = -jax.scipy.special.logit(base_infection_rate)

  groups_size = np.sum(groups, axis=1)
  sensitivity = utils.select_from_sizes(sensitivity, groups_size)
  specificity = utils.select_from_sizes(specificity, groups_size)
  gamma0 = np.log(sensitivity + specificity - 1) - np.log(1 - sensitivity)
  gamma1 = np.log(sensitivity + specificity - 1) - np.log(sensitivity)
  gamma = tests * gamma1 + (1 - tests) * gamma0
  test_sign = 1 - 2 * tests[:, np.newaxis]

  # Initialization
  alphabeta = np.zeros((2, n_groups, n_patients))
  alpha_beta_iteration = [alphabeta, 0]

  # return marginal from alphabeta
  def marginal_from_alphabeta(alphabeta):
    beta_bar = np.sum(alphabeta[1, :, :], axis=0)
    return jax.scipy.special.expit(-beta_bar - mu)

  # lbp loop
  def lbp_loop(_, alphabeta):
    alpha = alphabeta[0, :, :]
    beta = alphabeta[1, :, :]

    # update alpha
    beta_bar = np.sum(beta, axis=0)
    alpha = jax.nn.log_sigmoid(beta_bar - beta + mu)
    alpha *= groups

    # update beta
    alpha_bar = np.sum(alpha, axis=1, keepdims=True)
    beta = np.log1p(test_sign *
                    np.exp(-alpha + alpha_bar + gamma[:, np.newaxis]))
    beta *= groups
    return np.stack((alpha, beta), axis=0)

  def cond_fun(alpha_beta_iteration):
    alphabeta, iteration = alpha_beta_iteration
    marginal = marginal_from_alphabeta(alphabeta)
    marginal_plus_one_iteration = marginal_from_alphabeta(
        lbp_loop(0, alphabeta))
    converged = np.allclose(marginal, marginal_plus_one_iteration, atol=atol)
    return (not converged) and (iteration < max_iterations)

  def body_fun(alpha_beta_iteration):
    alphabeta, iteration = alpha_beta_iteration
    alphabeta = jax.lax.fori_loop(0, min_iterations, lbp_loop, alphabeta)
    iteration += min_iterations
    return [alphabeta, iteration]

  # Run LBP while loop
  while cond_fun(alpha_beta_iteration):
    alpha_beta_iteration = body_fun(alpha_beta_iteration)

  alphabeta, _ = alpha_beta_iteration

  # Compute two consecutive marginals
  marginal = marginal_from_alphabeta(alphabeta)
  marginal_plus_one_iteration = marginal_from_alphabeta(lbp_loop(0, alphabeta))

  return marginal, np.amax(np.abs(marginal - marginal_plus_one_iteration))
Exemplo n.º 9
0
def joint_mi_criterion_mg(particle_weights, particles, cur_group,
                          cur_positives, previous_groups_prob_particles_states,
                          previous_groups_cumcond_entropy, sensitivity,
                          specificity, backtracking):
    """Compares the benefit of adding one group to previously selected ones.

  Groups are formed iteratively by considering all possible individuals
  that can be considered to add (or remove if backtracking).

  If the sensitivity and/or specificity parameters are group size dependent,
  we take that into account in our optimization.
  Here all groups have the same size, hence they all share the same
  specificity / sensitivity setting. We just replace the vector by its value at
  the appropriate coordinate.

  The size of the group considered here will be the size of cur_group + 1 if
  going forward / -1 if backtracking.

  Args:
   particle_weights: weights of particles
   particles: particles summarizing belief about infection status
   cur_group: group currently considered to add to former groups.
   cur_positives: stores which particles would test positive w.r.t cur_group
   previous_groups_prob_particles_states: particles x test outcome probabilities
   previous_groups_cumcond_entropy: previous conditional entropies
   sensitivity: value (vector) of sensitivity(-ies depending on group size).
   specificity: value (vector) of specificity(-ies depending on group size).
   backtracking: (bool), True if removing rather than adding individuals.

  Returns:
    cur_group : group updated with best choice
    cur_positives : bool vector keeping trace of whether particles
                             would test or not positive
    new_objective : MI reached with this new group
    prob_particles_states : if cur_group were to be selected, this matrix
      would keep track of probability of seeing one of 2^j possible test
      outcomes across all particles.
    new_cond_entropy : if cur_group were to be selected, this constant would be
      added to store the conditional entropies of all tests carried out thusfar
  """
    group_size = np.atleast_1d(np.sum(cur_group) + 1 - 2 * backtracking)
    sensitivity = utils.select_from_sizes(sensitivity, group_size)
    specificity = utils.select_from_sizes(specificity, group_size)
    if backtracking:
        # if backtracking, we recompute the truth table for all proposed groups,
        # namely run the np.dot below
        # TODO(cuturi)? If we switch to integer arithmetic we may be able to
        # save on this iteration by keeping track of how many positives there
        # are, and not just on whether there is or not one positive.
        candidate_groups = np.logical_not(
            add_ones_to_line(np.logical_not(cur_group)))
        positive_in_groups = np.dot(candidate_groups,
                                    np.transpose(particles)) > 0
    else:
        # in forward mode, candidate groups are recovered by adding
        # a 1 instead of zeros. Therefore, we can use previous vector of positive
        # in groups to simply compute all positive in groups for candidates
        indices_of_false_in_cur_group, = np.where(np.logical_not(cur_group))
        positive_in_groups = np.logical_or(
            cur_positives[:, np.newaxis],
            particles[:, indices_of_false_in_cur_group])
        # recover a candidates x n_particles matrix
        positive_in_groups = np.transpose(positive_in_groups)

    entropy_spec = metrics.binary_entropy(specificity)
    gamma = metrics.binary_entropy(sensitivity) - entropy_spec
    cond_entropy = previous_groups_cumcond_entropy + entropy_spec + gamma * np.sum(
        particle_weights[np.newaxis, :] * positive_in_groups, axis=1)
    rho = specificity + sensitivity - 1

    # positive_in_groups defines probability of two possible outcomes for the test
    # of each new candidate group.
    probabilities_new_test = np.stack(
        (specificity - rho * positive_in_groups,
         1 - specificity + rho * positive_in_groups),
        axis=-1)
    # we now incorporate previous probability of all previous groups added so far
    # and expand x 2 the state space of possible test results.
    new_plus_previous_groups_prob_particles_states = np.concatenate(
        (probabilities_new_test[:, :, 0][:, :, np.newaxis] *
         previous_groups_prob_particles_states[np.newaxis, :, :],
         probabilities_new_test[:, :, 1][:, :, np.newaxis] *
         previous_groups_prob_particles_states[np.newaxis, :, :]),
        axis=2)

    # average over particles to recover probability of all 2^j possible
    # test results
    new_plus_previous_groups_prob_states = np.sum(
        particle_weights[np.newaxis, :, np.newaxis] *
        new_plus_previous_groups_prob_particles_states,
        axis=1)

    whole_entropy = metrics.entropy(new_plus_previous_groups_prob_states,
                                    axis=1)

    # exhaustive way to compute cond entropy, useful to check
    # computations.
    # cond_entropy_old = np.sum(
    #     particle_weights[np.newaxis, :] *
    #     entropy(new_plus_previous_groups_prob_particles_states, axis=2),
    #     axis=1)
    objectives = whole_entropy - cond_entropy

    # greedy selection of largest/smallest value
    index = np.argmax(objectives)

    if backtracking:
        # return most promising group by recovering it from the matrix directly
        logging.info('backtracking, candidate_groups size: %i',
                     candidate_groups.shape)
        cur_group = candidate_groups[index, :]

    else:
        # return most promising group by adding a 1
        cur_group = jax.ops.index_update(cur_group,
                                         indices_of_false_in_cur_group[index],
                                         True)

    # refresh the status of vector positives
    cur_positives = positive_in_groups[index, :]
    new_objective = objectives[index]
    prob_particles_states = new_plus_previous_groups_prob_particles_states[
        index, :, :]
    new_cond_entropy = cond_entropy[index]
    return (cur_group, cur_positives, new_objective, prob_particles_states,
            new_cond_entropy)
Exemplo n.º 10
0
def loopy_belief_propagation(tests,
                             groups,
                             base_infection_rate,
                             sensitivity,
                             specificity,
                             n_iter=200):
    """LBP approach to compute approximate marginal of posterior distribution.

  Outputs marginal approximation of posterior distribution using all tests'
  history and test setup parameters.

  Args:
    tests : np.ndarray<bool>[n_groups] results stored as a vector of booleans
    groups : np.ndarray<bool>[n_groups, n_patients] matrix of groups
    base_infection_rate : np.ndarray<float> [1,] or [n_patients,] infection rate
    sensitivity : np.ndarray<float> [?,] of sensitivity per group size
    specificity : np.ndarray<float> [?,] of specificity per group size
    n_iter : int, number of loops in belief propagation.
  Returns:
    a vector of marginal probabilities for all n_patients.
  """
    if np.size(groups) == 0:
        if np.size(base_infection_rate) == 1:  # only one rate
            return base_infection_rate * np.ones(groups.shape[1])
        elif np.size(base_infection_rate) == groups.shape[1]:
            return base_infection_rate
        else:
            raise ValueError(
                "Improper size for vector of base infection rates")

    n_groups, n_patients = groups.shape
    mu = -jax.scipy.special.logit(base_infection_rate)

    groups_size = np.sum(groups, axis=1)
    sensitivity = utils.select_from_sizes(sensitivity, groups_size)
    specificity = utils.select_from_sizes(specificity, groups_size)
    gamma0 = np.log(sensitivity + specificity - 1) - np.log(1 - sensitivity)
    gamma1 = np.log(sensitivity + specificity - 1) - np.log(sensitivity)
    gamma = tests * gamma1 + (1 - tests) * gamma0
    test_sign = 1 - 2 * tests[:, np.newaxis]

    # Initialization
    alphabeta = np.zeros((2, n_groups, n_patients))

    # lbp loop
    def lbp_loop(_, alphabeta):
        alpha = alphabeta[0, :, :]
        beta = alphabeta[1, :, :]
        # update alpha
        beta_bar = np.sum(beta, axis=0)
        alpha = jax.nn.log_sigmoid(beta_bar - beta + mu)
        alpha *= groups

        # update beta
        alpha_bar = np.sum(alpha, axis=1, keepdims=True)
        beta = np.log1p(test_sign *
                        np.exp(-alpha + alpha_bar + gamma[:, np.newaxis]))
        beta *= groups
        return np.stack((alpha, beta), axis=0)

    # Run LBP loop
    alphabeta = jax.lax.fori_loop(0, n_iter, lbp_loop, alphabeta)

    # return marginals
    beta_bar = np.sum(alphabeta[1, :, :], axis=0)
    return jax.scipy.special.expit(-beta_bar - mu)