Ejemplo n.º 1
0
def per_condition_loss(d, c, conds, model, args, idx=1):
    tmp1, tmp2 = torch.split(d, int(d.size()[-1] / 2), dim=1)

    condition_tensor = d.clone()
    tmp1, tmp2 = torch.split(condition_tensor,
                             int(condition_tensor.size()[-1] / 2),
                             dim=1)
    for kk in conds:
        tmp1[:, kk], tmp2[:, kk] = 0, 0
    cond_d = torch.cat((tmp1, tmp2), 1)

    # Run 10 times for resampling
    my_recon_list, my_z_means_list, my_log_var_list = [], [], []
    for resample in range(10):
        recon_batch, z_means, log_var = model(c.clone().cuda(args.gpu_id),
                                              cond_d.clone().cuda(args.gpu_id))
        my_recon_list.append(recon_batch)
        my_z_means_list.append(z_means)
        my_log_var_list.append(log_var)

    recon_batch = torch.mean(torch.stack(my_recon_list), dim=0)
    z_means = torch.mean(torch.stack(my_z_means_list), dim=0)
    log_var = torch.mean(torch.stack(my_log_var_list), dim=0)
    loss_fn = str_to_object(args.loss_fn)

    ELBO_loss, RCL_loss, KLD_loss, _, _ = loss_fn(
        c.cuda(args.gpu_id), recon_batch.cuda(args.gpu_id), z_means, log_var,
        args)
    if idx == 1:
        return ELBO_loss.item()
    elif idx == 5:
        return ELBO_loss.item(), RCL_loss.item(), KLD_loss.item()
    else:
        return recon_batch, z_means, log_var
Ejemplo n.º 2
0
def make_dataframe(kl_per_lt, kl_all_lt, selected_features, feature_names,
                   z_means, c, recon_batch, log_var, conds, popped_cond, args):

    z_means_x, z_var_x, z_means_y, z_var_y = [], [], [], []
    all_kl, all_lt = [], []

    loss_fn = str_to_object(args.loss_fn)

    total_ELBO, total_RCL, total_KLD, _, _ = loss_fn(
        c.cuda(args.gpu_id), recon_batch.cuda(args.gpu_id), z_means, log_var,
        args)
    selected_features["ELBO"].append(total_ELBO.item())
    selected_features["RCL"].append(total_RCL.item())
    selected_features["KLD"].append(total_KLD.item())
    if popped_cond != []:
        selected_features["selected_feature_number"].append(str(popped_cond))
        if args.data_type == 'aics_features':
            name = feature_names[popped_cond]
            selected_features["selected_feature_name"].append(name)
        else:
            selected_features["selected_feature_name"].append(None)
    else:
        selected_features["selected_feature_number"].append(None)
        selected_features["selected_feature_name"].append(None)

    for ii in range(z_means.size()[-1]):

        _, rcl_per_lt_temp, kl_per_lt_temp, _, _ = loss_fn(
            c.cuda(args.gpu_id), recon_batch.cuda(args.gpu_id), z_means[:, ii],
            log_var[:, ii], args)

        all_kl = np.append(all_kl, kl_per_lt_temp.item())
        all_lt.append(ii)
        # print('greedy encoding plots', c.size()[-1] - len(conds))
        kl_per_lt["num_conds"].append(c.size()[-1] - len(conds))
        if popped_cond != []:
            kl_per_lt["popped_cond"].append(popped_cond)
        else:
            kl_per_lt["popped_cond"].append(None)
        kl_per_lt["latent_dim"].append(ii)
        kl_per_lt["kl_divergence"].append(kl_per_lt_temp.item())
        kl_per_lt['cond_comb'].append(str([i for i in conds]))
    all_kl, all_lt = list(zip(*sorted(zip(all_kl, all_lt))))
    all_kl = list(all_kl)
    all_lt = list(all_lt)

    z_means_x = np.append(z_means_x, z_means[:, all_lt[-1]].data.cpu().numpy())
    z_means_y = np.append(z_means_y, z_means[:, all_lt[-2]].data.cpu().numpy())
    z_var_x = np.append(z_var_x, log_var[:, all_lt[-1]].data.cpu().numpy())
    z_var_y = np.append(z_var_y, log_var[:, all_lt[-2]].data.cpu().numpy())
    kl_all_lt['z_means_x'].append(z_means_x)
    kl_all_lt['z_means_y'].append(z_means_y)
    kl_all_lt['z_var_x'].append(z_var_x)
    kl_all_lt['z_var_y'].append(z_var_y)
    kl_all_lt['num_conds'].append(c.size()[-1] - len(conds))
    kl_all_lt['cond_comb'].append(str([i for i in conds]))

    return kl_per_lt, kl_all_lt, selected_features
Ejemplo n.º 3
0
def get_model(model_fn, model_kwargs: Optional[Dict] = None) -> nn.Module:
    model_fn = str_to_object(model_fn)
    try:
        return model_fn(**model_kwargs)
    except:
        a = dict([(key, value) for key, value in model_kwargs.items()
                  if key != "sklearn_data" and key != 'projection_dim'
                  and key != 'mask_percentage'])
        return model_fn(**a)
Ejemplo n.º 4
0
def make_plot_encoding_greedy(args: argparse.Namespace,
                              model,
                              df: pd.DataFrame,
                              c,
                              d,
                              feature_names=None,
                              save=True,
                              proj_matrix=None) -> None:
    """
    c and d are X_test and C_test
    """
    sns.set_context("talk")
    path_save_dir = Path(args.path_save_dir)
    vis_enc = str_to_object(
        "CVAE_testbed.metrics.greedy_visualize_encoder.GreedyVisualizeEncoder")
    try:
        conds = [i for i in range(args.model_kwargs["dec_layers"][-1][-1])]
    except:
        conds = [i for i in range(args.model_kwargs["dec_layers"][-1])]

    kl_per_lt, kl_all_lt, selected_features, first_features = vis_enc(
        args,
        model,
        conds,
        c[-1, :].clone(),
        d[-1, :].clone(),
        kl_per_lt=None,
        kl_all_lt=None,
        selected_features=None,
        feature_names=feature_names)

    kl_per_lt, kl_all_lt, selected_features, first_features = pd.DataFrame(
        kl_per_lt), pd.DataFrame(kl_all_lt), pd.DataFrame(
            selected_features), pd.DataFrame(first_features)

    if save is True:
        path_csv = path_save_dir / Path("kl_per_lt.csv")
        kl_per_lt.to_csv(path_csv)
        LOGGER.info(f"Saved: {path_csv}")

        path_csv = path_save_dir / Path("kl_all_lt.csv")
        kl_all_lt.to_csv(path_csv)
        LOGGER.info(f"Saved: {path_csv}")

        path_csv = path_save_dir / Path("selected_features.csv")
        selected_features.to_csv(path_csv)
        LOGGER.info(f"Saved: {path_csv}")

        path_csv = path_save_dir / Path("first_features.csv")
        first_features.to_csv(path_csv)
        LOGGER.info(f"Saved: {path_csv}")

    fig, ax1 = plt.subplots(1, 1, figsize=(7 * 1, 4))
    sns.lineplot(ax=ax1,
                 data=kl_per_lt,
                 x='latent_dim',
                 y='kl_divergence',
                 estimator='mean')

    if save is True:
        path_save_fig = path_save_dir / Path("greedy_elbo_kld_rcl_dims.png")
        fig.savefig(path_save_fig, bbox_inches="tight")
        LOGGER.info(f"Saved: {path_save_fig}")

    fig2, ax = plt.subplots(1, 1, figsize=(7 * 8, 4))
    first_features.sort_values(by='ELBO', ascending=False, inplace=True)

    if all(pd.isna(first_features['selected_feature_name'])):
        bar_fig = sns.lineplot(data=first_features,
                               ax=ax,
                               x='selected_feature_number',
                               y='ELBO',
                               label='ELBO',
                               sort=False)
        sns.scatterplot(data=first_features,
                        ax=ax,
                        x='selected_feature_number',
                        y='ELBO',
                        s=100,
                        color=".2")
        sns.lineplot(data=first_features,
                     ax=ax,
                     x='selected_feature_number',
                     y='RCL',
                     label='RCL',
                     sort=False)
        sns.scatterplot(data=first_features,
                        ax=ax,
                        x='selected_feature_number',
                        y='RCL',
                        s=100,
                        color=".2")
    else:
        bar_fig = sns.lineplot(data=first_features,
                               ax=ax,
                               x='selected_feature_name',
                               y='ELBO',
                               label="ELBO",
                               sort=False)
        sns.scatterplot(data=first_features,
                        ax=ax,
                        x='selected_feature_name',
                        y='ELBO',
                        s=100,
                        color=".2")
        sns.lineplot(data=first_features,
                     ax=ax,
                     x='selected_feature_name',
                     y='RCL',
                     label="RCL",
                     sort=False)
        sns.scatterplot(data=first_features,
                        ax=ax,
                        x='selected_feature_name',
                        y='RCL',
                        s=100,
                        color=".2")

    for item in bar_fig.get_xticklabels():
        item.set_rotation(45)

    ax.set_title('ELBO per selected first feature')
    ax.set_xlabel('Selected feature')
    ax.set_ylabel('ELBO')

    if save is True:
        path_save_fig = path_save_dir / Path(
            "greedy_barplots_first_selection.png")
        fig2.savefig(path_save_fig, bbox_inches="tight")
        LOGGER.info(f"Saved: {path_save_fig}")

    fig2, ax = plt.subplots(1, 1, figsize=(7 * 8, 4))

    print(selected_features)

    selected_features.sort_values(by='ELBO', ascending=False, inplace=True)

    if all(pd.isna(selected_features['selected_feature_name'])):
        bar_fig = sns.lineplot(data=selected_features,
                               ax=ax,
                               x='selected_feature_number',
                               y='ELBO',
                               label='ELBO',
                               sort=False)
        sns.scatterplot(data=selected_features,
                        ax=ax,
                        x='selected_feature_number',
                        y='ELBO',
                        s=100,
                        color=".2")
        sns.lineplot(data=selected_features,
                     ax=ax,
                     x='selected_feature_number',
                     y='RCL',
                     label='RCL',
                     sort=False)
        sns.scatterplot(data=selected_features,
                        ax=ax,
                        x='selected_feature_number',
                        y='RCL',
                        s=100,
                        color=".2")
    else:
        bar_fig = sns.lineplot(data=selected_features,
                               ax=ax,
                               x='selected_feature_name',
                               y='ELBO',
                               label='ELBO',
                               sort=False)
        sns.scatterplot(data=selected_features,
                        ax=ax,
                        x='selected_feature_name',
                        y='ELBO',
                        s=100,
                        color=".2")
        sns.lineplot(data=selected_features,
                     ax=ax,
                     x='selected_feature_name',
                     y='RCL',
                     label='RCL',
                     sort=False)
        sns.scatterplot(data=selected_features,
                        ax=ax,
                        x='selected_feature_name',
                        y='RCL',
                        s=100,
                        color=".2")

    for item in bar_fig.get_xticklabels():
        item.set_rotation(45)

    ax.set_title('ELBO per selected feature')
    ax.set_xlabel('Selected feature')
    ax.set_ylabel('ELBO')
    if save is True:
        path_save_fig = path_save_dir / Path("greedy_barplots.png")
        fig2.savefig(path_save_fig, bbox_inches="tight")
        LOGGER.info(f"Saved: {path_save_fig}")
Ejemplo n.º 5
0
def test(
    args,
    epoch,
    loss_fn,
    X_test,
    C_test,
    Cond_indices_test,
    batch_size,
    model,
    optimizer,
    gpu_id,
    model_kwargs,
):

    model.eval()
    test_loss, rcl_loss, kld_loss = 0, 0, 0
    rcl_per_condition_loss, kld_per_condition_loss = (
        torch.zeros(X_test.size()[-1] + 1),
        torch.zeros(X_test.size()[-1] + 1),
    )
    batch_rcl, batch_kld = (
        torch.empty([0]),
        torch.empty([0]),
    )
    batch_length = 0

    with torch.no_grad():
        for j, i in enumerate(X_test):
            optimizer.zero_grad()
            c, d, cond_labels = X_test[j], C_test[j], Cond_indices_test[j]

            recon_batch, mu, log_var = model(c.cuda(gpu_id), d.cuda(gpu_id))
            loss_fn = str_to_object(args.loss_fn)
            loss, rcl, kld, rcl_per_element, kld_per_element = loss_fn(
                c.cuda(gpu_id), recon_batch.cuda(gpu_id), mu, log_var, args)
            test_loss += loss.item()
            rcl_loss += rcl.item()
            kld_loss += kld.item()
            for jj, ii in enumerate(torch.unique(cond_labels)):

                this_cond_positions = cond_labels == ii
                if len(torch.unique(cond_labels)) == c.size()[-1] + 1:
                    batch_rcl = torch.cat(
                        [
                            batch_rcl.cuda(gpu_id),
                            torch.sum(rcl_per_element[this_cond_positions],
                                      dim=0).view(1, -1),
                        ],
                        0,
                    )
                    batch_kld = torch.cat(
                        [
                            batch_kld.cuda(gpu_id),
                            torch.sum(kld_per_element[this_cond_positions],
                                      dim=0).view(1, -1),
                        ],
                        0,
                    )
                    batch_length += 1

                this_cond_rcl = torch.sum(rcl_per_element[this_cond_positions])
                this_cond_kld = torch.sum(kld_per_element[this_cond_positions])
                rcl_per_condition_loss[jj] += this_cond_rcl.item()
                kld_per_condition_loss[jj] += this_cond_kld.item()

    num_batches = len(X_test)

    print("====> Epoch: {} Test losses: {:.4f}".format(epoch, test_loss /
                                                       num_batches))
    print("====> RCL loss: {:.4f}".format(rcl_loss / num_batches))
    print("====> KLD loss: {:.4f}".format(kld_loss / num_batches))

    batch_rcl, batch_kld = batch_rcl, batch_kld

    return (
        test_loss / num_batches,
        rcl_loss / num_batches,
        kld_loss / num_batches,
        rcl_per_condition_loss / num_batches,
        kld_per_condition_loss / num_batches,
        batch_rcl,
        batch_kld,
        batch_length,
    )
Ejemplo n.º 6
0
def train_model():
    """
    Trains a model
    """
    tic = time.time()
    args = get_args()

    path_save_dir = Path(args.path_save_dir)
    if path_save_dir.exists():
        raise ValueError(f"Save directory already exists! ({path_save_dir})")
    path_save_dir.mkdir(parents=True)

    logging.getLogger().addHandler(
        logging.FileHandler(path_save_dir / Path("run.log"), mode="w"))
    save_args(args, path_save_dir / Path("training_options.json"))

    device = (torch.device("cuda", args.gpu_id)
              if torch.cuda.is_available() else torch.device("cpu"))
    LOGGER.info(f"Using device: {device}")

    feature_names = None
    proj_matrix = None

    if args.data_type == "mnist":
        load_data = str_to_object(args.dataloader)
        train_iterator, test_iterator = load_data(args.batch_size,
                                                  args.model_kwargs)
    elif args.data_type == "aics_features":
        load_data = str_to_object(args.dataloader)
        test_instance = load_data(
            args.num_batches,
            args.batch_size,
            args.model_kwargs,
            corr=False,
            train=True,
            mask=False,
        )
        X_train, C_train, Cond_indices_train = test_instance.get_train_data()
        X_test, C_test, Cond_indices_test = test_instance.get_test_data()
        feature_names = test_instance.get_feature_names()
        this_dataloader_color = test_instance.get_color()
        # print('CVAE train')
        # print(X_train.size(), X_test.size())
    elif args.data_type == "synthetic":
        if "mask_percentage" in args.model_kwargs:
            mask_bool = True
        else:
            mask_bool = False
        load_data = str_to_object(args.dataloader)
        if "projection_dim" in args.model_kwargs:
            X_train, C_train, Cond_indices_train, proj_matrix = load_data(
                args.num_batches,
                args.batch_size,
                args.model_kwargs,
                corr=False,
                train=True,
                mask=mask_bool,
            ).get_all_items()
            path_csv = path_save_dir / Path("projection_options.pt")
            print(proj_matrix)
            with path_csv.open("wb") as fo:
                torch.save(proj_matrix, fo)
            LOGGER.info(f"Saved: {path_csv}")
            test_instance = load_data(
                args.num_batches,
                args.batch_size,
                args.model_kwargs,
                corr=False,
                train=False,
                P=proj_matrix,
                mask=mask_bool,
            )
            X_test, C_test, Cond_indices_test = test_instance.get_all_items()
            this_dataloader_color = test_instance.get_color()
        else:
            X_train, C_train, Cond_indices_train, _ = load_data(
                args.num_batches,
                args.batch_size,
                args.model_kwargs,
                corr=False,
                train=True,
                mask=mask_bool,
            ).get_all_items()
            test_instance = load_data(
                args.num_batches,
                args.batch_size,
                args.model_kwargs,
                corr=False,
                train=False,
                mask=mask_bool,
            )
            X_test, C_test, Cond_indices_test = test_instance.get_all_items()
            this_dataloader_color = test_instance.get_color()

    model = get_model(args.model_fn, args.model_kwargs).to(device)
    opt = optim.Adam(model.parameters(), lr=args.lr)
    loss_fn = str_to_object(args.loss_fn)

    make_plot_encoding_greedy = str_to_object(
        "CVAE_testbed.utils.greedy_encoding_plots.make_plot_encoding_greedy")
    make_plot_encoding = str_to_object(
        "CVAE_testbed.utils.encoding_plots.make_plot_encoding")
    make_plot = str_to_object(
        "CVAE_testbed.utils.loss_and_image_plots.make_plot")
    pca = str_to_object("CVAE_testbed.utils.pca.get_PCA_features")

    make_fid_plot = str_to_object("CVAE_testbed.utils.FID_score.make_plot_FID")

    if args.data_type == "mnist":
        run = str_to_object(
            "CVAE_testbed.run_models.run_test_train.run_test_train")

        stats = run(
            model,
            opt,
            loss_fn,
            device,
            args.batch_size,
            train_iterator,
            test_iterator,
            args.n_epochs,
            args.model_kwargs,
        )
    elif args.data_type == "synthetic" or args.data_type == 'aics_features':
        run = str_to_object(
            "CVAE_testbed.run_models.run_synthetic_test.run_synthetic")
        stats, stats_per_dim = run(
            args,
            X_train,
            C_train,
            Cond_indices_train,
            X_test,
            C_test,
            Cond_indices_test,
            args.n_epochs,
            args.loss_fn,
            model,
            opt,
            args.batch_size,
            args.gpu_id,
            args.model_kwargs,
        )
        if args.data_type == 'aics_features' or args.data_type == 'synthetic':
            print(proj_matrix)

            # First load non shuffled data
            if proj_matrix is not None:
                this_dataloader = load_data(args.num_batches,
                                            args.batch_size,
                                            args.model_kwargs,
                                            shuffle=False,
                                            P=proj_matrix,
                                            train=False)
            elif args.data_type == 'aics_features':
                this_dataloader = load_data(args.num_batches,
                                            args.batch_size,
                                            args.model_kwargs,
                                            shuffle=False,
                                            train=False)
            else:
                this_dataloader = load_data(args.num_batches,
                                            args.batch_size,
                                            args.model_kwargs,
                                            shuffle=False,
                                            train=False)
            X_non_shuffled, C_non_shuffled, _ = this_dataloader.get_all_items()

            # Now check encoding
            make_plot_encoding(args, model, stats, X_non_shuffled.clone(),
                               C_non_shuffled.clone(), this_dataloader_color,
                               True, proj_matrix)
            pca_dataframe = pca(args, this_dataloader, True)
            make_plot_encoding_greedy(args, model, stats,
                                      X_non_shuffled.clone(),
                                      C_non_shuffled.clone(), feature_names,
                                      True, proj_matrix)
            try:
                make_fid_plot(args, model, X_non_shuffled.clone(),
                              C_non_shuffled.clone())
            except:
                pass
        else:
            make_plot_encoding(args, model, stats, X_test, C_test)

    this_model = ModelLoader(model, path_save_dir)
    this_model.save_model()
    path_csv = path_save_dir / Path("costs.csv")
    stats.to_csv(path_csv)
    LOGGER.info(f"Saved: {path_csv}")

    path_csv = path_save_dir / Path("costs_per_dimension.csv")
    stats_per_dim.to_csv(path_csv)
    LOGGER.info(f"Saved: {path_csv}")

    make_plot(stats, stats_per_dim, path_save_dir, args)
    LOGGER.info(f"Elapsed time: {time.time() - tic:.2f}")
    print("saved:", path_save_dir)
Ejemplo n.º 7
0
def make_plot_FID(args: argparse.Namespace, model, X_test, C_test, save=True):

    X_test = X_test.view(-1, X_test.size()[-1])
    C_test = C_test.view(-1, C_test.size()[-1])
    # X_test = X_test[-1,:]
    # C_test = C_test[-1,:]
    print(X_test.size(), C_test.size())

    sns.set_context("talk")
    path_save_dir = Path(args.path_save_dir)
    compute_fid = str_to_object(
        "CVAE_testbed.run_models.generative_metric.compute_generative_metric_synthetic"
    )

    try:
        this_kwargs = args.model_kwargs["dec_layers"][-1][-1]
    except:
        this_kwargs = args.model_kwargs["dec_layers"][-1]

    # ADD YOUR PATH HERE
    csv_greedy_features = pd.read_csv(
        '~/Github/cookiecutter/CVAE_testbed/scripts' + args.path_save_dir[1:] +
        '/selected_features.csv')

    # conds = [i for i in range(this_kwargs)]
    conds = [
        i for i in csv_greedy_features['selected_feature_number']
        if not math.isnan(i)
    ]

    fid_data = {'num_conds': [], 'fid': []}

    print(len(conds))

    for i in range(len(conds) + 1):

        tmp1, tmp2 = torch.split(C_test.clone(),
                                 int(C_test.clone().size()[-1] / 2),
                                 dim=1)
        for kk in conds:
            tmp1[:, int(kk)], tmp2[:, int(kk)] = 0, 0
        cond_d = torch.cat((tmp1, tmp2), 1)

        print(len(torch.nonzero(cond_d)))

        try:
            this_fid = compute_fid(X_test.clone(), cond_d.clone(), args, model,
                                   conds)
        except:
            this_fid = np.NaN
        print('fid', this_fid)

        fid_data['num_conds'].append(X_test.size()[-1] - len(conds))
        fid_data['fid'].append(this_fid)

        try:
            conds.pop()
        except:
            pass

    fid_data = pd.DataFrame(fid_data)

    fig, ax = plt.subplots(1, 1, figsize=(7 * 4, 5))
    sns.lineplot(ax=ax, data=fid_data, x='num_conds', y='fid')
    sns.scatterplot(ax=ax,
                    data=fid_data,
                    x='num_conds',
                    y='fid',
                    s=100,
                    color=".2")

    if save is True:
        path_csv = path_save_dir / Path("fid_data.csv")
        fid_data.to_csv(path_csv)
        LOGGER.info(f"Saved: {path_csv}")

        path_save_fig = path_save_dir / Path("fid_score.png")
        fig.savefig(path_save_fig, bbox_inches="tight")
        LOGGER.info(f"Saved: {path_save_fig}")
Ejemplo n.º 8
0
def visualize_encoder_synthetic(args,
                                model,
                                conds,
                                c,
                                d,
                                kl_per_lt=None,
                                kl_vs_rcl=None):

    z_means_x, z_means_y = [], []
    z_var_x, z_var_y = [], []
    model.cuda(args.gpu_id)

    with torch.no_grad():
        if kl_per_lt is None:
            kl_per_lt = {
                "latent_dim": [],
                "kl_divergence": [],
                "num_conds": [],
            }

        if kl_vs_rcl is None:
            kl_vs_rcl = {"num_conds": [], "KLD": [], "RCL": [], "ELBO": []}
        all_kl, all_lt = [], []

        tmp1, tmp2 = torch.split(d, int(d.size()[-1] / 2), dim=1)
        for kk in conds:
            tmp1[:, kk], tmp2[:, kk] = 0, 0
        cond_d = torch.cat((tmp1, tmp2), 1)

        # Run 10 times for resampling
        my_recon_list, my_z_means_list, my_log_var_list = [], [], []
        for resample in range(10):
            recon_batch, z_means, log_var = model(
                c.clone().cuda(args.gpu_id),
                cond_d.clone().cuda(args.gpu_id))
            my_recon_list.append(recon_batch)
            my_z_means_list.append(z_means)
            my_log_var_list.append(log_var)

        recon_batch = torch.mean(torch.stack(my_recon_list), dim=0)
        z_means = torch.mean(torch.stack(my_z_means_list), dim=0)
        log_var = torch.mean(torch.stack(my_log_var_list), dim=0)

        loss_fn = str_to_object(args.loss_fn)

        elbo_loss_total, rcl_per_lt_temp_total, kl_per_lt_temp_total, _, _ = loss_fn(
            c.cuda(args.gpu_id), recon_batch.cuda(args.gpu_id), z_means,
            log_var, args)
        # print('conds is', conds, elbo_loss_total, rcl_per_lt_temp_total, kl_per_lt_temp_total)
        kl_vs_rcl['num_conds'].append(c.size()[-1] - len(conds))
        kl_vs_rcl['KLD'].append(kl_per_lt_temp_total.item())
        kl_vs_rcl['RCL'].append(rcl_per_lt_temp_total.item())
        kl_vs_rcl['ELBO'].append(elbo_loss_total.item())

        for ii in range(z_means.size()[-1]):
            elbo_loss, rcl_per_lt_temp, kl_per_lt_temp, _, _ = loss_fn(
                c.cuda(args.gpu_id), recon_batch.cuda(args.gpu_id),
                z_means[:, ii], log_var[:, ii], args)
            # print(elbo_loss.item(), rcl_per_lt_temp.item(), kl_per_lt_temp.item())

            all_kl = np.append(all_kl, kl_per_lt_temp.item())
            all_lt.append(ii)
            kl_per_lt["num_conds"].append(c.size()[-1] - len(conds))
            kl_per_lt["latent_dim"].append(ii)
            kl_per_lt["kl_divergence"].append(kl_per_lt_temp.item())
        all_kl, all_lt = list(zip(*sorted(zip(all_kl, all_lt))))
        all_kl = list(all_kl)
        all_lt = list(all_lt)

        z_means_x = np.append(z_means_x,
                              z_means[:, all_lt[-1]].data.cpu().numpy())
        z_means_y = np.append(z_means_y,
                              z_means[:, all_lt[-2]].data.cpu().numpy())
        z_var_x = np.append(z_var_x, log_var[:, all_lt[-1]].data.cpu().numpy())
        z_var_y = np.append(z_var_y, log_var[:, all_lt[-2]].data.cpu().numpy())
    return z_means_x, z_means_y, kl_per_lt, z_var_x, z_var_y, kl_vs_rcl
def train(
    args,
    epoch,
    loss_fn,
    X_train,
    C_train,
    Cond_indices_train,
    batch_size,
    model,
    optimizer,
    gpu_id,
    model_kwargs,
):
    model.train()
    train_loss, rcl_loss, kld_loss = 0, 0, 0
    rcl_per_condition_loss, kld_per_condition_loss = (
        torch.zeros(X_train.size()[-1] + 1),
        torch.zeros(X_train.size()[-1] + 1),
    )
    batch_rcl, batch_kld = (
        torch.empty([0]),
        torch.empty([0]),
    )
    batch_length = 0

    for j, i in enumerate(X_train):
        optimizer.zero_grad()
        c, d, cond_labels = X_train[j], C_train[j], Cond_indices_train[j]

        if args.model_fn == 'CVAE_testbed.models.CVAE_baseline_2.CVAE':
            recon_batch, z1, mu, log_var, mu2, log_var2, z2, z1_prior = model(
                c.cuda(gpu_id), d.cuda(gpu_id))

            loss_fn = str_to_object(args.loss_fn)
            loss, rcl, kld, rcl_per_element, kld_per_element = loss_fn(
                c.cuda(gpu_id), recon_batch.cuda(gpu_id), z1, z1_prior, z2,
                [mu, log_var], [mu2, log_var2], args)
        else:
            recon_batch, mu, log_var = model(c.cuda(gpu_id), d.cuda(gpu_id))

            loss_fn = str_to_object(args.loss_fn)
            loss, rcl, kld, rcl_per_element, kld_per_element = loss_fn(
                c.cuda(gpu_id), recon_batch.cuda(gpu_id), mu, log_var, args)
        # print(loss, rcl, kld)
        loss.backward()
        train_loss += loss.item()
        rcl_loss += rcl.item()
        kld_loss += kld.item()

        for jj, ii in enumerate(torch.unique(cond_labels)):

            this_cond_positions = cond_labels == ii
            batch_rcl = torch.cat(
                [
                    batch_rcl.cuda(gpu_id),
                    torch.sum(rcl_per_element[this_cond_positions],
                              dim=0).view(1, -1),
                ],
                0,
            )
            batch_kld = torch.cat(
                [
                    batch_kld.cuda(gpu_id),
                    torch.sum(kld_per_element[this_cond_positions],
                              dim=0).view(1, -1),
                ],
                0,
            )
            batch_length += 1

            this_cond_rcl = torch.sum(rcl_per_element[this_cond_positions])
            this_cond_kld = torch.sum(kld_per_element[this_cond_positions])

            rcl_per_condition_loss[jj] += this_cond_rcl.item()
            kld_per_condition_loss[jj] += this_cond_kld.item()
        optimizer.step()

    num_batches = len(X_train)
    # print(num_batches)
    print("====> Epoch: {} Train loss: {:.4f}".format(epoch, train_loss /
                                                      num_batches))
    print("====> Train RCL loss: {:.4f}".format(rcl_loss / num_batches))
    print("====> Train KLD loss: {:.4f}".format(kld_loss / num_batches))

    batch_rcl, batch_kld = batch_rcl, batch_kld

    return (
        train_loss / num_batches,
        rcl_loss / num_batches,
        kld_loss / num_batches,
        rcl_per_condition_loss / num_batches,
        kld_per_condition_loss / num_batches,
        batch_rcl,
        batch_kld,
        batch_length,
    )
Ejemplo n.º 10
0
def make_plot_encoding(
        args: argparse.Namespace,
        model,
        df: pd.DataFrame,
        c,
        d,
        this_dataloader_color=None,
        save=True,
        proj_matrix=None
                      ) -> None:
    """
    c and d are X_test and C_test
    """
    sns.set_context("talk")
    path_save_dir = Path(args.path_save_dir)
    vis_enc = str_to_object(
        "CVAE_testbed.metrics.visualize_encoder.visualize_encoder_synthetic"
    )
    try:
        conds = [i for i in range(args.model_kwargs["dec_layers"][-1][-1])]
    except:
        conds = [i for i in range(args.model_kwargs["dec_layers"][-1])]
 
    try:
        latent_dims = args.model_kwargs["vae_layers"][-1][-1]
    except:
        latent_dims = args.model_kwargs["enc_layers"][-1]

    fig, ax = plt.subplots(1, 1, figsize=(7, 5))
    sns.lineplot(ax=ax, data=df, x="epoch", y="total_train_ELBO")
    sns.lineplot(ax=ax, data=df, x="epoch", y="total_test_ELBO")
    ax.set_ylim([0, df.total_test_ELBO.quantile(0.95)])
    ax.legend(["Train loss", "Test loss"])
    ax.set_ylabel('Loss')
    ax.set_title("Actual ELBO (no beta) vs epoch")

    if save is True:
        path_save_fig = path_save_dir / Path("ELBO.png")
        fig.savefig(path_save_fig, bbox_inches="tight")
        LOGGER.info(f"Saved: {path_save_fig}")

    fig, (ax1, ax, ax2, ax3) = plt.subplots(1, 4, figsize=(7 * 4, 5))
    fig2 = plt.figure(figsize=(12, 10))
    bax = brokenaxes(xlims=((0, latent_dims-50), (latent_dims - 4, latent_dims)), hspace=0.15)

    if "total_train_losses" in df.columns:
        sns.lineplot(ax=ax1, data=df, x="epoch", y="total_train_losses")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss")
    if "total_test_losses" in df.columns:
        sns.lineplot(ax=ax1, data=df, x="epoch", y="total_test_losses")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss")
        ax1.set_ylim([0, df.total_test_losses.quantile(0.95)])
        ax1.legend(["Train loss", "Test loss"])
    ax1.set_title("ELBO (beta*KLD + RCL) vs epoch")

    try:
        this_kwargs = args.model_kwargs["dec_layers"][-1][-1]
    except:
        this_kwargs = args.model_kwargs["dec_layers"][-1]

    conds = [i for i in range(this_kwargs)]
    # if len(conds) > 20:
    #     conds = [i for i in conds if i%20 == 0]

    if args.post_plot_kwargs["latent_space_colorbar"] == "yes":
        # color = this_dataloader.get_color()
        color = this_dataloader_color
    else:
        color = None

    for i in range(len(conds) + 1):
        # print('inside main plot encoding', i, len(conds) + 1)
        if i == 0:
            z_means_x, z_means_y, kl_per_lt, _, _, kl_vs_rcl = vis_enc(
                args,
                model,
                conds,
                c[-1, :].clone(),
                d[-1, :].clone(),
                kl_per_lt=None,
                kl_vs_rcl=None
            )
            ax.scatter(
                z_means_x,
                z_means_y,
                marker=".",
                s=30,
                label=str(i)
            )
            if color is not None:
                colormap_plot(
                    path_save_dir,
                    c[-1, :].clone(), z_means_x,
                    z_means_y, color,
                    conds)
        else:
            z_means_x, z_means_y, kl_per_lt, _, _, kl_vs_rcl = vis_enc(
                args,
                model,
                conds,
                c[-1, :].clone(),
                d[-1, :].clone(),
                kl_per_lt,
                kl_vs_rcl
            )
            ax.scatter(
                z_means_x,
                z_means_y,
                marker=".",
                s=30,
                label=str(i)
            )
            if color is not None:
                colormap_plot(
                    path_save_dir,
                    c[-1, :].clone(),
                    z_means_x,
                    z_means_y,
                    color,
                    conds
                    )
        try:
            conds.pop()
        except:
            pass

    kl_per_lt = pd.DataFrame(kl_per_lt)
    kl_vs_rcl = pd.DataFrame(kl_vs_rcl)

    if save is True:
        path_csv = path_save_dir / Path("encoding_kl_per_lt.csv")  
        kl_per_lt.to_csv(path_csv)
        LOGGER.info(f"Saved: {path_csv}")
        path_csv = path_save_dir / Path("encoding_kl_vs_rcl.csv")
        kl_vs_rcl.to_csv(path_csv)
        LOGGER.info(f"Saved: {path_csv}")

    ax.set_title("Latent space")
    ax.legend()

    conds = [i for i in range(this_kwargs)]
    # if len(conds) > 20:
    #     conds = [i for i in conds if i%20 == 0]

    for i in range(len(conds) + 1):
        tmp = kl_per_lt.loc[kl_per_lt["num_conds"] == c.size()[-1] - len(conds)]
        tmp_2 = kl_vs_rcl.loc[kl_vs_rcl["num_conds"] == c.size()[-1] - len(conds)]
        tmp = tmp.sort_values(
            by="kl_divergence",
            ascending=False
            )
        tmp = tmp.reset_index(drop=True)
        x = tmp.index.values
        y = tmp.iloc[:, 1].values
        sns.lineplot(
            ax=ax2,
            data=tmp,
            x=tmp.index,
            y="kl_divergence",
            label=str(i),
            legend='brief'
            )
        bax.plot(x, y)
        ax3.scatter(tmp_2['RCL'].mean(), tmp_2['KLD'].mean(), label=str(i))
        # sns.scatterplot(
        #     ax=ax3,
        #     data=tmp,
        #     x="rcl",
        #     y="kl_divergence",
        #     label=str(i),
        #     legend='brief'
        #     )
        try:
            conds.pop()
        except:
            pass

    ax2.set_xlabel("Latent dimension")
    ax2.set_ylabel("KLD")
    ax2.set_title("KLD per latent dim")
    ax3.set_xlabel("MSE")
    ax3.set_ylabel("KLD")
    ax3.set_title("MSE vs KLD")
    # bax.legend(loc="best")
    bax.set_xlabel("Latent dimension")
    bax.set_ylabel("KLD")
    bax.set_title("KLD per latent dim")

    conds = [i for i in range(this_kwargs)]
    if len(conds) > 30:
        ax.get_legend().remove()
        ax2.get_legend().remove()

    if save is True:
        path_save_fig = path_save_dir / Path("encoding_test_plots.png")
        fig.savefig(path_save_fig, bbox_inches="tight")
        LOGGER.info(f"Saved: {path_save_fig}")

    if save is True:
        path_save_fig = path_save_dir / Path("brokenaxes_KLD_per_dim.png")
        fig2.savefig(path_save_fig, bbox_inches="tight")
        LOGGER.info(f"Saved: {path_save_fig}")