Exemplo n.º 1
0
    def __init__(
        self,
        model,
        gene_dataset: GeneExpressionDataset,
        shuffle=False,
        indices=None,
        use_cuda=True,
        data_loader_kwargs=dict(),
    ):
        """

        When added to annotation, has a private name attribute
        """
        self.model = model
        self.gene_dataset = gene_dataset
        self.to_monitor = []
        self.use_cuda = use_cuda

        if indices is not None and shuffle:
            raise ValueError("indices is mutually exclusive with shuffle")
        if indices is None:
            if shuffle:
                sampler = RandomSampler(gene_dataset)
            else:
                sampler = SequentialSampler(gene_dataset)
        else:
            if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"):
                indices = np.where(indices)[0].ravel()
            sampler = SubsetRandomSampler(indices)
        self.data_loader_kwargs = copy.copy(data_loader_kwargs)
        self.data_loader_kwargs.update(
            {"collate_fn": gene_dataset.collate_fn_builder(), "sampler": sampler}
        )
        self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs)
Exemplo n.º 2
0
    def test_collate_add(self):
        data = np.ones((25, 2)) * np.arange(0, 25).reshape((-1, 1))
        batch_indices = np.arange(0, 25).reshape((-1, 1))
        x_coords = np.arange(0, 25).reshape((-1, 1))
        proteins = (np.ones((25, 3)) + np.arange(0, 25).reshape(
            (-1, 1)) + np.arange(0, 3))
        proteins_name = ["A", "B", "C"]
        dataset = GeneExpressionDataset()
        dataset.populate_from_data(
            data,
            batch_indices=batch_indices,
            cell_attributes_dict={"x_coords": x_coords},
            Ys=[
                CellMeasurement(
                    name="proteins",
                    data=proteins,
                    columns_attr_name="protein_names",
                    columns=proteins_name,
                )
            ],
        )

        collate_fn = dataset.collate_fn_builder(add_attributes_and_types={
            "x_coords": np.float32,
            "proteins": np.float32
        })
        x, mean, var, batch, labels, x_coords_tensor, proteins_tensor = collate_fn(
            [1, 2])
        self.assertListEqual(x_coords_tensor.tolist(), [[1.0], [2.0]])
        self.assertListEqual(proteins_tensor.tolist(),
                             [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
Exemplo n.º 3
0
    def __init__(
        self,
        model: TOTALVI,
        gene_dataset: GeneExpressionDataset,
        shuffle: bool = False,
        indices: Optional[np.ndarray] = None,
        use_cuda: bool = True,
        data_loader_kwargs=dict(),
    ):

        super().__init__(
            model,
            gene_dataset,
            shuffle=shuffle,
            indices=indices,
            use_cuda=use_cuda,
            data_loader_kwargs=data_loader_kwargs,
        )
        # Add protein tensor as another tensor to be loaded
        self.data_loader_kwargs.update(
            {
                "collate_fn": gene_dataset.collate_fn_builder(
                    {"protein_expression": np.float32}
                )
            }
        )
        self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs)
Exemplo n.º 4
0
    def test_collate_normal(self):
        data = np.ones((25, 2)) * np.arange(0, 25).reshape((-1, 1))
        batch_indices = np.arange(0, 25).reshape((-1, 1))
        dataset = GeneExpressionDataset()
        dataset.populate_from_data(data, batch_indices=batch_indices)

        collate_fn = dataset.collate_fn_builder()
        x, mean, var, batch, labels = collate_fn([1, 2])
        self.assertListEqual(x.tolist(), [[1.0, 1.0], [2.0, 2.0]])
        self.assertListEqual(batch.tolist(), [[1], [2]])