def __init__( self, model, adata: anndata.AnnData, shuffle=False, indices=None, use_cuda=True, batch_size=128, data_loader_kwargs=dict(), ): self.model = model if "_scvi" not in adata.uns.keys(): raise ValueError( "Please run setup_anndata() on your anndata object first.") for key in self._data_and_attributes.keys(): if key not in adata.uns["_scvi"]["data_registry"].keys(): raise ValueError( "{} required for model but not included when setup_anndata was run" .format(key)) self.dataset = ScviDataset(adata, getitem_tensors=self._data_and_attributes) self.to_monitor = [] self.use_cuda = use_cuda if indices is None: inds = np.arange(len(self.dataset)) if shuffle: sampler_kwargs = { "indices": inds, "batch_size": batch_size, "shuffle": True, } else: sampler_kwargs = { "indices": inds, "batch_size": batch_size, "shuffle": False, } else: if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"): indices = np.where(indices)[0].ravel() indices = np.asarray(indices) sampler_kwargs = { "indices": indices, "batch_size": batch_size, "shuffle": True, } self.sampler_kwargs = sampler_kwargs sampler = BatchSampler(**self.sampler_kwargs) self.data_loader_kwargs = copy.copy(data_loader_kwargs) # do not touch batch size here, sampler gives batched indices self.data_loader_kwargs.update({ "sampler": sampler, "batch_size": None }) self.data_loader = DataLoader(self.dataset, **self.data_loader_kwargs) self.original_indices = self.indices
def __init__( self, adata: anndata.AnnData, shuffle=False, indices=None, batch_size=128, data_and_attributes: Optional[dict] = None, **data_loader_kwargs, ): if "_scvi" not in adata.uns.keys(): raise ValueError( "Please run setup_anndata() on your anndata object first.") if data_and_attributes is not None: data_registry = adata.uns["_scvi"]["data_registry"] for key in data_and_attributes.keys(): if key not in data_registry.keys(): raise ValueError( "{} required for model but not included when setup_anndata was run" .format(key)) self.dataset = ScviDataset(adata, getitem_tensors=data_and_attributes) sampler_kwargs = { "batch_size": batch_size, "shuffle": shuffle, } if indices is None: indices = np.arange(len(self.dataset)) sampler_kwargs["indices"] = indices else: if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"): indices = np.where(indices)[0].ravel() indices = np.asarray(indices) sampler_kwargs["indices"] = indices self.indices = indices self.sampler_kwargs = sampler_kwargs sampler = BatchSampler(**self.sampler_kwargs) self.data_loader_kwargs = copy.copy(data_loader_kwargs) # do not touch batch size here, sampler gives batched indices self.data_loader_kwargs.update({ "sampler": sampler, "batch_size": None }) super().__init__(self.dataset, **self.data_loader_kwargs)
def test_scvidataset_getitem(): adata = synthetic_iid() setup_anndata( adata, batch_key="batch", labels_key="labels", protein_expression_obsm_key="protein_expression", protein_names_uns_key="protein_names", ) # check that we can successfully pass in a list of tensors to get tensors_to_get = ["batch_indices", "local_l_var"] bd = ScviDataset(adata, getitem_tensors=tensors_to_get) np.testing.assert_array_equal(tensors_to_get, list(bd[1].keys())) # check that we can successfully pass in a dict of tensors and their associated types bd = ScviDataset(adata, getitem_tensors={ "X": np.int, "local_l_var": np.float64 }) assert bd[1]["X"].dtype == np.int64 assert bd[1]["local_l_var"].dtype == np.float64 # check that by default we get all the registered tensors bd = ScviDataset(adata) all_registered_tensors = list(adata.uns["_scvi"]["data_registry"].keys()) np.testing.assert_array_equal(all_registered_tensors, list(bd[1].keys())) assert bd[1]["X"].shape[0] == bd.n_vars # check that ScviDataset returns numpy array adata1 = synthetic_iid() bd = ScviDataset(adata1) for key, value in bd[1].items(): assert type(value) == np.ndarray # check ScviDataset returns numpy array counts were sparse adata = synthetic_iid(run_setup_anndata=False) adata.X = sparse.csr_matrix(adata.X) setup_anndata(adata) bd = ScviDataset(adata) for key, value in bd[1].items(): assert type(value) == np.ndarray # check ScviDataset returns numpy array if pro exp was sparse adata = synthetic_iid(run_setup_anndata=False) adata.obsm["protein_expression"] = sparse.csr_matrix( adata.obsm["protein_expression"]) setup_anndata(adata, batch_key="batch", protein_expression_obsm_key="protein_expression") bd = ScviDataset(adata) for key, value in bd[1].items(): assert type(value) == np.ndarray # check pro exp is being returned as numpy array even if its DF adata = synthetic_iid(run_setup_anndata=False) adata.obsm["protein_expression"] = pd.DataFrame( adata.obsm["protein_expression"], index=adata.obs_names) setup_anndata(adata, batch_key="batch", protein_expression_obsm_key="protein_expression") bd = ScviDataset(adata) for key, value in bd[1].items(): assert type(value) == np.ndarray