formula_like='1 + product_name + dose_level',
                          treatment_col='product_name',
                          num_boot=100000,
                          verbose=1,
                          num_cpus=94,
                          resampling='permutation',
                          approx=False)

    return subset


if __name__ == '__main__':

    data_path = '/data_volume/memento/sciplex/'
    ct = 'A549'
    adata = sc.read(data_path + 'h5ad/{}.h5ad'.format(ct))

    target_list = [
        'Aurora Kinase', 'DNA/RNA Synthesis', 'HDAC',
        'Histone Methyltransferase', 'JAK', 'PARP', 'Sirtuin'
    ]

    target_to_dir = {
        'Aurora Kinase': 'aurora',
        'DNA/RNA Synthesis': 'dna_rna',
        'HDAC': 'hdac',
        'Histone Methyltransferase': 'histmeth',
        'JAK': 'jak',
        'PARP': 'parp',
        'Sirtuin': 'sirt'
    }
Example #2
0
def marker(method, in_path, out_path=None):
    adata = sc.read(in_path)

    if method == "t-test":
        with co.nb.cell():
            # ### Finding marker genes
            #
            # Let us compute a ranking for the highly differential genes in each cluster. For this, by default, the .raw attribute of AnnData is used in case it has been initialized before. The simplest and fastest method to do so is the t-test.
            sc.tl.rank_genes_groups(adata, 'leiden', method='t-test')
            sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False)

    elif method == "wilcoxon":
        sc.settings.verbosity = 2  # reduce the verbosity

        with co.nb.cell():
            # The result of a [Wilcoxon rank-sum (Mann-Whitney-U)](https://de.wikipedia.org/wiki/Wilcoxon-Mann-Whitney-Test) test is very similar. We recommend using the latter in publications, see e.g., [Sonison & Robinson (2018)](https://doi.org/10.1038/nmeth.4612). You might also consider much more powerful differential testing packages like MAST, limma, DESeq2 and, for python, the recent diffxpy.
            sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon')
            sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False)
    elif method == "logreg":
        with co.nb.cell():
            # As an alternative, let us rank genes using logistic regression. For instance, this has been suggested by [Natranos et al. (2018)](https://doi.org/10.1101/258566). The essential difference is that here, we use a multi-variate appraoch whereas conventional differential tests are uni-variate. [Clark et al. (2014)](https://doi.org/10.1186/1471-2105-15-79) has more details.
            sc.tl.rank_genes_groups(adata, 'leiden', method='logreg')
            sc.pl.rank_genes_groups(adata, n_genes=25, sharey=False)
    else:
        raise ValueError(
            f"Unknown analysis method: {method}. Expected 't-test', 'wilcoxon', or 'logreg'."
        )

    with co.nb.cell():
        # Let us also define a list of marker genes for later reference.
        marker_genes = [
            'IL7R', 'CD79A', 'MS4A1', 'CD8A', 'CD8B', 'LYZ', 'CD14', 'LGALS3',
            'S100A8', 'GNLY', 'NKG7', 'KLRB1', 'FCGR3A', 'MS4A7', 'FCER1A',
            'CST3', 'PPBP'
        ]

    with co.nb.cell():
        # With the exceptions of IL7R, which is only found by the t-test and FCER1A, which is only found by the other two approaches, all marker genes are recovered in all approaches.
        #
        # | Louvain Group | Markers       | Cell Type         |
        # | ------------- | ------------- | ----------------- |
        # | 0             | IL7R          | CD4 T cells       |
        # | 1             | CD14, LYZ     | CD14+ Monocytes   |
        # | 2             | MS4A1         | B cells           |
        # | 3             | CD8A          | CD8 T cells       |
        # | 4             | GNLY, NKG7    | NK cells          |
        # | 5             | FCGR3A, MS4A7 | FCGR3A+ Monocytes |
        # | 6             | FCER1A, CST3  | Dendritic Cells   |
        # | 7             | PPBP          | Megakaryocytes    |
        pass

    with co.nb.cell():
        # Show the 10 top ranked genes per cluster 0, 1, …, 7 in a dataframe.
        df = pd.DataFrame(adata.uns['rank_genes_groups']['names']).head(5)
        print(df.to_markdown())

    with co.nb.cell():
        # Get a table with the scores and groups.
        result = adata.uns['rank_genes_groups']
        groups = result['names'].dtype.names
        if "pvals" in result:
            df = pd.DataFrame({
                group + '_' + key[:1]: result[key][group]
                for group in groups for key in ['names', 'pvals']
            }).head(5)
            print(df.to_markdown())

    with co.nb.cell():
        # Compare to a single cluster:
        sc.tl.rank_genes_groups(adata,
                                'leiden',
                                groups=['0'],
                                reference='1',
                                method='wilcoxon')
        sc.pl.rank_genes_groups(adata, groups=['0'], n_genes=20)

    with co.nb.cell():
        # If we want a more detailed view for a certain group, use `sc.pl.rank_genes_groups_violin`.
        sc.pl.rank_genes_groups_violin(adata, groups='0', n_genes=8)

    with co.nb.cell():
        # If you want to compare a certain gene across groups, use the following.
        sc.pl.violin(adata, ['CST3', 'NKG7', 'PPBP'], groupby='leiden')

    with co.nb.cell():
        # Actually mark the cell types.
        new_cluster_names = [
            'CD4 T', 'CD14 Monocytes', 'B', 'CD8 T', 'NK', 'FCGR3A Monocytes',
            'Dendritic', 'Megakaryocytes'
        ]
        adata.rename_categories('leiden', new_cluster_names)
        sc.pl.umap(adata,
                   color='leiden',
                   legend_loc='on data',
                   title='',
                   frameon=False,
                   save='.pdf')

    with co.nb.cell():
        # Now that we annotated the cell types, let us visualize the marker genes.
        sc.pl.dotplot(adata, marker_genes, groupby='leiden')

    with co.nb.cell():
        # There is also a very compact violin plot.
        sc.pl.stacked_violin(adata,
                             marker_genes,
                             groupby='leiden',
                             rotation=90)

    print(adata)

    # Save the result.
    adata.write(out_path)
Example #3
0
def readh5ad(filename):
    return sc.read(filename)
import os
import numpy as np
import scanpy as sc
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

data_path = "/dfs/user/atwang/data/spt_zhuang/source/processed_data/counts.h5ad"
result_dir = "/dfs/user/atwang/results/st_gnn_results/spt_zhuang/expression"
os.makedirs(result_dir, exist_ok=True)

adata = sc.read(data_path)
# sc.pp.calculate_qc_metrics(adata)
# sc.pp.normalize_total(adata, target_sum=1e4)
# sc.pp.log1p(adata)
# sc.pl.highly_variable_genes(adata)
# plt.savefig(os.path.join(result_dir, "var_genes.svg"), bbox_inches='tight')
# plt.clf()

sns.set()
sns.distplot(np.log10(adata.X.flatten() + 1), kde=False)
plt.xlabel("log10(Expression + 1)")
plt.ylabel("Density")
plt.savefig(os.path.join(result_dir, "hist.svg"), bbox_inches='tight')
plt.clf()
Example #5
0
out_exprs = open("tables/scRNA-seq_expressions.tsv", "w")
out_dropouts = open("tables/scRNA-seq_dropouts.tsv", "w")

# Headers
out = ["Filename", "Organ", "Cell type", "Number of cells", g0, g1, g2]
print("\t".join(out), file=out_exprs)
print("\t".join(out), file=out_dropouts)

# Cell type columns
ct_columns = [
    "BroadCellType", "broad_celltype", "CellType", "cell_type", "Celltypes",
    "celltype"
]

for fn, organ in data.items():
    adata = sc.read("data/COVID-19CellAtlas/{}".format(fn))

    # Select first matching cell type column
    for ct in ct_columns:
        if ct in adata.obs:
            ct_col = ct
            break

    if fn == "lukowski19.processed.h5ad":
        # This file was scaled using 10**3 factor
        adata.X = (adata.X.expm1() / 10).log1p()
    elif fn == "madissoon19_lung.processed.h5ad":
        # This file was not transformed at all
        library_sizes = np.sum(adata.X, axis=1)
        adata.X = (adata.X.multiply(1 / library_sizes).tocsc() * 10**4).log1p()
    elif fn.startswith("madissoon20"):
Example #6
0
    pp_adata = annotate_cluster(adata=pp_adata,
                                cluster_algorithm=cluster_algorithm,
                                resolution=0.1)

    # 5. Plot UMAP scRNAseq data
    visualise_clusters(adata=pp_adata,
                       save_folder=save_folder,
                       key='cluster_labels',
                       title="SC")


if __name__ == '__main__':
    today = date.today()
    # create saving folder
    output_path = os.path.join("..", "..", "..", "output", "SupplFigure_4A",
                               str(today))
    os.makedirs(output_path, exist_ok=True)

    # Load data:
    # Use merged scRNAseq samples for publication
    pp_adata_sc = sc.read(
        os.path.join(
            "..", "..", "..", 'adata_storage', '2020-10-19',
            'sc_adata_minumi_600_maxumi_25000_mg_500_mc_20_mt_0_25.h5'))

    unsupvised_cluster_algo = 'leiden'

    main(save_folder=output_path,
         pp_adata=pp_adata_sc,
         cluster_algorithm=unsupvised_cluster_algo)
Example #7
0
#Scripts for making sall example adata file, based on the
#LKLSK_smallexample_compressed.h5ad file

import scanpy as sc
import pandas as pd
import numpy as np

adata = sc.read('./data/LKLSK_smallexample_compressed.h5ad')

#Making the file smaller, so that it does not increase the repo size too much
sc.pp.subsample(adata, fraction=0.05, random_state=123)

#Adding a field in .obs slot with both negative and positive values for testing the diverging colourscale
adata.obs['posneg'] = np.random.randn(adata.shape[0])

#Adding specific colours in the .uns slot
adata.uns['louvain_colors'] = [
    "#569072", "#4f2a98", "#7ae26f", "#bd51c0", "#55a735", "#756de1",
    "#a8d358", "#642876", "#cec33d", "#54509c", "#d29933", "#6886d6",
    "#d34e30", "#5fda98", "#dc5196", "#46863d", "#c883d4", "#cfcc72",
    "#32315c", "#d07839", "#6dc5dc", "#d7425e", "#71ddca", "#962f68",
    "#aecc98", "#5d2847", "#7e8233", "#96acdf", "#94342d", "#3a879c",
    "#cf7f74", "#344b26", "#e0a4c7", "#7f5d30", "#4a658e", "#d4af81",
    "#652e26", "#a0688a"
]

adata.uns['random4_colors'] = ["#b25c4d", "#64acaf", "#8b5aa5", "#91ad58"]

adata.write('./data/tiny_example1.h5ad', compression='lzf')
Example #8
0
target_conditions = ['Pancreas SS2', 'Pancreas CelSeq2']

trvae_epochs = 500
surgery_epochs = 500

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

adata_all = sc.read(
    os.path.expanduser(
        f'~/Documents/benchmarking_datasets/pancreas_normalized.h5ad'))
adata = adata_all.raw.to_adata()
adata = remove_sparsity(adata)
source_adata = adata[~adata.obs[condition_key].isin(target_conditions)]
target_adata = adata[adata.obs[condition_key].isin(target_conditions)]
source_conditions = source_adata.obs[condition_key].unique().tolist()

trvae = sca.models.TRVAE(
    adata=source_adata,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128],
)
trvae.train(n_epochs=trvae_epochs,
            alpha_epoch_anneal=200,
    for text in ax.set_labels:  #this is set label
        text.set_fontsize(16)
    for text in ax.subset_labels:  #this is number in circle
        text.set_fontsize(16)
    #4. save
    plt.tight_layout()
    plt.savefig(f_out, dpi=300)
    plt.close()
    return


###############################################################################
#1. main loop
for sample in l_sample:
    #1. load
    df1 = sc.read(f'{fd_before}/{sample}.h5ad').obs.loc[:, ['anno']]
    df2 = sc.read(f'{fd_after}/{sample}.h5ad').obs.loc[:, ['anno']].fillna(
        'Unknown')  # why has nan value?

    #2. make cell count df
    l_data = []
    for cell in l_cell:
        #1. cell list
        l_c1 = df1.loc[df1['anno'] == cell, :].index.tolist()
        l_c2 = df2.loc[df2['anno'] == cell, :].index.tolist()
        l_overlap = [i for i in l_c1 if i in l_c2]
        n_c1 = len(l_c1)
        n_c2 = len(l_c2)
        n_overlap = len(l_overlap)
        l_data.append((cell, n_c1, n_c2, n_overlap))
    df = pd.DataFrame(l_data, columns=['cell', 'before', 'after', 'overlap'])
def reconstruct_whole_data(data_name="pbmc", condition_key="condition"):
    if data_name == "pbmc":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/train_pbmc.h5ad")
    elif data_name == "hpoly":
        stim_key = "Hpoly.Day10"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/train_hpoly.h5ad")
    elif data_name == "salmonella":
        stim_key = "Salmonella"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/train_salmonella.h5ad")
    elif data_name == "species":
        stim_key = "LPS6"
        ctrl_key = "unst"
        cell_type_key = "species"
        train = sc.read("../data/train_species.h5ad")
    elif data_name == "study":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/train_study.h5ad")

    all_data = anndata.AnnData()
    for idx, cell_type in enumerate(
            train.obs[cell_type_key].unique().tolist()):
        print(f"Reconstructing for {cell_type}")
        network = scgen.VAEArith(
            x_dimension=train.X.shape[1],
            z_dimension=100,
            alpha=0.00005,
            dropout_rate=0.2,
            learning_rate=0.001,
            model_path=f"../models/scGen/{data_name}/{cell_type}/scgen")
        network.restore_model()

        cell_type_data = train[train.obs[cell_type_key] == cell_type]
        cell_type_ctrl_data = train[((train.obs[cell_type_key] == cell_type) &
                                     (train.obs[condition_key] == ctrl_key))]
        net_train_data = train[~((train.obs[cell_type_key] == cell_type) &
                                 (train.obs[condition_key] == stim_key))]
        pred, delta = network.predict(adata=net_train_data,
                                      conditions={
                                          "ctrl": ctrl_key,
                                          "stim": stim_key
                                      },
                                      cell_type_key=cell_type_key,
                                      condition_key=condition_key,
                                      celltype_to_predict=cell_type)

        pred_adata = anndata.AnnData(
            pred,
            obs={
                condition_key: [f"{cell_type}_pred_stim"] * len(pred),
                cell_type_key: [cell_type] * len(pred)
            },
            var={"var_names": cell_type_data.var_names})
        ctrl_adata = anndata.AnnData(
            cell_type_ctrl_data.X,
            obs={
                condition_key:
                [f"{cell_type}_ctrl"] * len(cell_type_ctrl_data),
                cell_type_key: [cell_type] * len(cell_type_ctrl_data)
            },
            var={"var_names": cell_type_ctrl_data.var_names})
        if sparse.issparse(cell_type_data.X):
            real_stim = cell_type_data[cell_type_data.obs[condition_key] ==
                                       stim_key].X.A
        else:
            real_stim = cell_type_data[cell_type_data.obs[condition_key] ==
                                       stim_key].X
        real_stim_adata = anndata.AnnData(
            real_stim,
            obs={
                condition_key: [f"{cell_type}_real_stim"] * len(real_stim),
                cell_type_key: [cell_type] * len(real_stim)
            },
            var={"var_names": cell_type_data.var_names})
        if idx == 0:
            all_data = ctrl_adata.concatenate(pred_adata, real_stim_adata)
        else:
            all_data = all_data.concatenate(ctrl_adata, pred_adata,
                                            real_stim_adata)

        print(f"Finish Reconstructing for {cell_type}")
        network.sess.close()
    all_data.write_h5ad(f"../data/reconstructed/scGen/{data_name}.h5ad")
Example #11
0
def _load_dataset_from_url(fpath: Union[os.PathLike, str],
                           url: str) -> AnnData:
    adata = read(fpath, backup_url=url, sparse=True, cache=True)
    adata.var_names_make_unique()

    return adata
    # 3.3 Read out all leukocyte cells
    include_cytokine_dp(adata=adata_leukocytes, cytokines=cytokines, save_folder=save_folder,
                        label=sc_cluster_obs, key='SC_merged', paper_figure='SC')
    # 3.3 Read out all leukocyte cells but exclude double positive cytokine cells
    adata_leukocytes, obs_name = exclude_cytokine_dp(adata=adata_leukocytes, cytoresps_dict=cytoresps_dict)

    # Plot cytokines and highlight double positive
    plot_annotated_cells(adata=adata_leukocytes, color='cytokines_others', paper_figure='',
                         save_folder=save_folder, key='SC', title="Leukocytes_IL17A_IFNG",
                         xpos=0.02, ypos=0.95, palette=["#ff7f00", "#377eb8", 'purple'])

    # Add cytokine label to adata and Plot: Highlight cytokines
    adata_leukocytes = add_observables.add_columns_genes(adata=adata_leukocytes, genes='IFNG', label='IFNG')

    """ Suppl. Figure 4B: Highlight IFN-g """
    plot_annotated_cells(adata=adata_leukocytes, color='IFNG_label', paper_figure='4B', save_folder=save_folder,
                         key='SC', title="Leukocyte_IFNG", xpos=0.02, ypos=0.95)


if __name__ == '__main__':
    today = date.today()
    # create saving folder
    output_path = os.path.join("..", "..", "..", "output", "SupplFigure_4B", str(today))
    os.makedirs(output_path, exist_ok=True)

    # Load data:
    # Use merged scRNAseq samples for publication
    clustered_adata_sc = sc.read(os.path.join("..", "..", "..", 'adata_storage', '2020-12-04_SC_Data_QC_clustered.h5'))

    main(save_folder=output_path, adata=clustered_adata_sc)
Example #13
0
import h5py, sys, math, time, scanpy, ModelUtil
import numpy as np

if len(sys.argv) <= 1:
    print('Usage: readh5ad.py <file_name>')
    quit()
fn = sys.argv[1]

f = scanpy.read(fn)
Y = f.X
'''
# sample randomly 10000 rows.
rows = np.random.randint(Y.shape[0], size=10000)
columns = np.random.randint(Y.shape[1], size=5000)
Y = Y[rows, :]
Y = Y[:,columns]
'''

Y = Y.astype(np.float32)
if hasattr(Y, 'todense'):
    Y = Y.todense()


def exp1p(x):
    if (x > 0):
        return np.float32(math.expm1(x))
    else:
        return np.float32(0)


vexp1p = np.vectorize(exp1p)
Example #14
0
from sklearn.metrics import roc_curve, auc
os.chdir('/Users/kj22643/Documents/Documents/231_Classifier_Project/code')
from func_file import find_mean_AUC
from func_file import find_mean_AUC_SVM



path = '/Users/kj22643/Documents/Documents/231_Classifier_Project/data'
#path = '/stor/scratch/Brock/231_10X_data/'
os.chdir(path)
sc.settings.figdir = 'KJ_plots'
sc.set_figure_params(dpi_save=300)
sc.settings.verbosity = 3

#%% Load in pre and post treatment 231 data
adata = sc.read('post-cell-cycle-regress.h5ad')
adata.obs.head()
# current samples:
#BgL1K
#30hr
#Rel-1 AA107 7 weeks
#Rel-2 AA113  10 weeks - 1 day
# We will change these to time points
# Count the number of unique lineages in all the cells
uniquelins = adata.obs['lineage'].unique()
nunique = len(uniquelins)



#%% Identify lineages that have been recalled from the pre-treatment sample
# Make a column labeled recalled lin that can be use to identify the specific lineages of interest
Example #15
0
def visualize_trained_network_results(data_dict,
                                      z_dim=100,
                                      arch_style=1,
                                      preprocess=True,
                                      max_size=80000):
    plt.close("all")
    data_name = data_dict.get('name', None)
    source_key = data_dict.get('source_key', None)
    target_key = data_dict.get('target_key', None)
    img_width = data_dict.get('width', None)
    img_height = data_dict.get('height', None)
    n_channels = data_dict.get('n_channels', None)
    train_digits = data_dict.get('train_digits', None)
    test_digits = data_dict.get('test_digits', None)
    attribute = data_dict.get('attribute', None)

    path_to_save = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/UMAPs/"
    os.makedirs(path_to_save, exist_ok=True)
    sc.settings.figdir = os.path.abspath(path_to_save)

    if data_name == "celeba":
        gender = data_dict.get('gender', None)
        data = trvae.prepare_and_load_celeba(
            file_path="../data/celeba/img_align_celeba.zip",
            attr_path="../data/celeba/list_attr_celeba.txt",
            landmark_path="../data/celeba/list_landmarks_align_celeba.txt",
            gender=gender,
            attribute=attribute,
            max_n_images=max_size,
            img_width=img_width,
            img_height=img_height,
            restore=True,
            save=False)

        if sparse.issparse(data.X):
            data.X = data.X.A

        train_images = data.X
        train_data = anndata.AnnData(X=data)
        train_data.obs['condition'] = data.obs['condition'].values
        train_data.obs.loc[train_data.obs['condition'] == 1,
                           'condition'] = f'with {attribute}'
        train_data.obs.loc[train_data.obs['condition'] == -1,
                           'condition'] = f'without {attribute}'

        train_data.obs['labels'] = data.obs['labels'].values
        train_data.obs.loc[train_data.obs['labels'] == 1, 'labels'] = f'Male'
        train_data.obs.loc[train_data.obs['labels'] == -1,
                           'labels'] = f'Female'

        if preprocess:
            train_images /= 255.0
    else:
        train_data = sc.read(f"../data/{data_name}/{data_name}.h5ad")
        train_images = np.reshape(train_data.X,
                                  (-1, img_width, img_height, n_channels))

        if preprocess:
            train_images /= 255.0

    train_labels, _ = trvae.label_encoder(train_data)
    fake_labels = np.ones(train_labels.shape)

    network = trvae.DCtrVAE(
        x_dimension=(img_width, img_height, n_channels),
        z_dimension=z_dim,
        arch_style=arch_style,
        model_path=
        f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/",
    )

    network.restore_model()

    train_data_feed = np.reshape(train_images,
                                 (-1, img_width, img_height, n_channels))

    latent_with_true_labels = network.to_z_latent(train_data_feed,
                                                  train_labels)
    latent_with_fake_labels = network.to_z_latent(train_data_feed, fake_labels)
    mmd_latent_with_true_labels = network.to_mmd_layer(network,
                                                       train_data_feed,
                                                       train_labels,
                                                       feed_fake=False)
    mmd_latent_with_fake_labels = network.to_mmd_layer(network,
                                                       train_data_feed,
                                                       train_labels,
                                                       feed_fake=True)

    latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
    latent_with_true_labels.obs['condition'] = pd.Categorical(
        train_data.obs['condition'].values)

    latent_with_fake_labels = sc.AnnData(X=latent_with_fake_labels)
    latent_with_fake_labels.obs['condition'] = pd.Categorical(
        train_data.obs['condition'].values)

    mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
    mmd_latent_with_true_labels.obs['condition'] = train_data.obs[
        'condition'].values

    mmd_latent_with_fake_labels = sc.AnnData(X=mmd_latent_with_fake_labels)
    mmd_latent_with_fake_labels.obs['condition'] = train_data.obs[
        'condition'].values

    if data_name.__contains__("mnist") or data_name == "celeba":
        latent_with_true_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)
        latent_with_fake_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)
        mmd_latent_with_true_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)
        mmd_latent_with_fake_labels.obs['labels'] = pd.Categorical(
            train_data.obs['labels'].values)

        color = ['condition', 'labels']
    else:
        color = ['condition']

    if train_digits is not None:
        train_data.obs.loc[(train_data.obs['condition'] == source_key) &
                           (train_data.obs['labels'].isin(train_digits)),
                           'type'] = 'training'
        train_data.obs.loc[(train_data.obs['condition'] == source_key) &
                           (train_data.obs['labels'].isin(test_digits)),
                           'type'] = 'training'
        train_data.obs.loc[(train_data.obs['condition'] == target_key) &
                           (train_data.obs['labels'].isin(train_digits)),
                           'type'] = 'training'
        train_data.obs.loc[(train_data.obs['condition'] == target_key) &
                           (train_data.obs['labels'].isin(test_digits)),
                           'type'] = 'heldout'

    sc.pp.neighbors(train_data)
    sc.tl.umap(train_data)
    sc.pl.umap(train_data,
               color=color,
               save=f'_{data_name}_train_data.png',
               show=False,
               wspace=0.5)

    if train_digits is not None:
        sc.tl.umap(train_data)
        sc.pl.umap(train_data,
                   color=['type', 'labels'],
                   save=f'_{data_name}_data_type.png',
                   show=False)

    sc.pp.neighbors(latent_with_true_labels)
    sc.tl.umap(latent_with_true_labels)
    sc.pl.umap(latent_with_true_labels,
               color=color,
               save=f"_{data_name}_latent_with_true_labels.png",
               wspace=0.5,
               show=False)

    sc.pp.neighbors(latent_with_fake_labels)
    sc.tl.umap(latent_with_fake_labels)
    sc.pl.umap(latent_with_fake_labels,
               color=color,
               save=f"_{data_name}_latent_with_fake_labels.png",
               wspace=0.5,
               show=False)

    sc.pp.neighbors(mmd_latent_with_true_labels)
    sc.tl.umap(mmd_latent_with_true_labels)
    sc.pl.umap(mmd_latent_with_true_labels,
               color=color,
               save=f"_{data_name}_mmd_latent_with_true_labels.png",
               wspace=0.5,
               show=False)

    sc.pp.neighbors(mmd_latent_with_fake_labels)
    sc.tl.umap(mmd_latent_with_fake_labels)
    sc.pl.umap(mmd_latent_with_fake_labels,
               color=color,
               save=f"_{data_name}_mmd_latent_with_fake_labels.png",
               wspace=0.5,
               show=False)

    plt.close("all")
Example #16
0
            authorized=row['authorized']
        )
        result = conn.execute(dataset_insert)
        dataset_key = result.inserted_primary_key
        print('dataset_key:', dataset_key)

# Read each anndata object.
for filename in os.listdir(DATADIR):
    print(filename)
    
    # Skip two cluster solutions already there.
    if filename == '10xGenomics_pbmc8k' \
        or filename == '10xGenomics_t_3k_4k_aggregate':
        continue

    ad = sc.read(os.path.join(DATADIR,filename))

    # Find the dataset_id
    dataset_name = filename.split("_clustered")[0]
    print("**************************************")
    print('dataset_name:', dataset_name)

    rows = select(
        [dataset.c.name, dataset.c.id]).where(dataset.c.name == dataset_name)
    result = conn.execute(rows)
    for row in result:
        dataset_id = row['id']
        name = row['name']
    #print('dataset id, name:', dataset_id, name)

    # Load the cluster solution into the DB.
Example #17
0
def train_network(
    data_dict=None,
    z_dim=100,
    mmd_dimension=256,
    alpha=0.001,
    beta=100,
    gamma=1.0,
    kernel='multi-scale-rbf',
    n_epochs=500,
    batch_size=512,
    dropout_rate=0.2,
    arch_style=1,
    preprocess=True,
    learning_rate=0.001,
    gpus=1,
    max_size=50000,
    early_stopping_limit=50,
):
    data_name = data_dict['name']
    source_key = data_dict.get('source_key', None)
    target_key = data_dict.get('target_key', None)
    img_width = data_dict.get("width", None)
    img_height = data_dict.get("height", None)
    n_channels = data_dict.get("n_channels", None)
    train_digits = data_dict.get("train_digits", None)
    test_digits = data_dict.get("test_digits", None)
    attribute = data_dict.get('attribute', None)

    if data_name == "celeba":
        gender = data_dict.get('gender', None)
        data = trvae.prepare_and_load_celeba(
            file_path="../data/celeba/img_align_celeba.zip",
            attr_path="../data/celeba/list_attr_celeba.txt",
            landmark_path="../data/celeba/list_landmarks_align_celeba.txt",
            gender=gender,
            attribute=attribute,
            max_n_images=max_size,
            img_width=img_width,
            img_height=img_height,
            restore=True,
            save=True)

        if sparse.issparse(data.X):
            data.X = data.X.A

        source_images = data.copy()[data.obs['condition'] == source_key].X
        target_images = data.copy()[data.obs['condition'] == target_key].X

        source_images = np.reshape(source_images,
                                   (-1, img_width, img_height, n_channels))
        target_images = np.reshape(target_images,
                                   (-1, img_width, img_height, n_channels))

        if preprocess:
            source_images /= 255.0
            target_images /= 255.0
    else:
        data = sc.read(f"../data/{data_name}/{data_name}.h5ad")

        source_images = data.copy()[data.obs["condition"] == source_key].X
        target_images = data.copy()[data.obs["condition"] == target_key].X

        source_images = np.reshape(source_images,
                                   (-1, img_width, img_height, n_channels))
        target_images = np.reshape(target_images,
                                   (-1, img_width, img_height, n_channels))

        if preprocess:
            source_images /= 255.0
            target_images /= 255.0

    source_labels = np.zeros(shape=source_images.shape[0])
    target_labels = np.ones(shape=target_images.shape[0])
    train_labels = np.concatenate([source_labels, target_labels], axis=0)

    train_images = np.concatenate([source_images, target_images], axis=0)
    train_images = np.reshape(train_images,
                              (-1, np.prod(source_images.shape[1:])))
    if data_name.__contains__('mnist'):
        preprocessed_data = anndata.AnnData(X=train_images)
        preprocessed_data.obs["condition"] = train_labels
        preprocessed_data.obs['labels'] = data.obs['labels'].values
        data = preprocessed_data.copy()
    else:
        preprocessed_data = anndata.AnnData(X=train_images)
        preprocessed_data.obs['condition'] = train_labels
        if data.obs.columns.__contains__('labels'):
            preprocessed_data.obs['labels'] = data.obs['condition'].values
        data = preprocessed_data.copy()

    train_size = int(data.shape[0] * 0.85)
    indices = np.arange(data.shape[0])
    np.random.shuffle(indices)
    train_idx = indices[:train_size]
    test_idx = indices[train_size:]

    data_train = data[train_idx, :]
    data_valid = data[test_idx, :]
    print(data_train.shape, data_valid.shape)

    if train_digits is not None:
        train_data = data_train.copy()[~(
            (data_train.obs['labels'].isin(test_digits)) &
            (data_train.obs['condition'] == 1))]
        valid_data = data_valid.copy()[~(
            (data_valid.obs['labels'].isin(test_digits)) &
            (data_valid.obs['condition'] == 1))]
    elif data_name == "celeba":
        train_data = data_train.copy()[~(
            (data_train.obs['labels'] == -1) &
            (data_train.obs['condition'] == target_key))]
        valid_data = data_valid.copy()[~(
            (data_valid.obs['labels'] == -1) &
            (data_valid.obs['condition'] == target_key))]
    else:
        train_data = data_train.copy()
        valid_data = data_valid.copy()

    network = trvae.archs.DCtrVAE(
        x_dimension=source_images.shape[1:],
        z_dimension=z_dim,
        mmd_dimension=mmd_dimension,
        alpha=alpha,
        beta=beta,
        gamma=gamma,
        kernel=kernel,
        arch_style=arch_style,
        train_with_fake_labels=False,
        learning_rate=learning_rate,
        model_path=
        f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/",
        gpus=gpus,
        dropout_rate=dropout_rate)

    print(train_data.shape, valid_data.shape)
    network.train(train_data,
                  use_validation=True,
                  valid_adata=valid_data,
                  n_epochs=n_epochs,
                  batch_size=batch_size,
                  verbose=2,
                  early_stop_limit=early_stopping_limit,
                  shuffle=True,
                  save=True)

    print("Model has been trained")
from matplotlib import colors
import seaborn as sb
#from rpy2.robjects.packages import importr
#from gprofiler import gprofiler
plt.rcParams['figure.figsize'] = (8, 8)  #rescale figures
sc.settings.verbosity = 1
sc.set_figure_params(dpi=200, dpi_save=300)

#matplotlib.rcParams['pdf.fonttype']=42
#matplotlib.rcParams['font.size']=6

todo = 'leiden_r0.3'

sc.settings.figdir = 'markers-{0}'.format(todo)

adata = sc.read('learned.h5ad')

marker_genes_dict = {
    'Epiblast': ["Pou5f1"],
    'Primitive streak': ["Mixl1"],  #Nanong?!?!
    'Endoderms': ["Cer1", "Sox7"],
    'Mesoderms': ["T", 'Cdx1'],
    'Ectoderms': ['Six3'],  # And Grhl2
    'Exe endoderm': ["Apoa2"],
    'Exe ectoderm': ["Tfap2c"],
    'Cardiomyocytes': ["Tnnt2"],
    'Blood prog.': [
        "Lmo2",
    ],
    'Erythroid': ["Gypa"],
}
    target_condition = "LPS6"
    target_conditions = ['LPS6']
    le = {"unst": 0, "LPS6": 1}
elif data_name == "kang":
    keys = ["control", "stimulated"]
    specific_cell_type = "NK"
    cell_type_key = "cell_type"
    condition_key = "condition"
    control_condition = "control"
    target_condition = "stimulated"
    target_conditions = ['stimulated']
    le = {"control": 0, "stimulated": 1}
else:
    raise Exception("Invalid data name")

adata = sc.read(f"./data/{data_name}/{data_name}_normalized.h5ad")
adata = adata.copy()[adata.obs[condition_key].isin(keys)]

if adata.shape[1] > 2000:
    sc.pp.highly_variable_genes(adata, n_top_genes=2000)
    adata = adata[:, adata.var['highly_variable']]

train_adata, valid_adata = train_test_split(adata, 0.80)

net_train_adata = train_adata[~(
    (train_adata.obs[cell_type_key] == specific_cell_type) &
    (train_adata.obs[condition_key].isin(target_conditions)))]
net_valid_adata = valid_adata[~(
    (valid_adata.obs[cell_type_key] == specific_cell_type) &
    (valid_adata.obs[condition_key].isin(target_conditions)))]
Example #20
0
    verbose = args.verbose
    type_ = args.type

    # set prefix for output and results column name
    base = os.path.basename(args.input).split('.h5ad')[0]

    if verbose:
        print('Options')
        print(f'    type:\t{type_}')

    ###

    print("reading adata input file")
    if os.stat(args.input).st_size > 0:
        adata = sc.read(args.input, cache=True)
        print(adata)
        if (type_ == 'knn'):
            neighbors = adata.obsp['connectivities']
            del adata
            diff_neighbors = diffusion_conn(neighbors,
                                            min_k=50,
                                            copy=False,
                                            max_iterations=20)
            scio.mmwrite(target=os.path.join(args.output,
                                             f'{base}_diffconn.mtx'),
                         a=diff_neighbors)
            print("done")
        else:
            print('Wrong type chosen, doing nothing.')
    else:
Example #21
0
def main():

    h5ad_file = sys.argv[1]
    cellType_label = sys.argv[2]
    out_prefix = sys.argv[-1]

    #h5ad_file = '/Users/pgarcia-nietochanzuckerberg.com/projects/cell_type_transfer/pancreas/data/hca_model_alternative.h5ad'
    #out_prefix = '/Users/pgarcia-nietochanzuckerberg.com/projects/cell_type_transfer/pancreas/results/modeling_alternative_'
    #cellType_label = 'cellType'

    # Setting plotting settings
    sc.settings.autoshow = False
    sc.settings.figdir = os.path.dirname(out_prefix)

    # Read data
    annData = sc.read(h5ad_file)

    # Process data
    annData = model_preprocessing.preprocessing(annData,
                                                do_log1p=True,
                                                do_select_variable_genes=True,
                                                do_min_cells_filter=True,
                                                do_min_genes_filter=True,
                                                percent_cells=0.05,
                                                min_genes=10,
                                                n_var_genes=500)

    sc.pp.normalize_total(annData, exclude_highly_expressed=True)

    # Get available models
    models = [
        i for i in annData.obs.keys().to_list()
        if i.startswith(cellType_label + '_')
    ]
    for m in models:
        correct_label = annData.obs[cellType_label] == annData.obs[m].to_list()
        annData.obs[m + '_accuracy'] = [
            'correct' if i else 'mislabelled' for i in correct_label
        ]

    # Get data ready for plotting
    sc.pp.pca(annData)
    sc.pp.neighbors(annData)
    sc.tl.umap(annData)

    sc.pl.umap(annData, color=cellType_label)
    sc.pl.umap(annData,
               save=os.path.basename(out_prefix) + '_' + cellType_label +
               '_original_umap_noColor.pdf')
    sc.pl.umap(annData,
               color=cellType_label,
               save=os.path.basename(out_prefix) + '_' + cellType_label +
               '_original_umap.pdf')

    with open(out_prefix + 'performance_metrics_per_model_test.tsv', 'w') as f:
        print('model', 'correct', 'mislablled', 'accuracy', sep="\t", file=f)
        for m in models:
            sc.pl.umap(annData,
                       color=m,
                       save=os.path.basename(out_prefix) + '_' + m +
                       '_predicted_umap.pdf')
            sc.pl.umap(annData,
                       color=m + '_accuracy',
                       save=os.path.basename(out_prefix) + '_' + m +
                       '_missed_umap.pdf')
            correct = sum(annData.obs[m + '_accuracy'] == 'correct')
            miss = sum(annData.obs[m + '_accuracy'] != 'correct')
            print(m,
                  correct,
                  miss,
                  correct / (correct + miss),
                  sep="\t",
                  file=f)

    # Compile cv results
    cv = {}
    cv_keys = [i for i in annData.uns.keys() if i.startswith('metrics_')]
    for a in cv_keys:
        cv[a] = pd.DataFrame(annData.uns[a])
        cv[a]['model'] = a.split("metrics_")[1]

    # Writing
    cv = pd.concat(cv)
    cv.to_csv(out_prefix + 'performance_metrics_per_model_train.tsv',
              sep='\t',
              index=False)

    with open(out_prefix + 'best_pars_per_model.md', 'w') as f:
        best_pars_keys = [
            i for i in annData.uns.keys() if i.startswith('best_pars_')
        ]
        for b in best_pars_keys:
            print('* ', b.split('best_pars_')[1], sep='', file=f)
            for key in annData.uns[b]:
                print('  * ', key, ': ', annData.uns[b][key], file=f)
Example #22
0
        "--output",
        default="pca_mqc.png",
        help=
        "Output filename. Will default to pca_mqc.png, Optional [*_mqc.yaml]")
    parser.add_argument("--recipe",
                        default="recipe_seurat",
                        help="preprocessing recipe name as defined in Scanpy")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = argparser()

    adata = sc.read(args.exprs)

    # standard preprocessing
    sc.pp.filter_cells(adata, min_genes=200)
    sc.pp.filter_genes(adata, min_cells=3)
    sc.pp.filter_genes(adata, min_counts=5)
    sc.pp.normalize_total(adata, key_added='n_counts_all')
    f = sc.pp.filter_genes_dispersion(adata.X,
                                      flavor='cell_ranger',
                                      n_top_genes=1000,
                                      log=False)
    adata._inplace_subset_var(f.gene_subset)
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    sc.pp.scale(adata)
                                                max_n_images=50000,
                                                img_width=64,
                                                img_height=64,
                                                restore=True,
                                                save=True)
    input_shape = (64, 64, 3)
elif data_name == "mnist":
    conditions = ["normal", "thin", "thick"]
    target_conditions = ["thin", 'thick']
    source_condition = "normal"
    labelencoder = {"normal": 0, "thin": 1, "thick": 2}
    label_key = "labels"
    condition_key = "condition"
    specific_labels = [1, 3, 6, 7]
    arch_style = 1
    adata = sc.read("./data/thick_thin_mnist/thick_thin_mnist.h5ad")
    input_shape = (28, 28, 1)
else:
    raise Exception("Invalid data name")

adata = remove_sparsity(adata)
# Preprocessing
adata.X /= 255.0


train_adata, valid_adata = reptrvae.utils.train_test_split(adata, 0.80)

net_train_adata = train_adata[
    ~((train_adata.obs[label_key].isin(specific_labels)) & (train_adata.obs[condition_key].isin(target_conditions)))]
net_valid_adata = valid_adata[
    ~((valid_adata.obs[label_key].isin(specific_labels)) & (valid_adata.obs[condition_key].isin(target_conditions)))]
Example #24
0
import numpy as np
import mkl
import scanpy as sc
import os
import time

# UMAP
umap_min_dist = 0.3
umap_spread = 1.0

sc.settings.n_jobs = 56  # Set it to number of cpus on a CPU socket

os.environ["OMP_NUM_THREADS"] = str(sc.settings.n_jobs)
mkl.set_num_threads(sc.settings.n_jobs)

adata = sc.read('before_umap.h5ad')
print(adata.shape)
umap_time = time.time()
sc.tl.umap(adata, min_dist=umap_min_dist, spread=umap_spread)
print("UMAP time : %s" % (time.time() - umap_time))

sc.pl.umap(adata,
           color=["Stmn2_raw"],
           color_map="Blues",
           vmax=1,
           vmin=-0.05,
           save="_Stmn2_raw.png")
sc.pl.umap(adata,
           color=["Hes1_raw"],
           color_map="Blues",
           vmax=1,
Example #25
0
def upload(pathname):
    import scanpy as sc
    import os
    import anndata
    from scipy.sparse import csr_matrix
    filename, file_extension = os.path.splitext(pathname)
    if file_extension == ".mat":
        from scipy.io import loadmat
        import pandas as pd
        x = loadmat(pathname)
        keys = []
        for key in x.keys():
            keys.append(key)

        #obs is the cell
        #var is gene
        #pick the largest
        largest = 3
        largest_size = 0
        for i in range(len(keys) - 3):
            if len(x[keys[i + 3]].shape) == 2:
                size = (x[keys[i + 3]].shape[0] * x[keys[i + 3]].shape[1])
            else:
                size = x[keys[i + 3]].shape[0]
            if size >= largest_size:
                largest = i + 3
                largest_size = size
        obs_d, var_d = {}, {}
        for i in range(len(keys) - 3):
            if i != largest - 3:
                if (x[keys[i + 3]].flatten()).shape[0] == (
                        x[keys[largest]]).shape[0]:
                    obs_d[keys[i + 3]] = x[keys[i + 3]].flatten()
                elif (x[keys[i + 3]].flatten()).shape[0] == (
                        x[keys[largest]]).shape[1]:
                    var_d[keys[i + 3]] = x[keys[i + 3]].flatten()
                #else:
        obs_df = pd.DataFrame(data=obs_d)
        var_df = pd.DataFrame(data=var_d)

        data = anndata.AnnData(X=x[keys[largest]],
                               obs=None if obs_df.empty else obs_df,
                               var=None if var_df.empty else var_df)

    elif file_extension == ".npz":
        import numpy as np
        import pandas as pd
        x = np.load(pathname)
        #pick largest size file
        largest = 0
        largest_size = 0
        for i in range(len(x.files)):
            if len(x[x.files[i]].shape) == 2:
                size = (x[x.files[i]].shape[0] * x[x.files[i]].shape[1])
            else:
                size = x[x.files[i]].shape[0]
            if size >= largest_size:
                largest = i
                largest_size = size
        obs_d, var_d = {}, {}
        for i in range(len(x.files)):
            if i != largest:
                if len(x[x.files[i]].flatten()) == len(x[x.files[largest]]):
                    obs_d[x.files[i]] = x[x.files[i]].flatten()
                elif len(x[x.files[i]].flatten()) == len(
                        x[x.files[largest]][0]):
                    var_d[x.files[i]] = x[x.files[i]].flatten()
                #else:
        obs_df = pd.DataFrame(data=obs_d)
        var_df = pd.DataFrame(data=var_d)
        data = anndata.AnnData(X=x[x.files[largest]],
                               obs=None if obs_df.empty else obs_df,
                               var=None if var_df.empty else var_df)
    elif file_extension == ".mtx":
        data = sc.read_10x_mtx(os.path.dirname(pathname))
    elif file_extension == ".csv":
        data = sc.read_csv(pathname)
    elif file_extension == ".xlsx":
        data = sc.read_excel(pathname)
    elif file_extension == ".txt":
        data = sc.read_text(pathname)
    else:
        data = sc.read(pathname)

    print(pathname, " uploaded !")
    return data
Example #26
0
def read_h5ad(dataset):

    #read input h5ad
    return sc.read(dataset)
    print("File read!")
Example #27
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "adata",
        type=str,
        help="AnnData object saved as .h5ad file to add EmptyDrops labels to",
    )
    parser.add_argument(
        "emptydrops_out",
        type=str,
        help="Output from EmptyDrops as .csv file corresponding to adata",
    )
    args = parser.parse_args()

    a = sc.read(args.adata)  # read in anndata
    e = pd.read_csv(args.emptydrops_out,
                    index_col=0)  # read in EmptyDrops results

    # add emptydrops results to anndata object in .obs
    a.obs[[
        "EmptyDrops_LogProb", "EmptyDrops_pval", "EmptyDrops_limited",
        "EmptyDrops_FDR"
    ]] = e[["LogProb", "PValue", "Limited", "FDR"]].values
    a.obs.EmptyDrops_FDR.fillna(
        value=1.0, inplace=True
    )  # make FDR 1.0 for untested barcodes so the final label ignores them
    a.obs["EmptyDrops"] = "False"
    a.obs.loc[a.obs.EmptyDrops_FDR <= 0.001, "EmptyDrops"] = "True"

    # overwrite original .h5ad file
Example #28
0
def evaluate_network(data_dict=None,
                     z_dim=100,
                     n_files=5,
                     k=5,
                     arch_style=1,
                     preprocess=True,
                     max_size=80000):
    data_name = data_dict['name']
    source_key = data_dict.get('source_key', None)
    target_key = data_dict.get('target_key', None)
    img_width = data_dict.get("width", None)
    img_height = data_dict.get("height", None)
    n_channels = data_dict.get('n_channels', None)
    train_digits = data_dict.get('train_digits', None)
    test_digits = data_dict.get('test_digits', None)
    attribute = data_dict.get('attribute', None)

    if data_name == "celeba":
        gender = data_dict.get('gender', None)
        data = trvae.prepare_and_load_celeba(
            file_path="../data/celeba/img_align_celeba.zip",
            attr_path="../data/celeba/list_attr_celeba.txt",
            landmark_path="../data/celeba/list_landmarks_align_celeba.txt",
            gender=gender,
            attribute=attribute,
            max_n_images=max_size,
            img_width=img_width,
            img_height=img_height,
            restore=True,
            save=False)

        valid_data = data.copy()[data.obs['labels'] ==
                                 -1]  # get females (Male = -1)
        train_data = data.copy()[data.obs['labels'] ==
                                 +1]  # get males (Male = 1)

        if sparse.issparse(valid_data.X):
            valid_data.X = valid_data.X.A

        source_images_train = train_data[train_data.obs["condition"] ==
                                         source_key].X
        source_images_valid = valid_data[valid_data.obs["condition"] ==
                                         source_key].X

        source_images_train = np.reshape(
            source_images_train, (-1, img_width, img_height, n_channels))
        source_images_valid = np.reshape(
            source_images_valid, (-1, img_width, img_height, n_channels))

        if preprocess:
            source_images_train /= 255.0
            source_images_valid /= 255.0
    else:
        data = sc.read(f"../data/{data_name}/{data_name}.h5ad")
        if train_digits is not None:
            train_data = data[data.obs['labels'].isin(train_digits)]
            valid_data = data[data.obs['labels'].isin(test_digits)]
        else:
            train_data = data.copy()
            valid_data = data.copy()

        source_images_train = train_data[train_data.obs["condition"] ==
                                         source_key].X
        target_images_train = train_data[train_data.obs["condition"] ==
                                         target_key].X

        source_images_train = np.reshape(
            source_images_train, (-1, img_width, img_height, n_channels))
        target_images_train = np.reshape(
            target_images_train, (-1, img_width, img_height, n_channels))

        source_images_valid = valid_data[valid_data.obs["condition"] ==
                                         source_key].X
        target_images_valid = valid_data[valid_data.obs["condition"] ==
                                         target_key].X

        source_images_valid = np.reshape(
            source_images_valid, (-1, img_width, img_height, n_channels))
        target_images_valid = np.reshape(
            target_images_valid, (-1, img_width, img_height, n_channels))

        if preprocess:
            source_images_train /= 255.0
            source_images_valid /= 255.0

            target_images_train /= 255.0
            target_images_valid /= 255.0

    image_shape = (img_width, img_height, n_channels)

    source_images_train = np.reshape(source_images_train,
                                     (-1, np.prod(image_shape)))
    source_images_valid = np.reshape(source_images_valid,
                                     (-1, np.prod(image_shape)))

    source_data_train = anndata.AnnData(X=source_images_train)
    source_data_valid = anndata.AnnData(X=source_images_valid)

    network = trvae.DCtrVAE(
        x_dimension=image_shape,
        z_dimension=z_dim,
        arch_style=arch_style,
        model_path=
        f"../models/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/"
    )

    network.restore_model()

    results_path_train = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/train/"
    results_path_valid = f"../results/RCCVAE/{data_name}-{img_width}x{img_height}-{preprocess}/{arch_style}-{z_dim}/{source_key} to {target_key}/valid/"
    os.makedirs(results_path_train, exist_ok=True)
    os.makedirs(results_path_valid, exist_ok=True)

    if sparse.issparse(valid_data.X):
        valid_data.X = valid_data.X.A
    if test_digits is not None:
        k = len(test_digits)
    for j in range(n_files):
        if test_digits is not None:
            source_sample_train = []
            source_sample_valid = []

            target_sample_train = []
            target_sample_valid = []

            for digit in test_digits:
                source_images_digit_valid = valid_data[
                    (valid_data.obs['labels'] == digit)
                    & (valid_data.obs['condition'] == source_key)]
                target_images_digit_valid = valid_data[
                    (valid_data.obs['labels'] == digit)
                    & (valid_data.obs['condition'] == target_key)]
                if j == 0:
                    source_images_digit_valid.X /= 255.0
                random_samples = np.random.choice(
                    source_images_digit_valid.shape[0], 1, replace=False)

                source_sample_valid.append(
                    source_images_digit_valid.X[random_samples])
                target_sample_valid.append(
                    target_images_digit_valid.X[random_samples])

            for digit in train_digits:
                source_images_digit_train = train_data[
                    (train_data.obs['labels'] == digit)
                    & (train_data.obs['condition'] == source_key)]
                target_images_digit_train = train_data[
                    (train_data.obs['labels'] == digit)
                    & (train_data.obs['condition'] == target_key)]
                if j == 0:
                    source_images_digit_train.X /= 255.0
                random_samples = np.random.choice(
                    source_images_digit_train.shape[0], 1, replace=False)

                source_sample_train.append(
                    source_images_digit_train.X[random_samples])
                target_sample_train.append(
                    target_images_digit_train.X[random_samples])
        else:
            random_samples_train = np.random.choice(source_data_train.shape[0],
                                                    k,
                                                    replace=False)
            random_samples_valid = np.random.choice(source_data_valid.shape[0],
                                                    k,
                                                    replace=False)
            source_sample_train = source_data_train.X[random_samples_train]
            source_sample_valid = source_data_valid.X[random_samples_valid]

        source_sample_train = np.array(source_sample_train)
        source_sample_valid = np.array(source_sample_valid)
        # if data_name.__contains__("mnist"):
        #     target_sample = np.array(target_sample)
        #     target_sample_reshaped = np.reshape(target_sample, (-1, *image_shape))

        source_sample_train = np.reshape(source_sample_train,
                                         (-1, np.prod(image_shape)))
        source_sample_train_reshaped = np.reshape(source_sample_train,
                                                  (-1, *image_shape))
        if data_name.__contains__("mnist"):
            target_sample_train = np.reshape(target_sample_train,
                                             (-1, np.prod(image_shape)))
            target_sample_train_reshaped = np.reshape(target_sample_train,
                                                      (-1, *image_shape))
            target_sample_valid = np.reshape(target_sample_valid,
                                             (-1, np.prod(image_shape)))
            target_sample_valid_reshaped = np.reshape(target_sample_valid,
                                                      (-1, *image_shape))

        source_sample_valid = np.reshape(source_sample_valid,
                                         (-1, np.prod(image_shape)))
        source_sample_valid_reshaped = np.reshape(source_sample_valid,
                                                  (-1, *image_shape))

        source_sample_train = anndata.AnnData(X=source_sample_train)
        source_sample_valid = anndata.AnnData(X=source_sample_valid)

        pred_sample_train = network.predict(adata=source_sample_train,
                                            encoder_labels=np.zeros((k, 1)),
                                            decoder_labels=np.ones((k, 1)))
        pred_sample_train = np.reshape(pred_sample_train,
                                       newshape=(-1, *image_shape))

        pred_sample_valid = network.predict(adata=source_sample_valid,
                                            encoder_labels=np.zeros((k, 1)),
                                            decoder_labels=np.ones((k, 1)))
        pred_sample_valid = np.reshape(pred_sample_valid,
                                       newshape=(-1, *image_shape))

        print(source_sample_train.shape, source_sample_train_reshaped.shape,
              pred_sample_train.shape)

        plt.close("all")
        if train_digits is not None:
            k = len(train_digits)
        if data_name.__contains__("mnist"):
            fig, ax = plt.subplots(len(train_digits), 3, figsize=(k * 1, 6))
        else:
            fig, ax = plt.subplots(k, 2, figsize=(k * 1, 6))
        for i in range(k):
            ax[i, 0].axis('off')
            if source_sample_train_reshaped.shape[-1] > 1:
                ax[i, 0].imshow(source_sample_train_reshaped[i])
            else:
                ax[i, 0].imshow(source_sample_train_reshaped[i, :, :, 0],
                                cmap='Greys')
            ax[i, 1].axis('off')
            if data_name.__contains__("mnist"):
                ax[i, 2].axis('off')
            # if i == 0:
            #     if data_name == "celeba":
            #         ax[i, 0].set_title(f"without {data_dict['attribute']}")
            #         ax[i, 1].set_title(f"with {data_dict['attribute']}")
            #     elif data_name.__contains__("mnist"):
            #         ax[i, 0].set_title(f"Source")
            #         ax[i, 1].set_title(f"Target (Ground Truth)")
            #         ax[i, 2].set_title(f"Target (Predicted)")
            #     else:
            #         ax[i, 0].set_title(f"{source_key}")
            #         ax[i, 1].set_title(f"{target_key}")

            if pred_sample_train.shape[-1] > 1:
                ax[i, 1].imshow(pred_sample_train[i])
            else:
                ax[i, 1].imshow(target_sample_train_reshaped[i, :, :, 0],
                                cmap='Greys')
                ax[i, 2].imshow(pred_sample_train[i, :, :, 0], cmap='Greys')
            # if data_name.__contains__("mnist"):
            #     ax[i, 2].imshow(target_sample_reshaped[i, :, :, 0], cmap='Greys')
        plt.savefig(os.path.join(results_path_train, f"sample_images_{j}.pdf"))

        print(source_sample_valid.shape, source_sample_valid_reshaped.shape,
              pred_sample_valid.shape)

        plt.close("all")
        if test_digits is not None:
            k = len(test_digits)
        if data_name.__contains__("mnist"):
            fig, ax = plt.subplots(k, 3, figsize=(k * 1, 6))
        else:
            fig, ax = plt.subplots(k, 2, figsize=(k * 1, 6))
        for i in range(k):
            ax[i, 0].axis('off')
            if source_sample_valid_reshaped.shape[-1] > 1:
                ax[i, 0].imshow(source_sample_valid_reshaped[i])
            else:
                ax[i, 0].imshow(source_sample_valid_reshaped[i, :, :, 0],
                                cmap='Greys')
            ax[i, 1].axis('off')
            if data_name.__contains__("mnist"):
                ax[i, 2].axis('off')
            # if i == 0:
            #     if data_name == "celeba":
            #         ax[i, 0].set_title(f"without {data_dict['attribute']}")
            #         ax[i, 1].set_title(f"with {data_dict['attribute']}")
            #     elif data_name.__contains__("mnist"):
            #         ax[i, 0].set_title(f"Source")
            #         ax[i, 1].set_title(f"Target (Ground Truth)")
            #         ax[i, 2].set_title(f"Target (Predicted)")
            #     else:
            #         ax[i, 0].set_title(f"{source_key}")
            #         ax[i, 1].set_title(f"{target_key}")

            if pred_sample_valid.shape[-1] > 1:
                ax[i, 1].imshow(pred_sample_valid[i])
            else:
                ax[i, 1].imshow(target_sample_valid_reshaped[i, :, :, 0],
                                cmap='Greys')
                ax[i, 2].imshow(pred_sample_valid[i, :, :, 0], cmap='Greys')
            # if data_name.__contains__("mnist"):
            #     ax[i, 2].imshow(target_sample_reshaped[i, :, :, 0], cmap='Greys')
        plt.savefig(
            os.path.join(results_path_valid, f"./sample_images_{j}.pdf"))
Example #29
0
        print('Options')
        print(f'    type:\t{type_}')
        print(f'    batch_key:\t{batch_key}')
        print(f'    label_key:\t{label_key}')
        print(f'    assay:\t{assay}')
        print(f'    organism:\t{organism}')
        print(f'    n_hvgs:\t{n_hvgs}')
        print(f'    setup:\t{setup}')
        print(f'    optimised clustering results:\t{cluster_nmi}')

    ###

    empty_file = False

    print("reading adata before integration")
    adata = sc.read(args.uncorrected, cache=True)
    print(adata)
    print("reading adata after integration")
    if os.stat(args.integrated).st_size == 0:
        print(f'{args.integrated} is empty, setting all metrics to NA.')
        adata_int = adata
        empty_file = True
    else:
        adata_int = sc.read(args.integrated, cache=True)
        print(adata_int)

    if (n_hvgs is not None):
        if (adata_int.n_vars < n_hvgs):
            raise ValueError(
                "There are less genes in the corrected adata than specified for HVG selection"
            )
Example #30
0
from scdcdm.util import comp_ana as ca
from scdcdm.model import dirichlet_time_models as tm
from scdcdm.util import data_visualization as viz

tfd = tfp.distributions
tfb = tfp.bijectors

pd.set_option('display.max_columns', 500)

#%%

# get Lisa's dataset

data_path = "C:/Users/Johannes/Documents/PhD/single-cell/hackathon_sep2020/thymus_data.h5ad"

data = sp.read(data_path)

#%%

# pseudo-covariate of 1 on all samples
data.obs["c"] = 1

print(data.X)

#%%

viz.plot_feature_stackbars(data, ["day"])

#%%
importlib.reload(ca)
importlib.reload(mod)