예제 #1
0
    def test_get_dataset(self):
        MixtureRegistry.add("test_mix3", [(self.cached_task.name, 1)])

        task_ds = TaskRegistry.get_dataset(self.cached_task.name, {
            "inputs": 13,
            "targets": 13
        },
                                           "validation",
                                           use_cached=False,
                                           shuffle=False)

        mix_ds = MixtureRegistry.get("test_mix3").get_dataset(
            {
                "inputs": 13,
                "targets": 13
            },
            "validation",
            use_cached=False,
            shuffle=False)

        # mix.get_dataset strips non-output features
        task_ds = task_ds.map(
            lambda x: {k: x[k]
                       for k in ["inputs", "targets"]})

        # limit size since get_dataset repeats the dataset
        test_utils.assert_datasets_eq(task_ds.repeat(2), mix_ds.take(4))
 def test_same_seeds_random_tp_uncached_match(self):
   dataset1 = self.random_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=False, shuffle=True, seed=0).repeat(4)
   dataset2 = self.random_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=False, shuffle=True, seed=0).repeat(4)
   test_utils.assert_datasets_eq(dataset1, dataset2)
 def test_no_shuffle_with_seed_uncached_match(self):
   dataset1 = self.uncached_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=False, shuffle=False, seed=0)
   dataset2 = self.uncached_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=False, shuffle=False, seed=42)
   test_utils.assert_datasets_eq(dataset1, dataset2)
 def test_same_seeds_cached_match(self):
   dataset1 = self.cached_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=True, shuffle=True, seed=0)
   dataset2 = self.cached_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=True, shuffle=True, seed=0)
   test_utils.assert_datasets_eq(dataset1, dataset2)
예제 #5
0
    def test_num_epochs(self):
        # Try repeating after preprocessing the dataset to verify the outputs are
        # the same.
        epoch1_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0)
        # `random_task` has 3 examples per epoch.
        epoch2_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0).repeat(2).skip(3)
        test_utils.assert_datasets_eq(epoch1_ds, epoch2_ds)

        # Try repeating before preprocessing the dataset to verify the outputs are
        # different.
        epoch1_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0)
        # `random_task` has 3 examples per epoch.
        epoch2_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0,
                                                 num_epochs=2).skip(3)
        test_utils.assert_datasets_neq(epoch1_ds, epoch2_ds)