def setup_anndata( cls, adata: AnnData, layer: Optional[str] = None, **kwargs, ): """ %(summary)s. Parameters ---------- %(param_layer)s """ setup_method_args = cls._get_setup_method_args(**locals()) # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
def setup_anndata( cls, adata: AnnData, batch_key: Optional[str] = None, labels_key: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, layer: Optional[str] = None, **kwargs, ): """ %(summary)s. Parameters ---------- %(param_batch_key)s %(param_labels_key)s %(param_layer)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
def setup_anndata( cls, adata: AnnData, labels_key: Optional[str] = None, layer: Optional[str] = None, **kwargs, ): """ %(summary)s. Parameters ---------- %(param_labels_key)s %(param_layer)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
def generic_setup_adata_manager( adata: AnnData, batch_key: Optional[str] = None, labels_key: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, layer: Optional[str] = None, protein_expression_obsm_key: Optional[str] = None, protein_names_uns_key: Optional[str] = None, ) -> AnnDataManager: batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) anndata_fields = [ batch_field, LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] if protein_expression_obsm_key is not None: anndata_fields.append( ProteinObsmField( REGISTRY_KEYS.PROTEIN_EXP_KEY, protein_expression_obsm_key, use_batch_mask=True, batch_key=batch_field.attr_key, colnames_uns_key=protein_names_uns_key, is_count_data=True, )) adata_manager = AnnDataManager(fields=anndata_fields) adata_manager.register_fields(adata) return adata_manager
def _create_indices_adata_manager(adata: AnnData) -> AnnDataManager: # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] adata_manager = AnnDataManager(fields=anndata_fields) adata_manager.register_fields(adata) return adata_manager
def setup_anndata( cls, adata: AnnData, **kwargs, ) -> Optional[AnnData]: setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
def setup_anndata( cls, adata: AnnData, **kwargs, ) -> Optional[AnnData]: setup_method_args = cls._get_setup_method_args(**locals()) # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
def setup_anndata( cls, adata: AnnData, protein_expression_obsm_key: str, protein_names_uns_key: Optional[str] = None, batch_key: Optional[str] = None, layer: Optional[str] = None, size_factor_key: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, **kwargs, ) -> Optional[AnnData]: """ %(summary)s. Parameters ---------- %(param_adata)s protein_expression_obsm_key key in `adata.obsm` for protein expression data. protein_names_uns_key key in `adata.uns` for protein names. If None, will use the column names of `adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign sequential names to proteins. %(param_batch_key)s %(param_layer)s %(param_size_factor_key)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s %(param_copy)s Returns ------- %(returns)s """ setup_method_args = cls._get_setup_method_args(**locals()) batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField( REGISTRY_KEYS.LABELS_KEY, None), # Default labels field for compatibility with TOTALVAE batch_field, NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ProteinObsmField( REGISTRY_KEYS.PROTEIN_EXP_KEY, protein_expression_obsm_key, use_batch_mask=True, batch_key=batch_field.attr_key, colnames_uns_key=protein_names_uns_key, is_count_data=True, ), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)