예제 #1
0
파일: rare.py 프로젝트: tanaylab/metacells
def _compress_modules(
    *,
    adata_of_all_genes_of_all_cells: AnnData,
    what: Union[str, ut.Matrix] = "__x__",
    min_cells_of_modules: int,
    max_cells_of_modules: int,
    target_metacell_size: float,
    min_modules_size_factor: float,
    related_gene_indices_of_modules: List[List[int]],
    rare_module_of_cells: ut.NumpyVector,
) -> List[List[int]]:
    list_of_rare_gene_indices_of_modules: List[List[int]] = []
    list_of_names_of_genes_of_modules: List[List[str]] = []

    min_umis_of_modules = target_metacell_size * min_modules_size_factor
    ut.log_calc("min_umis_of_modules", min_umis_of_modules)

    total_all_genes_of_all_cells = ut.get_o_numpy(
        adata_of_all_genes_of_all_cells, what, sum=True)

    cell_counts_of_modules: List[int] = []

    ut.log_calc("compress modules:")
    modules_count = len(related_gene_indices_of_modules)
    for module_index, gene_indices_of_module in enumerate(
            related_gene_indices_of_modules):
        if len(gene_indices_of_module) == 0:
            continue

        with ut.log_step(
                "- module",
                module_index,
                formatter=lambda module_index: ut.progress_description(
                    modules_count, module_index, "module"),
        ):
            module_cells_mask = rare_module_of_cells == module_index
            module_cells_count = np.sum(module_cells_mask)
            module_umis_count = np.sum(
                total_all_genes_of_all_cells[module_cells_mask])

            if module_cells_count < min_cells_of_modules:
                if ut.logging_calc():
                    ut.log_calc("cells",
                                str(module_cells_count) + " (too few)")
                rare_module_of_cells[module_cells_mask] = -1
                continue

            if module_cells_count > max_cells_of_modules:
                if ut.logging_calc():
                    ut.log_calc("cells",
                                str(module_cells_count) + " (too many)")
                rare_module_of_cells[module_cells_mask] = -1
                continue

            ut.log_calc("cells", module_cells_count)

            if module_umis_count < min_umis_of_modules:
                if ut.logging_calc():
                    ut.log_calc("UMIs", str(module_umis_count) + " (too few)")
                rare_module_of_cells[module_cells_mask] = -1
                continue

            ut.log_calc("UMIs", module_umis_count)

            next_module_index = len(list_of_rare_gene_indices_of_modules)
            if module_index != next_module_index:
                ut.log_calc("is reindexed to", next_module_index)
                rare_module_of_cells[module_cells_mask] = next_module_index
                module_index = next_module_index

            next_module_index += 1
            list_of_rare_gene_indices_of_modules.append(gene_indices_of_module)

            if ut.logging_calc():
                cell_counts_of_modules.append(np.sum(module_cells_mask))
            list_of_names_of_genes_of_modules.append(  #
                sorted(adata_of_all_genes_of_all_cells.
                       var_names[gene_indices_of_module]))

    if ut.logging_calc():
        ut.log_calc("final modules:")
        for module_index, (module_cells_count, module_gene_names) in enumerate(
                zip(cell_counts_of_modules,
                    list_of_names_of_genes_of_modules)):
            ut.log_calc(
                f"- module: {module_index} cells: {module_cells_count} genes: {module_gene_names}"
            )  #

    return list_of_rare_gene_indices_of_modules
예제 #2
0
파일: rare.py 프로젝트: tanaylab/metacells
def _identify_cells(
    *,
    adata_of_all_genes_of_all_cells: AnnData,
    what: Union[str, ut.Matrix] = "__x__",
    related_gene_indices_of_modules: List[List[int]],
    min_cell_module_total: int,
    min_cells_of_modules: int,
    max_cells_of_modules: int,
    rare_module_of_cells: ut.NumpyVector,
) -> None:
    max_strength_of_cells = np.zeros(adata_of_all_genes_of_all_cells.n_obs)

    ut.log_calc("cells for modules:")
    modules_count = len(related_gene_indices_of_modules)
    for module_index, related_gene_indices_of_module in enumerate(
            related_gene_indices_of_modules):
        if len(related_gene_indices_of_module) == 0:
            continue

        with ut.log_step(
                "- module",
                module_index,
                formatter=lambda module_index: ut.progress_description(
                    modules_count, module_index, "module"),
        ):
            adata_of_related_genes_of_all_cells = ut.slice(
                adata_of_all_genes_of_all_cells,
                name=f".module{module_index}.related_genes",
                vars=related_gene_indices_of_module,
                top_level=False,
            )
            total_related_genes_of_all_cells = ut.get_o_numpy(
                adata_of_related_genes_of_all_cells, what, sum=True)

            mask_of_strong_cells_of_module = total_related_genes_of_all_cells >= min_cell_module_total

            median_strength_of_module = np.median(
                total_related_genes_of_all_cells[
                    mask_of_strong_cells_of_module])  #
            strong_cells_count = np.sum(mask_of_strong_cells_of_module)

            if strong_cells_count > max_cells_of_modules:
                if ut.logging_calc():
                    ut.log_calc(
                        "strong_cells",
                        ut.mask_description(mask_of_strong_cells_of_module) +
                        " (too many)")  #
                related_gene_indices_of_module.clear()
                continue

            if strong_cells_count < min_cells_of_modules:
                if ut.logging_calc():
                    ut.log_calc(
                        "strong_cells",
                        ut.mask_description(mask_of_strong_cells_of_module) +
                        " (too few)")  #
                related_gene_indices_of_module.clear()
                continue

            ut.log_calc("strong_cells", mask_of_strong_cells_of_module)

            strength_of_all_cells = total_related_genes_of_all_cells / median_strength_of_module
            mask_of_strong_cells_of_module &= strength_of_all_cells >= max_strength_of_cells
            max_strength_of_cells[
                mask_of_strong_cells_of_module] = strength_of_all_cells[
                    mask_of_strong_cells_of_module]

            rare_module_of_cells[mask_of_strong_cells_of_module] = module_index
예제 #3
0
def compute_type_compatible_sizes(
    adatas: List[AnnData],
    *,
    size: str = "grouped",
    kind: str = "type",
) -> None:
    """
    Given multiple annotated data of groups, compute a "compatible" size for each one to allow for
    consistent inner normalized variance comparison.

    Since the inner normalized variance quality measure is sensitive to the group (metacell) sizes,
    it is useful to artificially shrink the groups so the sizes will be similar between the compared
    data sets. Assuming each group (metacell) has a type annotation, for each such type, we give
    each one a "compatible" size (less than or equal to its actual size) so that using this reduced
    size will give us comparable measures between all the data sets.

    The "compatible" sizes are chosen such that the density distributions of the sizes in all data
    sets would be as similar to each other as possible.

    .. note::

        This is only effective if the groups are "similar" in size. Using this to compare very coarse
        grouping (few thousands of cells) with fine-grained ones (few dozens of cells) will still
        result in very different results.

    **Input**

    Several annotated ``adatas`` where each observation is a group. Should contain per-observation
    ``size`` annotation (default: {size}) and ``kind`` annotation (default: {kind}).

    **Returns**

    Sets the following in each ``adata``:

    Per-Observation (group) Annotations:

        ``compatible_size``
            The number of grouped cells in the group to use for computing excess R^2 and inner
            normalized variance.

    **Computation**

    1. For each type, sort the groups (metacells) in increasing number of grouped observations (cells).

    2. Consider the maximal quantile (rank) of the next smallest group (metacell) in each data set.

    3. Compute the minimal number of grouped observations in all the metacells whose quantile is up
       to this maximal quantile.

    4. Use this as the "compatible" size for all these groups, and remove them from consideration.

    5. Loop until all groups are assigned a "compatible" size.
    """
    assert len(adatas) > 0
    if len(adatas) == 1:
        ut.set_o_data(
            adatas[0], "compatible_size",
            ut.get_o_numpy(adatas[0], size, formatter=ut.sizes_description))
        return

    group_sizes_of_data = [
        ut.get_o_numpy(adata, size, formatter=ut.sizes_description)
        for adata in adatas
    ]
    group_types_of_data = [ut.get_o_numpy(adata, kind) for adata in adatas]

    unique_types: Set[Any] = set()
    for group_types in group_types_of_data:
        unique_types.update(group_types)

    compatible_size_of_data = [np.full(adata.n_obs, -1) for adata in adatas]

    groups_count_of_data: List[int] = []
    for type_index, group_type in enumerate(sorted(unique_types)):
        with ut.log_step(
                f"- {group_type}",
                ut.progress_description(len(unique_types), type_index,
                                        "type")):
            sorted_group_indices_of_data = [
                np.argsort(group_sizes)[group_types == group_type]
                for group_sizes, group_types in zip(group_sizes_of_data,
                                                    group_types_of_data)
            ]

            groups_count_of_data = [
                len(sorted_group_indices)
                for sorted_group_indices in sorted_group_indices_of_data
            ]

            ut.log_calc("group_counts", groups_count_of_data)

            def _for_each(value_of_data: List[T]) -> List[T]:
                return [
                    value for groups_count, value in zip(
                        groups_count_of_data, value_of_data)
                    if groups_count > 0
                ]

            groups_count_of_each = _for_each(groups_count_of_data)

            if len(groups_count_of_each) == 0:
                continue

            sorted_group_indices_of_each = _for_each(
                sorted_group_indices_of_data)
            group_sizes_of_each = _for_each(group_sizes_of_data)
            compatible_size_of_each = _for_each(compatible_size_of_data)

            if len(groups_count_of_each) == 1:
                compatible_size_of_each[0][
                    sorted_group_indices_of_each[0]] = group_sizes_of_each[0][
                        sorted_group_indices_of_each[0]]

            group_quantile_of_each = [
                (np.arange(len(sorted_group_indices)) + 1) /
                len(sorted_group_indices)
                for sorted_group_indices in sorted_group_indices_of_each
            ]

            next_position_of_each = np.full(len(group_quantile_of_each), 0)

            while True:
                next_quantile_of_each = [
                    group_quantile[next_position]
                    for group_quantile, next_position in zip(
                        group_quantile_of_each, next_position_of_each)
                ]
                next_quantile = max(next_quantile_of_each)

                last_position_of_each = next_position_of_each.copy()
                next_position_of_each[:] = [
                    np.sum(group_quantile <= next_quantile)
                    for group_quantile in group_quantile_of_each
                ]

                positions_of_each = [
                    range(last_position, next_position)
                    for last_position, next_position in zip(
                        last_position_of_each, next_position_of_each)
                ]

                sizes_of_each = [
                    group_sizes[sorted_group_indices[positions]]
                    for group_sizes, sorted_group_indices, positions in zip(
                        group_sizes_of_each, sorted_group_indices_of_each,
                        positions_of_each)
                ]

                min_size_of_each = [
                    np.min(sizes) for sizes, positions in zip(
                        sizes_of_each, positions_of_each)
                ]
                min_size = min(min_size_of_each)

                for sorted_group_indices, positions, compatible_size in zip(
                        sorted_group_indices_of_each, positions_of_each,
                        compatible_size_of_each):
                    compatible_size[sorted_group_indices[positions]] = min_size

                is_done_of_each = [
                    next_position == groups_count
                    for next_position, groups_count in zip(
                        next_position_of_each, groups_count_of_each)
                ]
                if all(is_done_of_each):
                    break

                assert not any(is_done_of_each)

    for adata, compatible_size in zip(adatas, compatible_size_of_data):
        assert np.min(compatible_size) > 0
        ut.set_o_data(adata, "compatible_size", compatible_size)
예제 #4
0
파일: rare.py 프로젝트: tanaylab/metacells
def _related_genes(  # pylint: disable=too-many-statements,too-many-branches
    *,
    adata_of_all_genes_of_all_cells: AnnData,
    what: Union[str, ut.Matrix] = "__x__",
    rare_gene_indices_of_modules: List[List[int]],
    allowed_genes_mask: ut.NumpyVector,
    min_genes_of_modules: int,
    min_gene_maximum: int,
    min_cells_of_modules: int,
    max_cells_of_modules: int,
    min_cell_module_total: int,
    min_related_gene_fold_factor: float,
    max_related_gene_increase_factor: float,
) -> List[List[int]]:
    total_all_cells_umis_of_all_genes = ut.get_v_numpy(
        adata_of_all_genes_of_all_cells, what, sum=True)

    ut.log_calc("genes for modules:")
    modules_count = 0
    related_gene_indices_of_modules: List[List[int]] = []

    rare_gene_indices_of_any: Set[int] = set()
    for rare_gene_indices_of_module in rare_gene_indices_of_modules:
        if len(rare_gene_indices_of_module) >= min_genes_of_modules:
            rare_gene_indices_of_any.update(list(rare_gene_indices_of_module))

    for rare_gene_indices_of_module in rare_gene_indices_of_modules:
        if len(rare_gene_indices_of_module) < min_genes_of_modules:
            continue

        module_index = modules_count
        modules_count += 1

        with ut.log_step("- module", module_index):
            ut.log_calc(
                "rare_gene_names",
                sorted(adata_of_all_genes_of_all_cells.
                       var_names[rare_gene_indices_of_module]))

            adata_of_module_genes_of_all_cells = ut.slice(
                adata_of_all_genes_of_all_cells,
                name=f".module{module_index}.rare_gene",
                vars=rare_gene_indices_of_module,
                top_level=False,
            )

            total_module_genes_umis_of_all_cells = ut.get_o_numpy(
                adata_of_module_genes_of_all_cells, what, sum=True)

            mask_of_expressed_cells = total_module_genes_umis_of_all_cells > 0

            expressed_cells_count = np.sum(mask_of_expressed_cells)

            if expressed_cells_count > max_cells_of_modules:
                if ut.logging_calc():
                    ut.log_calc(
                        "expressed_cells",
                        ut.mask_description(mask_of_expressed_cells) +
                        " (too many)")
                continue

            if expressed_cells_count < min_cells_of_modules:
                if ut.logging_calc():
                    ut.log_calc(
                        "expressed_cells",
                        ut.mask_description(mask_of_expressed_cells) +
                        " (too few)")
                continue

            ut.log_calc("expressed_cells", mask_of_expressed_cells)

            adata_of_all_genes_of_expressed_cells_of_module = ut.slice(
                adata_of_all_genes_of_all_cells,
                name=f".module{module_index}.rare_cell",
                obs=mask_of_expressed_cells,
                top_level=False,
            )

            total_expressed_cells_umis_of_all_genes = ut.get_v_numpy(
                adata_of_all_genes_of_expressed_cells_of_module,
                what,
                sum=True)

            data = ut.get_vo_proper(
                adata_of_all_genes_of_expressed_cells_of_module,
                what,
                layout="column_major")
            max_expressed_cells_umis_of_all_genes = ut.max_per(data,
                                                               per="column")

            total_background_cells_umis_of_all_genes = (
                total_all_cells_umis_of_all_genes -
                total_expressed_cells_umis_of_all_genes)

            expressed_cells_fraction_of_all_genes = total_expressed_cells_umis_of_all_genes / sum(
                total_expressed_cells_umis_of_all_genes)

            background_cells_fraction_of_all_genes = total_background_cells_umis_of_all_genes / sum(
                total_background_cells_umis_of_all_genes)

            mask_of_related_genes = (
                allowed_genes_mask
                & (max_expressed_cells_umis_of_all_genes >= min_gene_maximum)
                & (expressed_cells_fraction_of_all_genes >=
                   background_cells_fraction_of_all_genes *
                   (2**min_related_gene_fold_factor)))

            related_gene_indices = np.where(mask_of_related_genes)[0]
            assert np.all(mask_of_related_genes[rare_gene_indices_of_module])

            base_genes_of_all_cells_adata = ut.slice(
                adata_of_all_genes_of_all_cells,
                name=f".module{module_index}.base",
                vars=rare_gene_indices_of_module)
            total_base_genes_of_all_cells = ut.get_o_numpy(
                base_genes_of_all_cells_adata, what, sum=True)
            mask_of_strong_base_cells = total_base_genes_of_all_cells >= min_cell_module_total
            count_of_strong_base_cells = np.sum(mask_of_strong_base_cells)

            if ut.logging_calc():
                ut.log_calc(
                    "candidate_gene_names",
                    sorted(adata_of_all_genes_of_all_cells.
                           var_names[related_gene_indices]))
                ut.log_calc("base_strong_genes", count_of_strong_base_cells)

            related_gene_indices_of_module = list(rare_gene_indices_of_module)
            for gene_index in related_gene_indices:
                if gene_index in rare_gene_indices_of_module:
                    continue

                if gene_index in rare_gene_indices_of_any:
                    ut.log_calc(
                        f"- candidate gene {adata_of_all_genes_of_all_cells.var_names[gene_index]} "
                        f"belongs to another module")
                    continue

                if gene_index not in rare_gene_indices_of_module:
                    related_gene_of_all_cells_adata = ut.slice(
                        adata_of_all_genes_of_all_cells,
                        name=
                        f".{adata_of_all_genes_of_all_cells.var_names[gene_index]}",
                        vars=np.array([gene_index]),
                    )
                    assert related_gene_of_all_cells_adata.n_vars == 1
                    total_related_genes_of_all_cells = ut.get_o_numpy(
                        related_gene_of_all_cells_adata, what, sum=True)
                    total_related_genes_of_all_cells += total_base_genes_of_all_cells
                    mask_of_strong_related_cells = total_related_genes_of_all_cells >= min_cell_module_total
                    count_of_strong_related_cells = np.sum(
                        mask_of_strong_related_cells)
                    ut.log_calc(
                        f"- candidate gene {adata_of_all_genes_of_all_cells.var_names[gene_index]} "
                        f"strong cells: {count_of_strong_related_cells} "
                        f"factor: {count_of_strong_related_cells / count_of_strong_base_cells}"
                    )
                    if count_of_strong_related_cells > max_related_gene_increase_factor * count_of_strong_base_cells:
                        continue

                related_gene_indices_of_module.append(gene_index)

            related_gene_indices_of_modules.append(
                related_gene_indices_of_module)  #

    if ut.logging_calc():
        ut.log_calc("related genes for modules:")
        for module_index, related_gene_indices_of_module in enumerate(
                related_gene_indices_of_modules):
            ut.log_calc(
                f"- module {module_index} related_gene_names",
                sorted(adata_of_all_genes_of_all_cells.
                       var_names[related_gene_indices_of_module]),
            )

    return related_gene_indices_of_modules
예제 #5
0
def compute_inner_normalized_variance(
    what: Union[str, ut.Matrix] = "__x__",
    *,
    compatible_size: Optional[str] = None,
    downsample_min_samples: int = pr.downsample_min_samples,
    downsample_min_cell_quantile: float = pr.downsample_min_cell_quantile,
    downsample_max_cell_quantile: float = pr.downsample_max_cell_quantile,
    min_gene_total: int = pr.quality_min_gene_total,
    adata: AnnData,
    gdata: AnnData,
    group: Union[str, ut.Vector] = "metacell",
    random_seed: int = pr.random_seed,
) -> None:
    """
    Compute the inner normalized variance (variance / mean) for each gene in each group.

    This is also known as the "index of dispersion" and can serve as a quality measure for
    the groups. An ideal group would contain only cells with "the same" biological state
    and all remaining inner variance would be due to technical sampling noise.

    **Input**

    Annotated ``adata``, where the observations are cells and the variables are genes, where
    ``what`` is a per-variable-per-observation matrix or the name of a per-variable-per-observation
    annotation containing such a matrix.

    In addition, ``gdata`` is assumed to have one observation for each group, and use the same
    genes as ``adata``.

    **Returns**

    Sets the following in ``gdata``:

    Per-Variable Per-Observation (Gene-Cell) Annotations

        ``inner_variance``
            For each gene and group, the variance of the gene in the group.

        ``inner_normalized_variance``
            For each gene and group, the normalized variance (variance over mean) of the
            gene in the group.

    **Computation Parameters**

    For each group (metacell):

    1. If ``compatible_size`` (default: {compatible_size}) is specified, it should be an
       integer per-observation annotation of the groups, whose value is at most the
       number of grouped cells in the group. Pick a random subset of the cells of
       this size. If ``compatible_size`` is ``None``, use all the cells of the group.

    2. Invoke :py:func:`metacells.tools.downsample.downsample_cells` to downsample the surviving
       cells to the same total number of UMIs, using the ``downsample_min_samples`` (default:
       {downsample_min_samples}), ``downsample_min_cell_quantile`` (default:
       {downsample_min_cell_quantile}), ``downsample_max_cell_quantile`` (default:
       {downsample_max_cell_quantile}) and the ``random_seed`` (default: {random_seed}).

    3. Compute the normalized variance of each gene based on the downsampled data. Set the
       result to ``nan`` for genes with less than ``min_gene_total`` (default: {min_gene_total}).
    """
    cells_data = ut.get_vo_proper(adata, what, layout="row_major")

    if compatible_size is not None:
        compatible_size_of_groups: Optional[ut.NumpyVector] = ut.get_o_numpy(
            gdata, compatible_size, formatter=ut.sizes_description)
    else:
        compatible_size_of_groups = None

    group_of_cells = ut.get_o_numpy(adata,
                                    group,
                                    formatter=ut.groups_description)

    groups_count = np.max(group_of_cells) + 1
    assert groups_count > 0

    assert gdata.n_obs == groups_count
    variance_per_gene_per_group = np.full(gdata.shape, None, dtype="float32")
    normalized_variance_per_gene_per_group = np.full(gdata.shape,
                                                     None,
                                                     dtype="float32")

    for group_index in range(groups_count):
        with ut.log_step(
                "- group",
                group_index,
                formatter=lambda group_index: ut.progress_description(
                    groups_count, group_index, "group"),
        ):
            if compatible_size_of_groups is not None:
                compatible_size_of_group = compatible_size_of_groups[
                    group_index]
            else:
                compatible_size_of_group = None

            _collect_group_data(
                group_index,
                group_of_cells=group_of_cells,
                cells_data=cells_data,
                compatible_size=compatible_size_of_group,
                downsample_min_samples=downsample_min_samples,
                downsample_min_cell_quantile=downsample_min_cell_quantile,
                downsample_max_cell_quantile=downsample_max_cell_quantile,
                min_gene_total=min_gene_total,
                random_seed=random_seed,
                variance_per_gene_per_group=variance_per_gene_per_group,
                normalized_variance_per_gene_per_group=
                normalized_variance_per_gene_per_group,
            )

    ut.set_vo_data(gdata, "inner_variance", variance_per_gene_per_group)
    ut.set_vo_data(gdata, "inner_normalized_variance",
                   normalized_variance_per_gene_per_group)