Example #1
0
        def __init__(self, X=None, V=None, Grid=None, *args, **kwargs):
            self.norm_dict = {}

            if X is not None and V is not None:
                self.parameters = update_n_merge_dict(kwargs, {
                    "X": X,
                    "V": V,
                    "Grid": Grid
                })

                import tempfile
                from dynode.vectorfield import networkModels
                from dynode.vectorfield.samplers import VelocityDataSampler
                from dynode.vectorfield.losses_weighted import (
                    MSE,
                )  # MAD, BinomialChannel, WassersteinDistance, CosineDistance

                good_ind = np.where(~np.isnan(V.sum(1)))[0]
                good_V = V[good_ind, :]
                good_X = X[good_ind, :]

                self.valid_ind = good_ind

                velocity_data_sampler = VelocityDataSampler(
                    adata={
                        "X": good_X,
                        "V": good_V
                    },
                    normalize_velocity=kwargs.get("normalize_velocity", False),
                )

                vf_kwargs = {
                    "X": X,
                    "V": V,
                    "Grid": Grid,
                    "model": networkModels,
                    "sirens": False,
                    "enforce_positivity": False,
                    "velocity_data_sampler": velocity_data_sampler,
                    "time_course_data_sampler": None,
                    "network_dim": X.shape[1],
                    "velocity_loss_function":
                    MSE(),  # CosineDistance(), # #MSE(), MAD()
                    "time_course_loss_function":
                    None,  # BinomialChannel(p=0.1, alpha=1)
                    "velocity_x_initialize": X,
                    "time_course_x0_initialize": None,
                    "smoothing_factor": None,
                    "stability_factor": None,
                    "load_model_from_buffer": False,
                    "buffer_path": tempfile.mkdtemp(),
                    "hidden_features": 256,
                    "hidden_layers": 3,
                    "first_omega_0": 30.0,
                    "hidden_omega_0": 30.0,
                }
                vf_kwargs = update_dict(vf_kwargs, self.parameters)
                super().__init__(**vf_kwargs)
Example #2
0
def VectorField(
    adata: anndata.AnnData,
    basis: Union[None, str] = None,
    layer: str = "X",
    dims: Union[int, list, None] = None,
    genes: Union[list, None] = None,
    normalize: bool = False,
    grid_velocity: bool = False,
    grid_num: int = 50,
    velocity_key: str = "velocity_S",
    method: str = "SparseVFC",
    model_buffer_path: Union[str, None] = None,
    return_vf_object: bool = False,
    map_topography: bool = True,
    pot_curl_div: bool = False,
    cores: int = 1,
    copy: bool = False,
    **kwargs,
) -> Union[anndata.AnnData, base_vectorfield]:
    """Learn a function of high dimensional vector field from sparse single cell samples in the entire space robustly.

    Parameters
    ----------
        adata:
            AnnData object that contains embedding and velocity data
        basis:
            The embedding data to use. The vector field function will be learned on the low dimensional embedding and can be then
            projected back to the high dimensional space.
        layer:
            Which layer of the data will be used for vector field function reconstruction. The layer once provided, will override
            the `basis` argument and then learn the vector field function in high dimensional space.
        dims:
            The dimensions that will be used for reconstructing vector field functions. If it is an `int` all dimension from
            the first dimension to `dims` will be used; if it is a list, the dimensions in the list will be used.
        genes:
            The gene names whose gene expression will be used for vector field reconstruction. By default (when genes is
            set to None), the genes used for velocity embedding (var.use_for_transition) will be used for vector field reconstruction.
            Note that the genes to be used need to have velocity calculated.
        normalize:
            Logic flag to determine whether to normalize the data to have zero means and unit covariance. This is often
            required for raw dataset (for example, raw UMI counts and RNA velocity values in high dimension). But it is
            normally not required for low dimensional embeddings by PCA or other non-linear dimension reduction methods.
        grid_velocity:
            Whether to generate grid velocity. Note that by default it is set to be False, but for datasets with embedding
            dimension less than 4, the grid velocity will still be generated. Please note that number of total grids in
            the space increases exponentially as the number of dimensions increases. So it may quickly lead to lack of
            memory, for example, it cannot allocate the array with grid_num set to be 50 and dimension is 6 (50^6 total
            grids) on 32 G memory computer. Although grid velocity may not be generated, the vector field function can still
            be learned for thousands of dimensions and we can still predict the transcriptomic cell states over long time period.
        grid_num:
            The number of grids in each dimension for generating the grid velocity.
        velocity_key:
            The key from the adata layer that corresponds to the velocity matrix.
        method:
            Method that is used to reconstruct the vector field functionally. Currently only SparseVFC supported but other
            improved approaches are under development.
        buffer_path:
               The directory address keeping all the saved/to-be-saved torch variables and NN modules. When `method` is
               set to be `dynode`, buffer_path will set to be
        return_vf_object:
            Whether or not to include an instance of a vectorfield class in the the `VecFld` dictionary in the `uns`
            attribute.
        map_topography:
            Whether to quantify the topography of the 2D vector field.
        pot_curl_div:
            Whether to calculate potential, curl or divergence for each cell. Potential can be calculated for any basis
            while curl and divergence is by default only applied to 2D basis. However, divergence is applicable for any
            dimension while curl is generally only defined for 2/3 D systems.
        cores:
            Number of cores to run the ddhodge function. If cores is set to be > 1, multiprocessing will be used to parallel
            the ddhodge calculation.
        copy:
            Whether to return a new deep copy of `adata` instead of updating `adata` object passed in arguments and returning `None`.
        kwargs:
            Other additional parameters passed to the vectorfield class.

    Returns
    -------
        adata: :class:`Union[anndata.AnnData, base_vectorfield]`
            If `copy` and `return_vf_object` arguments are set to False, `annData` object is updated with the `VecFld` dictionary in the `uns` attribute.
            If `return_vf_object` is set to True, then a vector field class object is returned.
            If `copy` is set to True, a deep copy of the original `adata` object is returned.
    """
    logger = LoggerManager.get_logger("dynamo-topography")
    logger.info("vectorfield calculation begins...", indent_level=1)
    logger.log_time()
    if copy:
        logger.info(
            "Deep copying AnnData object and working on the new copy. Original AnnData object will not be modified.",
            indent_level=1,
        )
        adata = adata.copy()

    if basis is not None:
        logger.info("Retrieve X and V based on basis: %s. \n "
                    "       Vector field will be learned in the %s space." %
                    (basis.upper(), basis.upper()))
        X = adata.obsm["X_" + basis].copy()
        V = adata.obsm["velocity_" + basis].copy()

        if np.isscalar(dims):
            X, V = X[:, :dims], V[:, :dims]
        elif type(dims) is list:
            X, V = X[:, dims], V[:, dims]
    else:
        logger.info(
            "Retrieve X and V based on `genes`, layer: %s. \n "
            "       Vector field will be learned in the gene expression space."
            % layer)
        valid_genes = (list(set(genes).intersection(adata.var.index))
                       if genes is not None else
                       adata.var_names[adata.var.use_for_transition])
        if layer == "X":
            X = adata[:, valid_genes].X.copy()
            X = np.expm1(X)
        else:
            X = inverse_norm(adata, adata.layers[layer])

        V = adata[:, valid_genes].layers[velocity_key].copy()

        if sp.issparse(X):
            X, V = X.A, V.A

    Grid = None
    if X.shape[1] < 4 or grid_velocity:
        logger.info(
            "Generating high dimensional grids and convert into a row matrix.")
        # smart way for generating high dimensional grids and convert into a row matrix
        min_vec, max_vec = (
            X.min(0),
            X.max(0),
        )
        min_vec = min_vec - 0.01 * np.abs(max_vec - min_vec)
        max_vec = max_vec + 0.01 * np.abs(max_vec - min_vec)

        Grid_list = np.meshgrid(
            *[np.linspace(i, j, grid_num) for i, j in zip(min_vec, max_vec)])
        Grid = np.array([i.flatten() for i in Grid_list]).T

    if X is None:
        raise Exception(
            f"X is None. Make sure you passed the correct X or {basis} dimension reduction method."
        )
    elif V is None:
        raise Exception("V is None. Make sure you passed the correct V.")

    logger.info("Learning vector field with method: %s." % (method.lower()))
    if method.lower() == "sparsevfc":
        vf_kwargs = {
            "M": None,
            "a": 5,
            "beta": None,
            "ecr": 1e-5,
            "gamma": 0.9,
            "lambda_": 3,
            "minP": 1e-5,
            "MaxIter": 30,
            "theta": 0.75,
            "div_cur_free_kernels": False,
            "velocity_based_sampling": True,
            "sigma": 0.8,
            "eta": 0.5,
            "seed": 0,
        }
    elif method.lower() == "dynode":
        try:
            import dynode
            from dynode.vectorfield import networkModels
            from dynode.vectorfield.samplers import VelocityDataSampler

            # from dynode.vectorfield.losses_weighted import MAD, BinomialChannel, WassersteinDistance, CosineDistance
            from dynode.vectorfield.losses_weighted import MSE
            from .scVectorField import dynode_vectorfield
        except ImportError:
            raise ImportError("You need to install the package `dynode`."
                              "install dynode via `pip install dynode`")

        velocity_data_sampler = VelocityDataSampler(
            adata={
                "X": X,
                "V": V
            }, normalize_velocity=normalize)
        max_iter = 2 * 100000 * np.log(X.shape[0]) / (250 + np.log(X.shape[0]))

        cwd, cwt = os.getcwd(), datetime.datetime.now()

        if model_buffer_path is None:
            model_buffer_path = cwd + "/" + basis + "_" + str(
                cwt.year) + "_" + str(cwt.month) + "_" + str(cwt.day)
            warnings.warn(f"the buffer path saving the dynode model is in %s" %
                          (model_buffer_path))

        vf_kwargs = {
            "model": networkModels,
            "sirens": False,
            "enforce_positivity": False,
            "velocity_data_sampler": velocity_data_sampler,
            "time_course_data_sampler": None,
            "network_dim": X.shape[1],
            "velocity_loss_function":
            MSE(),  # CosineDistance(), # #MSE(), MAD()
            # BinomialChannel(p=0.1, alpha=1)
            "time_course_loss_function": None,
            "velocity_x_initialize": X,
            "time_course_x0_initialize": None,
            "smoothing_factor": None,
            "stability_factor": None,
            "load_model_from_buffer": False,
            "buffer_path": model_buffer_path,
            "hidden_features": 256,
            "hidden_layers": 3,
            "first_omega_0": 30.0,
            "hidden_omega_0": 30.0,
        }
        train_kwargs = {
            "max_iter": int(max_iter),
            "velocity_batch_size": 50,
            "time_course_batch_size": 100,
            "autoencoder_batch_size": 50,
            "velocity_lr": 1e-4,
            "velocity_x_lr": 0,
            "time_course_lr": 1e-4,
            "time_course_x0_lr": 1e4,
            "autoencoder_lr": 1e-4,
            "velocity_sample_fraction": 1,
            "time_course_sample_fraction": 1,
            "iter_per_sample_update": None,
        }
    else:
        raise ValueError(
            f"current only support two methods, SparseVFC and dynode")

    vf_kwargs = update_dict(vf_kwargs, kwargs)

    if method.lower() == "sparsevfc":
        VecFld = svc_vectorfield(X, V, Grid, **vf_kwargs)
        vf_dict = VecFld.train(normalize=normalize, **kwargs)
    elif method.lower() == "dynode":
        train_kwargs = update_dict(train_kwargs, kwargs)
        VecFld = dynode_vectorfield(X, V, Grid, **vf_kwargs)
        # {"VecFld": VecFld.train(**kwargs)}
        vf_dict = VecFld.train(**train_kwargs)

    vf_key = "VecFld" if basis is None else "VecFld_" + basis

    vf_dict["method"] = method
    if basis is not None:
        key = "velocity_" + basis + "_" + method
        X_copy_key = "X_" + basis + "_" + method

        logger.info_insert_adata(key, adata_attr="obsm")
        logger.info_insert_adata(X_copy_key, adata_attr="obsm")
        adata.obsm[key] = vf_dict["V"]
        adata.obsm[X_copy_key] = vf_dict["X"]

        vf_dict["dims"] = dims

        logger.info_insert_adata(vf_key, adata_attr="uns")
        adata.uns[vf_key] = vf_dict
    else:
        key = velocity_key + "_" + method

        logger.info_insert_adata(key, adata_attr="layers")
        adata.layers[key] = sp.csr_matrix((adata.shape))
        adata.layers[key][:, valid_genes] = vf_dict["V"]

        vf_dict["layer"] = layer
        vf_dict["genes"] = genes
        vf_dict["velocity_key"] = velocity_key

        logger.info_insert_adata(vf_key, adata_attr="uns")
        adata.uns[vf_key] = vf_dict

    if X.shape[1] == 2 and map_topography:
        tp_kwargs = {"n": 25}
        tp_kwargs = update_dict(tp_kwargs, kwargs)

        logger.info("Mapping topography...")
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            adata = topography(adata,
                               basis=basis,
                               X=X,
                               layer=layer,
                               dims=[0, 1],
                               VecFld=vf_dict,
                               **tp_kwargs)
    if pot_curl_div:
        if basis in ["pca", "umap", "tsne", "diffusion_map", "trimap"]:
            logger.info(
                "Running ddhodge to estimate vector field based pseudotime...")

            ddhodge(adata, basis=basis, cores=cores)
            if X.shape[1] == 2:
                logger.info("Computing curl...")
                curl(adata, basis=basis)

            logger.info("Computing divergence...")
            divergence(adata, basis=basis)

    control_point, inlier_prob, valid_ids = (
        "control_point_" + basis if basis is not None else "control_point",
        "inlier_prob_" + basis if basis is not None else "inlier_prob",
        vf_dict["valid_ind"],
    )
    if method.lower() == "sparsevfc":
        logger.info_insert_adata(control_point, adata_attr="obs")
        logger.info_insert_adata(inlier_prob, adata_attr="obs")

        adata.obs[control_point], adata.obs[inlier_prob] = False, np.nan
        adata.obs.loc[adata.obs_names[vf_dict["ctrl_idx"]],
                      control_point] = True
        adata.obs.loc[adata.obs_names[valid_ids],
                      inlier_prob] = vf_dict["P"].flatten()

    # angles between observed velocity and that predicted by vector field across cells:
    cell_angels = np.zeros(adata.n_obs)
    for i, u, v in zip(valid_ids, V[valid_ids], vf_dict["V"]):
        cell_angels[i] = angle(u, v)

    if basis is not None:
        temp_key = "obs_vf_angle_" + basis

        logger.info_insert_adata(temp_key, adata_attr="obs")
        adata.obs[temp_key] = cell_angels
    else:
        temp_key = "obs_vf_angle"
        logger.info_insert_adata(temp_key, adata_attr="obs")
        adata.obs[temp_key] = cell_angels

    logger.finish_progress("VectorField")
    if return_vf_object:
        return VecFld
    elif copy:
        return adata
    return None