def __init__( self, model, gene_dataset: GeneExpressionDataset, shuffle=False, indices=None, use_cuda=True, data_loader_kwargs=dict(), ): """ When added to annotation, has a private name attribute """ self.model = model self.gene_dataset = gene_dataset self.to_monitor = [] self.use_cuda = use_cuda if indices is not None and shuffle: raise ValueError("indices is mutually exclusive with shuffle") if indices is None: if shuffle: sampler = RandomSampler(gene_dataset) else: sampler = SequentialSampler(gene_dataset) else: if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"): indices = np.where(indices)[0].ravel() sampler = SubsetRandomSampler(indices) self.data_loader_kwargs = copy.copy(data_loader_kwargs) self.data_loader_kwargs.update( {"collate_fn": gene_dataset.collate_fn_builder(), "sampler": sampler} ) self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs)
def test_collate_add(self): data = np.ones((25, 2)) * np.arange(0, 25).reshape((-1, 1)) batch_indices = np.arange(0, 25).reshape((-1, 1)) x_coords = np.arange(0, 25).reshape((-1, 1)) proteins = (np.ones((25, 3)) + np.arange(0, 25).reshape( (-1, 1)) + np.arange(0, 3)) proteins_name = ["A", "B", "C"] dataset = GeneExpressionDataset() dataset.populate_from_data( data, batch_indices=batch_indices, cell_attributes_dict={"x_coords": x_coords}, Ys=[ CellMeasurement( name="proteins", data=proteins, columns_attr_name="protein_names", columns=proteins_name, ) ], ) collate_fn = dataset.collate_fn_builder(add_attributes_and_types={ "x_coords": np.float32, "proteins": np.float32 }) x, mean, var, batch, labels, x_coords_tensor, proteins_tensor = collate_fn( [1, 2]) self.assertListEqual(x_coords_tensor.tolist(), [[1.0], [2.0]]) self.assertListEqual(proteins_tensor.tolist(), [[2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
def __init__( self, model: TOTALVI, gene_dataset: GeneExpressionDataset, shuffle: bool = False, indices: Optional[np.ndarray] = None, use_cuda: bool = True, data_loader_kwargs=dict(), ): super().__init__( model, gene_dataset, shuffle=shuffle, indices=indices, use_cuda=use_cuda, data_loader_kwargs=data_loader_kwargs, ) # Add protein tensor as another tensor to be loaded self.data_loader_kwargs.update( { "collate_fn": gene_dataset.collate_fn_builder( {"protein_expression": np.float32} ) } ) self.data_loader = DataLoader(gene_dataset, **self.data_loader_kwargs)
def test_collate_normal(self): data = np.ones((25, 2)) * np.arange(0, 25).reshape((-1, 1)) batch_indices = np.arange(0, 25).reshape((-1, 1)) dataset = GeneExpressionDataset() dataset.populate_from_data(data, batch_indices=batch_indices) collate_fn = dataset.collate_fn_builder() x, mean, var, batch, labels = collate_fn([1, 2]) self.assertListEqual(x.tolist(), [[1.0, 1.0], [2.0, 2.0]]) self.assertListEqual(batch.tolist(), [[1], [2]])