예제 #1
0
def visualize_trained_network_results(data_dict,
                                      z_dim=100,
                                      mmd_dimension=128,
                                      loss_fn='mse'):
    plt.close("all")
    data_name = data_dict['name']
    source_keys = data_dict.get("source_conditions")
    target_keys = data_dict.get("target_conditions")
    cell_type_key = data_dict.get("cell_type", None)
    need_merge = data_dict.get('need_merge', False)
    label_encoder = data_dict.get('label_encoder', None)
    condition_key = data_dict.get('condition', 'condition')

    if need_merge:
        data, _ = merge_data(data_dict)
    else:
        data = sc.read(f"../data/{data_name}/{data_name}.h5ad")
        if loss_fn != 'mse':
            data = normalize_hvg(data,
                                 filter_min_counts=False,
                                 normalize_input=False,
                                 logtrans_input=True)

    cell_types = data.obs[cell_type_key].unique().tolist()

    spec_cell_type = data_dict.get("spec_cell_types", None)
    if spec_cell_type:
        cell_types = spec_cell_type

    for cell_type in cell_types:
        path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/"
        os.makedirs(path_to_save, exist_ok=True)
        sc.settings.figdir = os.path.abspath(path_to_save)
        train_data = data.copy()[~(
            (data.obs[condition_key].isin(target_keys)) &
            (data.obs[cell_type_key] == cell_type))]

        cell_type_adata = data.copy()[data.obs[cell_type_key] == cell_type]

        n_conditions = len(train_data.obs[condition_key].unique().tolist())

        network = trvae.trVAEMulti(
            x_dimension=data.shape[1],
            z_dimension=z_dim,
            loss_fn=loss_fn,
            n_conditions=n_conditions,
            mmd_dimension=mmd_dimension,
            model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/",
        )

        network.restore_model()

        if sparse.issparse(data.X):
            data.X = data.X.A

        feed_data = data

        train_labels, _ = trvae.label_encoder(data, label_encoder,
                                              condition_key)
        fake_labels = []

        for i in range(n_conditions):
            fake_labels.append(np.zeros(train_labels.shape) + i)

        latent_with_true_labels = network.to_z_latent(feed_data.X,
                                                      train_labels)
        latent_with_fake_labels = [
            network.to_z_latent(feed_data.X, fake_labels[i])
            for i in range(n_conditions)
        ]
        mmd_latent_with_true_labels = network.to_mmd_layer(feed_data,
                                                           train_labels,
                                                           feed_fake=0)
        mmd_latent_with_fake_labels = [
            network.to_mmd_layer(feed_data, train_labels, feed_fake=i)
            for i in range(n_conditions)
        ]

        if data_name in ["pbmc", 'endo_norm']:
            sc.tl.rank_genes_groups(cell_type_adata,
                                    groupby=condition_key,
                                    n_genes=100,
                                    method="wilcoxon")
            top_100_genes = cell_type_adata.uns["rank_genes_groups"]["names"][
                target_keys[-1]].tolist()
            gene_list = top_100_genes[:10]
        elif data_name in ['pancreas', 'nmuil_count']:
            gene_list = None
            top_100_genes = None
        else:
            sc.tl.rank_genes_groups(cell_type_adata,
                                    groupby=condition_key,
                                    n_genes=100,
                                    method="wilcoxon")
            top_50_down_genes = cell_type_adata.uns["rank_genes_groups"][
                "names"][source_keys[0]].tolist()
            top_50_up_genes = cell_type_adata.uns["rank_genes_groups"][
                "names"][target_keys[-1]].tolist()
            top_100_genes = top_50_up_genes + top_50_down_genes
            gene_list = top_50_down_genes[:5] + top_50_up_genes[:5]

        perturbation_list = data_dict.get("perturbation", [])
        pred_adatas = None
        for source, dest, name, source_label, target_label in perturbation_list:
            print(source, dest, name)
            pred_adata = visualize_multi_perturbation_between(
                network,
                cell_type_adata,
                pred_adatas,
                source_condition=source,
                target_condition=dest,
                name=name,
                source_label=source_label,
                target_label=target_label,
                cell_type=cell_type,
                data_name=data_name,
                top_100_genes=top_100_genes,
                gene_list=gene_list,
                path_to_save=path_to_save,
                condition_key=condition_key)
            if pred_adatas is None:
                pred_adatas = pred_adata
            else:
                pred_adatas = pred_adatas.concatenate(pred_adata)

        pred_adatas.write_h5ad(
            filename=f"../data/reconstructed/RCVAEMulti/{data_name}.h5ad")

        import matplotlib as mpl
        mpl.rcParams.update(mpl.rcParamsDefault)

        color = [condition_key, cell_type_key]

        latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
        latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        latent_with_fake_labels = [
            sc.AnnData(X=latent_with_fake_labels[i])
            for i in range(n_conditions)
        ]
        for i in range(n_conditions):
            latent_with_fake_labels[i].obs[condition_key] = data.obs[
                condition_key].values
            latent_with_fake_labels[i].obs[cell_type_key] = data.obs[
                cell_type_key].values

            sc.pp.neighbors(latent_with_fake_labels[i])
            sc.tl.umap(latent_with_fake_labels[i])
            sc.pl.umap(
                latent_with_fake_labels[i],
                color=color,
                save=f"_{data_name}_{cell_type}_latent_with_fake_labels_{i}",
                show=False,
                wspace=0.15,
                frameon=False)

        mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
        mmd_latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        mmd_latent_with_fake_labels = [
            sc.AnnData(X=mmd_latent_with_fake_labels[i])
            for i in range(n_conditions)
        ]
        for i in range(n_conditions):
            mmd_latent_with_fake_labels[i].obs[condition_key] = data.obs[
                condition_key].values
            mmd_latent_with_fake_labels[i].obs[cell_type_key] = data.obs[
                cell_type_key].values

            sc.pp.neighbors(mmd_latent_with_fake_labels[i])
            sc.tl.umap(mmd_latent_with_fake_labels[i])
            sc.pl.umap(mmd_latent_with_fake_labels[i],
                       color=color,
                       save=f"_{data_name}_latent_with_fake_labels_{i}",
                       show=False,
                       wspace=0.15,
                       frameon=False)

        sc.pp.neighbors(train_data)
        sc.tl.umap(train_data)
        sc.pl.umap(train_data,
                   color=color,
                   save=f'_{data_name}_{cell_type}_train_data',
                   show=False,
                   wspace=0.15,
                   frameon=False)

        sc.pp.neighbors(latent_with_true_labels)
        sc.tl.umap(latent_with_true_labels)
        sc.pl.umap(latent_with_true_labels,
                   color=color,
                   save=f"_{data_name}_{cell_type}_latent_with_true_labels",
                   show=False,
                   wspace=0.15,
                   frameon=False)

        sc.pp.neighbors(mmd_latent_with_true_labels)
        sc.tl.umap(mmd_latent_with_true_labels)
        sc.pl.umap(
            mmd_latent_with_true_labels,
            color=color,
            save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels",
            show=False,
            wspace=0.15,
            frameon=False)
        if gene_list is not None:
            for target_condition in target_keys:
                pred_adata = pred_adatas[pred_adatas.obs[condition_key].str.
                                         endswith(target_condition)]
                violin_adata = cell_type_adata.concatenate(pred_adata)
                for gene in gene_list[:3]:
                    sc.pl.violin(
                        violin_adata,
                        keys=gene,
                        groupby=condition_key,
                        save=
                        f"_{data_name}_{cell_type}_{gene}_{target_condition}.pdf",
                        show=False,
                        wspace=0.2,
                        rotation=90,
                        frameon=False)

        plt.close("all")
예제 #2
0
def visualize_batch_correction(data_dict, z_dim=100, mmd_dimension=128):
    plt.close("all")
    data_name = data_dict['name']
    source_keys = data_dict.get("source_conditions")
    target_keys = data_dict.get("target_conditions")
    cell_type_key = data_dict.get("cell_type", None)
    need_merge = data_dict.get('need_merge', False)
    label_encoder = data_dict.get('label_encoder', None)
    condition_key = data_dict.get('condition', 'condition')

    if need_merge:
        data, _ = merge_data(data_dict)
    else:
        data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad")

    cell_types = data.obs[cell_type_key].unique().tolist()

    spec_cell_type = data_dict.get("spec_cell_types", None)
    if spec_cell_type:
        cell_types = spec_cell_type

    for cell_type in cell_types:
        path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/"
        os.makedirs(path_to_save, exist_ok=True)
        sc.settings.figdir = os.path.abspath(path_to_save)

        train_data = data.copy()[~(
            (data.obs[condition_key].isin(target_keys)) &
            (data.obs[cell_type_key] == cell_type))]

        cell_type_adata = data[data.obs[cell_type_key] == cell_type]
        network = trvae.trVAEMulti(
            x_dimension=data.shape[1],
            z_dimension=z_dim,
            n_conditions=len(source_keys),
            mmd_dimension=mmd_dimension,
            model_path=f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/",
        )

        network.restore_model()

        if sparse.issparse(data.X):
            data.X = data.X.A

        feed_data = data.X

        train_labels, _ = trvae.label_encoder(data, label_encoder,
                                              condition_key)

        mmd_latent_with_true_labels = network.to_mmd_layer(network,
                                                           feed_data,
                                                           train_labels,
                                                           feed_fake=0)

        latent_with_true_labels = network.to_z_latent(feed_data, train_labels)

        import matplotlib as mpl
        mpl.rcParams.update(mpl.rcParamsDefault)

        color = [condition_key]

        mmd_latent_with_true_labels = sc.AnnData(X=mmd_latent_with_true_labels)
        mmd_latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        mmd_latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        latent_with_true_labels = sc.AnnData(X=latent_with_true_labels)
        latent_with_true_labels.obs[condition_key] = data.obs[
            condition_key].values
        latent_with_true_labels.obs[cell_type_key] = data.obs[
            cell_type_key].values

        sc.pp.neighbors(train_data)
        sc.tl.umap(train_data)
        sc.pl.umap(train_data,
                   color=color,
                   save=f'_{data_name}_{cell_type}_train_data',
                   show=False,
                   wspace=0.15,
                   frameon=False)

        sc.pp.neighbors(mmd_latent_with_true_labels)
        sc.tl.umap(mmd_latent_with_true_labels)
        sc.pl.umap(
            mmd_latent_with_true_labels,
            color=color,
            save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels",
            show=False,
            wspace=0.15,
            frameon=False)

        sc.pp.neighbors(latent_with_true_labels)
        sc.tl.umap(latent_with_true_labels)
        sc.pl.umap(latent_with_true_labels,
                   color=color,
                   save=f"_{data_name}_{cell_type}_latent_with_true_labels",
                   show=False,
                   wspace=0.15,
                   frameon=False)

        # mmd_latent_with_true_labels.obs['mmd'] = 'others'
        # mmd_latent_with_true_labels.obs['mmd'] = mmd_latent_with_true_labels.obs.mmd.astype(str)
        # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-{target_keys[0]}'], inplace=True)
        # mmd_latent_with_true_labels.obs['mmd'].cat.add_categories([f'alpha-others'], inplace=True)
        # print(mmd_latent_with_true_labels.obs['mmd'].cat.categories)
        # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] == target_keys[0]) &
        #                                      mmd_latent_with_true_labels.obs[
        #                                          cell_type_key] == cell_type), 'mmd'] = f'alpha-{target_keys[0]}'
        # mmd_latent_with_true_labels.obs.loc[((mmd_latent_with_true_labels.obs[condition_key] != target_keys[0]) &
        #                                      mmd_latent_with_true_labels.obs[
        #                                          cell_type_key] == cell_type), 'mmd'] = f'alpha-others'
        #
        # sc.pl.umap(mmd_latent_with_true_labels, color='mmd',
        #            save=f"_{data_name}_{cell_type}_mmd_latent_with_true_labels_cell_comparison",
        #            show=False,
        #            wspace=0.15,
        #            frameon=False)
        mmd_latent_with_true_labels.write_h5ad('../data/mmd.h5ad')
        latent_with_true_labels.write_h5ad('../data/latent.h5ad')
예제 #3
0
def train_network(
    data_dict=None,
    z_dim=100,
    mmd_dimension=256,
    alpha=0.001,
    beta=100,
    kernel='multi-scale-rbf',
    n_epochs=500,
    batch_size=512,
    early_stop_limit=50,
    dropout_rate=0.2,
    learning_rate=0.001,
    loss_fn='mse',
    verbose=2,
):
    data_name = data_dict['name']
    target_keys = data_dict.get("target_conditions")
    cell_type_key = data_dict.get("cell_type", None)
    need_merge = data_dict.get('need_merge', False)
    label_encoder = data_dict.get('label_encoder', None)
    condition_key = data_dict.get('condition', 'condition')

    if need_merge:
        train_data, valid_data = merge_data(data_dict)
    else:
        adata = sc.read(f"../data/{data_name}/{data_name}.h5ad")
        if loss_fn != 'mse':
            adata = normalize_hvg(adata,
                                  filter_min_counts=False,
                                  normalize_input=False,
                                  logtrans_input=True)
        train_data, valid_data = train_test_split(adata, 0.80)

    spec_cell_type = data_dict.get("spec_cell_types", None)
    if cell_type_key is not None:
        cell_types = train_data.obs[cell_type_key].unique().tolist()
        if spec_cell_type:
            cell_types = spec_cell_type

        for cell_type in cell_types:
            net_train_data = train_data.copy()[~(
                (train_data.obs[cell_type_key] == cell_type) &
                (train_data.obs[condition_key].isin(target_keys)))]
            net_valid_data = valid_data.copy()[~(
                (valid_data.obs[cell_type_key] == cell_type) &
                (valid_data.obs[condition_key].isin(target_keys)))]
            n_conditions = len(
                net_train_data.obs[condition_key].unique().tolist())
            if data_name == 'pancreas':
                use_leaky_relu = True
            else:
                use_leaky_relu = False

            network = trvae.trVAEMulti(
                x_dimension=net_train_data.shape[1],
                z_dimension=z_dim,
                n_conditions=n_conditions,
                mmd_dimension=mmd_dimension,
                alpha=alpha,
                beta=beta,
                kernel=kernel,
                learning_rate=learning_rate,
                loss_fn=loss_fn,
                model_path=
                f"../models/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/",
                dropout_rate=dropout_rate,
                use_leaky_relu=use_leaky_relu)

            network.train(net_train_data,
                          label_encoder,
                          condition_key,
                          use_validation=True,
                          valid_adata=net_valid_data,
                          n_epochs=n_epochs,
                          batch_size=batch_size,
                          verbose=verbose,
                          early_stop_limit=early_stop_limit,
                          shuffle=True,
                          save=True)

            print(f"Model for {cell_type} has been trained")