Exemplo n.º 1
0
def run_MNN(datasets, task, task_adata, method_name, log_dir, args):
    method_key = '{}_aligned'.format(method_name)
    A_idx = task_adata.obs[task.batch_key] == task.source_batch
    B_idx = task_adata.obs[task.batch_key] == task.target_batch
    if args.input_space == 'PCA':
        A_X = task_adata[A_idx].obsm['PCA']
        B_X = task_adata[B_idx].obsm['PCA']
    else:
        A_X = task_adata[A_idx].X
        B_X = task_adata[B_idx].X
#             # standardizing
#             scaler = StandardScaler().fit(np.concatenate((A_X,B_X)))
#             A_X = scaler.transform(A_X)
#             B_X = scaler.transform(B_X)
    mnn_adata_A = anndata.AnnData(X=A_X, obs=task_adata[A_idx].obs)
    mnn_adata_B = anndata.AnnData(X=B_X, obs=task_adata[B_idx].obs)
    t0 = datetime.datetime.now()
    corrected = mnnpy.mnn_correct(mnn_adata_A, mnn_adata_B)
    t1 = datetime.datetime.now()
    time_str = pretty_tdelta(t1 - t0)
    print(f'took: {time_str}')
    with open(log_dir / 'fit_time.txt', 'w') as f:
        f.write(time_str + '\n')
    task_adata.obsm[method_key] = np.zeros(corrected[0].shape)
    task_adata.obsm[method_key][np.where(A_idx)[0]] = corrected[0].X[:mnn_adata_A.shape[0]]
    task_adata.obsm[method_key][np.where(B_idx)[0]] = corrected[0].X[mnn_adata_A.shape[0]:]
Exemplo n.º 2
0
def runMNN(adata, batch, hvg=None):
    import mnnpy
    checkSanity(adata, batch, hvg)
    split = splitBatches(adata, batch)

    corrected = mnnpy.mnn_correct(*split, var_subset=hvg)

    return corrected[0]
Exemplo n.º 3
0
 def correction(self):
     print("Start MNN...\n")
     start = time.time()
     mnn_adata = self.adata.copy()
     adata_list = [
         mnn_adata[mnn_adata.obs[self.batch] == i]
         for i in mnn_adata.obs[self.batch].unique()
     ]
     corrected = mnnpy.mnn_correct(*adata_list, var_subset=self.hvgs)
     self.adata = corrected[0]
     print(f"MNN has taken {round(time.time() - start, 2)} seconds")
Exemplo n.º 4
0
def runMNN(adata, batch, hvg=None):
    import mnnpy
    checkSanity(adata, batch, hvg)
    split, categories = splitBatches(adata, batch, return_categories=True)

    corrected, _, _ = mnnpy.mnn_correct(*split,
                                        var_subset=hvg,
                                        batch_key=batch,
                                        batch_categories=categories,
                                        index_unique=None)

    return corrected
Exemplo n.º 5
0
sc.pp.log1p(adata)
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver='arpack')

# Plot UMAP and T-SNE before correction
umapplot(adata,
         color_by=[args.celltype, args.batch],
         save_file_prefix=f"mnn_umap_{args.adata}_before_cor")

# Correction
print("Starting MNN...")
start = time.time()
adata_list = [
    adata[adata.obs[args.batch] == i] for i in adata.obs[args.batch].unique()
]
hvgs = adata.var_names
adata_mnn = mnnpy.mnn_correct(adata_list[0], adata_list[1], var_subset=hvgs)[0]
print(f"MNN has taken {time.time() - start} seconds")

# Plot UMAP and T-SNE after correction
sc.pp.neighbors(adata_mnn, n_neighbors=10, n_pcs=20)
sc.tl.umap(adata_mnn)
sc.pl.umap(adata_mnn, color=[args.celltype, args.batch], show=False)
resname = f"./visualization/mnn_umap_{args.adata}_after_cor.png"
plt.savefig(resname, dpi=100)

# Save corrected adata
if not os.path.exists(f"./{args.adata[:6]}"):
    os.makedirs(f"./{args.adata[:6]}")
adata_mnn.write_h5ad(os.path.join(f"./{args.adata[:6]}/mnn_{args.adata}"))
Exemplo n.º 6
0
def mnn_correct(
    *datas: Union[AnnData, np.ndarray],
    var_index: Optional[Collection[str]] = None,
    var_subset: Optional[Collection[str]] = None,
    batch_key: str = 'batch',
    index_unique: str = '-',
    batch_categories: Optional[Collection[Any]] = None,
    k: int = 20,
    sigma: float = 1.0,
    cos_norm_in: bool = True,
    cos_norm_out: bool = True,
    svd_dim: Optional[int] = None,
    var_adj: bool = True,
    compute_angle: bool = False,
    mnn_order: Optional[Sequence[int]] = None,
    svd_mode: Literal['svd', 'rsvd', 'irlb'] = 'rsvd',
    do_concatenate: bool = True,
    save_raw: bool = False,
    n_jobs: Optional[int] = None,
    **kwargs,
) -> Tuple[Union[np.ndarray, AnnData], List[pd.DataFrame], Optional[List[Tuple[
        Optional[float], int]]], ]:
    """\
    Correct batch effects by matching mutual nearest neighbors [Haghverdi18]_ [Kang18]_.
    This uses the implementation of `mnnpy
    <https://github.com/chriscainx/mnnpy>`__ [Kang18]_.
    Depending on `do_concatenate`, returns matrices or `AnnData` objects in the
    original order containing corrected expression values or a concatenated
    matrix or AnnData object.
    Be reminded that it is not advised to use the corrected data matrices for
    differential expression testing.
    More information and bug reports `here <https://github.com/chriscainx/mnnpy>`__.
    Parameters
    ----------
    datas
        Expression matrices or AnnData objects. Matrices should be shaped like
        n_obs × n_vars (n_subject × n_feature) and have consistent number of columns.
        AnnData objects should have same number of variables.
    var_index
        The index (list of str) of vars (features). Necessary when using only a
        subset of vars to perform MNN correction, and should be supplied with
        `var_subset`. When `datas` are AnnData objects, `var_index` is ignored.
    var_subset
        The subset of vars (list of str) to be used when performing MNN
        correction. Typically, a list of highly variable features.
        When set to `None`, uses all vars.
    batch_key
        The `batch_key` for :meth:`~anndata.AnnData.concatenate`.
        Only valid when `do_concatenate` and supplying `AnnData` objects.
    index_unique
        The `index_unique` for :meth:`~anndata.AnnData.concatenate`.
        Only valid when `do_concatenate` and supplying `AnnData` objects.
    batch_categories
        The `batch_categories` for :meth:`~anndata.AnnData.concatenate`.
        Only valid when `do_concatenate` and supplying AnnData objects.
    k
        Number of mutual nearest neighbors.
    sigma
        The bandwidth of the Gaussian smoothing kernel used to compute the
        correction vectors. Default is 1.
    cos_norm_in
        Whether cosine normalization should be performed on the input data prior
        to calculating distances between subjects.
    cos_norm_out
        Whether cosine normalization should be performed prior to computing corrected expression values.
    svd_dim
        The number of dimensions to use for summarizing meaningful substructure
        within each batch. If None, meaningful components will not be removed
        from the correction vectors.
    var_adj
        Whether to adjust variance of the correction vectors. Note this step
        takes most computing time.
    compute_angle
        Whether to compute the angle between each subject’s correction vector and
        the meaningful subspace of the reference batch.
    mnn_order
        The order in which batches are to be corrected. When set to None, datas
        are corrected sequentially.
    svd_mode
        `'svd'` computes SVD using a non-randomized SVD-via-ID algorithm,
        while `'rsvd'` uses a randomized version. `'irlb'` perfores
        truncated SVD by implicitly restarted Lanczos bidiagonalization
        (forked from https://github.com/airysen/irlbpy).
    do_concatenate
        Whether to concatenate the corrected matrices or AnnData objects. Default is True.
    save_raw
        Whether to save the original expression data in the
        :attr:`~anndata.AnnData.raw` attribute.
    n_jobs
        The number of jobs. When set to `None`, automatically uses
        :attr:`quanp._settings.QuanpyConfig.n_jobs`.
    kwargs
        optional keyword arguments for irlb.
    Returns
    -------
    datas
        Corrected matrix/matrices or AnnData object/objects, depending on the
        input type and `do_concatenate`.
    mnn_list
        A list containing MNN pairing information as DataFrames in each iteration step.
    angle_list
        A list containing angles of each batch.
    """
    if len(datas) < 2:
        return datas, [], []

    try:
        from mnnpy import mnn_correct
    except ImportError:
        raise ImportError('Please install the package mnnpy '
                          '(https://github.com/chriscainx/mnnpy). ')

    n_jobs = settings.n_jobs if n_jobs is None else n_jobs
    datas, mnn_list, angle_list = mnn_correct(
        *datas,
        var_index=var_index,
        var_subset=var_subset,
        batch_key=batch_key,
        index_unique=index_unique,
        batch_categories=batch_categories,
        k=k,
        sigma=sigma,
        cos_norm_in=cos_norm_in,
        cos_norm_out=cos_norm_out,
        svd_dim=svd_dim,
        var_adj=var_adj,
        compute_angle=compute_angle,
        mnn_order=mnn_order,
        svd_mode=svd_mode,
        do_concatenate=do_concatenate,
        save_raw=save_raw,
        n_jobs=n_jobs,
        **kwargs,
    )
    return datas, mnn_list, angle_list
     lisi_scores.append(
         metrics.lisi2(task_adata.obsm[method_key],
                       task_adata.obs, [task.batch_key, task.ct_key],
                       perplexity=30))
 elif method == 'MNN':
     A_idx = task_adata.obs[task.batch_key] == task.source_batch
     B_idx = task_adata.obs[task.batch_key] == task.target_batch
     A_X = task_adata[A_idx].obsm['PCA']
     B_X = task_adata[B_idx].obsm['PCA']
     #             # standardizing
     #             scaler = StandardScaler().fit(np.concatenate((A_X,B_X)))
     #             A_X = scaler.transform(A_X)
     #             B_X = scaler.transform(B_X)
     mnn_adata_A = anndata.AnnData(X=A_X, obs=task_adata[A_idx].obs)
     mnn_adata_B = anndata.AnnData(X=B_X, obs=task_adata[B_idx].obs)
     corrected = mnnpy.mnn_correct(mnn_adata_A, mnn_adata_B)
     task_adata.obsm[method_key] = np.zeros(corrected[0].shape)
     task_adata.obsm[method_key][np.where(
         A_idx)[0]] = corrected[0].X[:mnn_adata_A.shape[0]]
     task_adata.obsm[method_key][np.where(
         B_idx)[0]] = corrected[0].X[mnn_adata_A.shape[0]:]
     task_adata.obsm[method_key + '_TSNE'] = TSNE(
         n_components=2).fit_transform(task_adata.obsm[method_key])
     task_adata.obsm[method_key + '_PCA'] = PCA(
         n_components=2).fit_transform(task_adata.obsm[method_key])
     plot_embedding(task_adata, method_key + '_PCA', task, pca_fig,
                    pca_outer_grid, i + 1, j + 1)
     plot_embedding(task_adata, method_key + '_TSNE', task, fig,
                    outer_grid, i + 1, j + 1)
     lisi_scores.append(
         metrics.lisi2(task_adata.obsm[method_key],
    index_unique = None

    if len(cell_ids) != len(np.unique(cell_ids)):
        print("Non-unique cell index detected!")
        print(
            "Make the index unique by joining the existing index names with the batch category, using index_unique='-'"
        )
        index_unique = '-'

    # [mnn_correct] Subset only HVG otherwise "ValueError: Lengths must match to compare"
    print(f"Perform mnnCorrect...")
    corrected = mnn_correct(
        *_adatas,
        var_index=args.var_index,
        var_subset=hvgs if args.var_subset is None else args.var_subset,
        batch_key=args.batch_key,
        index_unique=index_unique,
        k=args.k,
        n_jobs=args.n_jobs)
    adata = corrected[0]
    # Run MNN_CORRECT (mnnpy)
    # Open GitHub issue: https://github.com/theislab/scanpy/issues/757
    # sc.external.pp.mnn_correct(adata,
    #     var_index=options.var_index,
    #     var_subset=options.var_subset,
    #     batch_key=options.batch_key,
    #     k=options.k)
else:
    raise Exception(
        f"The given batch effect correction method {args.method} is not implemented."
    )
from scipy.io import loadmat
from scipy import sparse

import random

import mnnpy

dat = loadmat('data/10x_pooled_400.mat')
data = sparse.csc_matrix(dat['data'])
labs = dat['labels'].flatten()

indices = list(range(data.shape[1]))
random.shuffle(indices)
batch1 = indices[:200]
batch2 = indices[200:]
data1 = data[:, batch1]
data2 = data[:, batch2]
data1_dense = data1.T.toarray()
data2_dense = data2.T.toarray()
var_index = list(range(data1.shape[0]))
data_corrected = mnnpy.mnn_correct(data1.T, data2.T, var_index=var_index)

data_corrected_dense = mnnpy.mnn_correct(data1_dense,
                                         data2_dense,
                                         var_index=var_index)