示例#1
0
 def __init__(self,
              train_biom,
              test_biom,
              valid_biom,
              metadata=None,
              batch_category=None,
              batch_size=10,
              num_workers=1):
     super().__init__()
     self.train_biom = train_biom
     self.test_biom = test_biom
     self.val_biom = valid_biom
     self.batch_size = batch_size
     self.num_workers = num_workers
     if metadata is not None:
         self.metadata = pd.read_table(metadata, dtype=str)
         index_name = self.metadata.columns[0]
         self.metadata = self.metadata.set_index(index_name)
     else:
         self.metadata = None
     self.batch_category = batch_category
     if self.batch_category is None:
         self.collate_f = collate_single_f
     else:
         self.collate_f = collate_batch_f
     # collect class mappings if they exist
     if batch_category is not None:
         train_dataset = BiomDataset(load_table(self.train_biom),
                                     metadata=self.metadata,
                                     batch_category=self.batch_category)
         self.batch_categories = train_dataset.batch_cats
示例#2
0
 def test_dataloader(self):
     test_dataset = BiomDataset(load_table(self.hparams.test_biom))
     test_dataloader = DataLoader(test_dataset,
                                  batch_size=self.hparams.batch_size,
                                  collate_fn=collate_single_f,
                                  shuffle=False,
                                  num_workers=self.hparams.num_workers,
                                  drop_last=True,
                                  pin_memory=True)
     return test_dataloader
示例#3
0
 def test_biom_getitem(self):
     data = BiomDataset(self.table, self.metadata, batch_category='batch')
     exp_sample = np.array([
         65., 66., 12., 94., 37., 43., 97., 69., 6., 22., 87., 43., 87., 5.,
         51., 53., 26., 54., 51., 76., 15., 92., 30., 43., 97., 98., 7.,
         43., 25., 51., 75., 39., 13., 90., 89., 48., 60., 79., 9., 97.,
         35., 47., 13., 44., 70., 94., 80., 62., 99., 73.
     ])
     batch = data[0]
     npt.assert_allclose(batch[0], exp_sample)
     npt.assert_allclose(batch[1], np.array(0))
示例#4
0
 def test_dataloader(self):
     test_dataset = BiomDataset(load_table(self.test_biom),
                                metadata=self.metadata,
                                batch_category=self.batch_category)
     test_dataloader = DataLoader(test_dataset,
                                  batch_size=self.batch_size,
                                  collate_fn=self.collate_f,
                                  shuffle=False,
                                  num_workers=self.num_workers,
                                  drop_last=True,
                                  pin_memory=True)
     return test_dataloader
示例#5
0
 def val_dataloader(self):
     val_dataset = BiomDataset(load_table(self.val_biom),
                               metadata=self.metadata,
                               batch_category=self.batch_category)
     batch_size = min(len(val_dataset) - 1, self.batch_size)
     val_dataloader = DataLoader(val_dataset,
                                 batch_size=batch_size,
                                 collate_fn=self.collate_f,
                                 shuffle=False,
                                 num_workers=self.num_workers,
                                 drop_last=True,
                                 pin_memory=True)
     return val_dataloader
示例#6
0
 def test_biom(self):
     data = BiomDataset(self.table)