def train_dataloader(self): train_dataset = TMAlignDataset( self.hparams.train_pairs, construct_paths=isinstance(self.loss_func, SoftPathLoss)) train_dataloader = DataLoader( train_dataset, self.hparams.batch_size, collate_fn=collate_f, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True) return train_dataloader
def test_dataloader(self): test_dataset = TMAlignDataset( self.hparams.test_pairs, return_names=True, construct_paths=isinstance(self.loss_func, SoftPathLoss)) test_dataloader = DataLoader( test_dataset, self.hparams.batch_size, shuffle=False, collate_fn=test_collate_f, num_workers=self.hparams.num_workers, pin_memory=True) return test_dataloader
def val_dataloader(self): valid_dataset = TMAlignDataset( self.hparams.valid_pairs, construct_paths=isinstance(self.loss_func, SoftPathLoss)) valid_dataloader = DataLoader( valid_dataset, self.hparams.batch_size, collate_fn=collate_f, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True) return valid_dataloader
def test_getitem(self): x = TMAlignDataset(self.data_path, tm_threshold=0, pad_ends=False, clip_ends=True) res = x[0] self.assertEqual(len(res), 6) gene, pos, states, alignment_matrix, _, _ = res # test the lengths self.assertEqual(len(gene), 21) self.assertEqual(len(pos), 21) self.assertEqual(len(states), 21) # wtf is going on here?? self.assertEqual(alignment_matrix.shape, (21, 21))
def test_gappy_getitem(self): TMAlignDataset(self.data_path, tm_threshold=0, pad_ends=False, clip_ends=False)
def test_constructor(self): x = TMAlignDataset(self.data_path, tm_threshold=0, max_len=10000) self.assertEqual(len(x), 10)