def test_no_shuffle_different_seeds_random_tp_uncached_mismatch(self):
   dataset1 = self.random_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=False, shuffle=False, seed=0)
   dataset2 = self.random_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=False, shuffle=False, seed=42)
   test_utils.assert_datasets_neq(dataset1, dataset2)
 def test_different_seeds_cached_mismatch(self):
   dataset1 = self.cached_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=True, shuffle=True, seed=0)
   dataset2 = self.cached_task.get_dataset(
       {"inputs": 13, "targets": 13},
       split="train", use_cached=True, shuffle=True, seed=42)
   test_utils.assert_datasets_neq(dataset1, dataset2)
コード例 #3
0
    def test_num_epochs(self):
        # Try repeating after preprocessing the dataset to verify the outputs are
        # the same.
        epoch1_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0)
        # `random_task` has 3 examples per epoch.
        epoch2_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0).repeat(2).skip(3)
        test_utils.assert_datasets_eq(epoch1_ds, epoch2_ds)

        # Try repeating before preprocessing the dataset to verify the outputs are
        # different.
        epoch1_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0)
        # `random_task` has 3 examples per epoch.
        epoch2_ds = self.random_task.get_dataset({
            "inputs": 13,
            "targets": 13
        },
                                                 split="train",
                                                 use_cached=False,
                                                 shuffle=True,
                                                 seed=0,
                                                 num_epochs=2).skip(3)
        test_utils.assert_datasets_neq(epoch1_ds, epoch2_ds)