def test_get_dataset(self): MixtureRegistry.add("test_mix3", [(self.cached_task.name, 1)]) task_ds = TaskRegistry.get_dataset(self.cached_task.name, { "inputs": 13, "targets": 13 }, "validation", use_cached=False, shuffle=False) mix_ds = MixtureRegistry.get("test_mix3").get_dataset( { "inputs": 13, "targets": 13 }, "validation", use_cached=False, shuffle=False) # mix.get_dataset strips non-output features task_ds = task_ds.map( lambda x: {k: x[k] for k in ["inputs", "targets"]}) # limit size since get_dataset repeats the dataset test_utils.assert_datasets_eq(task_ds.repeat(2), mix_ds.take(4))
def test_no_shuffle_with_seed_uncached(self): dataset1 = self.uncached_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=False, seed=0) dataset2 = self.uncached_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=False, seed=42) test_utils.assert_datasets_eq(dataset1, dataset2)
def test_same_seeds_random_tp_uncached(self): dataset1 = self.uncached_random_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=True, seed=0) dataset2 = self.uncached_random_task.get_dataset( {"inputs": 13, "targets": 13}, split="train", use_cached=False, shuffle=True, seed=0) test_utils.assert_datasets_eq(dataset1, dataset2)
def test_same_seeds_cached_match(self): dataset1 = self.cached_task.get_dataset({ "inputs": 13, "targets": 13 }, split="train", use_cached=True, shuffle=True, seed=0) dataset2 = self.cached_task.get_dataset({ "inputs": 13, "targets": 13 }, split="train", use_cached=True, shuffle=True, seed=0) test_utils.assert_datasets_eq(dataset1, dataset2)