def runScGen(adata, batch, cell_type, epochs=100, hvg=None, model_path='/localscratch'): """ Parametrization taken from the tutorial notebook at: https://nbviewer.jupyter.org/github/M0hammadL/scGen_notebooks/blob/master/notebooks/scgen_batch_removal.ipynb """ import scgen checkSanity(adata, batch, hvg) # Fit the model network = scgen.VAEArith(x_dimension=adata.shape[1], model_path=model_path) network.train(train_data=adata, n_epochs=epochs, save=False) corrected_adata = scgen.batch_removal(network, adata, batch_key=batch, cell_label_key=cell_type) network.sess.close() return corrected_adata
def train_cross_study(data_name="study", z_dim=100, alpha=0.00005, n_epochs=300, batch_size=32, dropout_rate=0.2, learning_rate=0.001): train = sc.read("../data/train_study.h5ad") valid = sc.read("../data/valid_study.h5ad") net_train_data = train network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1], z_dimension=z_dim, alpha=alpha, dropout_rate=dropout_rate, learning_rate=learning_rate, model_path="../models/scGen/study/scgen") network.train(net_train_data, use_validation=True, valid_data=valid, n_epochs=n_epochs, batch_size=batch_size) print(f"network_{data_name} has been trained!") network.sess.close()
def train_batch_removal(data_name="study", z_dim=100, alpha=0.00005, n_epochs=300, batch_size=32, dropout_rate=0.2, learning_rate=0.001, condition_key="condition"): stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/kang_cross_train.h5ad") os.makedirs(f"./vae_results/{data_name}/all/", exist_ok=True) os.chdir(f"./vae_results/{data_name}/all/") net_train_data = train network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1], z_dimension=z_dim, alpha=alpha, dropout_rate=dropout_rate, learning_rate=learning_rate) # network.restore_model() network.train(net_train_data, n_epochs=n_epochs, batch_size=batch_size) print(f"network_{data_name} has been trained!") os.chdir("../../../")
def test_batch_removal(): train = sc.read("./tests/data/pancreas.h5ad", backup_url="https://goo.gl/V29FNk") train.obs["cell_type"] = train.obs["celltype"].tolist() network = scgen.VAEArith(x_dimension=train.shape[1], model_path="./models/batch") network.train(train_data=train, n_epochs=0) corrected_adata = scgen.batch_removal(network, train) network.sess.close()
def train(data_name="study", z_dim=100, alpha=0.00005, n_epochs=300, batch_size=32, dropout_rate=0.2, learning_rate=0.001, condition_key="condition"): if data_name == "pbmc": cell_type_to_monitor = "CD4T" stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/train.h5ad") elif data_name == "hpoly": cell_type_to_monitor = None stim_key = "Hpoly.Day10" ctrl_key = "Control" cell_type_key = "cell_label" train = sc.read("../data/ch10_train_7000.h5ad") elif data_name == "salmonella": cell_type_to_monitor = None stim_key = "Salmonella" ctrl_key = "Control" cell_type_key = "cell_label" train = sc.read("../data/chsal_train_7000.h5ad") elif data_name == "species": cell_type_to_monitor = "rat" stim_key = "LPS6" ctrl_key = "unst" cell_type_key = "species" train = sc.read("../data/train_all_lps6.h5ad") elif data_name == "study": stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/kang_cross_train.h5ad") os.makedirs(f"./vae_results/{data_name}/whole/", exist_ok=True) os.chdir(f"./vae_results/{data_name}/whole/") net_train_data = train network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1], z_dimension=z_dim, alpha=alpha, dropout_rate=dropout_rate, learning_rate=learning_rate) # network.restore_model() network.train(net_train_data, n_epochs=n_epochs, batch_size=batch_size) print(f"network has been trained!") # scgen.visualize_trained_network_results(network, train, cell_type, # conditions={"ctrl": ctrl_key, "stim": stim_key}, # condition_key="condition", cell_type_key=cell_type_key, # path_to_save="./figures/tensorflow/") os.chdir("../../../")
def test_train_whole_data_one_celltype_out(data_name="pbmc", z_dim=50, alpha=0.1, n_epochs=1000, batch_size=32, dropout_rate=0.25, learning_rate=0.001, condition_key="condition", cell_type_to_train=None): if data_name == "pbmc": stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/train_pbmc.h5ad") valid = sc.read("../data/valid_pbmc.h5ad") elif data_name == "hpoly": stim_key = "Hpoly.Day10" ctrl_key = "Control" cell_type_key = "cell_label" train = sc.read("../data/train_hpoly.h5ad") valid = sc.read("../data/valid_hpoly.h5ad") elif data_name == "salmonella": stim_key = "Salmonella" ctrl_key = "Control" cell_type_key = "cell_label" train = sc.read("../data/train_salmonella.h5ad") valid = sc.read("../data/valid_salmonella.h5ad") elif data_name == "species": stim_key = "LPS6" ctrl_key = "unst" cell_type_key = "species" train = sc.read("../data/train_species.h5ad") valid = sc.read("../data/valid_species.h5ad") for cell_type in train.obs[cell_type_key].unique().tolist(): if cell_type_to_train is not None and cell_type != cell_type_to_train: continue net_train_data = train[~((train.obs[cell_type_key] == cell_type) & (train.obs[condition_key] == stim_key))] net_valid_data = valid[~((valid.obs[cell_type_key] == cell_type) & (valid.obs[condition_key] == stim_key))] network = scgen.VAEArith( x_dimension=net_train_data.X.shape[1], z_dimension=z_dim, alpha=alpha, dropout_rate=dropout_rate, learning_rate=learning_rate, model_path=f"../models/scGen/{data_name}/{cell_type}/scgen") network.train(net_train_data, use_validation=True, valid_data=net_valid_data, n_epochs=n_epochs, batch_size=batch_size) network.sess.close() print(f"network_{cell_type} has been trained!")
def test_batch_removal(): train = sc.read( "./data/pancreas.h5ad", backup_url= "https://www.dropbox.com/s/qj1jlm9w10wmt0u/pancreas.h5ad?dl=1") train.obs["cell_type"] = train.obs["celltype"].tolist() network = scgen.VAEArith(x_dimension=train.shape[1], model_path="./models/batch") network.train(train_data=train, n_epochs=1, verbose=1) corrected_adata = scgen.batch_removal(network, train) print(corrected_adata.obs) network.sess.close()
def test_reg_mean_plot(): train = sc.read("./tests/data/train.h5ad", backup_url="https://goo.gl/33HtVh") network = scgen.VAEArith(x_dimension=train.shape[1], model_path="../models/test") network.train(train_data=train, n_epochs=0) unperturbed_data = train[((train.obs["cell_type"] == "CD4T") & (train.obs["condition"] == "control"))] condition = {"ctrl": "control", "stim": "stimulated"} pred, delta = network.predict(adata=train, adata_to_predict=unperturbed_data, conditions=condition) pred_adata = anndata.AnnData(pred, obs={"condition": ["pred"] * len(pred)}, var={"var_names": train.var_names}) CD4T = train[train.obs["cell_type"] == "CD4T"] all_adata = CD4T.concatenate(pred_adata) scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={ "x": "control", "y": "pred" }, path_to_save="tests/reg_mean1.pdf") scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={ "x": "control", "y": "pred" }, path_to_save="tests/reg_mean2.pdf", gene_list=["ISG15", "CD3D"]) scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={ "x": "control", "y": "pred", "y1": "stimulated" }, path_to_save="tests/reg_mean3.pdf") scgen.plotting.reg_mean_plot( all_adata, condition_key="condition", axis_keys={ "x": "control", "y": "pred", "y1": "stimulated" }, gene_list=["ISG15", "CD3D"], path_to_save="tests/reg_mean.pdf", ) network.sess.close()
def create_model(x_train): network = scgen.VAEArith( x_dimension=x_train.X.shape[1], z_dimension={{choice([10, 20, 50, 75, 100])}}, learning_rate={{choice([0.1, 0.01, 0.001, 0.0001])}}, alpha={{choice([0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001])}}, dropout_rate={{choice([0.2, 0.25, 0.5, 0.75, 0.8])}}, model_path=f"./") result = network.train(x_train, n_epochs={{choice([100, 150, 200, 250])}}, batch_size={{choice([32, 64, 128, 256])}}, verbose=2, shuffle=True, save=False) best_loss = np.amin(result.history['loss']) print('Best Loss of model:', best_loss) return {'loss': best_loss, 'status': STATUS_OK, 'model': network.vae_model}
def test_train_whole_data_some_celltypes_out(data_name="pbmc", z_dim=100, alpha=0.00005, n_epochs=300, batch_size=32, dropout_rate=0.2, learning_rate=0.001, condition_key="condition", c_out=None, c_in=None): if data_name == "pbmc": stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/train_pbmc.h5ad") valid = sc.read("../data/valid_pbmc.h5ad") net_train_data = scgen.data_remover(train, remain_list=c_in, remove_list=c_out, cell_type_key=cell_type_key, condition_key=condition_key) net_valid_data = scgen.data_remover(valid, remain_list=c_in, remove_list=c_out, cell_type_key=cell_type_key, condition_key=condition_key) network = scgen.VAEArith( x_dimension=net_train_data.X.shape[1], z_dimension=z_dim, alpha=alpha, dropout_rate=dropout_rate, learning_rate=learning_rate, model_path=f"../models/scGen/pbmc/heldout/{len(c_out)}/scgen") network.train(net_train_data, use_validation=True, valid_data=net_valid_data, n_epochs=n_epochs, batch_size=batch_size) print(f"network has been trained!") network.sess.close()
def test_train_whole_data_some_celltypes_out(data_name="study", z_dim=100, alpha=0.00005, n_epochs=300, batch_size=32, dropout_rate=0.2, learning_rate=0.001, condition_key="condition", c_out=None, c_in=None): if data_name == "pbmc": stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/train.h5ad") os.makedirs(f"./vae_results/{data_name}/heldout/{len(c_out)}/", exist_ok=True) os.chdir(f"./vae_results/{data_name}/heldout/{len(c_out)}/") net_train_data = scgen.data_remover(train, remain_list=c_in, remove_list=c_out, cell_type_key=cell_type_key, condition_key=condition_key) print(net_train_data) network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1], z_dimension=z_dim, alpha=alpha, dropout_rate=dropout_rate, learning_rate=learning_rate) # network.restore_model() network.train(net_train_data, n_epochs=n_epochs, batch_size=batch_size) print(f"network has been trained!") os.chdir("../../../../")
def test_binary_classifier(): train = sc.read("./tests/data/train.h5ad", backup_url="https://goo.gl/33HtVh") network = scgen.VAEArith(x_dimension=train.shape[1], model_path="../models/test") network.train(train_data=train, n_epochs=0) unperturbed_data = train[((train.obs["cell_type"] == "CD4T") & (train.obs["condition"] == "control"))] condition = {"ctrl": "control", "stim": "stimulated"} pred, delta = network.predict(adata=train, adata_to_predict=unperturbed_data, conditions=condition) scgen.plotting.binary_classifier( network, train, delta, condition_key="condition", conditions={ "ctrl": "control", "stim": "stimulated" }, path_to_save="tests/binary_classifier.pdf") network.sess.close()
sc.pp.normalize_per_cell(all_train, counts_per_cell_after=1e4) sc.pp.log1p(all_train) all_train = all_train[:, ((tmp_nuclei.var['highly_variable'] | tmp_cells.var['highly_variable']))].copy() print("DATA PROCESSING COMPLETE") # Round 1: Correct for source ## Train scGen on the trainings dataset and apply batch removal to both trainings and the complete set of nuclei/cells all.obs['cell_type'] = all.obs['leiden_annotated'].values all.obs['batch'] = all.obs['cell_source'].values all_train.obs['cell_type'] = all_train.obs['leiden_annotated'].values all_train.obs['batch'] = all_train.obs['cell_source'].values network = scgen.VAEArith(x_dimension=all_train.shape[1], model_path=path_out + 'scgen_model/scgen_model_V_source') print("Network made") network.train(train_data=all_train, n_epochs=20) print("Network trained") corrected_adata = batch_removal_ct5(network, all, batch_key="batch", cell_label_key="cell_type") corrected_adata_train = batch_removal_ct5(network, all_train, batch_key="batch", cell_label_key="cell_type") corrected_adata.write(path_out + "ALL_corrected_FB_ATV_source.h5ad")
def reconstruct_whole_data(data_name="pbmc", condition_key="condition"): if data_name == "pbmc": stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/train.h5ad") elif data_name == "hpoly": stim_key = "Hpoly.Day10" ctrl_key = "Control" cell_type_key = "cell_label" train = sc.read("../data/ch10_train_7000.h5ad") elif data_name == "salmonella": stim_key = "Salmonella" ctrl_key = "Control" cell_type_key = "cell_label" train = sc.read("../data/chsal_train_7000.h5ad") elif data_name == "species": stim_key = "LPS6" ctrl_key = "unst" cell_type_key = "species" train = sc.read("../data/train_all_lps6.h5ad") elif data_name == "study": stim_key = "stimulated" ctrl_key = "control" cell_type_key = "cell_type" train = sc.read("../data/kang_cross_train.h5ad") all_data = anndata.AnnData() for idx, cell_type in enumerate( train.obs[cell_type_key].unique().tolist()): print(f"Reconstructing for {cell_type}") os.chdir(f"./vae_results/{data_name}/{cell_type}") network = scgen.VAEArith(x_dimension=train.X.shape[1], z_dimension=100, alpha=0.00005, dropout_rate=0.2, learning_rate=0.001) network.restore_model() cell_type_data = train[train.obs[cell_type_key] == cell_type] cell_type_ctrl_data = train[((train.obs[cell_type_key] == cell_type) & (train.obs[condition_key] == ctrl_key))] pred, delta = network.predict(adata=cell_type_data, conditions={ "ctrl": ctrl_key, "stim": stim_key }, cell_type_key=cell_type_key, condition_key=condition_key, celltype_to_predict=cell_type) pred_adata = anndata.AnnData( pred, obs={ condition_key: [f"{cell_type}_pred_stim"] * len(pred), cell_type_key: [cell_type] * len(pred) }, var={"var_names": cell_type_data.var_names}) ctrl_adata = anndata.AnnData( cell_type_ctrl_data.X, obs={ condition_key: [f"{cell_type}_ctrl"] * len(cell_type_ctrl_data), cell_type_key: [cell_type] * len(cell_type_ctrl_data) }, var={"var_names": cell_type_ctrl_data.var_names}) if sparse.issparse(cell_type_data.X): real_stim = cell_type_data[cell_type_data.obs[condition_key] == stim_key].X.A else: real_stim = cell_type_data[cell_type_data.obs[condition_key] == stim_key].X real_stim_adata = anndata.AnnData( real_stim, obs={ condition_key: [f"{cell_type}_real_stim"] * len(real_stim), cell_type_key: [cell_type] * len(real_stim) }, var={"var_names": cell_type_data.var_names}) if idx == 0: all_data = ctrl_adata.concatenate(pred_adata, real_stim_adata) else: all_data = all_data.concatenate(ctrl_adata, pred_adata, real_stim_adata) os.chdir("../../../") print(f"Finish Reconstructing for {cell_type}") all_data.write_h5ad(f"./vae_results/{data_name}/reconstructed.h5ad")