コード例 #1
0
def runScGen(adata,
             batch,
             cell_type,
             epochs=100,
             hvg=None,
             model_path='/localscratch'):
    """
    Parametrization taken from the tutorial notebook at:
    https://nbviewer.jupyter.org/github/M0hammadL/scGen_notebooks/blob/master/notebooks/scgen_batch_removal.ipynb
    """
    import scgen

    checkSanity(adata, batch, hvg)

    # Fit the model
    network = scgen.VAEArith(x_dimension=adata.shape[1], model_path=model_path)
    network.train(train_data=adata, n_epochs=epochs, save=False)
    corrected_adata = scgen.batch_removal(network,
                                          adata,
                                          batch_key=batch,
                                          cell_label_key=cell_type)

    network.sess.close()

    return corrected_adata
コード例 #2
0
def train_cross_study(data_name="study",
                      z_dim=100,
                      alpha=0.00005,
                      n_epochs=300,
                      batch_size=32,
                      dropout_rate=0.2,
                      learning_rate=0.001):
    train = sc.read("../data/train_study.h5ad")
    valid = sc.read("../data/valid_study.h5ad")

    net_train_data = train
    network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1],
                             z_dimension=z_dim,
                             alpha=alpha,
                             dropout_rate=dropout_rate,
                             learning_rate=learning_rate,
                             model_path="../models/scGen/study/scgen")

    network.train(net_train_data,
                  use_validation=True,
                  valid_data=valid,
                  n_epochs=n_epochs,
                  batch_size=batch_size)
    print(f"network_{data_name} has been trained!")
    network.sess.close()
コード例 #3
0
def train_batch_removal(data_name="study",
                        z_dim=100,
                        alpha=0.00005,
                        n_epochs=300,
                        batch_size=32,
                        dropout_rate=0.2,
                        learning_rate=0.001,
                        condition_key="condition"):
    stim_key = "stimulated"
    ctrl_key = "control"
    cell_type_key = "cell_type"
    train = sc.read("../data/kang_cross_train.h5ad")

    os.makedirs(f"./vae_results/{data_name}/all/", exist_ok=True)
    os.chdir(f"./vae_results/{data_name}/all/")
    net_train_data = train
    network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1],
                             z_dimension=z_dim,
                             alpha=alpha,
                             dropout_rate=dropout_rate,
                             learning_rate=learning_rate)

    # network.restore_model()
    network.train(net_train_data, n_epochs=n_epochs, batch_size=batch_size)
    print(f"network_{data_name} has been trained!")

    os.chdir("../../../")
コード例 #4
0
def test_batch_removal():
    train = sc.read("./tests/data/pancreas.h5ad",
                    backup_url="https://goo.gl/V29FNk")
    train.obs["cell_type"] = train.obs["celltype"].tolist()
    network = scgen.VAEArith(x_dimension=train.shape[1],
                             model_path="./models/batch")
    network.train(train_data=train, n_epochs=0)
    corrected_adata = scgen.batch_removal(network, train)
    network.sess.close()
コード例 #5
0
def train(data_name="study",
          z_dim=100,
          alpha=0.00005,
          n_epochs=300,
          batch_size=32,
          dropout_rate=0.2,
          learning_rate=0.001,
          condition_key="condition"):
    if data_name == "pbmc":
        cell_type_to_monitor = "CD4T"
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/train.h5ad")
    elif data_name == "hpoly":
        cell_type_to_monitor = None
        stim_key = "Hpoly.Day10"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/ch10_train_7000.h5ad")
    elif data_name == "salmonella":
        cell_type_to_monitor = None
        stim_key = "Salmonella"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/chsal_train_7000.h5ad")
    elif data_name == "species":
        cell_type_to_monitor = "rat"
        stim_key = "LPS6"
        ctrl_key = "unst"
        cell_type_key = "species"
        train = sc.read("../data/train_all_lps6.h5ad")
    elif data_name == "study":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/kang_cross_train.h5ad")

    os.makedirs(f"./vae_results/{data_name}/whole/", exist_ok=True)
    os.chdir(f"./vae_results/{data_name}/whole/")
    net_train_data = train
    network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1],
                             z_dimension=z_dim,
                             alpha=alpha,
                             dropout_rate=dropout_rate,
                             learning_rate=learning_rate)

    # network.restore_model()
    network.train(net_train_data, n_epochs=n_epochs, batch_size=batch_size)
    print(f"network has been trained!")

    # scgen.visualize_trained_network_results(network, train, cell_type,
    #                                         conditions={"ctrl": ctrl_key, "stim": stim_key},
    #                                         condition_key="condition", cell_type_key=cell_type_key,
    #                                         path_to_save="./figures/tensorflow/")
    os.chdir("../../../")
コード例 #6
0
def test_train_whole_data_one_celltype_out(data_name="pbmc",
                                           z_dim=50,
                                           alpha=0.1,
                                           n_epochs=1000,
                                           batch_size=32,
                                           dropout_rate=0.25,
                                           learning_rate=0.001,
                                           condition_key="condition",
                                           cell_type_to_train=None):
    if data_name == "pbmc":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/train_pbmc.h5ad")
        valid = sc.read("../data/valid_pbmc.h5ad")
    elif data_name == "hpoly":
        stim_key = "Hpoly.Day10"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/train_hpoly.h5ad")
        valid = sc.read("../data/valid_hpoly.h5ad")
    elif data_name == "salmonella":
        stim_key = "Salmonella"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/train_salmonella.h5ad")
        valid = sc.read("../data/valid_salmonella.h5ad")
    elif data_name == "species":
        stim_key = "LPS6"
        ctrl_key = "unst"
        cell_type_key = "species"
        train = sc.read("../data/train_species.h5ad")
        valid = sc.read("../data/valid_species.h5ad")

    for cell_type in train.obs[cell_type_key].unique().tolist():
        if cell_type_to_train is not None and cell_type != cell_type_to_train:
            continue
        net_train_data = train[~((train.obs[cell_type_key] == cell_type) &
                                 (train.obs[condition_key] == stim_key))]
        net_valid_data = valid[~((valid.obs[cell_type_key] == cell_type) &
                                 (valid.obs[condition_key] == stim_key))]
        network = scgen.VAEArith(
            x_dimension=net_train_data.X.shape[1],
            z_dimension=z_dim,
            alpha=alpha,
            dropout_rate=dropout_rate,
            learning_rate=learning_rate,
            model_path=f"../models/scGen/{data_name}/{cell_type}/scgen")

        network.train(net_train_data,
                      use_validation=True,
                      valid_data=net_valid_data,
                      n_epochs=n_epochs,
                      batch_size=batch_size)
        network.sess.close()
        print(f"network_{cell_type} has been trained!")
コード例 #7
0
def test_batch_removal():
    train = sc.read(
        "./data/pancreas.h5ad",
        backup_url=
        "https://www.dropbox.com/s/qj1jlm9w10wmt0u/pancreas.h5ad?dl=1")
    train.obs["cell_type"] = train.obs["celltype"].tolist()
    network = scgen.VAEArith(x_dimension=train.shape[1],
                             model_path="./models/batch")
    network.train(train_data=train, n_epochs=1, verbose=1)
    corrected_adata = scgen.batch_removal(network, train)
    print(corrected_adata.obs)
    network.sess.close()
コード例 #8
0
ファイル: test_plotting.py プロジェクト: martinkla/scgen
def test_reg_mean_plot():
    train = sc.read("./tests/data/train.h5ad",
                    backup_url="https://goo.gl/33HtVh")
    network = scgen.VAEArith(x_dimension=train.shape[1],
                             model_path="../models/test")
    network.train(train_data=train, n_epochs=0)
    unperturbed_data = train[((train.obs["cell_type"] == "CD4T") &
                              (train.obs["condition"] == "control"))]
    condition = {"ctrl": "control", "stim": "stimulated"}
    pred, delta = network.predict(adata=train,
                                  adata_to_predict=unperturbed_data,
                                  conditions=condition)
    pred_adata = anndata.AnnData(pred,
                                 obs={"condition": ["pred"] * len(pred)},
                                 var={"var_names": train.var_names})
    CD4T = train[train.obs["cell_type"] == "CD4T"]
    all_adata = CD4T.concatenate(pred_adata)
    scgen.plotting.reg_mean_plot(all_adata,
                                 condition_key="condition",
                                 axis_keys={
                                     "x": "control",
                                     "y": "pred"
                                 },
                                 path_to_save="tests/reg_mean1.pdf")
    scgen.plotting.reg_mean_plot(all_adata,
                                 condition_key="condition",
                                 axis_keys={
                                     "x": "control",
                                     "y": "pred"
                                 },
                                 path_to_save="tests/reg_mean2.pdf",
                                 gene_list=["ISG15", "CD3D"])
    scgen.plotting.reg_mean_plot(all_adata,
                                 condition_key="condition",
                                 axis_keys={
                                     "x": "control",
                                     "y": "pred",
                                     "y1": "stimulated"
                                 },
                                 path_to_save="tests/reg_mean3.pdf")
    scgen.plotting.reg_mean_plot(
        all_adata,
        condition_key="condition",
        axis_keys={
            "x": "control",
            "y": "pred",
            "y1": "stimulated"
        },
        gene_list=["ISG15", "CD3D"],
        path_to_save="tests/reg_mean.pdf",
    )
    network.sess.close()
コード例 #9
0
def create_model(x_train):
    network = scgen.VAEArith(
        x_dimension=x_train.X.shape[1],
        z_dimension={{choice([10, 20, 50, 75, 100])}},
        learning_rate={{choice([0.1, 0.01, 0.001, 0.0001])}},
        alpha={{choice([0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001])}},
        dropout_rate={{choice([0.2, 0.25, 0.5, 0.75, 0.8])}},
        model_path=f"./")

    result = network.train(x_train,
                           n_epochs={{choice([100, 150, 200, 250])}},
                           batch_size={{choice([32, 64, 128, 256])}},
                           verbose=2,
                           shuffle=True,
                           save=False)
    best_loss = np.amin(result.history['loss'])
    print('Best Loss of model:', best_loss)
    return {'loss': best_loss, 'status': STATUS_OK, 'model': network.vae_model}
コード例 #10
0
def test_train_whole_data_some_celltypes_out(data_name="pbmc",
                                             z_dim=100,
                                             alpha=0.00005,
                                             n_epochs=300,
                                             batch_size=32,
                                             dropout_rate=0.2,
                                             learning_rate=0.001,
                                             condition_key="condition",
                                             c_out=None,
                                             c_in=None):
    if data_name == "pbmc":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/train_pbmc.h5ad")
        valid = sc.read("../data/valid_pbmc.h5ad")

    net_train_data = scgen.data_remover(train,
                                        remain_list=c_in,
                                        remove_list=c_out,
                                        cell_type_key=cell_type_key,
                                        condition_key=condition_key)

    net_valid_data = scgen.data_remover(valid,
                                        remain_list=c_in,
                                        remove_list=c_out,
                                        cell_type_key=cell_type_key,
                                        condition_key=condition_key)

    network = scgen.VAEArith(
        x_dimension=net_train_data.X.shape[1],
        z_dimension=z_dim,
        alpha=alpha,
        dropout_rate=dropout_rate,
        learning_rate=learning_rate,
        model_path=f"../models/scGen/pbmc/heldout/{len(c_out)}/scgen")

    network.train(net_train_data,
                  use_validation=True,
                  valid_data=net_valid_data,
                  n_epochs=n_epochs,
                  batch_size=batch_size)
    print(f"network has been trained!")
    network.sess.close()
コード例 #11
0
def test_train_whole_data_some_celltypes_out(data_name="study",
                                             z_dim=100,
                                             alpha=0.00005,
                                             n_epochs=300,
                                             batch_size=32,
                                             dropout_rate=0.2,
                                             learning_rate=0.001,
                                             condition_key="condition",
                                             c_out=None,
                                             c_in=None):
    if data_name == "pbmc":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/train.h5ad")

    os.makedirs(f"./vae_results/{data_name}/heldout/{len(c_out)}/",
                exist_ok=True)
    os.chdir(f"./vae_results/{data_name}/heldout/{len(c_out)}/")

    net_train_data = scgen.data_remover(train,
                                        remain_list=c_in,
                                        remove_list=c_out,
                                        cell_type_key=cell_type_key,
                                        condition_key=condition_key)

    print(net_train_data)

    network = scgen.VAEArith(x_dimension=net_train_data.X.shape[1],
                             z_dimension=z_dim,
                             alpha=alpha,
                             dropout_rate=dropout_rate,
                             learning_rate=learning_rate)

    # network.restore_model()
    network.train(net_train_data, n_epochs=n_epochs, batch_size=batch_size)
    print(f"network has been trained!")
    os.chdir("../../../../")
コード例 #12
0
ファイル: test_plotting.py プロジェクト: martinkla/scgen
def test_binary_classifier():
    train = sc.read("./tests/data/train.h5ad",
                    backup_url="https://goo.gl/33HtVh")
    network = scgen.VAEArith(x_dimension=train.shape[1],
                             model_path="../models/test")
    network.train(train_data=train, n_epochs=0)
    unperturbed_data = train[((train.obs["cell_type"] == "CD4T") &
                              (train.obs["condition"] == "control"))]
    condition = {"ctrl": "control", "stim": "stimulated"}
    pred, delta = network.predict(adata=train,
                                  adata_to_predict=unperturbed_data,
                                  conditions=condition)
    scgen.plotting.binary_classifier(
        network,
        train,
        delta,
        condition_key="condition",
        conditions={
            "ctrl": "control",
            "stim": "stimulated"
        },
        path_to_save="tests/binary_classifier.pdf")
    network.sess.close()
コード例 #13
0
sc.pp.normalize_per_cell(all_train, counts_per_cell_after=1e4)
sc.pp.log1p(all_train)
all_train = all_train[:, ((tmp_nuclei.var['highly_variable']
                           | tmp_cells.var['highly_variable']))].copy()

print("DATA PROCESSING COMPLETE")

# Round 1: Correct for source
## Train scGen on the trainings dataset and apply batch removal to both trainings and the complete set of nuclei/cells
all.obs['cell_type'] = all.obs['leiden_annotated'].values
all.obs['batch'] = all.obs['cell_source'].values
all_train.obs['cell_type'] = all_train.obs['leiden_annotated'].values
all_train.obs['batch'] = all_train.obs['cell_source'].values

network = scgen.VAEArith(x_dimension=all_train.shape[1],
                         model_path=path_out +
                         'scgen_model/scgen_model_V_source')
print("Network made")
network.train(train_data=all_train, n_epochs=20)
print("Network trained")

corrected_adata = batch_removal_ct5(network,
                                    all,
                                    batch_key="batch",
                                    cell_label_key="cell_type")
corrected_adata_train = batch_removal_ct5(network,
                                          all_train,
                                          batch_key="batch",
                                          cell_label_key="cell_type")

corrected_adata.write(path_out + "ALL_corrected_FB_ATV_source.h5ad")
コード例 #14
0
def reconstruct_whole_data(data_name="pbmc", condition_key="condition"):
    if data_name == "pbmc":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/train.h5ad")
    elif data_name == "hpoly":
        stim_key = "Hpoly.Day10"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/ch10_train_7000.h5ad")
    elif data_name == "salmonella":
        stim_key = "Salmonella"
        ctrl_key = "Control"
        cell_type_key = "cell_label"
        train = sc.read("../data/chsal_train_7000.h5ad")
    elif data_name == "species":
        stim_key = "LPS6"
        ctrl_key = "unst"
        cell_type_key = "species"
        train = sc.read("../data/train_all_lps6.h5ad")
    elif data_name == "study":
        stim_key = "stimulated"
        ctrl_key = "control"
        cell_type_key = "cell_type"
        train = sc.read("../data/kang_cross_train.h5ad")

    all_data = anndata.AnnData()
    for idx, cell_type in enumerate(
            train.obs[cell_type_key].unique().tolist()):
        print(f"Reconstructing for {cell_type}")
        os.chdir(f"./vae_results/{data_name}/{cell_type}")
        network = scgen.VAEArith(x_dimension=train.X.shape[1],
                                 z_dimension=100,
                                 alpha=0.00005,
                                 dropout_rate=0.2,
                                 learning_rate=0.001)
        network.restore_model()

        cell_type_data = train[train.obs[cell_type_key] == cell_type]
        cell_type_ctrl_data = train[((train.obs[cell_type_key] == cell_type) &
                                     (train.obs[condition_key] == ctrl_key))]
        pred, delta = network.predict(adata=cell_type_data,
                                      conditions={
                                          "ctrl": ctrl_key,
                                          "stim": stim_key
                                      },
                                      cell_type_key=cell_type_key,
                                      condition_key=condition_key,
                                      celltype_to_predict=cell_type)

        pred_adata = anndata.AnnData(
            pred,
            obs={
                condition_key: [f"{cell_type}_pred_stim"] * len(pred),
                cell_type_key: [cell_type] * len(pred)
            },
            var={"var_names": cell_type_data.var_names})
        ctrl_adata = anndata.AnnData(
            cell_type_ctrl_data.X,
            obs={
                condition_key:
                [f"{cell_type}_ctrl"] * len(cell_type_ctrl_data),
                cell_type_key: [cell_type] * len(cell_type_ctrl_data)
            },
            var={"var_names": cell_type_ctrl_data.var_names})
        if sparse.issparse(cell_type_data.X):
            real_stim = cell_type_data[cell_type_data.obs[condition_key] ==
                                       stim_key].X.A
        else:
            real_stim = cell_type_data[cell_type_data.obs[condition_key] ==
                                       stim_key].X
        real_stim_adata = anndata.AnnData(
            real_stim,
            obs={
                condition_key: [f"{cell_type}_real_stim"] * len(real_stim),
                cell_type_key: [cell_type] * len(real_stim)
            },
            var={"var_names": cell_type_data.var_names})
        if idx == 0:
            all_data = ctrl_adata.concatenate(pred_adata, real_stim_adata)
        else:
            all_data = all_data.concatenate(ctrl_adata, pred_adata,
                                            real_stim_adata)

        os.chdir("../../../")
        print(f"Finish Reconstructing for {cell_type}")
    all_data.write_h5ad(f"./vae_results/{data_name}/reconstructed.h5ad")