Пример #1
0
 def train_dataloader(self):
     train_dataset = TMAlignDataset(
         self.hparams.train_pairs,
         construct_paths=isinstance(self.loss_func, SoftPathLoss))
     train_dataloader = DataLoader(
         train_dataset, self.hparams.batch_size, collate_fn=collate_f,
         shuffle=True, num_workers=self.hparams.num_workers,
         pin_memory=True)
     return train_dataloader
Пример #2
0
 def test_dataloader(self):
     test_dataset = TMAlignDataset(
         self.hparams.test_pairs, return_names=True,
         construct_paths=isinstance(self.loss_func, SoftPathLoss))
     test_dataloader = DataLoader(
         test_dataset, self.hparams.batch_size, shuffle=False,
         collate_fn=test_collate_f, num_workers=self.hparams.num_workers,
         pin_memory=True)
     return test_dataloader
Пример #3
0
 def val_dataloader(self):
     valid_dataset = TMAlignDataset(
         self.hparams.valid_pairs,
         construct_paths=isinstance(self.loss_func, SoftPathLoss))
     valid_dataloader = DataLoader(
         valid_dataset, self.hparams.batch_size, collate_fn=collate_f,
         shuffle=False, num_workers=self.hparams.num_workers,
         pin_memory=True)
     return valid_dataloader
Пример #4
0
 def test_getitem(self):
     x = TMAlignDataset(self.data_path, tm_threshold=0,
                        pad_ends=False, clip_ends=True)
     res = x[0]
     self.assertEqual(len(res), 6)
     gene, pos, states, alignment_matrix, _, _ = res
     # test the lengths
     self.assertEqual(len(gene), 21)
     self.assertEqual(len(pos), 21)
     self.assertEqual(len(states), 21)
     # wtf is going on here??
     self.assertEqual(alignment_matrix.shape, (21, 21))
Пример #5
0
 def test_gappy_getitem(self):
     TMAlignDataset(self.data_path, tm_threshold=0,
                    pad_ends=False, clip_ends=False)
Пример #6
0
 def test_constructor(self):
     x = TMAlignDataset(self.data_path, tm_threshold=0, max_len=10000)
     self.assertEqual(len(x), 10)