def test_different_seeds_cached(self): dataset1 = self.cached_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=True, shuffle=True, seed=0) dataset2 = self.cached_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=True, shuffle=True, seed=42) test_utils.assert_datasets_neq(dataset1, dataset2)
def test_no_shuffle_different_seeds_random_tp_uncached(self): dataset1 = self.uncached_random_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=False, seed=0) dataset2 = self.uncached_random_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=False, seed=42) test_utils.assert_datasets_neq(dataset1, dataset2)
def test_same_seeds_random_tp_uncached_mismatch(self): # Expected *not* to equal due to parallel mapping. dataset1 = self.uncached_random_task.get_dataset( { "inputs": 13, "targets": 13 }, split="train", use_cached=False, shuffle=True, seed=0).repeat(4) dataset2 = self.uncached_random_task.get_dataset( { "inputs": 13, "targets": 13 }, split="train", use_cached=False, shuffle=True, seed=0).repeat(4) test_utils.assert_datasets_neq(dataset1, dataset2)