Example #1
0
def visualize_trained_network_results(data_dict, z_dim=100):
    plt.close("all")
    data_name = data_dict.get('name', None)
    source_key = data_dict.get('source_key', None)
    target_key = data_dict.get('target_key', None)
    cell_type_key = data_dict.get("cell_type", None)

    data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad")
    cell_types = data.obs[cell_type_key].unique().tolist()

    spec_cell_type = data_dict.get("spec_cell_types", None)
    if spec_cell_type is not []:
        cell_types = spec_cell_type

    for cell_type in cell_types:
        path_to_save = f"../results/CVAE/{data_name}/{cell_type}/{z_dim}/{source_key} to {target_key}/Visualizations/"
        os.makedirs(path_to_save, exist_ok=True)
        sc.settings.figdir = os.path.abspath(path_to_save)

        train_data = data.copy()[~((data.obs['condition'] == target_key) &
                                   (data.obs[cell_type_key] == cell_type))]

        cell_type_adata = data[data.obs[cell_type_key] == cell_type]

        network = trvae.CVAE(
            x_dimension=data.shape[1],
            z_dimension=z_dim,
            model_path=f"../models/CVAE/{data_name}/{cell_type}/{z_dim}/cvae")

        network.restore_model()

        if sparse.issparse(data.X):
            data.X = data.X.A

        feed_data = data.X

        train_labels, _ = trvae.label_encoder(data)
        fake_labels = np.ones(train_labels.shape)

        latent_with_true_labels = network.to_latent(feed_data, train_labels)
        latent_with_fake_labels = network.to_latent(feed_data, fake_labels)
        mmd_latent_with_true_labels = network.to_mmd_layer(
            feed_data, train_labels)
        mmd_latent_with_fake_labels = network.to_mmd_layer(
            feed_data, fake_labels)

        cell_type_ctrl = cell_type_adata.copy()[
            cell_type_adata.obs['condition'] == source_key]
        print(cell_type_ctrl.shape, cell_type_adata.shape)

        pred_celltypes = network.predict(cell_type_ctrl,
                                         labels=np.ones(
                                             (cell_type_ctrl.shape[0], 1)))
        pred_adata = anndata.AnnData(X=pred_celltypes)
        pred_adata.obs['condition'] = ['predicted'] * pred_adata.shape[0]
        pred_adata.var = cell_type_adata.var

        if data_name == "pbmc":
            sc.tl.rank_genes_groups(cell_type_adata,
                                    groupby="condition",
                                    n_genes=100,
                                    method="wilcoxon")
            top_100_genes = cell_type_adata.uns["rank_genes_groups"]["names"][
                target_key].tolist()
            gene_list = top_100_genes[:10]
        else:
            sc.tl.rank_genes_groups(cell_type_adata,
                                    groupby="condition",
                                    n_genes=100,
                                    method="wilcoxon")
            top_50_down_genes = cell_type_adata.uns["rank_genes_groups"][
                "names"][source_key].tolist()
            top_50_up_genes = cell_type_adata.uns["rank_genes_groups"][
                "names"][target_key].tolist()
            top_100_genes = top_50_up_genes + top_50_down_genes
            gene_list = top_50_down_genes[:5] + top_50_up_genes[:5]

        cell_type_adata = cell_type_adata.concatenate(pred_adata)

        trvae.plotting.reg_mean_plot(
            cell_type_adata,
            top_100_genes=top_100_genes,
            gene_list=gene_list,
            condition_key='condition',
            axis_keys={
                "x": 'predicted',
                'y': target_key
            },
            labels={
                'x': 'pred stim',
                'y': 'real stim'
            },
            legend=False,
            fontsize=20,
            textsize=14,
            title=cell_type,
            path_to_save=os.path.join(
                path_to_save, f'rcvae_reg_mean_{data_name}_{cell_type}.pdf'))

        trvae.plotting.reg_var_plot(
            cell_type_adata,
            top_100_genes=top_100_genes,
            gene_list=gene_list,
            condition_key='condition',
            axis_keys={
                "x": 'predicted',
                'y': target_key
            },
            labels={
                'x': 'pred stim',
                'y': 'real stim'
            },
            legend=False,
            fontsize=20,
            textsize=14,
            title=cell_type,
            path_to_save=os.path.join(
                path_to_save, f'rcvae_reg_var_{data_name}_{cell_type}.pdf'))

        import matplotlib as mpl
        mpl.rcParams.update(mpl.rcParamsDefault)

        latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
        latent_with_true_labels.obs['condition'] = data.obs['condition'].values
        latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels)
        latent_with_fake_labels.obs['condition'] = data.obs['condition'].values
        latent_with_fake_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
        mmd_latent_with_true_labels.obs['condition'] = data.obs[
            'condition'].values
        mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels)
        mmd_latent_with_fake_labels.obs['condition'] = data.obs[
            'condition'].values
        mmd_latent_with_fake_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        color = ['condition', cell_type_key]

        sc.pp.neighbors(train_data)
        sc.tl.umap(train_data)
        sc.pl.umap(train_data,
                   color=color,
                   save=f'_{data_name}_{cell_type}_train_data',
                   show=False)

        sc.pp.neighbors(latent_with_true_labels)
        sc.tl.umap(latent_with_true_labels)
        sc.pl.umap(latent_with_true_labels,
                   color=color,
                   save=f"_{data_name}_{cell_type}_latent_with_true_labels",
                   show=False)

        sc.pp.neighbors(latent_with_fake_labels)
        sc.tl.umap(latent_with_fake_labels)
        sc.pl.umap(latent_with_fake_labels,
                   color=color,
                   save=f"_{data_name}_{cell_type}_latent_with_fake_labels",
                   show=False)

        sc.pp.neighbors(mmd_latent_with_true_labels)
        sc.tl.umap(mmd_latent_with_true_labels)
        sc.pl.umap(
            mmd_latent_with_true_labels,
            color=color,
            save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels",
            show=False)

        sc.pp.neighbors(mmd_latent_with_fake_labels)
        sc.tl.umap(mmd_latent_with_fake_labels)
        sc.pl.umap(
            mmd_latent_with_fake_labels,
            color=color,
            save=f"_{data_name}_{cell_type}_mmd_latent_with_fake_labels",
            show=False)

        sc.pl.violin(cell_type_adata,
                     keys=top_100_genes[0],
                     groupby='condition',
                     save=f"_{data_name}_{cell_type}_{top_100_genes[0]}",
                     show=False)

        plt.close("all")
Example #2
0
def visualize_trained_network_results_multimodal(data_dict, z_dim=100):
    plt.close("all")
    data_name = data_dict.get('name', None)
    source_key = data_dict.get('source_key', None)
    target_key = data_dict.get('target_key', None)

    data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad")
    path_to_save = f"../results/RCVAE/{data_name}/{z_dim}/{source_key} to {target_key}/Visualizations/"
    os.makedirs(path_to_save, exist_ok=True)
    sc.settings.figdir = os.path.abspath(path_to_save)

    network = trvae.trVAE(
        x_dimension=data.shape[1],
        z_dimension=z_dim,
        model_path=f"../models/RCVAE/{data_name}/{z_dim}/",
    )
    network.restore_model()
    if sparse.issparse(data.X):
        data.X = data.X.A

    feed_data = data.X
    train_labels, _ = trvae.label_encoder(data)
    fake_labels = np.ones(train_labels.shape)
    latent_with_true_labels = network.to_latent(feed_data, train_labels)
    latent_with_fake_labels = network.to_latent(feed_data, fake_labels)
    mmd_latent_with_true_labels = network.to_mmd_layer(network,
                                                       feed_data,
                                                       train_labels,
                                                       feed_fake=False)
    mmd_latent_with_fake_labels = network.to_mmd_layer(network,
                                                       feed_data,
                                                       train_labels,
                                                       feed_fake=True)

    import matplotlib as mpl
    mpl.rcParams.update(mpl.rcParamsDefault)

    latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
    latent_with_true_labels.obs['condition'] = data.obs['condition'].values
    # latent_with_true_labels.obs[cell_type_key] = data.obs[cell_type_key].values

    latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels)
    latent_with_fake_labels.obs['condition'] = data.obs['condition'].values
    # latent_with_fake_labels.obs[cell_type_key] = data.obs[cell_type_key].values

    mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
    mmd_latent_with_true_labels.obs['condition'] = data.obs['condition'].values
    # mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[cell_type_key].values

    mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels)
    mmd_latent_with_fake_labels.obs['condition'] = data.obs['condition'].values
    # mmd_latent_with_fake_labels.obs[cell_type_key] = data.obs[cell_type_key].values

    color = ['condition']

    sc.pp.neighbors(data)
    sc.tl.umap(data)
    sc.pl.umap(data, color=color, save=f'_{data_name}_train_data', show=False)

    sc.pp.neighbors(latent_with_true_labels)
    sc.tl.umap(latent_with_true_labels)
    sc.pl.umap(latent_with_true_labels,
               color=color,
               save=f"_{data_name}_latent_with_true_labels",
               show=False)

    sc.pp.neighbors(latent_with_fake_labels)
    sc.tl.umap(latent_with_fake_labels)
    sc.pl.umap(latent_with_fake_labels,
               color=color,
               save=f"_{data_name}__latent_with_fake_labels",
               show=False)

    sc.pp.neighbors(mmd_latent_with_true_labels)
    sc.tl.umap(mmd_latent_with_true_labels)
    sc.pl.umap(mmd_latent_with_true_labels,
               color=color,
               save=f"_{data_name}_mmd_latent_with_true_labels",
               show=False)

    sc.pp.neighbors(mmd_latent_with_fake_labels)
    sc.tl.umap(mmd_latent_with_fake_labels)
    sc.pl.umap(mmd_latent_with_fake_labels,
               color=color,
               save=f"_{data_name}_mmd_latent_with_fake_labels",
               show=False)
    plt.close("all")
Example #3
0
def visualize_trained_network_results(data_dict,
                                      z_dim=100,
                                      mmd_dimension=128,
                                      loss_fn='mse'):
    plt.close("all")
    data_name = data_dict['name']
    source_keys = data_dict.get("source_conditions")
    target_keys = data_dict.get("target_conditions")
    cell_type_key = data_dict.get("cell_type", None)
    need_merge = data_dict.get('need_merge', False)
    label_encoder = data_dict.get('label_encoder', None)
    condition_key = data_dict.get('condition', 'condition')

    if need_merge:
        data, _ = merge_data(data_dict)
    else:
        data = sc.read(f"../data/{data_name}/{data_name}.h5ad")
        if loss_fn != 'mse':
            data = normalize_hvg(data,
                                 filter_min_counts=False,
                                 normalize_input=False,
                                 logtrans_input=True)

    cell_types = data.obs[cell_type_key].unique().tolist()

    spec_cell_type = data_dict.get("spec_cell_types", None)
    if spec_cell_type:
        cell_types = spec_cell_type

    for cell_type in cell_types:
        path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/"
        os.makedirs(path_to_save, exist_ok=True)
        sc.settings.figdir = os.path.abspath(path_to_save)
        train_data = data.copy()[~(
            (data.obs[condition_key].isin(target_keys)) &
            (data.obs[cell_type_key] == cell_type))]

        cell_type_adata = data.copy()[data.obs[cell_type_key] == cell_type]

        n_conditions = len(train_data.obs[condition_key].unique().tolist())

        network = trvae.trVAEMulti(
            x_dimension=data.shape[1],
            z_dimension=z_dim,
            loss_fn=loss_fn,
            n_conditions=n_conditions,
            mmd_dimension=mmd_dimension,
            model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/",
        )

        network.restore_model()

        if sparse.issparse(data.X):
            data.X = data.X.A

        feed_data = data

        train_labels, _ = trvae.label_encoder(data, label_encoder,
                                              condition_key)
        fake_labels = []

        for i in range(n_conditions):
            fake_labels.append(np.zeros(train_labels.shape) + i)

        latent_with_true_labels = network.to_z_latent(feed_data.X,
                                                      train_labels)
        latent_with_fake_labels = [
            network.to_z_latent(feed_data.X, fake_labels[i])
            for i in range(n_conditions)
        ]
        mmd_latent_with_true_labels = network.to_mmd_layer(feed_data,
                                                           train_labels,
                                                           feed_fake=0)
        mmd_latent_with_fake_labels = [
            network.to_mmd_layer(feed_data, train_labels, feed_fake=i)
            for i in range(n_conditions)
        ]

        if data_name in ["pbmc", 'endo_norm']:
            sc.tl.rank_genes_groups(cell_type_adata,
                                    groupby=condition_key,
                                    n_genes=100,
                                    method="wilcoxon")
            top_100_genes = cell_type_adata.uns["rank_genes_groups"]["names"][
                target_keys[-1]].tolist()
            gene_list = top_100_genes[:10]
        elif data_name in ['pancreas', 'nmuil_count']:
            gene_list = None
            top_100_genes = None
        else:
            sc.tl.rank_genes_groups(cell_type_adata,
                                    groupby=condition_key,
                                    n_genes=100,
                                    method="wilcoxon")
            top_50_down_genes = cell_type_adata.uns["rank_genes_groups"][
                "names"][source_keys[0]].tolist()
            top_50_up_genes = cell_type_adata.uns["rank_genes_groups"][
                "names"][target_keys[-1]].tolist()
            top_100_genes = top_50_up_genes + top_50_down_genes
            gene_list = top_50_down_genes[:5] + top_50_up_genes[:5]

        perturbation_list = data_dict.get("perturbation", [])
        pred_adatas = None
        for source, dest, name, source_label, target_label in perturbation_list:
            print(source, dest, name)
            pred_adata = visualize_multi_perturbation_between(
                network,
                cell_type_adata,
                pred_adatas,
                source_condition=source,
                target_condition=dest,
                name=name,
                source_label=source_label,
                target_label=target_label,
                cell_type=cell_type,
                data_name=data_name,
                top_100_genes=top_100_genes,
                gene_list=gene_list,
                path_to_save=path_to_save,
                condition_key=condition_key)
            if pred_adatas is None:
                pred_adatas = pred_adata
            else:
                pred_adatas = pred_adatas.concatenate(pred_adata)

        pred_adatas.write_h5ad(
            filename=f"../data/reconstructed/RCVAEMulti/{data_name}.h5ad")

        import matplotlib as mpl
        mpl.rcParams.update(mpl.rcParamsDefault)

        color = [condition_key, cell_type_key]

        latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
        latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        latent_with_fake_labels = [
            sc.AnnData(X=latent_with_fake_labels[i])
            for i in range(n_conditions)
        ]
        for i in range(n_conditions):
            latent_with_fake_labels[i].obs[condition_key] = data.obs[
                condition_key].values
            latent_with_fake_labels[i].obs[cell_type_key] = data.obs[
                cell_type_key].values

            sc.pp.neighbors(latent_with_fake_labels[i])
            sc.tl.umap(latent_with_fake_labels[i])
            sc.pl.umap(
                latent_with_fake_labels[i],
                color=color,
                save=f"_{data_name}_{cell_type}_latent_with_fake_labels_{i}",
                show=False,
                wspace=0.15,
                frameon=False)

        mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
        mmd_latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        mmd_latent_with_fake_labels = [
            sc.AnnData(X=mmd_latent_with_fake_labels[i])
            for i in range(n_conditions)
        ]
        for i in range(n_conditions):
            mmd_latent_with_fake_labels[i].obs[condition_key] = data.obs[
                condition_key].values
            mmd_latent_with_fake_labels[i].obs[cell_type_key] = data.obs[
                cell_type_key].values

            sc.pp.neighbors(mmd_latent_with_fake_labels[i])
            sc.tl.umap(mmd_latent_with_fake_labels[i])
            sc.pl.umap(mmd_latent_with_fake_labels[i],
                       color=color,
                       save=f"_{data_name}_latent_with_fake_labels_{i}",
                       show=False,
                       wspace=0.15,
                       frameon=False)

        sc.pp.neighbors(train_data)
        sc.tl.umap(train_data)
        sc.pl.umap(train_data,
                   color=color,
                   save=f'_{data_name}_{cell_type}_train_data',
                   show=False,
                   wspace=0.15,
                   frameon=False)

        sc.pp.neighbors(latent_with_true_labels)
        sc.tl.umap(latent_with_true_labels)
        sc.pl.umap(latent_with_true_labels,
                   color=color,
                   save=f"_{data_name}_{cell_type}_latent_with_true_labels",
                   show=False,
                   wspace=0.15,
                   frameon=False)

        sc.pp.neighbors(mmd_latent_with_true_labels)
        sc.tl.umap(mmd_latent_with_true_labels)
        sc.pl.umap(
            mmd_latent_with_true_labels,
            color=color,
            save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels",
            show=False,
            wspace=0.15,
            frameon=False)
        if gene_list is not None:
            for target_condition in target_keys:
                pred_adata = pred_adatas[pred_adatas.obs[condition_key].str.
                                         endswith(target_condition)]
                violin_adata = cell_type_adata.concatenate(pred_adata)
                for gene in gene_list[:3]:
                    sc.pl.violin(
                        violin_adata,
                        keys=gene,
                        groupby=condition_key,
                        save=
                        f"_{data_name}_{cell_type}_{gene}_{target_condition}.pdf",
                        show=False,
                        wspace=0.2,
                        rotation=90,
                        frameon=False)

        plt.close("all")
Example #4
0
def visualize_batch_correction(data_dict, z_dim=100, mmd_dimension=128):
    plt.close("all")
    data_name = data_dict['name']
    source_keys = data_dict.get("source_conditions")
    target_keys = data_dict.get("target_conditions")
    cell_type_key = data_dict.get("cell_type", None)
    need_merge = data_dict.get('need_merge', False)
    label_encoder = data_dict.get('label_encoder', None)
    condition_key = data_dict.get('condition', 'condition')

    if need_merge:
        data, _ = merge_data(data_dict)
    else:
        data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad")

    cell_types = data.obs[cell_type_key].unique().tolist()

    spec_cell_type = data_dict.get("spec_cell_types", None)
    if spec_cell_type:
        cell_types = spec_cell_type

    for cell_type in cell_types:
        path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/"
        os.makedirs(path_to_save, exist_ok=True)
        sc.settings.figdir = os.path.abspath(path_to_save)

        train_data = data.copy()[~(
            (data.obs[condition_key].isin(target_keys)) &
            (data.obs[cell_type_key] == cell_type))]

        cell_type_adata = data[data.obs[cell_type_key] == cell_type]
        network = trvae.trVAEMulti(
            x_dimension=data.shape[1],
            z_dimension=z_dim,
            n_conditions=len(source_keys),
            mmd_dimension=mmd_dimension,
            model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/",
        )

        network.restore_model()

        if sparse.issparse(data.X):
            data.X = data.X.A

        feed_data = data.X

        train_labels, _ = trvae.label_encoder(data, label_encoder,
                                              condition_key)

        mmd_latent_with_true_labels = network.to_mmd_layer(network,
                                                           feed_data,
                                                           train_labels,
                                                           feed_fake=0)

        latent_with_true_labels = network.to_z_latent(feed_data, train_labels)

        import matplotlib as mpl
        mpl.rcParams.update(mpl.rcParamsDefault)

        color = [condition_key]

        mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
        mmd_latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
        latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        sc.pp.neighbors(train_data)
        sc.tl.umap(train_data)
        sc.pl.umap(train_data,
                   color=color,
                   save=f'_{data_name}_{cell_type}_train_data',
                   show=False,
                   wspace=0.15,
                   frameon=False)

        sc.pp.neighbors(mmd_latent_with_true_labels)
        sc.tl.umap(mmd_latent_with_true_labels)
        sc.pl.umap(
            mmd_latent_with_true_labels,
            color=color,
            save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels",
            show=False,
            wspace=0.15,
            frameon=False)

        sc.pp.neighbors(latent_with_true_labels)
        sc.tl.umap(latent_with_true_labels)
        sc.pl.umap(latent_with_true_labels,
                   color=color,
                   save=f"_{data_name}_{cell_type}_latent_with_true_labels",
                   show=False,
                   wspace=0.15,
                   frameon=False)

        # mmd_latent_with_true_labels.obs['mmd'] = 'others'
        # mmd_latent_with_true_labels.obs['mmd'] = mmd_latent_with_true_labels.obs.mmd.astype(str)
        # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-{target_keys[0]}'], inplace=True)
        # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-others'], inplace=True)
        # print(mmd_latent_with_true_labels.obs['mmd'].cat.categories)
        # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] == target_keys[0]) &
        #                                      mmd_latent_with_true_labels.obs[
        #                                          cell_type_key] == cell_type), 'mmd'] = f'alpha-{target_keys[0]}'
        # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] != target_keys[0]) &
        #                                      mmd_latent_with_true_labels.obs[
        #                                          cell_type_key] == cell_type), 'mmd'] = f'alpha-others'
        #
        # sc.pl.umap(mmd_latent_with_true_labels, color='mmd',
        #            save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels_cell_comparison",
        #            show=False,
        #            wspace=0.15,
        #            frameon=False)
        mmd_latent_with_true_labels.write_h5ad('../data/mmd.h5ad')
        latent_with_true_labels.write_h5ad('../data/latent.h5ad')
Example #5
0
def visualize_trained_network_results(data_dict,
                                      z_dim=100,
                                      arch_style=1,
                                      preprocess=True,
                                      max_size=80000):
    plt.close("all")
    data_name = data_dict.get('name', None)
    source_key = data_dict.get('source_key', None)
    target_key = data_dict.get('target_key', None)
    img_width = data_dict.get('width', None)
    img_height = data_dict.get('height', None)
    n_channels = data_dict.get('n_channels', None)
    train_digits = data_dict.get('train_digits', None)
    test_digits = data_dict.get('test_digits', None)
    attribute = data_dict.get('attribute', None)

    path_to_save = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/UMAPs/"
    os.makedirs(path_to_save, exist_ok=True)
    sc.settings.figdir = os.path.abspath(path_to_save)

    if data_name == "celeba":
        gender = data_dict.get('gender', None)
        data = trvae.prepare_and_load_celeba(
            file_path="../data/celeba/img_align_celeba.zip",
            attr_path="../data/celeba/list_attr_celeba.txt",
            landmark_path="../data/celeba/list_landmarks_align_celeba.txt",
            gender=gender,
            attribute=attribute,
            max_n_images=max_size,
            img_width=img_width,
            img_height=img_height,
            restore=True,
            save=False)

        if sparse.issparse(data.X):
            data.X = data.X.A

        train_images = data.X
        train_data = anndata.AnnData(X=data)
        train_data.obs['condition'] = data.obs['condition'].values
        train_data.obs.loc[train_data.obs['condition'] == 1,
                           'condition'] = f'with {attribute}'
        train_data.obs.loc[train_data.obs['condition'] == -1,
                           'condition'] = f'without {attribute}'

        train_data.obs['labels'] = data.obs['labels'].values
        train_data.obs.loc[train_data.obs['labels'] == 1, 'labels'] = f'Male'
        train_data.obs.loc[train_data.obs['labels'] == -1,
                           'labels'] = f'Female'

        if preprocess:
            train_images /= 255.0
    else:
        train_data = sc.read(f"../data/{data_name}/{data_name}.h5ad")
        train_images = np.reshape(train_data.X,
                                  (-1, img_width, img_height, n_channels))

        if preprocess:
            train_images /= 255.0

    train_labels, _ = trvae.label_encoder(train_data)
    fake_labels = np.ones(train_labels.shape)

    network = trvae.DCtrVAE(
        x_dimension=(img_width, img_height, n_channels),
        z_dimension=z_dim,
        arch_style=arch_style,
        model_path=
        f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/",
    )

    network.restore_model()

    train_data_feed = np.reshape(train_images,
                                 (-1, img_width, img_height, n_channels))

    latent_with_true_labels = network.to_z_latent(train_data_feed,
                                                  train_labels)
    latent_with_fake_labels = network.to_z_latent(train_data_feed, fake_labels)
    mmd_latent_with_true_labels = network.to_mmd_layer(network,
                                                       train_data_feed,
                                                       train_labels,
                                                       feed_fake=False)
    mmd_latent_with_fake_labels = network.to_mmd_layer(network,
                                                       train_data_feed,
                                                       train_labels,
                                                       feed_fake=True)

    latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
    latent_with_true_labels.obs['condition'] = pd.Categorical(
        train_data.obs['condition'].values)

    latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels)
    latent_with_fake_labels.obs['condition'] = pd.Categorical(
        train_data.obs['condition'].values)

    mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
    mmd_latent_with_true_labels.obs['condition'] = train_data.obs[
        'condition'].values

    mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels)
    mmd_latent_with_fake_labels.obs['condition'] = train_data.obs[
        'condition'].values

    if data_name.__contains__("mnist") or data_name == "celeba":
        latent_with_true_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)
        latent_with_fake_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)
        mmd_latent_with_true_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)
        mmd_latent_with_fake_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)

        color = ['condition', 'labels']
    else:
        color = ['condition']

    if train_digits is not None:
        train_data.obs.loc[(train_data.obs['condition'] == source_key) &
                           (train_data.obs['labels'].isin(train_digits)),
                           'type'] = 'training'
        train_data.obs.loc[(train_data.obs['condition'] == source_key) &
                           (train_data.obs['labels'].isin(test_digits)),
                           'type'] = 'training'
        train_data.obs.loc[(train_data.obs['condition'] == target_key) &
                           (train_data.obs['labels'].isin(train_digits)),
                           'type'] = 'training'
        train_data.obs.loc[(train_data.obs['condition'] == target_key) &
                           (train_data.obs['labels'].isin(test_digits)),
                           'type'] = 'heldout'

    sc.pp.neighbors(train_data)
    sc.tl.umap(train_data)
    sc.pl.umap(train_data,
               color=color,
               save=f'_{data_name}_train_data.png',
               show=False,
               wspace=0.5)

    if train_digits is not None:
        sc.tl.umap(train_data)
        sc.pl.umap(train_data,
                   color=['type', 'labels'],
                   save=f'_{data_name}_data_type.png',
                   show=False)

    sc.pp.neighbors(latent_with_true_labels)
    sc.tl.umap(latent_with_true_labels)
    sc.pl.umap(latent_with_true_labels,
               color=color,
               save=f"_{data_name}_latent_with_true_labels.png",
               wspace=0.5,
               show=False)

    sc.pp.neighbors(latent_with_fake_labels)
    sc.tl.umap(latent_with_fake_labels)
    sc.pl.umap(latent_with_fake_labels,
               color=color,
               save=f"_{data_name}_latent_with_fake_labels.png",
               wspace=0.5,
               show=False)

    sc.pp.neighbors(mmd_latent_with_true_labels)
    sc.tl.umap(mmd_latent_with_true_labels)
    sc.pl.umap(mmd_latent_with_true_labels,
               color=color,
               save=f"_{data_name}_mmd_latent_with_true_labels.png",
               wspace=0.5,
               show=False)

    sc.pp.neighbors(mmd_latent_with_fake_labels)
    sc.tl.umap(mmd_latent_with_fake_labels)
    sc.pl.umap(mmd_latent_with_fake_labels,
               color=color,
               save=f"_{data_name}_mmd_latent_with_fake_labels.png",
               wspace=0.5,
               show=False)

    plt.close("all")