Ejemplo n.º 1
0
def spatial_neighbors(
    adata: AnnData,
    spatial_key: str = Key.obsm.spatial,
    coord_type: Optional[Union[str, CoordType]] = None,
    n_rings: int = 1,
    n_neigh: int = 6,
    delaunay: bool = False,
    radius: Optional[float] = None,
    transform: Optional[Union[str, Transform]] = None,
    key_added: Optional[str] = None,
) -> None:
    """
    Create a graph from spatial coordinates.

    Parameters
    ----------
    %(adata)s
    %(spatial_key)s
    coord_type
        Type of coordinate system. Valid options are:

            - `{c.VISIUM!r}` - Visium coordinates.
            - `{c.GENERIC!r}` - generic coordinates.

        If `None`, use `{c.VISIUM!r}` if ``spatial_key`` is present in :attr:`anndata.AnnData.obsm`,
        otherwise use `{c.GENERIC!r}`.
    n_rings
        Number of rings of neighbors for Visium data.
    n_neigh
        Number of neighborhoods to consider for non-Visium data.
    delaunay
        Whether to compute the graph from Delaunay triangulation.
    radius
        Radius of neighbors for non-Visium data.
    transform
        Type of adjacency matrix transform. Valid options are:

            - `{t.SPECTRAL.s!r}` - spectral transformation of the adjacency matrix.
            - `{t.COSINE.s!r}` - cosine transformation of the adjacency matrix.
            - `{t.NONE.v}` - no transformation of the adjacency matrix.

    key_added
        Key which controls where the results are saved.

    Returns
    -------
    Modifies the ``adata`` with the following keys:

        - :attr:`anndata.AnnData.obsp` ``['{{key_added}}_connectivities']`` - spatial connectivity matrix.
        - :attr:`anndata.AnnData.obsp` ``['{{key_added}}_distances']`` - spatial distances matrix.
        - :attr:`anndata.AnnData.uns`  ``['{{key_added}}']`` - spatial neighbors dictionary.
    """
    _assert_positive(n_rings, name="n_rings")
    _assert_positive(n_neigh, name="n_neigh")
    _assert_spatial_basis(adata, spatial_key)

    transform = Transform.NONE if transform is None else Transform(transform)
    if coord_type is None:
        coord_type = CoordType.VISIUM if Key.uns.spatial in adata.uns else CoordType.GENERIC
    else:
        coord_type = CoordType(coord_type)

    start = logg.info(
        f"Creating graph using `{coord_type}` coordinates and `{transform}` transform"
    )

    coords = adata.obsm[spatial_key]
    if coord_type == CoordType.VISIUM:
        if n_rings > 1:
            Adj: csr_matrix = _build_connectivity(coords,
                                                  6,
                                                  neigh_correct=True,
                                                  set_diag=True,
                                                  delaunay=delaunay,
                                                  return_distance=False)
            Res = Adj
            Walk = Adj
            for i in range(n_rings - 1):
                Walk = Walk @ Adj
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", SparseEfficiencyWarning)
                    Walk[Res.nonzero()] = 0.0
                Walk.eliminate_zeros()
                Walk.data[:] = i + 2.0
                Res = Res + Walk
            Adj = Res
            Adj.setdiag(0.0)
            Adj.eliminate_zeros()

            Dst = Adj.copy()
            Adj.data[:] = 1.0
        else:
            Adj = _build_connectivity(coords,
                                      6,
                                      neigh_correct=True,
                                      delaunay=delaunay)
            Dst = None

    elif coord_type == CoordType.GENERIC:
        Adj, Dst = _build_connectivity(coords,
                                       n_neigh,
                                       radius,
                                       delaunay=delaunay,
                                       return_distance=True)
    else:
        raise NotImplementedError(coord_type)

    # check transform
    if transform == Transform.SPECTRAL:
        Adj = _transform_a_spectral(Adj)
    elif transform == Transform.COSINE:
        Adj = _transform_a_cosine(Adj)
    elif transform == Transform.NONE:
        pass
    else:
        raise NotImplementedError(
            f"Transform `{transform}` is not yet implemented.")

    neighs_key = Key.uns.spatial_neighs(key_added)
    conns_key = Key.obsp.spatial_conn(key_added)
    dists_key = Key.obsp.spatial_dist(key_added)

    neighbors_dict = {
        "connectivities_key": conns_key,
        "params": {
            "n_neighbors": n_neigh,
            "coord_type": coord_type.v,
            "radius": radius,
            "transform": transform.v
        },
        "distances_key": dists_key,
    }

    _save_data(adata, attr="obsp", key=conns_key, data=Adj)
    if Dst is not None:
        _save_data(adata, attr="obsp", key=dists_key, data=Dst, prefix=False)

    _save_data(adata,
               attr="uns",
               key=neighs_key,
               data=neighbors_dict,
               prefix=False,
               time=start)
Ejemplo n.º 2
0
def nhood_enrichment(
    adata: AnnData,
    cluster_key: str,
    connectivity_key: Optional[str] = None,
    n_perms: int = 1000,
    numba_parallel: bool = False,
    seed: Optional[int] = None,
    copy: bool = False,
    n_jobs: Optional[int] = None,
    backend: str = "loky",
    show_progress_bar: bool = True,
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
    """
    Compute neighborhood enrichment by permutation test.

    Parameters
    ----------
    %(adata)s
    %(cluster_key)s
    %(conn_key)s
    %(n_perms)s
    %(numba_parallel)s
    %(seed)s
    %(copy)s
    %(parallelize)s

    Returns
    -------
    If ``copy = True``, returns a :class:`tuple` with the z-score and the enrichment count.

    Otherwise, modifies the ``adata`` with the following keys:

        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score.
        - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count.
    """
    connectivity_key = Key.obsp.spatial_conn(connectivity_key)
    _assert_categorical_obs(adata, cluster_key)
    _assert_connectivity_key(adata, connectivity_key)
    _assert_positive(n_perms, name="n_perms")

    adj = adata.obsp[connectivity_key]
    original_clust = adata.obs[cluster_key]
    clust_map = {
        v: i
        for i, v in enumerate(original_clust.cat.categories.values)
    }  # map categories
    int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt)

    indices, indptr = (adj.indices.astype(ndt), adj.indptr.astype(ndt))
    n_cls = len(clust_map)

    _test = _create_function(n_cls, parallel=numba_parallel)
    count = _test(indices, indptr, int_clust)

    n_jobs = _get_n_cores(n_jobs)
    start = logg.info(
        f"Calculating neighborhood enrichment using `{n_jobs}` core(s)")

    perms = parallelize(
        _nhood_enrichment_helper,
        collection=np.arange(n_perms),
        extractor=np.vstack,
        n_jobs=n_jobs,
        backend=backend,
        show_progress_bar=show_progress_bar,
    )(callback=_test,
      indices=indices,
      indptr=indptr,
      int_clust=int_clust,
      n_cls=n_cls,
      seed=seed)
    zscore = (count - perms.mean(axis=0)) / perms.std(axis=0)

    if copy:
        return zscore, count

    _save_data(
        adata,
        attr="uns",
        key=Key.uns.nhood_enrichment(cluster_key),
        data={
            "zscore": zscore,
            "count": count
        },
        time=start,
    )
Ejemplo n.º 3
0
    def test(
        self,
        cluster_key: str,
        clusters: Optional[Cluster_t] = None,
        n_perms: int = 1000,
        threshold: float = 0.01,
        seed: Optional[int] = None,
        corr_method: Optional[str] = None,
        corr_axis: Union[str, CorrAxis] = CorrAxis.INTERACTIONS.v,
        alpha: float = 0.05,
        copy: bool = False,
        key_added: Optional[str] = None,
        numba_parallel: Optional[bool] = None,
        **kwargs: Any,
    ) -> Optional[Mapping[str, pd.DataFrame]]:
        """
        Perform the permutation test as described in :cite:`cellphonedb`.

        Parameters
        ----------
        %(cluster_key)s
        clusters
            Clusters from :attr:`anndata.AnnData.obs` ``['{{cluster_key}}']``. Can be specified either as a sequence
            of :class:`tuple` or just a sequence of cluster names, in which case all combinations considered.
        %(n_perms)s
        threshold
            Do not perform permutation test if any of the interacting components is being expressed
            in less than ``threshold`` percent of cells within a given cluster.
        %(seed)s
        %(corr_method)s
        corr_axis
            Axis over which to perform the FDR correction. Only used when ``corr_method != None``. Valid options are:

                - `{fa.INTERACTIONS.s!r}` - correct interactions by performing FDR correction across the clusters.
                - `{fa.CLUSTERS.s!r}` - correct clusters by performing FDR correction across the interactions.
        alpha
            Significance level for FDR correction. Only used when ``corr_method != None``.
        %(copy)s
        key_added
            Key in :attr:`anndata.AnnData.uns` where the result is stored if ``copy = False``.
            If `None`, ``'{{cluster_key}}_ligrec'`` will be used.
        %(numba_parallel)s
        %(parallelize)s

        Returns
        -------
        %(ligrec_test_returns)s
        """
        _assert_positive(n_perms, name="n_perms")
        _assert_categorical_obs(self._adata, key=cluster_key)

        if corr_method is not None:
            corr_axis = CorrAxis(corr_axis)
        if TYPE_CHECKING:
            assert isinstance(corr_axis, CorrAxis)

        if len(self._adata.obs[cluster_key].cat.categories) <= 1:
            raise ValueError(
                f"Expected at least `2` clusters, found `{len(self._adata.obs[cluster_key].cat.categories)}`."
            )
        if TYPE_CHECKING:
            assert isinstance(self.interactions, pd.DataFrame)
            assert isinstance(self._filtered_data, pd.DataFrame)

        interactions = self.interactions[[SOURCE, TARGET]]
        self._filtered_data["clusters"] = self._adata.obs[cluster_key].astype(
            "string").astype("category").values

        if clusters is None:
            clusters = list(
                map(str, self._adata.obs[cluster_key].cat.categories))
        if all(isinstance(c, str) for c in clusters):
            clusters = list(product(
                clusters, repeat=2))  # type: ignore[no-redef,assignment]
        clusters = sorted(
            _check_tuple_needles(
                clusters,  # type: ignore[arg-type]
                self._filtered_data["clusters"].cat.categories,
                msg="Invalid cluster `{0!r}`.",
                reraise=True,
            ))
        clusters_flat = list({c for cs in clusters for c in cs})

        data = self._filtered_data.loc[
            np.isin(self._filtered_data["clusters"], clusters_flat), :]
        data["clusters"] = data["clusters"].cat.remove_unused_categories()
        cat = data["clusters"].cat

        cluster_mapper = dict(zip(cat.categories, range(len(cat.categories))))
        gene_mapper = dict(zip(data.columns[:-1],
                               range(len(data.columns) -
                                     1)))  # -1 for 'clusters'

        data.columns = [
            gene_mapper[c] if c != "clusters" else c for c in data.columns
        ]
        clusters_ = np.array([[cluster_mapper[c1], cluster_mapper[c2]]
                              for c1, c2 in clusters],
                             dtype=np.uint32)

        cat.rename_categories(cluster_mapper, inplace=True)
        # much faster than applymap (tested on 1M interactions)
        interactions_ = np.vectorize(lambda g: gene_mapper[g])(
            interactions.values)

        n_jobs = _get_n_cores(kwargs.pop("n_jobs", None))
        start = logg.info(
            f"Running `{n_perms}` permutations on `{len(interactions)}` interactions "
            f"and `{len(clusters)}` cluster combinations using `{n_jobs}` core(s)"
        )
        res = _analysis(
            data,
            interactions_,
            clusters_,
            threshold=threshold,
            n_perms=n_perms,
            seed=seed,
            n_jobs=n_jobs,
            numba_parallel=numba_parallel,
            **kwargs,
        )

        res = {
            "means":
            _create_sparse_df(
                res.means,
                index=pd.MultiIndex.from_frame(interactions,
                                               names=[SOURCE, TARGET]),
                columns=pd.MultiIndex.from_tuples(
                    clusters, names=["cluster_1", "cluster_2"]),
                fill_value=0,
            ),
            "pvalues":
            _create_sparse_df(
                res.pvalues,
                index=pd.MultiIndex.from_frame(interactions,
                                               names=[SOURCE, TARGET]),
                columns=pd.MultiIndex.from_tuples(
                    clusters, names=["cluster_1", "cluster_2"]),
                fill_value=np.nan,
            ),
            "metadata":
            self.interactions[self.interactions.columns.difference(
                [SOURCE, TARGET])],
        }
        res["metadata"].index = res["means"].index.copy()

        if TYPE_CHECKING:
            assert isinstance(res, dict)

        if corr_method is not None:
            logg.info(f"Performing FDR correction across the `{corr_axis.v}` "
                      f"using method `{corr_method}` at level `{alpha}`")
            res["pvalues"] = _fdr_correct(res["pvalues"],
                                          corr_method,
                                          corr_axis,
                                          alpha=alpha)

        if copy:
            logg.info("Finish", time=start)
            return res

        _save_data(self._adata,
                   attr="uns",
                   key=Key.uns.ligrec(cluster_key, key_added),
                   data=res,
                   time=start)
Ejemplo n.º 4
0
def moran(
    adata: AnnData,
    connectivity_key: str = Key.obsp.spatial_conn(),
    genes: Optional[Union[str, Sequence[str]]] = None,
    transformation: Literal["r", "B", "D", "U", "V"] = "r",
    n_perms: int = 1000,
    corr_method: Optional[str] = "fdr_bh",
    layer: Optional[str] = None,
    seed: Optional[int] = None,
    copy: bool = False,
    n_jobs: Optional[int] = None,
    backend: str = "loky",
    show_progress_bar: bool = True,
) -> Optional[pd.DataFrame]:
    """
    Calculate Moran’s I Global Autocorrelation Statistic.

    Parameters
    ----------
    %(adata)s
    %(conn_key)s
    genes
        List of gene names, as stored in :attr:`anndata.AnnData.var_names`, used to compute Moran's I statistics
        :cite:`pysal`.

        If `None`, it's computed :attr:`anndata.AnnData.var` ``['highly_variable']``, if present. Otherwise,
        it's computed for all genes.
    transformation
        Transformation to be used, as reported in :class:`esda.Moran`. Default is `"r"`, row-standardized.
    %(n_perms)s
    %(corr_method)s
    layer
        Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`.
    %(seed)s
    %(copy)s
    %(parallelize)s

    Returns
    -------
    If ``copy = True``, returns a :class:`pandas.DataFrame` with the following keys:

        - `'I'` - Moran's I statistic.
        - `'pval_sim'` - p-value based on permutations.
        - `'VI_sim'` - variance of `'I'` from permutations.
        - `'pval_sim_{{corr_method}}'` - the corrected p-values if ``corr_method != None`` .

    Otherwise, modifies the ``adata`` with the following key:

        - :attr:`anndata.AnnData.uns` ``['moranI']`` - the above mentioned dataframe.
    """
    if esda is None or libpysal is None:
        raise ImportError(
            "Please install `esda` and `libpysal` as `pip install esda libpysal`."
        )

    _assert_positive(n_perms, name="n_perms")
    _assert_connectivity_key(adata, connectivity_key)

    if genes is None:
        if "highly_variable" in adata.var.columns:
            genes = adata[:, adata.var.highly_variable.values].var_names.values
        else:
            genes = adata.var_names.values
    genes = _assert_non_empty_sequence(genes, name="genes")

    n_jobs = _get_n_cores(n_jobs)
    start = logg.info(
        f"Calculating for `{len(genes)}` genes using `{n_jobs}` core(s)")

    w = _set_weight_class(adata, key=connectivity_key)  # init weights
    df = parallelize(
        _moran_helper,
        collection=genes,
        extractor=pd.concat,
        use_ixs=True,
        n_jobs=n_jobs,
        backend=backend,
        show_progress_bar=show_progress_bar,
    )(adata=adata,
      weights=w,
      transformation=transformation,
      permutations=n_perms,
      layer=layer,
      seed=seed)

    if corr_method is not None:
        _, pvals_adj, _, _ = multipletests(df["pval_sim"].values,
                                           alpha=0.05,
                                           method=corr_method)
        df[f"pval_sim_{corr_method}"] = pvals_adj

    df.sort_values(by="I", ascending=False, inplace=True)

    if copy:
        logg.info("Finish", time=start)
        return df

    _save_data(adata, attr="uns", key="moranI", data=df, time=start)
Ejemplo n.º 5
0
def spatial_autocorr(
    adata: AnnData,
    connectivity_key: str = Key.obsp.spatial_conn(),
    genes: Optional[Union[str, Sequence[str]]] = None,
    mode: Literal[
        "moran",
        "geary"] = SpatialAutocorr.MORAN.s,  # type: ignore[assignment]
    transformation: bool = True,
    n_perms: Optional[int] = None,
    two_tailed: bool = False,
    corr_method: Optional[str] = "fdr_bh",
    layer: Optional[str] = None,
    seed: Optional[int] = None,
    use_raw: bool = False,
    copy: bool = False,
    n_jobs: Optional[int] = None,
    backend: str = "loky",
    show_progress_bar: bool = True,
) -> Optional[pd.DataFrame]:
    """
    Calculate Global Autocorrelation Statistic (Moran’s I  or Geary's C).

    See  :cite:`pysal` for reference.

    Parameters
    ----------
    %(adata)s
    %(conn_key)s
    genes
        List of gene names, as stored in :attr:`anndata.AnnData.var_names`, used to compute global
        spatial autocorrelation statistic.

        If `None`, it's computed :attr:`anndata.AnnData.var` ``['highly_variable']``, if present. Otherwise,
        it's computed for all genes.
    mode
        Mode of score calculation:

            - `{sp.MORAN.s!r}` - `Moran's I autocorrelation <https://en.wikipedia.org/wiki/Moran%27s_I>`_.
            - `{sp.GEARY.s!r}` - `Geary's C autocorrelation <https://en.wikipedia.org/wiki/Geary%27s_C>`_.

    transformation
        If `True`, weights in :attr:`anndata.AnnData.obsp` ``['{key}']`` are row-normalized,
        advised for analytic p-value calculation.
    %(n_perms)s
        If `None`, only p-values under normality assumption are computed.
    two_tailed
        If `True`, p-values are two-tailed, otherwise they are one-tailed.
    %(corr_method)s
    layer
        Layer in :attr:`anndata.AnnData.layers` to use. If `None`, use :attr:`anndata.AnnData.X`.
    %(seed)s
    %(copy)s
    %(parallelize)s

    Returns
    -------
    If ``copy = True``, returns a :class:`pandas.DataFrame` with the following keys:

        - `'I' or 'C'` - Moran's I or Geary's C statistic.
        - `'pval_norm'` - p-value under normality assumption.
        - `'var_norm'` - variance of `'score'` under normality assumption.
        - `'{{p_val}}_{{corr_method}}'` - the corrected p-values if ``corr_method != None`` .

    If ``n_perms != None`` is not None, additionally returns the following columns:

        - `'pval_z_sim'` - p-value based on standard normal approximation from permutations.
        - `'pval_sim'` - p-value based on permutations.
        - `'var_sim'` - variance of `'score'` from permutations.

    Otherwise, modifies the ``adata`` with the following key:

        - :attr:`anndata.AnnData.uns` ``['moranI']`` - the above mentioned dataframe, if ``mode = {sp.MORAN.s!r}``.
        - :attr:`anndata.AnnData.uns` ``['gearyC']`` - the above mentioned dataframe, if ``mode = {sp.GEARY.s!r}``.
    """
    _assert_connectivity_key(adata, connectivity_key)

    if genes is None:
        if "highly_variable" in adata.var.columns:
            genes = adata[:, adata.var.highly_variable.values].var_names.values
        else:
            genes = adata.var_names.values
    genes = _assert_non_empty_sequence(genes, name="genes")

    mode = SpatialAutocorr(mode)  # type: ignore[assignment]
    if TYPE_CHECKING:
        assert isinstance(mode, SpatialAutocorr)
    params = {
        "mode": mode.s,
        "transformation": transformation,
        "two_tailed": two_tailed
    }

    if mode == SpatialAutocorr.MORAN:
        params["func"] = _morans_i
        params["stat"] = "I"
        params["expected"] = -1.0 / (adata.shape[0] - 1)  # expected score
        params["ascending"] = False
    elif mode == SpatialAutocorr.GEARY:
        params["func"] = _gearys_c
        params["stat"] = "C"
        params["expected"] = 1.0
        params["ascending"] = True
    else:
        raise NotImplementedError(f"Mode `{mode}` is not yet implemented.")

    n_jobs = _get_n_cores(n_jobs)

    vals = _get_obs_rep(adata[:, genes], use_raw=use_raw, layer=layer).T
    g = adata.obsp[connectivity_key].copy()
    # row-normalize
    if transformation:
        normalize(g, norm="l1", axis=1, copy=False)

    score = params["func"](g, vals)

    start = logg.info(
        f"Calculating {mode}'s statistic for `{n_perms}` permutations using `{n_jobs}` core(s)"
    )
    if n_perms is not None:
        _assert_positive(n_perms, name="n_perms")
        perms = np.arange(n_perms)

        score_perms = parallelize(
            _score_helper,
            collection=perms,
            extractor=np.concatenate,
            use_ixs=True,
            n_jobs=n_jobs,
            backend=backend,
            show_progress_bar=show_progress_bar,
        )(mode=mode, g=g, vals=vals, seed=seed)
    else:
        score_perms = None

    with np.errstate(divide="ignore"):
        pval_results = _p_value_calc(score, score_perms, g, params)

    results = {params["stat"]: score}
    results.update(pval_results)

    df = pd.DataFrame(results, index=genes)

    if corr_method is not None:
        for pv in filter(lambda x: "pval" in x, df.columns):
            _, pvals_adj, _, _ = multipletests(df[pv].values,
                                               alpha=0.05,
                                               method=corr_method)
            df[f"{pv}_{corr_method}"] = pvals_adj

    df.sort_values(by=params["stat"],
                   ascending=params["ascending"],
                   inplace=True)

    if copy:
        logg.info("Finish", time=start)
        return df

    _save_data(adata,
               attr="uns",
               key=params["mode"] + params["stat"],
               data=df,
               time=start)
Ejemplo n.º 6
0
    def generate_spot_crops(
        self,
        adata: AnnData,
        library_id: Optional[str] = None,
        spatial_key: str = Key.obsm.spatial,
        spot_scale: float = 1.0,
        obs_names: Optional[Iterable[Any]] = None,
        as_array: Union[str, bool] = False,
        return_obs: bool = False,
        **kwargs: Any,
    ) -> Union[Iterator["ImageContainer"], Iterator[np.ndarray],
               Iterator[Tuple[np.ndarray, ...]], Iterator[Dict[str,
                                                               np.ndarray]], ]:
        """
        Iterate over :attr:`adata.obs_names` and extract crops.

        Implemented for 10X spatial datasets.

        Parameters
        ----------
        %(adata)s
        library_id
            Key in :attr:`anndata.AnnData.uns` ``['{spatial_key}']`` used to get the spot diameter.
        %(spatial_key)s
        spot_scale
            Scaling factor for the spot diameter. Larger values mean more context.
        obs_names
            Observations from :attr:`adata.obs_names` for which to generate the crops. If `None`, all names are used.
        %(as_array)s
        return_obs
            Whether to also yield names from ``obs_names``.
        kwargs
            Keyword arguments for :meth:`crop_center`.

        Yields
        ------
        If ``return_obs = True``, yields a :class:`tuple` ``(crop, obs_name)``. Otherwise, yields just the crops.
        The type of the crops depends on ``as_array``.
        """
        self._assert_not_empty()
        _assert_positive(spot_scale, name="scale")
        _assert_spatial_basis(adata, spatial_key)
        library_id = Key.uns.library_id(adata,
                                        spatial_key=spatial_key,
                                        library_id=library_id)

        if obs_names is None:
            obs_names = adata.obs_names
        obs_names = _assert_non_empty_sequence(obs_names, name="observations")

        adata = adata[obs_names, :]
        spatial = adata.obsm[spatial_key][:, :2]

        diameter = adata.uns[spatial_key][library_id]["scalefactors"][
            "spot_diameter_fullres"]
        radius = int(round(diameter // 2 * spot_scale))

        for i, obs in enumerate(adata.obs_names):
            crop = self.crop_center(y=spatial[i][1],
                                    x=spatial[i][0],
                                    radius=radius,
                                    **kwargs)
            crop.data.attrs[Key.img.obs] = obs
            crop = crop._maybe_as_array(as_array)

            yield (crop, obs) if return_obs else crop
Ejemplo n.º 7
0
    def crop_corner(
        self,
        y: FoI_t,
        x: FoI_t,
        size: Optional[Union[FoI_t, Tuple[FoI_t, FoI_t]]] = None,
        scale: float = 1.0,
        cval: Union[int, float] = 0,
        mask_circle: bool = False,
        preserve_dtypes: bool = True,
    ) -> "ImageContainer":
        """
        Extract a crop from the upper-left corner.

        Parameters
        ----------
        %(yx)s
        %(size)s
        scale
            Rescale the crop using :func:`skimage.transform.rescale`.
        cval
            Fill value to use if ``mask_circle = True`` or if crop goes out of the image boundary.
        mask_circle
            Whether to mask out values that are not within a circle defined by this crop.
            Only available if ``size`` defines a square.
        preserve_dtypes
            Whether to preserver the data types of underlying :class:`xarray.DataArray`, even if ``cval``
            is of different type.

        Returns
        -------
        The cropped image of size ``size * scale``.

        Raises
        ------
        ValueError
            If the crop would completely lie outside of the image or if ``mask_circle = True`` and
            ``size`` does not define a square.

        Notes
        -----
        If ``preserve_dtypes = True`` but ``cval`` cannot be safely cast, ``cval`` will be set to 0.
        """
        self._assert_not_empty()
        y, x = self._convert_to_pixel_space((y, x))

        size = self._get_size(size)
        size = self._convert_to_pixel_space(size)

        ys, xs = size
        _assert_positive(ys, name="height")
        _assert_positive(xs, name="width")
        _assert_positive(scale, name="scale")

        orig = CropCoords(x0=x, y0=y, x1=x + xs, y1=y + ys)

        ymin, xmin = self.shape
        coords = CropCoords(x0=min(max(x, 0), xmin),
                            y0=min(max(y, 0), ymin),
                            x1=min(x + xs, xmin),
                            y1=min(y + ys, ymin))

        if not coords.dy:
            raise ValueError("Height of the crop is empty.")
        if not coords.dx:
            raise ValueError("Width of the crop is empty.")

        crop = self.data.isel(x=slice(coords.x0, coords.x1),
                              y=slice(coords.y0, coords.y1)).copy(deep=False)
        crop.attrs[Key.img.coords] = coords

        if orig != coords:
            padding = orig - coords

            # because padding does not change dtype by itself
            for key, arr in crop.items():
                if preserve_dtypes:
                    if not np.can_cast(cval, arr.dtype, casting="safe"):
                        cval = 0
                else:
                    crop[key] = crop[key].astype(np.dtype(type(cval)),
                                                 copy=False)

            crop = crop.pad(
                y=(padding.y_pre, padding.y_post),
                x=(padding.x_pre, padding.x_post),
                mode="constant",
                constant_values=cval,
            )
            crop.attrs["padding"] = padding
        else:
            crop.attrs["padding"] = _NULL_PADDING

        return self._from_dataset(
            self._post_process(data=crop,
                               scale=scale,
                               cval=cval,
                               mask_circle=mask_circle,
                               preserve_dtypes=preserve_dtypes))