def test_train_val_split2(self): t1 = torch.Tensor(np.arange(10)) dataset = torch.utils.data.TensorDataset(t1) split_amt = 0.0 train_dataset, val_dataset = train_val_dataset_split( dataset, split_amt) self.assertEqual(len(train_dataset), int(len(t1) * (1 - split_amt))) self.assertEqual(len(val_dataset), int(len(t1) * split_amt))
def test_train_val_split2(self): t1 = torch.Tensor(np.arange(10)) dataset = torch.utils.data.TensorDataset(t1) split_amt = 0.0 def val_data_xform(x): return x def val_label_xform(y): return y**2 train_dataset, val_dataset = train_val_dataset_split( dataset, split_amt, val_data_xform, val_label_xform) self.assertEqual(len(train_dataset), int(len(t1) * (1 - split_amt))) self.assertEqual(len(val_dataset), int(len(t1) * split_amt)) self.assertEqual(val_dataset.data_transform, val_data_xform) self.assertEqual(val_dataset.label_transform, val_label_xform)