def log_likelihood(particles, test_results, groups, log_specificity, log_1msensitivity): """Computes individual (parallel) log_likelihood of k_groups test results. Args: particles: np.ndarray<bool>[n_particles, n_patients]. Each one is a possible scenario of a disease status of n patients. test_results: np.ndarray<bool>[n_groups] the results given by the wet lab for each of the tested groups. groups: np.ndarray<bool>[num_groups, num_patients] the definition of the group that were tested. log_specificity: np.ndarray. Depending on the configuration, it can be an array of size one or more if we have different sensitivities per group size. log_1msensitivity: np.ndarray. Depending on the configuration, it can be an array of size one or more if we have different specificities per group size. Returns: The log likelihood of the particles given the test results. """ positive_in_groups = np.dot(groups, np.transpose(particles)) > 0 group_sizes = np.sum(groups, axis=1) log_specificity = utils.select_from_sizes(log_specificity, group_sizes) log_1msensitivity = utils.select_from_sizes(log_1msensitivity, group_sizes) logit_specificity = special.logit(np.exp(log_specificity)) logit_sensitivity = -special.logit(np.exp(log_1msensitivity)) gamma = log_1msensitivity - log_specificity add_logits = logit_specificity + logit_sensitivity ll = np.sum(positive_in_groups * (gamma + test_results * add_logits)[:, np.newaxis], axis=0) return ll + np.sum(log_specificity - test_results * logit_specificity)
def next_best_group(particle_weights, particles, previous_groups, cur_group, sensitivity, specificity, utility_fun, backtracking): """Performs greedy utility optimization to compute the next best group. Given a set of groups previous_groups, and a current candidate group cur_group, this function computes the utility of the combination of previous_groups and cur_group modified by adding (if backtracking = True) or adding (if backtracking = False) on element to cur_group, and returns the combination with largest utility. Args: particle_weights: weights of particles particles: particles summarizing belief about infection status previous_groups: groups already chosen cur_group: group that we wish to optimize sensitivity: value (vector) of sensitivity(-ies depending on group size). specificity: value (vector) of specificity(-ies depending on group size). utility_fun: function to compute the utility of a set of groups backtracking: (bool), True if removing rather than adding individuals. Returns: best_group : cur_group updated with best choice utility: utility of best_group """ if backtracking: # Backward mode: test groups obtained by removing an item to cur_group candidate_groups = np.logical_not( mutual_information.add_ones_to_line(np.logical_not(cur_group))) else: # Forward mode: test groups obtained by adding an item to cur_group candidate_groups = mutual_information.add_ones_to_line(cur_group) n_candidates = candidate_groups.shape[0] # Combine past groups with candidate groups candidate_sets = np.concatenate( (np.repeat(previous_groups[:, :, np.newaxis], n_candidates, axis=2), np.expand_dims(np.transpose(candidate_groups), axis=0)), axis=0) # Compute utility of each candidate group group_sizes = np.sum(candidate_sets[:, :, 0], axis=1) group_sensitivities = utils.select_from_sizes(sensitivity, group_sizes) group_specificities = utils.select_from_sizes(specificity, group_sizes) group_util_fun = lambda x: group_utility(particle_weights, particles, x, group_sensitivities, group_specificities, utility_fun) mgroup_util_fun = jax.vmap(group_util_fun, in_axes=2) objectives = mgroup_util_fun(candidate_sets) # Greedy selection of largest value index = np.argmax(objectives) return (candidate_groups[index, :], objectives[index])
def __call__(self, rng, state): """Produces new groups and adds them to state's stack.""" p_weights, particles = state.particle_weights, state.particles marginal = onp.array(np.sum(p_weights[:, np.newaxis] * particles, axis=0)) marginal = onp.squeeze(marginal) not_cut_ids, = onp.where(np.logical_and( marginal < self.cut_off_high, marginal > self.cut_off_low)) marginal = marginal[not_cut_ids] sorted_ids = onp.argsort(marginal) sorted_marginal = onp.array(marginal[sorted_ids]) n_p = 0 n_r = marginal.size if n_r == 0: # no one left to test in between thresholds state.all_cleared = True return state all_new_groups = np.empty((0, state.num_patients), dtype=bool) while n_p < marginal.size: index_max = onp.amin((n_r, state.max_group_size)) group_sizes = onp.arange(1, index_max + 1) cum_prod_prob = onp.cumprod(1 - sorted_marginal[n_p:(n_p + index_max)]) # formula below is only valid for group_size > 1, # corrected below for a group of size 1. sensitivity = onp.array( utils.select_from_sizes(state.prior_sensitivity, group_sizes)) specificity = onp.array( utils.select_from_sizes(state.prior_specificity, group_sizes)) exp_div_size = ( 1 + group_sizes * (sensitivity + (1 - sensitivity - specificity) * cum_prod_prob) ) / group_sizes exp_div_size[0] = 1 # adjusted cost for one patient is one. opt_size_group = onp.argmin(exp_div_size) + 1 new_group = onp.zeros((1, state.num_patients)) new_group[0, not_cut_ids[sorted_ids[n_p:n_p + opt_size_group]]] = True all_new_groups = np.concatenate((all_new_groups, new_group), axis=0) n_p = n_p + opt_size_group n_r = n_r - opt_size_group # sample randomly extra_tests_needed groups in modified case, all in # regular ID. # Because ID is a Dorfman type approach, it might be followed # by exhaustive splitting, which requires to keep track of groups # that tested positives to retest them. all_new_groups = jax.random.permutation(rng, all_new_groups) if self.modified: # in the case where we use modified ID, we only subsample a few groups. # one needs to take care of requesting to keep track of positives. new_groups = all_new_groups[0:state.extra_tests_needed].astype(bool) state.add_groups_to_test(new_groups, results_need_clearing=True) else: # with regular ID we add all groups at once. state.add_groups_to_test(all_new_groups.astype(bool), results_need_clearing=True) return state
def test_select_from_sizes(self): sizes = np.array([1, 4, 8, 2]) prior = np.array([0.1]) self.assertArraysAllClose(utils.select_from_sizes(prior, sizes), prior[0] * np.ones_like(sizes), check_dtypes=True) prior = np.array([0.4, 0.2, 0.1]) expected = np.array([0.4, 0.1, 0.1, 0.2]) self.assertArraysAllClose(utils.select_from_sizes(prior, sizes), expected, check_dtypes=True)
def group_tests_outputs(self, rng, groups): """Produces test outputs taking into account test errors.""" n_groups = groups.shape[0] group_disease_indicator = np.dot(groups, self.diseased) > 0 group_sizes = np.sum(groups, axis=1) specificity = utils.select_from_sizes(self._specificity, group_sizes) sensitivity = utils.select_from_sizes(self._sensitivity, group_sizes) draw_u = jax.random.uniform(rng, shape=(n_groups, )) delta = sensitivity - specificity not_flip_proba = group_disease_indicator * delta + specificity test_flipped = draw_u > not_flip_proba return np.logical_xor(group_disease_indicator, test_flipped)
def __call__(self, rng, state): """Produces new groups and adds them to state's stack.""" p_weights, particles = state.particle_weights, state.particles marginal = onp.array( np.sum(p_weights[:, np.newaxis] * particles, axis=0)) marginal = onp.squeeze(marginal) not_cut_ids, = onp.where( np.logical_and(marginal < self.cut_off_high, marginal > self.cut_off_low)) marginal = marginal[not_cut_ids] sorted_ids = onp.argsort(marginal) sorted_marginal = onp.array(marginal[sorted_ids]) n_p = 0 n_r = marginal.size if n_r == 0: # no one left to test in between thresholds state.all_cleared = True return state all_new_groups = np.empty((0, state.num_patients), dtype=bool) while n_p < marginal.size: index_max = onp.amin((n_r, state.max_group_size)) group_sizes = onp.arange(1, index_max + 1) cum_prod_prob = onp.cumprod(1 - sorted_marginal[n_p:(n_p + index_max)]) # formula below is only valid for group_size > 1, # corrected below for a group of size 1. sensitivity = onp.array( utils.select_from_sizes(state.prior_sensitivity, group_sizes)) specificity = onp.array( utils.select_from_sizes(state.prior_specificity, group_sizes)) exp_div_size = (1 + group_sizes * (sensitivity + (1 - sensitivity - specificity) * cum_prod_prob) ) / group_sizes exp_div_size[0] = 1 # adjusted cost for one patient is one. opt_size_group = onp.argmin(exp_div_size) + 1 new_group = onp.zeros((1, state.num_patients)) new_group[0, not_cut_ids[sorted_ids[n_p:n_p + opt_size_group]]] = True all_new_groups = np.concatenate((all_new_groups, new_group), axis=0) n_p = n_p + opt_size_group n_r = n_r - opt_size_group # sample randomly extra_tests_needed groups all_new_groups = jax.random.permutation(rng, all_new_groups) new_groups = np.array(all_new_groups[0:state.extra_tests_needed], dtype=bool) state.add_groups_to_test(new_groups) return state
def get_groups(self, rng, state): """Produces random design matrix fixed number of 1s per line. Args: rng: np.ndarray<int>[2]: the random key. state: the current state.State of the system. Returns: A np.array<bool>[num_groups, patients]. """ if self.group_size is None: if np.size(state.prior_infection_rate) == 1: # candidate group sizes group_sizes = np.arange(state.max_group_size) + 1 sensitivity = utils.select_from_sizes(state.prior_sensitivity, group_sizes) specificity = utils.select_from_sizes(state.prior_specificity, group_sizes) rho = specificity + sensitivity - 1 utility_size_groups = ( sensitivity - rho * (1 - state.prior_infection_rate)**group_sizes - 0.5)**2 group_size = group_sizes[np.argmin(utility_size_groups)] else: group_size = state.max_group_size else: group_size = self.group_size group_size = int(np.squeeze(group_size)) new_groups = np.empty((0, state.num_patients), dtype=bool) for _ in range(state.extra_tests_needed): rng, rng_shuffle = jax.random.split(rng, 2) vec = np.zeros((1, state.num_patients), dtype=bool) idx = jax.random.permutation(rng_shuffle, np.arange(state.num_patients)) vec = jax.ops.index_update(vec, [0, idx[0:group_size]], True) new_groups = np.concatenate((new_groups, vec), axis=0) return new_groups
def loopy_belief_propagation(tests, groups, base_infection_rate, sensitivity, specificity, min_iterations, max_iterations, atol): """LBP approach to compute approximate marginal of posterior distribution. Outputs marginal approximation of posterior distribution using all tests' history and test setup parameters. Args: tests : np.ndarray<bool>[n_groups] results stored as a vector of booleans groups : np.ndarray<bool>[n_groups, n_patients] matrix of groups base_infection_rate : np.ndarray<float> [1,] or [n_patients,] infection rate sensitivity : np.ndarray<float> [?,] of sensitivity per group size specificity : np.ndarray<float> [?,] of specificity per group size min_iterations: int, min number of belief propagation iterations max_iterations: int, max number of belief propagation iterations atol: float, elementwise tolerance for the difference between two consecutive iterations. Returns: two vectors of marginal probabilities for all n_patients, obtained as consecutive evaluations of the LBP algorithm after n_iter and n_iter+1 iterations. """ n_groups, n_patients = groups.shape if np.size(groups) == 0: if np.size(base_infection_rate) == 1: # only one rate marginal = base_infection_rate * np.ones(n_patients) return marginal, 0 elif np.size(base_infection_rate) == n_patients: return base_infection_rate, 0 else: raise ValueError("Improper size for vector of base infection rates") mu = -jax.scipy.special.logit(base_infection_rate) groups_size = np.sum(groups, axis=1) sensitivity = utils.select_from_sizes(sensitivity, groups_size) specificity = utils.select_from_sizes(specificity, groups_size) gamma0 = np.log(sensitivity + specificity - 1) - np.log(1 - sensitivity) gamma1 = np.log(sensitivity + specificity - 1) - np.log(sensitivity) gamma = tests * gamma1 + (1 - tests) * gamma0 test_sign = 1 - 2 * tests[:, np.newaxis] # Initialization alphabeta = np.zeros((2, n_groups, n_patients)) alpha_beta_iteration = [alphabeta, 0] # return marginal from alphabeta def marginal_from_alphabeta(alphabeta): beta_bar = np.sum(alphabeta[1, :, :], axis=0) return jax.scipy.special.expit(-beta_bar - mu) # lbp loop def lbp_loop(_, alphabeta): alpha = alphabeta[0, :, :] beta = alphabeta[1, :, :] # update alpha beta_bar = np.sum(beta, axis=0) alpha = jax.nn.log_sigmoid(beta_bar - beta + mu) alpha *= groups # update beta alpha_bar = np.sum(alpha, axis=1, keepdims=True) beta = np.log1p(test_sign * np.exp(-alpha + alpha_bar + gamma[:, np.newaxis])) beta *= groups return np.stack((alpha, beta), axis=0) def cond_fun(alpha_beta_iteration): alphabeta, iteration = alpha_beta_iteration marginal = marginal_from_alphabeta(alphabeta) marginal_plus_one_iteration = marginal_from_alphabeta( lbp_loop(0, alphabeta)) converged = np.allclose(marginal, marginal_plus_one_iteration, atol=atol) return (not converged) and (iteration < max_iterations) def body_fun(alpha_beta_iteration): alphabeta, iteration = alpha_beta_iteration alphabeta = jax.lax.fori_loop(0, min_iterations, lbp_loop, alphabeta) iteration += min_iterations return [alphabeta, iteration] # Run LBP while loop while cond_fun(alpha_beta_iteration): alpha_beta_iteration = body_fun(alpha_beta_iteration) alphabeta, _ = alpha_beta_iteration # Compute two consecutive marginals marginal = marginal_from_alphabeta(alphabeta) marginal_plus_one_iteration = marginal_from_alphabeta(lbp_loop(0, alphabeta)) return marginal, np.amax(np.abs(marginal - marginal_plus_one_iteration))
def joint_mi_criterion_mg(particle_weights, particles, cur_group, cur_positives, previous_groups_prob_particles_states, previous_groups_cumcond_entropy, sensitivity, specificity, backtracking): """Compares the benefit of adding one group to previously selected ones. Groups are formed iteratively by considering all possible individuals that can be considered to add (or remove if backtracking). If the sensitivity and/or specificity parameters are group size dependent, we take that into account in our optimization. Here all groups have the same size, hence they all share the same specificity / sensitivity setting. We just replace the vector by its value at the appropriate coordinate. The size of the group considered here will be the size of cur_group + 1 if going forward / -1 if backtracking. Args: particle_weights: weights of particles particles: particles summarizing belief about infection status cur_group: group currently considered to add to former groups. cur_positives: stores which particles would test positive w.r.t cur_group previous_groups_prob_particles_states: particles x test outcome probabilities previous_groups_cumcond_entropy: previous conditional entropies sensitivity: value (vector) of sensitivity(-ies depending on group size). specificity: value (vector) of specificity(-ies depending on group size). backtracking: (bool), True if removing rather than adding individuals. Returns: cur_group : group updated with best choice cur_positives : bool vector keeping trace of whether particles would test or not positive new_objective : MI reached with this new group prob_particles_states : if cur_group were to be selected, this matrix would keep track of probability of seeing one of 2^j possible test outcomes across all particles. new_cond_entropy : if cur_group were to be selected, this constant would be added to store the conditional entropies of all tests carried out thusfar """ group_size = np.atleast_1d(np.sum(cur_group) + 1 - 2 * backtracking) sensitivity = utils.select_from_sizes(sensitivity, group_size) specificity = utils.select_from_sizes(specificity, group_size) if backtracking: # if backtracking, we recompute the truth table for all proposed groups, # namely run the np.dot below # TODO(cuturi)? If we switch to integer arithmetic we may be able to # save on this iteration by keeping track of how many positives there # are, and not just on whether there is or not one positive. candidate_groups = np.logical_not( add_ones_to_line(np.logical_not(cur_group))) positive_in_groups = np.dot(candidate_groups, np.transpose(particles)) > 0 else: # in forward mode, candidate groups are recovered by adding # a 1 instead of zeros. Therefore, we can use previous vector of positive # in groups to simply compute all positive in groups for candidates indices_of_false_in_cur_group, = np.where(np.logical_not(cur_group)) positive_in_groups = np.logical_or( cur_positives[:, np.newaxis], particles[:, indices_of_false_in_cur_group]) # recover a candidates x n_particles matrix positive_in_groups = np.transpose(positive_in_groups) entropy_spec = metrics.binary_entropy(specificity) gamma = metrics.binary_entropy(sensitivity) - entropy_spec cond_entropy = previous_groups_cumcond_entropy + entropy_spec + gamma * np.sum( particle_weights[np.newaxis, :] * positive_in_groups, axis=1) rho = specificity + sensitivity - 1 # positive_in_groups defines probability of two possible outcomes for the test # of each new candidate group. probabilities_new_test = np.stack( (specificity - rho * positive_in_groups, 1 - specificity + rho * positive_in_groups), axis=-1) # we now incorporate previous probability of all previous groups added so far # and expand x 2 the state space of possible test results. new_plus_previous_groups_prob_particles_states = np.concatenate( (probabilities_new_test[:, :, 0][:, :, np.newaxis] * previous_groups_prob_particles_states[np.newaxis, :, :], probabilities_new_test[:, :, 1][:, :, np.newaxis] * previous_groups_prob_particles_states[np.newaxis, :, :]), axis=2) # average over particles to recover probability of all 2^j possible # test results new_plus_previous_groups_prob_states = np.sum( particle_weights[np.newaxis, :, np.newaxis] * new_plus_previous_groups_prob_particles_states, axis=1) whole_entropy = metrics.entropy(new_plus_previous_groups_prob_states, axis=1) # exhaustive way to compute cond entropy, useful to check # computations. # cond_entropy_old = np.sum( # particle_weights[np.newaxis, :] * # entropy(new_plus_previous_groups_prob_particles_states, axis=2), # axis=1) objectives = whole_entropy - cond_entropy # greedy selection of largest/smallest value index = np.argmax(objectives) if backtracking: # return most promising group by recovering it from the matrix directly logging.info('backtracking, candidate_groups size: %i', candidate_groups.shape) cur_group = candidate_groups[index, :] else: # return most promising group by adding a 1 cur_group = jax.ops.index_update(cur_group, indices_of_false_in_cur_group[index], True) # refresh the status of vector positives cur_positives = positive_in_groups[index, :] new_objective = objectives[index] prob_particles_states = new_plus_previous_groups_prob_particles_states[ index, :, :] new_cond_entropy = cond_entropy[index] return (cur_group, cur_positives, new_objective, prob_particles_states, new_cond_entropy)
def loopy_belief_propagation(tests, groups, base_infection_rate, sensitivity, specificity, n_iter=200): """LBP approach to compute approximate marginal of posterior distribution. Outputs marginal approximation of posterior distribution using all tests' history and test setup parameters. Args: tests : np.ndarray<bool>[n_groups] results stored as a vector of booleans groups : np.ndarray<bool>[n_groups, n_patients] matrix of groups base_infection_rate : np.ndarray<float> [1,] or [n_patients,] infection rate sensitivity : np.ndarray<float> [?,] of sensitivity per group size specificity : np.ndarray<float> [?,] of specificity per group size n_iter : int, number of loops in belief propagation. Returns: a vector of marginal probabilities for all n_patients. """ if np.size(groups) == 0: if np.size(base_infection_rate) == 1: # only one rate return base_infection_rate * np.ones(groups.shape[1]) elif np.size(base_infection_rate) == groups.shape[1]: return base_infection_rate else: raise ValueError( "Improper size for vector of base infection rates") n_groups, n_patients = groups.shape mu = -jax.scipy.special.logit(base_infection_rate) groups_size = np.sum(groups, axis=1) sensitivity = utils.select_from_sizes(sensitivity, groups_size) specificity = utils.select_from_sizes(specificity, groups_size) gamma0 = np.log(sensitivity + specificity - 1) - np.log(1 - sensitivity) gamma1 = np.log(sensitivity + specificity - 1) - np.log(sensitivity) gamma = tests * gamma1 + (1 - tests) * gamma0 test_sign = 1 - 2 * tests[:, np.newaxis] # Initialization alphabeta = np.zeros((2, n_groups, n_patients)) # lbp loop def lbp_loop(_, alphabeta): alpha = alphabeta[0, :, :] beta = alphabeta[1, :, :] # update alpha beta_bar = np.sum(beta, axis=0) alpha = jax.nn.log_sigmoid(beta_bar - beta + mu) alpha *= groups # update beta alpha_bar = np.sum(alpha, axis=1, keepdims=True) beta = np.log1p(test_sign * np.exp(-alpha + alpha_bar + gamma[:, np.newaxis])) beta *= groups return np.stack((alpha, beta), axis=0) # Run LBP loop alphabeta = jax.lax.fori_loop(0, n_iter, lbp_loop, alphabeta) # return marginals beta_bar = np.sum(alphabeta[1, :, :], axis=0) return jax.scipy.special.expit(-beta_bar - mu)