Пример #1
0
def model_evaluation_loop(
    trainer,
    eval_encoder,
    counts_eval,
    encoder_eval_name,
    n_iwsamples=5000,
    n_picks=N_PICKS,
    n_cells=N_CELLS,
    do_observed_library=True,
    n_samples_queries=200,
):
    test_post = trainer.test_set.sequential()
    mdl = trainer.model
    train_post = trainer.train_set.sequential()

    # *** IWELBO 5000
    logging.info("IWELBO 5K estimation...")
    multicounts_eval = None
    if counts_eval is not None:
        multicounts_eval = (n_iwsamples / counts_eval.sum()) * counts_eval
        multicounts_eval = multicounts_eval.astype(int)
        print(multicounts_eval)
    iwelbo5000_loss = (
        test_post.getter(
            keys=["IWELBO"],
            n_samples=n_iwsamples,
            batch_size=64,
            do_observed_library=do_observed_library,
            encoder_key=encoder_eval_name,
            counts=multicounts_eval,
            z_encoder=eval_encoder,
        )["IWELBO"]
        .cpu()
        .numpy()
    ).mean()

    iwelbo5000train_loss = (
        train_post.getter(
            keys=["IWELBO"],
            n_samples=n_iwsamples,
            batch_size=64,
            do_observed_library=do_observed_library,
            encoder_key=encoder_eval_name,
            counts=multicounts_eval,
            z_encoder=eval_encoder,
        )["IWELBO"]
        .cpu()
        .numpy()
    ).mean()

    # *** KHAT
    multicounts_eval = None
    if counts_eval is not None:
        multicounts_eval = (n_iwsamples / counts_eval.sum()) * counts_eval
        multicounts_eval = multicounts_eval.astype(int)
    log_ratios = []
    n_samples_total = 1e4
    n_samples_per_pass = (
        300 if encoder_eval_name == "default" else multicounts_eval.sum()
    )
    n_iter = int(n_samples_total / n_samples_per_pass)
    logging.info("Multicounts: {}".format(multicounts_eval))
    logging.info(
        "Khat computation using {} samples".format(n_samples_per_pass * n_iter)
    )
    for _ in tqdm(range(n_iter)):
        with torch.no_grad():
            out = mdl(
                X_U,
                LOCAL_L_MEAN,
                LOCAL_L_VAR,
                loss_type=None,
                n_samples=n_samples_per_pass,
                reparam=False,
                encoder_key=encoder_eval_name,
                counts=multicounts_eval,
                do_observed_library=do_observed_library,
                z_encoder=eval_encoder,
            )
        out = out["log_ratio"].cpu()
        log_ratios.append(out)

    log_ratios = torch.cat(log_ratios)
    wi = torch.softmax(log_ratios, 0)
    ess_here = 1.0 / (wi ** 2).sum(0)

    _, khats = psislw(log_ratios.T.clone())

    logging.info("FDR/TPR ...")
    train_indices = train_post.indices
    y_train = Y[train_indices]

    decision_rule_fdr10 = np.zeros(n_picks)
    decision_rule_fdr05 = np.zeros(n_picks)
    decision_rule_fdr20 = np.zeros(n_picks)
    decision_rule_tpr10 = np.zeros(n_picks)
    decision_rule_fdr10_plugin = np.zeros(n_picks)
    decision_rule_tpr10_plugin = np.zeros(n_picks)
    fdr_gt = np.zeros((N_GENES, n_picks))
    pe_fdr = np.zeros((N_GENES, n_picks))
    fdr_gt_plugin = np.zeros((N_GENES, n_picks))
    pe_fdr_plugin = np.zeros((N_GENES, n_picks))
    y_preds_is = np.zeros((N_GENES, n_picks))
    y_preds_plugin = np.zeros((N_GENES, n_picks))
    y_gt = np.zeros((N_GENES, n_picks))

    np.random.seed(42)
    for ipick in range(n_picks):
        print(np.unique(y_train))
        if DO_POISSON:
            where_a = np.where(y_train == 0)[0]
            where_b = np.where(y_train == 1)[0]
        else:
            where_a = np.where(y_train == 1)[0]
            where_b = np.where(y_train == 2)[0]

        samples_a = np.random.choice(where_a, size=n_cells)
        samples_b = np.random.choice(where_b, size=n_cells)

        samples_a_overall = train_indices[samples_a]
        samples_b_overall = train_indices[samples_b]

        h_a = h[samples_a_overall]
        h_b = h[samples_b_overall]
        lfc_loc = h_a - h_b
        is_significant_de_local = (lfc_loc.abs() >= 0.5).float().mean(0) >= 0.5
        is_significant_de_local = is_significant_de_local.numpy()

        logging.info("IS flavor ...")
        multicounts_eval = None
        if counts_eval is not None:
            multicounts_eval = (n_samples_queries / counts_eval.sum()) * counts_eval
            multicounts_eval = multicounts_eval.astype(int)
        y_pred_is = get_predictions(
            train_post,
            samples_a,
            samples_b,
            encoder_key=encoder_eval_name,
            counts=multicounts_eval,
            n_post_samples=n_samples_queries,
            importance_sampling=True,
            do_observed_library=do_observed_library,
            encoder=eval_encoder,
        )
        y_pred_is = y_pred_is.numpy()

        true_fdr_arr = true_fdr(y_true=is_significant_de_local, y_pred=y_pred_is)
        pe_fdr_arr, y_decision_rule = posterior_expected_fdr(y_pred=y_pred_is)
        # Fdr related
        fdr_gt[:, ipick] = true_fdr_arr
        pe_fdr[:, ipick] = pe_fdr_arr

        _, y_decision_rule10 = posterior_expected_fdr(y_pred=y_pred_is, fdr_target=0.1)
        decision_rule_fdr10[ipick] = fdr_score(
            y_true=is_significant_de_local, y_pred=y_decision_rule10
        )
        decision_rule_tpr10[ipick] = tpr_score(
            y_true=is_significant_de_local, y_pred=y_decision_rule10
        )

        _, decision_rule_fdr05 = posterior_expected_fdr(
            y_pred=y_pred_is, fdr_target=0.05
        )
        decision_rule_fdr05[ipick] = fdr_score(
            y_true=is_significant_de_local, y_pred=y_decision_rule10
        )
        _, decision_rule_fdr20 = posterior_expected_fdr(
            y_pred=y_pred_is, fdr_target=0.2
        )
        decision_rule_fdr20[ipick] = fdr_score(
            y_true=is_significant_de_local, y_pred=y_decision_rule10
        )

        logging.info("Plugin flavor ...")
        y_pred_plugin = get_predictions(
            train_post,
            samples_a,
            samples_b,
            encoder_key=encoder_eval_name,
            counts=multicounts_eval,
            n_post_samples=n_samples_queries,
            importance_sampling=False,
            do_observed_library=do_observed_library,
            encoder=eval_encoder,
        )
        y_pred_plugin = y_pred_plugin.numpy()
        true_fdr_plugin_arr = true_fdr(
            y_true=is_significant_de_local, y_pred=y_pred_plugin
        )
        fdr_gt_plugin[:, ipick] = true_fdr_plugin_arr
        pe_fdr_plugin_arr, y_decision_rule = posterior_expected_fdr(
            y_pred=y_pred_plugin
        )
        pe_fdr_plugin[:, ipick] = pe_fdr_plugin_arr
        _, y_decision_rule10 = posterior_expected_fdr(
            y_pred=y_pred_plugin, fdr_target=0.1
        )
        decision_rule_fdr10_plugin[ipick] = fdr_score(
            y_true=is_significant_de_local, y_pred=y_decision_rule10
        )
        decision_rule_tpr10_plugin[ipick] = tpr_score(
            y_true=is_significant_de_local, y_pred=y_decision_rule10
        )
        y_preds_is[:, ipick] = y_pred_is
        y_preds_plugin[:, ipick] = y_pred_plugin
        y_gt[:, ipick] = is_significant_de_local

    prauc_plugin = np.array(
        [prauc(y=y_it, pred=y_pred) for (y_it, y_pred) in zip(y_gt.T, y_preds_plugin.T)]
    )

    # prauc_is = None
    prauc_is = np.array(
        [prauc(y=y_it, pred=y_pred) for (y_it, y_pred) in zip(y_gt.T, y_preds_is.T)]
    )
    all_fdr_gt = np.array(fdr_gt)
    all_pe_fdr = np.array(pe_fdr)
    fdr_gt_plugin = np.array(fdr_gt_plugin)
    fdr_diff = all_fdr_gt - all_pe_fdr
    loop_res = dict(
        iwelbo5000=np.array(iwelbo5000_loss),
        iwelbo5000_train=np.array(iwelbo5000train_loss),
        pe_fdr_plugin=pe_fdr_plugin,
        fdr_gt_plugin=fdr_gt_plugin,
        all_fdr_gt=all_fdr_gt,
        all_pe_fdr=all_pe_fdr,
        l1_fdr=np.linalg.norm(fdr_diff, axis=0, ord=1),
        l2_fdr=np.linalg.norm(fdr_diff, axis=0, ord=2),
        y_gt=y_gt,
        y_pred_is=y_preds_is,
        y_pred_plugin=y_preds_plugin,
        prauc_plugin=prauc_plugin,
        prauc_is=prauc_is,
        fdr_controlled_fdr10=np.array(decision_rule_fdr10),
        fdr_controlled_fdr05=np.array(decision_rule_fdr05),
        fdr_controlled_fdr20=np.array(decision_rule_fdr20),
        fdr_controlled_tpr10=np.array(decision_rule_tpr10),
        fdr_controlled_fdr10_plugin=np.array(decision_rule_fdr10_plugin),
        fdr_controlled_tpr10_plugin=np.array(decision_rule_tpr10_plugin),
        khat_10000=np.array(khats),
        lfc_gt=lfc_loc,
        ess=ess_here.numpy(),
    )
    return loop_res
Пример #2
0
            # plt.clf()

            # trainer.test_set.elbo()
            seq = trainer.test_set.sequential(batch_size=10)
            start = time.time()
            zs, logws = ais_trajectory(
                model, seq, n_sample=100, n_latent=dim_z, is_exp1=True
            )
            # Shapes n_samples, n_batch, (n_latent)
            iwelbo += [torch.logsumexp(logws, dim=0) - np.log(100)]
            TIME_100_SAMPLES_AIS.append(time.time() - start)
            # trainer.test_set.exact_log_likelihood()
            cubo += [0.5 * (torch.logsumexp(2 * logws, dim=0) - np.log(100))]
            # Input should be n_obs, n_samples
            log_ratios = logws.T
            _, khat_vals = psislw(log_ratios)
            khat.append(khat_vals)
            a_2_it = []

        res = {
            "CONFIGURATION": (learn_var, loss_gen, loss_wvar),
            "learn_var": learn_var,
            "loss_gen": loss_gen,
            "loss_wvar": loss_wvar,
            "n_hidden": n_hidden,
            # "IWELBO": (np.mean(iwelbo), np.std(iwelbo)),
            # "CUBO": (np.mean(cubo), np.std(cubo)),
            # "L1 loss gen_variance_dis": (np.mean(l1_gen_dis), np.std(l1_gen_dis)),
            # "L1 loss gen_variance_sign": (np.mean(l1_gen_sign), np.std(l1_gen_sign)),
            # "L1 loss post_variance_dis": (np.mean(l1_post_dis), np.std(l1_post_dis)),
            # "L1 loss post_variance_sign": (np.mean(l1_post_sign), np.std(l1_post_sign)),
Пример #3
0
                if do_defensive:
                    log_ratio = out["log_ratio"].cpu()
                else:
                    log_ratio = (out["log_px_z"] + out["log_pz2"] +
                                 out["log_pc"] + out["log_pz1_z2"] -
                                 out["log_qz1_x"] - out["log_qc_z1"] -
                                 out["log_qz2_z1"]).cpu()
                qc_z_here = out["log_qc_z1"].cpu().exp()
                qc_z.append(qc_z_here)
                log_ratios.append(log_ratio)
            # Concatenation over samples
            log_ratios = torch.cat(log_ratios, 1)
            qc_z = torch.cat(qc_z, 1)
            log_ratios_sum = (log_ratios * qc_z).sum(0)  # Sum over labels
            wi = torch.softmax(log_ratios_sum, 0)
            _, khats = psislw(log_ratios_sum.T.clone())
            # _, khats = psislw(log_ratios.view(-1, len(x_u)).numpy())
            khat1e4.append(khats)

        except Exception as e:
            raise e
            print(e)
            pass

    res = {
        "CONFIGURATION": scenario,
        "LOSS_GEN": loss_gen,
        "LOSS_WVAR": loss_wvar,
        "BATCH_SIZE": BATCH_SIZE,
        "N_SAMPLES_TRAIN": n_samples_train,
        "N_SAMPLES_WTHETA": n_samples_wtheta,
Пример #4
0
        do_observed_library=True,
    )

# Khat
# start = time.time()
post = trainer.create_posterior(model=mdl,
                                gene_dataset=DATASET,
                                indices=SAMPLE_IDX).sequential(batch_size=2)
n_post_samples = 50
zs, logws = ais_trajectory(model=mdl, loader=post, n_sample=n_post_samples)

cubo = 0.5 * torch.logsumexp(2 * logws, dim=0) - np.log(n_post_samples)
# query_execution_time = time.time() - start

iwelbo = torch.logsumexp(logws, dim=0) - np.log(n_post_samples)
_, khats = psislw(logws.T)

# Mus
test_indices = trainer.train_set.indices
y_test = Y[test_indices]
decision_rule_fdr10 = np.zeros(N_PICKS)
decision_rule_tpr10 = np.zeros(N_PICKS)
fdr_gt = np.zeros((N_GENES, N_PICKS))
pe_fdr = np.zeros((N_GENES, N_PICKS))
n_post_samples = 50

for ipick in range(N_PICKS):
    samples_a = np.random.choice(np.where(y_test == 1)[0], size=10)
    samples_b = np.random.choice(np.where(y_test == 2)[0], size=10)

    where_a = test_indices[samples_a]
def model_evaluation_loop(
    my_trainer,
    my_eval_encoder,
    my_counts_eval,
    my_encoder_eval_name,
):
    # posterior query evaluation: groundtruth
    seq = my_trainer.test_set.sequential(batch_size=10)
    mean = np.dot(DATASET.mz_cond_x_mean, DATASET.X[seq.indices, :].T)[0, :]
    std = np.sqrt(DATASET.pz_condx_var[0, 0])
    exact_cdf = norm.cdf(0, loc=mean, scale=std)

    is_cdf_nus = seq.prob_eval(
        1000,
        nu=nus,
        encoder_key=my_encoder_eval_name,
        counts=my_counts_eval,
        z_encoder=my_eval_encoder,
    )[2]
    plugin_cdf_nus = seq.prob_eval(
        1000,
        nu=nus,
        encoder_key=my_encoder_eval_name,
        counts=my_counts_eval,
        z_encoder=my_eval_encoder,
        plugin_estimator=True,
    )[2]
    exact_cdfs_nus = np.array(
        [norm.cdf(nu, loc=mean, scale=std) for nu in nus]).T

    log_ratios = (my_trainer.test_set.log_ratios(
        n_samples_mc=5000,
        encoder_key=my_encoder_eval_name,
        counts=my_counts_eval,
        z_encoder=my_eval_encoder,
    ).detach().numpy())
    # Input should be n_obs, n_samples
    log_ratios = log_ratios.T
    _, khat_vals = psislw(log_ratios)

    # posterior query evaluation: aproposal distribution
    seq_mean, seq_var, is_cdf, ess = seq.prob_eval(
        1000,
        encoder_key=my_encoder_eval_name,
        counts=my_counts_eval,
        z_encoder=my_eval_encoder,
    )

    gt_post_var = DATASET.pz_condx_var
    sigma_sqrt = sqrtm(gt_post_var)
    a_2_it = np.zeros(len(seq_var))
    #  Check that generative model is not defensive to compute A
    if seq_var[0] is not None:
        for it in range(len(seq_var)):
            seq_var_item = seq_var[it]  # Posterior variance
            d_inv = np.diag(1.0 /
                            seq_var_item)  # Variational posterior precision
            a = sigma_sqrt @ (d_inv @ sigma_sqrt) - np.eye(DIM_Z)
            a_2_it[it] = np.linalg.norm(a, ord=2)
    a_2_it = a_2_it.mean()

    return {
        "IWELBO":
        my_trainer.test_set.iwelbo(
            5000,
            encoder_key=my_encoder_eval_name,
            counts=my_counts_eval,
            z_encoder=my_eval_encoder,
        ),
        "L1_IS_ERRS":
        np.abs(is_cdf_nus - exact_cdfs_nus).mean(0),
        "L1_PLUGIN_ERRS":
        np.abs(plugin_cdf_nus - exact_cdfs_nus).mean(0),
        "KHAT":
        khat_vals,
        "exact_lls_test":
        my_trainer.test_set.exact_log_likelihood(),
        "exact_lls_train":
        my_trainer.train_set.exact_log_likelihood(),
        "model_lls_test":
        my_trainer.test_set.model_log_likelihood(),
        "model_lls_train":
        my_trainer.train_set.model_log_likelihood(),
        # "plugin_cdf": norm.cdf(0, loc=seq_mean[:, 0], scale=np.sqrt(seq_var[:, 0])),
        "l1_err_ex_is":
        np.mean(np.abs(exact_cdf - is_cdf)),
        "l2_ess":
        ess,
        "gt_post_var":
        DATASET.pz_condx_var,
        "a2_norm":
        a_2_it,
        # "sigma_sqrt": sqrtm(gt_post_var),
    }
Пример #6
0
def res_eval_loop(
    trainer,
    eval_encoder,
    counts_eval,
    encoder_eval_name,
    do_defensive: bool = False,
    debug: bool = False,
):
    model = trainer.model

    logging.info("Predictions computation ...")
    with torch.no_grad():
        # Below function integrates both inference methods for
        # mixture and simple statistics
        train_res = trainer.inference(
            trainer.test_loader,
            # trainer.train_loader,
            keys=[
                "qc_z1_all_probas",
                "y",
                "log_ratios",
                "qc_z1",
                "preds_is",
                "preds_plugin",
            ],
            n_samples=N_EVAL_SAMPLES,
            encoder_key=encoder_eval_name,
            counts=counts_eval,
        )
    y_pred = train_res["preds_plugin"].numpy()
    y_pred = y_pred / y_pred.sum(1, keepdims=True)

    y_pred_is = train_res["preds_is"].numpy()
    # y_pred_is = y_pred_is / y_pred_is.sum(1, keepdims=True)
    assert y_pred.shape == y_pred_is.shape, (y_pred.shape, y_pred_is.shape)

    y_true = train_res["y"].numpy()

    # Precision / Recall for discovery class
    # And accuracy
    logging.info("Precision, recall, auc ...")
    res_baseline = compute_reject_score(y_true=y_true, y_pred=y_pred)
    m_ap = res_baseline["precision_discovery"]
    m_recall = res_baseline["recall_discovery"]
    auc_pr = np.trapz(
        x=res_baseline["recall_discovery"],
        y=res_baseline["precision_discovery"],
    )

    res_baseline_is = compute_reject_score(y_true=y_true, y_pred=y_pred_is)
    m_ap_is = res_baseline_is["precision_discovery"]
    m_recall_is = res_baseline_is["recall_discovery"]
    auc_pr_is = np.trapz(
        x=res_baseline_is["recall_discovery"],
        y=res_baseline_is["precision_discovery"],
    )

    # Cubo / Iwelbo with 1e4 samples
    logging.info("Heldout CUBO/IWELBO computation ...")

    n_samples_total = 1e4
    if debug:
        n_samples_total = 200
    n_samples_per_pass = 100 if not do_defensive else counts_eval.sum()
    n_iter = int(n_samples_total / n_samples_per_pass)

    cubo_vals = []
    iwelbo_vals = []
    iwelbo_c_vals = []
    with torch.no_grad():
        i = 0
        for tensors in tqdm(trainer.test_loader):
            x, _ = tensors
            x = x
            log_ratios_batch = []
            log_qc_batch = []
            for _ in tqdm(range(n_iter)):
                out = model.inference(
                    x,
                    temperature=0.5,
                    n_samples=n_samples_per_pass,
                    encoder_key=encoder_eval_name,
                    counts=counts_eval,
                )
                if do_defensive:
                    log_ratio = out["log_ratio"].cpu()
                else:
                    log_ratio = (out["log_px_z"] + out["log_pz2"] +
                                 out["log_pc"] + out["log_pz1_z2"] -
                                 out["log_qz1_x"] - out["log_qc_z1"] -
                                 out["log_qz2_z1"]).cpu()
                log_ratios_batch.append(log_ratio)
                log_qc_batch.append(out["log_qc_z1"].cpu())

            i += 1
            if i == 20:
                break
            # Concatenation
            log_ratios_batch = torch.cat(log_ratios_batch, dim=1)
            log_qc_batch = torch.cat(log_qc_batch, dim=1)

            # Lower bounds
            # 1. Cubo
            # n_cat, n_samples, n_batch = log_ratios_batch.shape
            # cubo_val = torch.logsumexp(
            #     (2 * log_ratios_batch + log_qc_batch).view(n_cat * n_samples, n_batch),
            #     dim=0,
            #     keepdim=False,
            # ) - np.log(n_samples)

            # iwelbo_val = torch.logsumexp(
            #     (log_ratios_batch + log_qc_batch).view(n_cat * n_samples, n_batch),
            #     dim=0,
            #     keepdim=False,
            # ) - np.log(n_samples)
            # IWELBO C
            # # n_cat, n_samples, n_batch
            # qc_probs = log_qc_batch.permute([1, 2, 0]).exp()
            # qc_dist = Categorical(probs=qc_probs)
            # c_sampled = qc_dist.sample().unsqueeze(0)
            # # log_qc_samp = torch.gather(log_qc_batch, dim=0, index=c_sampled)
            # log_ratios_samp = torch.gather(log_ratios_batch, dim=0, index=c_sampled)
            # # Shape 1, n_samples, n_batch
            # iwelboc_val = torch.logsumexp(log_ratios_samp, dim=1) - np.log(n_samples)
            # iwelboc_val = iwelboc_val.squeeze()

            # cubo_vals.append(cubo_val.cpu())
            # iwelbo_vals.append(iwelbo_val.cpu())
            # iwelbo_c_vals.append(iwelboc_val.cpu())

            # RELAXED CASE
            n_samples, n_batch = log_ratios_batch.shape
            cubo_val = torch.logsumexp(
                2 * log_ratios_batch,
                dim=0,
                keepdim=False,
            ) - np.log(n_samples)

            iwelbo_val = torch.logsumexp(
                log_ratios_batch,
                dim=0,
                keepdim=False,
            ) - np.log(n_samples)

            cubo_vals.append(cubo_val.cpu())
            iwelbo_vals.append(iwelbo_val.cpu())
        cubo_vals = torch.cat(cubo_vals)
        iwelbo_vals = torch.cat(iwelbo_vals)
        # iwelbo_c_vals = torch.cat(iwelbo_c_vals)

    # Entropy
    where9 = train_res["y"] == 9
    probas9 = train_res["qc_z1_all_probas"].mean(0)[where9]
    entropy = (-probas9 * probas9.log()).sum(-1).mean(0)

    where_non9 = train_res["y"] != 9
    y_non9 = train_res["y"][where_non9]
    y_pred_non9 = y_pred[where_non9].argmax(1)
    m_accuracy = accuracy_score(y_non9, y_pred_non9)

    y_pred_non9_is = y_pred_is[where_non9].argmax(1)
    m_accuracy_is = accuracy_score(y_non9, y_pred_non9_is)

    # k_hat
    n_samples_total = 1e4
    if debug:
        n_samples_total = 200
    n_samples_per_pass = 25 if not do_defensive else counts_eval.sum()
    n_iter = int(n_samples_total / n_samples_per_pass)

    # a. Unsupervised case
    # log_ratios = []
    # qc_z = []
    # for _ in tqdm(range(n_iter)):
    #     with torch.no_grad():
    #         out = model.inference(
    #             X_SAMPLE,
    #             temperature=0.5,
    #             n_samples=n_samples_per_pass,
    #             encoder_key=encoder_eval_name,
    #             counts=counts_eval,
    #         )
    #     if do_defensive:
    #         log_ratio = out["log_ratio"].cpu()
    #     else:
    #         log_ratio = (
    #             out["log_px_z"]
    #             + out["log_pz2"]
    #             + out["log_pc"]
    #             + out["log_pz1_z2"]
    #             - out["log_qz1_x"]
    #             - out["log_qc_z1"]
    #             - out["log_qz2_z1"]
    #         ).cpu()
    #     qc_z_here = out["log_qc_z1"].cpu().exp()
    #     qc_z.append(qc_z_here)
    #     log_ratios.append(log_ratio)
    # # Concatenation over samples
    # log_ratios = torch.cat(log_ratios, 1)
    # qc_z = torch.cat(qc_z, 1)
    # log_ratios_sum = (log_ratios * qc_z).sum(0)  # Sum over labels
    # wi = torch.softmax(log_ratios_sum, 0)
    # _, khats = psislw(log_ratios_sum.T.clone())

    log_ratios = []
    for _ in tqdm(range(n_iter)):
        with torch.no_grad():
            out = model.inference(
                X_SAMPLE,
                temperature=0.5,
                n_samples=n_samples_per_pass,
                encoder_key=encoder_eval_name,
                counts=counts_eval,
            )
        if do_defensive:
            log_ratio = out["log_ratio"].cpu()
        else:
            log_ratio = (out["log_px_z"] + out["log_pz2"] + out["log_pc"] +
                         out["log_pz1_z2"] - out["log_qz1_x"] -
                         out["log_qc_z1"] - out["log_qz2_z1"]).cpu()
        log_ratios.append(log_ratio)
    # Concatenation over samples
    log_ratios = torch.cat(log_ratios, 0)
    _, khats = psislw(log_ratios.T.clone())

    x_samp, y_samp = DATASET.train_dataset[:128]
    where_ = y_samp != 9
    x_samp = x_samp[where_]
    y_samp = y_samp[where_]
    log_ratios = []
    for _ in tqdm(range(n_iter)):
        with torch.no_grad():
            out = model.inference(
                x_samp,
                y_samp,
                temperature=0.5,
                n_samples=n_samples_per_pass,
                encoder_key=encoder_eval_name,
                counts=counts_eval,
            )
        if do_defensive:
            log_ratio = out["log_ratio"].cpu()
        else:
            log_ratio = (
                out["log_px_z"] + out["log_pz2"] + out["log_pc"] +
                out["log_pz1_z2"] - out["log_qz1_x"]
                # - out["log_qc_z1"]
                - out["log_qz2_z1"]).cpu()
        log_ratios.append(log_ratio)
    # Concatenation over samples
    log_ratios = torch.cat(log_ratios, 0)
    _, khats_c_obs = psislw(log_ratios.T.clone())

    res = {
        "IWELBO": iwelbo_vals.mean().item(),
        # "IWELBOC": iwelbo_c_vals.mean().item(),
        "CUBO": cubo_vals.mean().item(),
        "AUC_IS": auc_pr_is,  #is this auprc?
        "KHAT": np.array(khats),  #psis
        #      "M_ACCURACY": m_accuracy,
        #      "MEAN_AP": m_ap,
        #      "MEAN_RECALL": m_recall,
        #      "KHATS_C_OBS": khats_c_obs,
        #      "M_ACCURACY_IS": m_accuracy_is,
        #     "MEAN_AP_IS": m_ap_is,
        #     "MEAN_RECALL_IS": m_recall_is,
        "AUC": auc_pr,
        #    "ENTROPY": entropy,
    }
    return res