def test_no_shuffle_different_seeds_random_tp_uncached_mismatch(self): dataset1 = self.random_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=False, seed=0) dataset2 = self.random_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=False, seed=42) test_utils.assert_datasets_neq(dataset1, dataset2)
def test_different_seeds_cached_mismatch(self): dataset1 = self.cached_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=True, shuffle=True, seed=0) dataset2 = self.cached_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=True, shuffle=True, seed=42) test_utils.assert_datasets_neq(dataset1, dataset2)
def test_num_epochs(self): # Try repeating after preprocessing the dataset to verify the outputs are # the same. epoch1_ds = self.random_task.get_dataset({ "inputs": 13, "targets": 13 }, split="train", use_cached=False, shuffle=True, seed=0) # `random_task` has 3 examples per epoch. epoch2_ds = self.random_task.get_dataset({ "inputs": 13, "targets": 13 }, split="train", use_cached=False, shuffle=True, seed=0).repeat(2).skip(3) test_utils.assert_datasets_eq(epoch1_ds, epoch2_ds) # Try repeating before preprocessing the dataset to verify the outputs are # different. epoch1_ds = self.random_task.get_dataset({ "inputs": 13, "targets": 13 }, split="train", use_cached=False, shuffle=True, seed=0) # `random_task` has 3 examples per epoch. epoch2_ds = self.random_task.get_dataset({ "inputs": 13, "targets": 13 }, split="train", use_cached=False, shuffle=True, seed=0, num_epochs=2).skip(3) test_utils.assert_datasets_neq(epoch1_ds, epoch2_ds)