Пример #1
0
    def __init__(
            self,
            model,
            adata: anndata.AnnData,
            shuffle=False,
            indices=None,
            use_cuda=True,
            batch_size=128,
            data_loader_kwargs=dict(),
    ):
        self.model = model
        if "_scvi" not in adata.uns.keys():
            raise ValueError(
                "Please run setup_anndata() on your anndata object first.")
        for key in self._data_and_attributes.keys():
            if key not in adata.uns["_scvi"]["data_registry"].keys():
                raise ValueError(
                    "{} required for model but not included when setup_anndata was run"
                    .format(key))
        self.dataset = ScviDataset(adata,
                                   getitem_tensors=self._data_and_attributes)
        self.to_monitor = []
        self.use_cuda = use_cuda

        if indices is None:
            inds = np.arange(len(self.dataset))
            if shuffle:
                sampler_kwargs = {
                    "indices": inds,
                    "batch_size": batch_size,
                    "shuffle": True,
                }
            else:
                sampler_kwargs = {
                    "indices": inds,
                    "batch_size": batch_size,
                    "shuffle": False,
                }
        else:
            if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"):
                indices = np.where(indices)[0].ravel()
            indices = np.asarray(indices)
            sampler_kwargs = {
                "indices": indices,
                "batch_size": batch_size,
                "shuffle": True,
            }

        self.sampler_kwargs = sampler_kwargs
        sampler = BatchSampler(**self.sampler_kwargs)
        self.data_loader_kwargs = copy.copy(data_loader_kwargs)
        # do not touch batch size here, sampler gives batched indices
        self.data_loader_kwargs.update({
            "sampler": sampler,
            "batch_size": None
        })
        self.data_loader = DataLoader(self.dataset, **self.data_loader_kwargs)
        self.original_indices = self.indices
Пример #2
0
    def __init__(
        self,
        adata: anndata.AnnData,
        shuffle=False,
        indices=None,
        batch_size=128,
        data_and_attributes: Optional[dict] = None,
        **data_loader_kwargs,
    ):

        if "_scvi" not in adata.uns.keys():
            raise ValueError(
                "Please run setup_anndata() on your anndata object first.")

        if data_and_attributes is not None:
            data_registry = adata.uns["_scvi"]["data_registry"]
            for key in data_and_attributes.keys():
                if key not in data_registry.keys():
                    raise ValueError(
                        "{} required for model but not included when setup_anndata was run"
                        .format(key))

        self.dataset = ScviDataset(adata, getitem_tensors=data_and_attributes)

        sampler_kwargs = {
            "batch_size": batch_size,
            "shuffle": shuffle,
        }

        if indices is None:
            indices = np.arange(len(self.dataset))
            sampler_kwargs["indices"] = indices
        else:
            if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"):
                indices = np.where(indices)[0].ravel()
            indices = np.asarray(indices)
            sampler_kwargs["indices"] = indices

        self.indices = indices
        self.sampler_kwargs = sampler_kwargs
        sampler = BatchSampler(**self.sampler_kwargs)
        self.data_loader_kwargs = copy.copy(data_loader_kwargs)
        # do not touch batch size here, sampler gives batched indices
        self.data_loader_kwargs.update({
            "sampler": sampler,
            "batch_size": None
        })

        super().__init__(self.dataset, **self.data_loader_kwargs)
Пример #3
0
def test_scvidataset_getitem():
    adata = synthetic_iid()
    setup_anndata(
        adata,
        batch_key="batch",
        labels_key="labels",
        protein_expression_obsm_key="protein_expression",
        protein_names_uns_key="protein_names",
    )
    # check that we can successfully pass in a list of tensors to get
    tensors_to_get = ["batch_indices", "local_l_var"]
    bd = ScviDataset(adata, getitem_tensors=tensors_to_get)
    np.testing.assert_array_equal(tensors_to_get, list(bd[1].keys()))

    # check that we can successfully pass in a dict of tensors and their associated types
    bd = ScviDataset(adata,
                     getitem_tensors={
                         "X": np.int,
                         "local_l_var": np.float64
                     })
    assert bd[1]["X"].dtype == np.int64
    assert bd[1]["local_l_var"].dtype == np.float64

    # check that by default we get all the registered tensors
    bd = ScviDataset(adata)
    all_registered_tensors = list(adata.uns["_scvi"]["data_registry"].keys())
    np.testing.assert_array_equal(all_registered_tensors, list(bd[1].keys()))
    assert bd[1]["X"].shape[0] == bd.n_vars

    # check that ScviDataset returns numpy array
    adata1 = synthetic_iid()
    bd = ScviDataset(adata1)
    for key, value in bd[1].items():
        assert type(value) == np.ndarray

    # check ScviDataset returns numpy array counts were sparse
    adata = synthetic_iid(run_setup_anndata=False)
    adata.X = sparse.csr_matrix(adata.X)
    setup_anndata(adata)
    bd = ScviDataset(adata)
    for key, value in bd[1].items():
        assert type(value) == np.ndarray

    # check ScviDataset returns numpy array if pro exp was sparse
    adata = synthetic_iid(run_setup_anndata=False)
    adata.obsm["protein_expression"] = sparse.csr_matrix(
        adata.obsm["protein_expression"])
    setup_anndata(adata,
                  batch_key="batch",
                  protein_expression_obsm_key="protein_expression")
    bd = ScviDataset(adata)
    for key, value in bd[1].items():
        assert type(value) == np.ndarray

    # check pro exp is being returned as numpy array even if its DF
    adata = synthetic_iid(run_setup_anndata=False)
    adata.obsm["protein_expression"] = pd.DataFrame(
        adata.obsm["protein_expression"], index=adata.obs_names)
    setup_anndata(adata,
                  batch_key="batch",
                  protein_expression_obsm_key="protein_expression")
    bd = ScviDataset(adata)
    for key, value in bd[1].items():
        assert type(value) == np.ndarray