예제 #1
0
    def _validate_anndata(self,
                          adata: Optional[AnnData] = None,
                          copy_if_view: bool = True):
        """Validate anndata has been properly registered, transfer if necessary."""
        if adata is None:
            adata = self.adata
        if adata.is_view:
            if copy_if_view:
                logger.info("Received view of anndata, making copy.")
                adata = adata.copy()
            else:
                raise ValueError("Please run `adata = adata.copy()`")

        if "_scvi" not in adata.uns_keys():
            logger.info("Input adata not setup with scvi. " +
                        "attempting to transfer anndata setup")
            transfer_anndata_setup(self.scvi_setup_dict_, adata)
        is_nonneg_int = _check_nonnegative_integers(
            get_from_registry(adata, _CONSTANTS.X_KEY))
        if not is_nonneg_int:
            logger.warning(
                "Make sure the registered X field in anndata contains unnormalized count data."
            )

        _check_anndata_setup_equivalence(self.scvi_setup_dict_, adata)

        return adata
예제 #2
0
    def validate_field(self, adata: AnnData) -> None:
        super().validate_field(adata)
        x = self.get_field_data(adata)

        if self.is_count_data and not _check_nonnegative_integers(x):
            logger_data_loc = ("adata.X" if self.attr_key is None else
                               f"adata.layers[{self.attr_key}]")
            warnings.warn(
                f"{logger_data_loc} does not contain unnormalized count data. "
                "Are you sure this is what you want?")
예제 #3
0
    def validate_field(self, adata: AnnData) -> None:
        super().validate_field(adata)
        if self.attr_key not in adata.obsm:
            raise KeyError(f"{self.attr_key} not found in adata.obsm.")

        obsm_data = self.get_field_data(adata)

        if self.is_count_data and not _check_nonnegative_integers(obsm_data):
            warnings.warn(
                f"adata.obsm['{self.attr_key}'] does not contain unnormalized count data. "
                "Are you sure this is what you want?"
            )
예제 #4
0
def _setup_protein_expression(
    adata, protein_expression_obsm_key, protein_names_uns_key, batch_key
):
    assert (
        protein_expression_obsm_key in adata.obsm.keys()
    ), "{} is not a valid key in adata.obsm".format(protein_expression_obsm_key)

    logger.info(
        "Using protein expression from adata.obsm['{}']".format(
            protein_expression_obsm_key
        )
    )
    pro_exp = adata.obsm[protein_expression_obsm_key]
    if _check_nonnegative_integers(pro_exp) is False:
        warnings.warn(
            "adata.obsm[{}] does not contain unnormalized count data. Are you sure this is what you want?".format(
                protein_expression_obsm_key
            )
        )
    # setup protein names
    if protein_names_uns_key is None and isinstance(
        adata.obsm[protein_expression_obsm_key], pd.DataFrame
    ):
        logger.info(
            "Using protein names from columns of adata.obsm['{}']".format(
                protein_expression_obsm_key
            )
        )
        protein_names = list(adata.obsm[protein_expression_obsm_key].columns)
    elif protein_names_uns_key is not None:
        logger.info(
            "Using protein names from adata.uns['{}']".format(protein_names_uns_key)
        )
        protein_names = adata.uns[protein_names_uns_key]
    else:
        logger.info("Generating sequential protein names")
        protein_names = np.arange(adata.obsm[protein_expression_obsm_key].shape[1])

    adata.uns["scvi_protein_names"] = protein_names

    # batch mask totalVI
    batch_mask = _get_batch_mask_protein_data(
        adata, protein_expression_obsm_key, batch_key
    )

    # check if it's actually needed
    if np.sum([~b for b in batch_mask]) > 0:
        logger.info("Found batches with missing protein expression")
        adata.uns["_scvi"]["totalvi_batch_mask"] = batch_mask
    return protein_expression_obsm_key
예제 #5
0
    def _validate_anndata(
        self, adata: Optional[AnnData] = None, copy_if_view: bool = True
    ):
        adata = super()._validate_anndata(adata, copy_if_view)
        error_msg = "Number of {} in anndata different from when setup_anndata was run. Please rerun setup_anndata."
        if _CONSTANTS.PROTEIN_EXP_KEY in adata.uns["_scvi"]["data_registry"].keys():
            pro_exp = get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY)
            if self.summary_stats["n_proteins"] != pro_exp.shape[1]:
                raise ValueError(error_msg.format("proteins"))
            is_nonneg_int = _check_nonnegative_integers(pro_exp)
            if not is_nonneg_int:
                warnings.warn(
                    "Make sure the registered protein expression in anndata contains unnormalized count data."
                )
        else:
            raise ValueError("No protein data found, please setup or transfer anndata")

        return adata
예제 #6
0
def _setup_x(adata, layer):
    if layer is not None:
        assert (layer in adata.layers.keys()
                ), "{} is not a valid key in adata.layers".format(layer)
        logger.info('Using data from adata.layers["{}"]'.format(layer))
        x_loc = "layers"
        x_key = layer
        x = adata.layers[x_key]
    else:
        logger.info("Using data from adata.X")
        x_loc = "X"
        x_key = "None"
        x = adata.X

    if _check_nonnegative_integers(x) is False:
        logger_data_loc = ("adata.X" if layer is None else
                           "adata.layers[{}]".format(layer))
        warnings.warn(
            "{} does not contain unnormalized count data. Are you sure this is what you want?"
            .format(logger_data_loc))

    return x_loc, x_key
예제 #7
0
def _setup_x(adata, layer, use_raw):
    if use_raw and layer:
        logging.warning("use_raw and layer were both passed in. Defaulting to use_raw.")

    # checking layers
    if use_raw:
        if adata.raw is None:
            raise ValueError("use_raw is True but adata.raw is None")
        logger.info("Using data from adata.raw.X")
        x_loc = "X"
        x_key = "None"
        x = adata.raw.X
    elif layer is not None:
        assert (
            layer in adata.layers.keys()
        ), "{} is not a valid key in adata.layers".format(layer)
        logger.info('Using data from adata.layers["{}"]'.format(layer))
        x_loc = "layers"
        x_key = layer
        x = adata.layers[x_key]
    else:
        logger.info("Using data from adata.X")
        x_loc = "X"
        x_key = "None"
        x = adata._X

    if _check_nonnegative_integers(x) is False:
        logger_data_loc = (
            "adata.X" if layer is None else "adata.layers[{}]".format(layer)
        )
        warnings.warn(
            "{} does not contain unnormalized count data. Are you sure this is what you want?".format(
                logger_data_loc
            )
        )

    return x_loc, x_key