def model_evaluation_loop( trainer, eval_encoder, counts_eval, encoder_eval_name, n_iwsamples=5000, n_picks=N_PICKS, n_cells=N_CELLS, do_observed_library=True, n_samples_queries=200, ): test_post = trainer.test_set.sequential() mdl = trainer.model train_post = trainer.train_set.sequential() # *** IWELBO 5000 logging.info("IWELBO 5K estimation...") multicounts_eval = None if counts_eval is not None: multicounts_eval = (n_iwsamples / counts_eval.sum()) * counts_eval multicounts_eval = multicounts_eval.astype(int) print(multicounts_eval) iwelbo5000_loss = ( test_post.getter( keys=["IWELBO"], n_samples=n_iwsamples, batch_size=64, do_observed_library=do_observed_library, encoder_key=encoder_eval_name, counts=multicounts_eval, z_encoder=eval_encoder, )["IWELBO"] .cpu() .numpy() ).mean() iwelbo5000train_loss = ( train_post.getter( keys=["IWELBO"], n_samples=n_iwsamples, batch_size=64, do_observed_library=do_observed_library, encoder_key=encoder_eval_name, counts=multicounts_eval, z_encoder=eval_encoder, )["IWELBO"] .cpu() .numpy() ).mean() # *** KHAT multicounts_eval = None if counts_eval is not None: multicounts_eval = (n_iwsamples / counts_eval.sum()) * counts_eval multicounts_eval = multicounts_eval.astype(int) log_ratios = [] n_samples_total = 1e4 n_samples_per_pass = ( 300 if encoder_eval_name == "default" else multicounts_eval.sum() ) n_iter = int(n_samples_total / n_samples_per_pass) logging.info("Multicounts: {}".format(multicounts_eval)) logging.info( "Khat computation using {} samples".format(n_samples_per_pass * n_iter) ) for _ in tqdm(range(n_iter)): with torch.no_grad(): out = mdl( X_U, LOCAL_L_MEAN, LOCAL_L_VAR, loss_type=None, n_samples=n_samples_per_pass, reparam=False, encoder_key=encoder_eval_name, counts=multicounts_eval, do_observed_library=do_observed_library, z_encoder=eval_encoder, ) out = out["log_ratio"].cpu() log_ratios.append(out) log_ratios = torch.cat(log_ratios) wi = torch.softmax(log_ratios, 0) ess_here = 1.0 / (wi ** 2).sum(0) _, khats = psislw(log_ratios.T.clone()) logging.info("FDR/TPR ...") train_indices = train_post.indices y_train = Y[train_indices] decision_rule_fdr10 = np.zeros(n_picks) decision_rule_fdr05 = np.zeros(n_picks) decision_rule_fdr20 = np.zeros(n_picks) decision_rule_tpr10 = np.zeros(n_picks) decision_rule_fdr10_plugin = np.zeros(n_picks) decision_rule_tpr10_plugin = np.zeros(n_picks) fdr_gt = np.zeros((N_GENES, n_picks)) pe_fdr = np.zeros((N_GENES, n_picks)) fdr_gt_plugin = np.zeros((N_GENES, n_picks)) pe_fdr_plugin = np.zeros((N_GENES, n_picks)) y_preds_is = np.zeros((N_GENES, n_picks)) y_preds_plugin = np.zeros((N_GENES, n_picks)) y_gt = np.zeros((N_GENES, n_picks)) np.random.seed(42) for ipick in range(n_picks): print(np.unique(y_train)) if DO_POISSON: where_a = np.where(y_train == 0)[0] where_b = np.where(y_train == 1)[0] else: where_a = np.where(y_train == 1)[0] where_b = np.where(y_train == 2)[0] samples_a = np.random.choice(where_a, size=n_cells) samples_b = np.random.choice(where_b, size=n_cells) samples_a_overall = train_indices[samples_a] samples_b_overall = train_indices[samples_b] h_a = h[samples_a_overall] h_b = h[samples_b_overall] lfc_loc = h_a - h_b is_significant_de_local = (lfc_loc.abs() >= 0.5).float().mean(0) >= 0.5 is_significant_de_local = is_significant_de_local.numpy() logging.info("IS flavor ...") multicounts_eval = None if counts_eval is not None: multicounts_eval = (n_samples_queries / counts_eval.sum()) * counts_eval multicounts_eval = multicounts_eval.astype(int) y_pred_is = get_predictions( train_post, samples_a, samples_b, encoder_key=encoder_eval_name, counts=multicounts_eval, n_post_samples=n_samples_queries, importance_sampling=True, do_observed_library=do_observed_library, encoder=eval_encoder, ) y_pred_is = y_pred_is.numpy() true_fdr_arr = true_fdr(y_true=is_significant_de_local, y_pred=y_pred_is) pe_fdr_arr, y_decision_rule = posterior_expected_fdr(y_pred=y_pred_is) # Fdr related fdr_gt[:, ipick] = true_fdr_arr pe_fdr[:, ipick] = pe_fdr_arr _, y_decision_rule10 = posterior_expected_fdr(y_pred=y_pred_is, fdr_target=0.1) decision_rule_fdr10[ipick] = fdr_score( y_true=is_significant_de_local, y_pred=y_decision_rule10 ) decision_rule_tpr10[ipick] = tpr_score( y_true=is_significant_de_local, y_pred=y_decision_rule10 ) _, decision_rule_fdr05 = posterior_expected_fdr( y_pred=y_pred_is, fdr_target=0.05 ) decision_rule_fdr05[ipick] = fdr_score( y_true=is_significant_de_local, y_pred=y_decision_rule10 ) _, decision_rule_fdr20 = posterior_expected_fdr( y_pred=y_pred_is, fdr_target=0.2 ) decision_rule_fdr20[ipick] = fdr_score( y_true=is_significant_de_local, y_pred=y_decision_rule10 ) logging.info("Plugin flavor ...") y_pred_plugin = get_predictions( train_post, samples_a, samples_b, encoder_key=encoder_eval_name, counts=multicounts_eval, n_post_samples=n_samples_queries, importance_sampling=False, do_observed_library=do_observed_library, encoder=eval_encoder, ) y_pred_plugin = y_pred_plugin.numpy() true_fdr_plugin_arr = true_fdr( y_true=is_significant_de_local, y_pred=y_pred_plugin ) fdr_gt_plugin[:, ipick] = true_fdr_plugin_arr pe_fdr_plugin_arr, y_decision_rule = posterior_expected_fdr( y_pred=y_pred_plugin ) pe_fdr_plugin[:, ipick] = pe_fdr_plugin_arr _, y_decision_rule10 = posterior_expected_fdr( y_pred=y_pred_plugin, fdr_target=0.1 ) decision_rule_fdr10_plugin[ipick] = fdr_score( y_true=is_significant_de_local, y_pred=y_decision_rule10 ) decision_rule_tpr10_plugin[ipick] = tpr_score( y_true=is_significant_de_local, y_pred=y_decision_rule10 ) y_preds_is[:, ipick] = y_pred_is y_preds_plugin[:, ipick] = y_pred_plugin y_gt[:, ipick] = is_significant_de_local prauc_plugin = np.array( [prauc(y=y_it, pred=y_pred) for (y_it, y_pred) in zip(y_gt.T, y_preds_plugin.T)] ) # prauc_is = None prauc_is = np.array( [prauc(y=y_it, pred=y_pred) for (y_it, y_pred) in zip(y_gt.T, y_preds_is.T)] ) all_fdr_gt = np.array(fdr_gt) all_pe_fdr = np.array(pe_fdr) fdr_gt_plugin = np.array(fdr_gt_plugin) fdr_diff = all_fdr_gt - all_pe_fdr loop_res = dict( iwelbo5000=np.array(iwelbo5000_loss), iwelbo5000_train=np.array(iwelbo5000train_loss), pe_fdr_plugin=pe_fdr_plugin, fdr_gt_plugin=fdr_gt_plugin, all_fdr_gt=all_fdr_gt, all_pe_fdr=all_pe_fdr, l1_fdr=np.linalg.norm(fdr_diff, axis=0, ord=1), l2_fdr=np.linalg.norm(fdr_diff, axis=0, ord=2), y_gt=y_gt, y_pred_is=y_preds_is, y_pred_plugin=y_preds_plugin, prauc_plugin=prauc_plugin, prauc_is=prauc_is, fdr_controlled_fdr10=np.array(decision_rule_fdr10), fdr_controlled_fdr05=np.array(decision_rule_fdr05), fdr_controlled_fdr20=np.array(decision_rule_fdr20), fdr_controlled_tpr10=np.array(decision_rule_tpr10), fdr_controlled_fdr10_plugin=np.array(decision_rule_fdr10_plugin), fdr_controlled_tpr10_plugin=np.array(decision_rule_tpr10_plugin), khat_10000=np.array(khats), lfc_gt=lfc_loc, ess=ess_here.numpy(), ) return loop_res
# plt.clf() # trainer.test_set.elbo() seq = trainer.test_set.sequential(batch_size=10) start = time.time() zs, logws = ais_trajectory( model, seq, n_sample=100, n_latent=dim_z, is_exp1=True ) # Shapes n_samples, n_batch, (n_latent) iwelbo += [torch.logsumexp(logws, dim=0) - np.log(100)] TIME_100_SAMPLES_AIS.append(time.time() - start) # trainer.test_set.exact_log_likelihood() cubo += [0.5 * (torch.logsumexp(2 * logws, dim=0) - np.log(100))] # Input should be n_obs, n_samples log_ratios = logws.T _, khat_vals = psislw(log_ratios) khat.append(khat_vals) a_2_it = [] res = { "CONFIGURATION": (learn_var, loss_gen, loss_wvar), "learn_var": learn_var, "loss_gen": loss_gen, "loss_wvar": loss_wvar, "n_hidden": n_hidden, # "IWELBO": (np.mean(iwelbo), np.std(iwelbo)), # "CUBO": (np.mean(cubo), np.std(cubo)), # "L1 loss gen_variance_dis": (np.mean(l1_gen_dis), np.std(l1_gen_dis)), # "L1 loss gen_variance_sign": (np.mean(l1_gen_sign), np.std(l1_gen_sign)), # "L1 loss post_variance_dis": (np.mean(l1_post_dis), np.std(l1_post_dis)), # "L1 loss post_variance_sign": (np.mean(l1_post_sign), np.std(l1_post_sign)),
if do_defensive: log_ratio = out["log_ratio"].cpu() else: log_ratio = (out["log_px_z"] + out["log_pz2"] + out["log_pc"] + out["log_pz1_z2"] - out["log_qz1_x"] - out["log_qc_z1"] - out["log_qz2_z1"]).cpu() qc_z_here = out["log_qc_z1"].cpu().exp() qc_z.append(qc_z_here) log_ratios.append(log_ratio) # Concatenation over samples log_ratios = torch.cat(log_ratios, 1) qc_z = torch.cat(qc_z, 1) log_ratios_sum = (log_ratios * qc_z).sum(0) # Sum over labels wi = torch.softmax(log_ratios_sum, 0) _, khats = psislw(log_ratios_sum.T.clone()) # _, khats = psislw(log_ratios.view(-1, len(x_u)).numpy()) khat1e4.append(khats) except Exception as e: raise e print(e) pass res = { "CONFIGURATION": scenario, "LOSS_GEN": loss_gen, "LOSS_WVAR": loss_wvar, "BATCH_SIZE": BATCH_SIZE, "N_SAMPLES_TRAIN": n_samples_train, "N_SAMPLES_WTHETA": n_samples_wtheta,
do_observed_library=True, ) # Khat # start = time.time() post = trainer.create_posterior(model=mdl, gene_dataset=DATASET, indices=SAMPLE_IDX).sequential(batch_size=2) n_post_samples = 50 zs, logws = ais_trajectory(model=mdl, loader=post, n_sample=n_post_samples) cubo = 0.5 * torch.logsumexp(2 * logws, dim=0) - np.log(n_post_samples) # query_execution_time = time.time() - start iwelbo = torch.logsumexp(logws, dim=0) - np.log(n_post_samples) _, khats = psislw(logws.T) # Mus test_indices = trainer.train_set.indices y_test = Y[test_indices] decision_rule_fdr10 = np.zeros(N_PICKS) decision_rule_tpr10 = np.zeros(N_PICKS) fdr_gt = np.zeros((N_GENES, N_PICKS)) pe_fdr = np.zeros((N_GENES, N_PICKS)) n_post_samples = 50 for ipick in range(N_PICKS): samples_a = np.random.choice(np.where(y_test == 1)[0], size=10) samples_b = np.random.choice(np.where(y_test == 2)[0], size=10) where_a = test_indices[samples_a]
def model_evaluation_loop( my_trainer, my_eval_encoder, my_counts_eval, my_encoder_eval_name, ): # posterior query evaluation: groundtruth seq = my_trainer.test_set.sequential(batch_size=10) mean = np.dot(DATASET.mz_cond_x_mean, DATASET.X[seq.indices, :].T)[0, :] std = np.sqrt(DATASET.pz_condx_var[0, 0]) exact_cdf = norm.cdf(0, loc=mean, scale=std) is_cdf_nus = seq.prob_eval( 1000, nu=nus, encoder_key=my_encoder_eval_name, counts=my_counts_eval, z_encoder=my_eval_encoder, )[2] plugin_cdf_nus = seq.prob_eval( 1000, nu=nus, encoder_key=my_encoder_eval_name, counts=my_counts_eval, z_encoder=my_eval_encoder, plugin_estimator=True, )[2] exact_cdfs_nus = np.array( [norm.cdf(nu, loc=mean, scale=std) for nu in nus]).T log_ratios = (my_trainer.test_set.log_ratios( n_samples_mc=5000, encoder_key=my_encoder_eval_name, counts=my_counts_eval, z_encoder=my_eval_encoder, ).detach().numpy()) # Input should be n_obs, n_samples log_ratios = log_ratios.T _, khat_vals = psislw(log_ratios) # posterior query evaluation: aproposal distribution seq_mean, seq_var, is_cdf, ess = seq.prob_eval( 1000, encoder_key=my_encoder_eval_name, counts=my_counts_eval, z_encoder=my_eval_encoder, ) gt_post_var = DATASET.pz_condx_var sigma_sqrt = sqrtm(gt_post_var) a_2_it = np.zeros(len(seq_var)) # Check that generative model is not defensive to compute A if seq_var[0] is not None: for it in range(len(seq_var)): seq_var_item = seq_var[it] # Posterior variance d_inv = np.diag(1.0 / seq_var_item) # Variational posterior precision a = sigma_sqrt @ (d_inv @ sigma_sqrt) - np.eye(DIM_Z) a_2_it[it] = np.linalg.norm(a, ord=2) a_2_it = a_2_it.mean() return { "IWELBO": my_trainer.test_set.iwelbo( 5000, encoder_key=my_encoder_eval_name, counts=my_counts_eval, z_encoder=my_eval_encoder, ), "L1_IS_ERRS": np.abs(is_cdf_nus - exact_cdfs_nus).mean(0), "L1_PLUGIN_ERRS": np.abs(plugin_cdf_nus - exact_cdfs_nus).mean(0), "KHAT": khat_vals, "exact_lls_test": my_trainer.test_set.exact_log_likelihood(), "exact_lls_train": my_trainer.train_set.exact_log_likelihood(), "model_lls_test": my_trainer.test_set.model_log_likelihood(), "model_lls_train": my_trainer.train_set.model_log_likelihood(), # "plugin_cdf": norm.cdf(0, loc=seq_mean[:, 0], scale=np.sqrt(seq_var[:, 0])), "l1_err_ex_is": np.mean(np.abs(exact_cdf - is_cdf)), "l2_ess": ess, "gt_post_var": DATASET.pz_condx_var, "a2_norm": a_2_it, # "sigma_sqrt": sqrtm(gt_post_var), }
def res_eval_loop( trainer, eval_encoder, counts_eval, encoder_eval_name, do_defensive: bool = False, debug: bool = False, ): model = trainer.model logging.info("Predictions computation ...") with torch.no_grad(): # Below function integrates both inference methods for # mixture and simple statistics train_res = trainer.inference( trainer.test_loader, # trainer.train_loader, keys=[ "qc_z1_all_probas", "y", "log_ratios", "qc_z1", "preds_is", "preds_plugin", ], n_samples=N_EVAL_SAMPLES, encoder_key=encoder_eval_name, counts=counts_eval, ) y_pred = train_res["preds_plugin"].numpy() y_pred = y_pred / y_pred.sum(1, keepdims=True) y_pred_is = train_res["preds_is"].numpy() # y_pred_is = y_pred_is / y_pred_is.sum(1, keepdims=True) assert y_pred.shape == y_pred_is.shape, (y_pred.shape, y_pred_is.shape) y_true = train_res["y"].numpy() # Precision / Recall for discovery class # And accuracy logging.info("Precision, recall, auc ...") res_baseline = compute_reject_score(y_true=y_true, y_pred=y_pred) m_ap = res_baseline["precision_discovery"] m_recall = res_baseline["recall_discovery"] auc_pr = np.trapz( x=res_baseline["recall_discovery"], y=res_baseline["precision_discovery"], ) res_baseline_is = compute_reject_score(y_true=y_true, y_pred=y_pred_is) m_ap_is = res_baseline_is["precision_discovery"] m_recall_is = res_baseline_is["recall_discovery"] auc_pr_is = np.trapz( x=res_baseline_is["recall_discovery"], y=res_baseline_is["precision_discovery"], ) # Cubo / Iwelbo with 1e4 samples logging.info("Heldout CUBO/IWELBO computation ...") n_samples_total = 1e4 if debug: n_samples_total = 200 n_samples_per_pass = 100 if not do_defensive else counts_eval.sum() n_iter = int(n_samples_total / n_samples_per_pass) cubo_vals = [] iwelbo_vals = [] iwelbo_c_vals = [] with torch.no_grad(): i = 0 for tensors in tqdm(trainer.test_loader): x, _ = tensors x = x log_ratios_batch = [] log_qc_batch = [] for _ in tqdm(range(n_iter)): out = model.inference( x, temperature=0.5, n_samples=n_samples_per_pass, encoder_key=encoder_eval_name, counts=counts_eval, ) if do_defensive: log_ratio = out["log_ratio"].cpu() else: log_ratio = (out["log_px_z"] + out["log_pz2"] + out["log_pc"] + out["log_pz1_z2"] - out["log_qz1_x"] - out["log_qc_z1"] - out["log_qz2_z1"]).cpu() log_ratios_batch.append(log_ratio) log_qc_batch.append(out["log_qc_z1"].cpu()) i += 1 if i == 20: break # Concatenation log_ratios_batch = torch.cat(log_ratios_batch, dim=1) log_qc_batch = torch.cat(log_qc_batch, dim=1) # Lower bounds # 1. Cubo # n_cat, n_samples, n_batch = log_ratios_batch.shape # cubo_val = torch.logsumexp( # (2 * log_ratios_batch + log_qc_batch).view(n_cat * n_samples, n_batch), # dim=0, # keepdim=False, # ) - np.log(n_samples) # iwelbo_val = torch.logsumexp( # (log_ratios_batch + log_qc_batch).view(n_cat * n_samples, n_batch), # dim=0, # keepdim=False, # ) - np.log(n_samples) # IWELBO C # # n_cat, n_samples, n_batch # qc_probs = log_qc_batch.permute([1, 2, 0]).exp() # qc_dist = Categorical(probs=qc_probs) # c_sampled = qc_dist.sample().unsqueeze(0) # # log_qc_samp = torch.gather(log_qc_batch, dim=0, index=c_sampled) # log_ratios_samp = torch.gather(log_ratios_batch, dim=0, index=c_sampled) # # Shape 1, n_samples, n_batch # iwelboc_val = torch.logsumexp(log_ratios_samp, dim=1) - np.log(n_samples) # iwelboc_val = iwelboc_val.squeeze() # cubo_vals.append(cubo_val.cpu()) # iwelbo_vals.append(iwelbo_val.cpu()) # iwelbo_c_vals.append(iwelboc_val.cpu()) # RELAXED CASE n_samples, n_batch = log_ratios_batch.shape cubo_val = torch.logsumexp( 2 * log_ratios_batch, dim=0, keepdim=False, ) - np.log(n_samples) iwelbo_val = torch.logsumexp( log_ratios_batch, dim=0, keepdim=False, ) - np.log(n_samples) cubo_vals.append(cubo_val.cpu()) iwelbo_vals.append(iwelbo_val.cpu()) cubo_vals = torch.cat(cubo_vals) iwelbo_vals = torch.cat(iwelbo_vals) # iwelbo_c_vals = torch.cat(iwelbo_c_vals) # Entropy where9 = train_res["y"] == 9 probas9 = train_res["qc_z1_all_probas"].mean(0)[where9] entropy = (-probas9 * probas9.log()).sum(-1).mean(0) where_non9 = train_res["y"] != 9 y_non9 = train_res["y"][where_non9] y_pred_non9 = y_pred[where_non9].argmax(1) m_accuracy = accuracy_score(y_non9, y_pred_non9) y_pred_non9_is = y_pred_is[where_non9].argmax(1) m_accuracy_is = accuracy_score(y_non9, y_pred_non9_is) # k_hat n_samples_total = 1e4 if debug: n_samples_total = 200 n_samples_per_pass = 25 if not do_defensive else counts_eval.sum() n_iter = int(n_samples_total / n_samples_per_pass) # a. Unsupervised case # log_ratios = [] # qc_z = [] # for _ in tqdm(range(n_iter)): # with torch.no_grad(): # out = model.inference( # X_SAMPLE, # temperature=0.5, # n_samples=n_samples_per_pass, # encoder_key=encoder_eval_name, # counts=counts_eval, # ) # if do_defensive: # log_ratio = out["log_ratio"].cpu() # else: # log_ratio = ( # out["log_px_z"] # + out["log_pz2"] # + out["log_pc"] # + out["log_pz1_z2"] # - out["log_qz1_x"] # - out["log_qc_z1"] # - out["log_qz2_z1"] # ).cpu() # qc_z_here = out["log_qc_z1"].cpu().exp() # qc_z.append(qc_z_here) # log_ratios.append(log_ratio) # # Concatenation over samples # log_ratios = torch.cat(log_ratios, 1) # qc_z = torch.cat(qc_z, 1) # log_ratios_sum = (log_ratios * qc_z).sum(0) # Sum over labels # wi = torch.softmax(log_ratios_sum, 0) # _, khats = psislw(log_ratios_sum.T.clone()) log_ratios = [] for _ in tqdm(range(n_iter)): with torch.no_grad(): out = model.inference( X_SAMPLE, temperature=0.5, n_samples=n_samples_per_pass, encoder_key=encoder_eval_name, counts=counts_eval, ) if do_defensive: log_ratio = out["log_ratio"].cpu() else: log_ratio = (out["log_px_z"] + out["log_pz2"] + out["log_pc"] + out["log_pz1_z2"] - out["log_qz1_x"] - out["log_qc_z1"] - out["log_qz2_z1"]).cpu() log_ratios.append(log_ratio) # Concatenation over samples log_ratios = torch.cat(log_ratios, 0) _, khats = psislw(log_ratios.T.clone()) x_samp, y_samp = DATASET.train_dataset[:128] where_ = y_samp != 9 x_samp = x_samp[where_] y_samp = y_samp[where_] log_ratios = [] for _ in tqdm(range(n_iter)): with torch.no_grad(): out = model.inference( x_samp, y_samp, temperature=0.5, n_samples=n_samples_per_pass, encoder_key=encoder_eval_name, counts=counts_eval, ) if do_defensive: log_ratio = out["log_ratio"].cpu() else: log_ratio = ( out["log_px_z"] + out["log_pz2"] + out["log_pc"] + out["log_pz1_z2"] - out["log_qz1_x"] # - out["log_qc_z1"] - out["log_qz2_z1"]).cpu() log_ratios.append(log_ratio) # Concatenation over samples log_ratios = torch.cat(log_ratios, 0) _, khats_c_obs = psislw(log_ratios.T.clone()) res = { "IWELBO": iwelbo_vals.mean().item(), # "IWELBOC": iwelbo_c_vals.mean().item(), "CUBO": cubo_vals.mean().item(), "AUC_IS": auc_pr_is, #is this auprc? "KHAT": np.array(khats), #psis # "M_ACCURACY": m_accuracy, # "MEAN_AP": m_ap, # "MEAN_RECALL": m_recall, # "KHATS_C_OBS": khats_c_obs, # "M_ACCURACY_IS": m_accuracy_is, # "MEAN_AP_IS": m_ap_is, # "MEAN_RECALL_IS": m_recall_is, "AUC": auc_pr, # "ENTROPY": entropy, } return res