Exemplo n.º 1
0
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
    def __init__(self, methodName):
        super().__init__(methodName)
        self._set_default_seed = False

    @combinations.generate(test_base.v1_only_combinations())
    def test_tf1_enable_tf2_behaviour(self):
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

    @combinations.generate(test_base.v1_only_combinations())
    def test_tf1_disable_tf2_behaviour(self):
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

    @combinations.generate(test_base.v2_only_combinations())
    def test_tf2_enable_tf2_behaviour(self):
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

    @combinations.generate(test_base.v2_only_combinations())
    def test_tf2_disable_tf2_behaviour(self):
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())
Exemplo n.º 2
0
class SaveCheckpointTest(IOTest, checkpoint_test_base.CheckpointTestBase):

  def _build_ds(self):
    dataset = dataset_ops.Dataset.range(42)
    return dataset_ops._SaveDataset(
        dataset=dataset, path=self._save_dir, shard_func=None, compression=None)

  # This tests checkpointing for the _SaveDataset, which is internally
  # consumed in the save() function. The purpose of this test is to
  # thoroughly test the checkpointing functionality of the internal dataset.
  @combinations.generate(
      combinations.times(test_base.v2_only_combinations(),
                         checkpoint_test_base.default_test_combinations()))
  def test(self, verify_fn):
    verify_fn(self, self._build_ds, num_outputs=42)

  @combinations.generate(test_base.eager_only_combinations())
  def testSaveCheckpointingAPI(self):
    dataset = dataset_ops.Dataset.range(40)
    checkpoint_args = {"directory": self._checkpoint_prefix, "max_to_keep": 50}
    dataset.save(self._save_dir, checkpoint_args=checkpoint_args)
    num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
    # By default, we checkpoint every increment. Each checkpoint writes a
    # file containing the data and a file containing the index. There is
    # also an overall checkpoint file. Thus, we expect (2 * 40) + 1 files.
    self.assertEqual(81, num_checkpoint_files)

  @combinations.generate(test_base.eager_only_combinations())
  def testSaveCheckpointingAPICustomCheckpointInterval(self):
    dataset = dataset_ops.Dataset.range(40)
    step_counter = variables.Variable(0, trainable=False)
    checkpoint_args = {
        "checkpoint_interval": 5,
        "step_counter": step_counter,
        "directory": self._checkpoint_prefix,
        "max_to_keep": 10,
    }
    dataset.save(self._save_dir, checkpoint_args=checkpoint_args)
    num_checkpoint_files = len(list(os.listdir(self._checkpoint_prefix)))
    # We expect (2 * 8) + 1 files.
    self.assertEqual(17, num_checkpoint_files)

  @combinations.generate(test_base.eager_only_combinations())
  def testSaveCheckpointingAPIIncorrectArgs(self):
    dataset = dataset_ops.Dataset.range(42)
    checkpoint_args = {
        "directory": self._checkpoint_prefix,
        "incorrect_arg": "incorrect_arg"
    }
    with self.assertRaises(TypeError):
      dataset.save(
          dataset, self._save_dir, checkpoint_args=checkpoint_args)
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           data_service_test_base.all_cluster_configurations())
    )
    def testDistributeBasic(self, work_dir, fault_tolerant_mode):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 10
        ds = self.make_distributed_range_dataset(10, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeSparse(self):
        cluster = self.create_cluster(num_workers=1)
        element = sparse_tensor.SparseTensor(indices=[[0]],
                                             values=constant_op.constant(
                                                 [0], dtype=dtypes.int32),
                                             dense_shape=[1])
        ds = dataset_ops.Dataset.from_tensors(element)
        ds = self.make_distributed_dataset(ds, cluster)
        results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
        self.assertAllEqual(results, [[0]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeRagged(self):
        cluster = self.create_cluster(num_workers=1)
        ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
        ds = ds.map(math_ops.range)
        ds = ds.apply(batching.dense_to_ragged_batch(2))
        ds = self.make_distributed_dataset(ds, cluster)
        results = [elem.to_tensor() for elem in ds]
        self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
        self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
        self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDifferentShuffleOrders(self):
        random_seed.set_random_seed(None)
        num_elements = 100
        cluster = self.create_cluster(num_workers=2)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.shuffle(num_elements)
        ds = self.make_distributed_dataset(ds, cluster)
        output = [elem.numpy() for elem in ds]

        # The output will be two sequences of range(num_elements)
        # non-deterministically interleaved together. If the orders of the elements
        # were the same, first_order and second_order computed below will be equal.
        first_order = {}
        second_order = {}
        for element in output:
            if element in first_order:
                second_order[element] = len(second_order)
            else:
                first_order[element] = len(first_order)
        self.assertNotEqual(first_order, second_order)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleEpochs(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        for _ in range(10):
            self.assertEqual(list(range(num_elements)),
                             [elem.numpy() for elem in ds])

    @combinations.generate(test_base.eager_only_combinations())
    def testRepeatedDataset(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_repetitions *
                                   list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testConcurrentEpoch(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_datasets = 3
        iterators = []
        results = []
        for _ in range(num_datasets):
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            iterators.append(iter(ds))
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = next(iterators[dataset_ind]).numpy()
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedEpoch(self):
        self.skipTest("Not yet implemented")
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_iterators = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        result = []
        iterators = []
        for _ in range(num_iterators):
            iterators.append(iter(ds))

        # Alternate reading between the iterators.
        for _ in range(2):
            for it in iterators:
                result.append(next(it).numpy())

        # Drain the rest of the elements.
        for it in iterators:
            for elem in it:
                result.append(elem.numpy())

        self.assertCountEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultiWorker(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertCountEqual(num_workers * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testMaxOutstandingRequests(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 max_outstanding_requests=1)
        self.assertCountEqual(num_workers * list(range(num_elements)),
                              self.getDatasetOutput(ds))

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10

        @def_function.function
        def f():
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in ds:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobName(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        iter1 = iter(ds1)
        iter2 = iter(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(next(iter1).numpy())
            results.append(next(iter2).numpy())
        for elem in iter1:
            results.append(elem.numpy())
        for elem in iter2:
            results.append(elem.numpy())
        self.assertCountEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDifferentJobNames(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name1")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name2")
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultiIteration(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        # iteration 1
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, [])
        # iteration 2
        self.assertDatasetProduces(ds2, list(range(num_elements)))
        self.assertDatasetProduces(ds1, [])

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameRepeat(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        num_repetitions = 3
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds1 = ds1.repeat(num_repetitions)
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = ds2.repeat(num_repetitions)
        results = []
        iter1 = iter(ds1)
        iter2 = iter(ds2)
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(next(iter1).numpy())
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(next(iter2).numpy())
        for elem in iter1:
            results.append(elem.numpy())
        for elem in iter2:
            results.append(elem.numpy())
        self.assertCountEqual(num_repetitions * list(range(num_elements)),
                              results)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
    def testRoundRobin(self, num_workers, num_consumers):
        cluster = self.create_cluster(num_workers=num_workers)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        ds = dataset_ops.Dataset.range(10000000)
        ds = ds.repeat()
        consumers = []
        for consumer_index in range(num_consumers):
            consumers.append(
                self.make_distributed_dataset(ds,
                                              cluster,
                                              job_name="test",
                                              consumer_index=consumer_index,
                                              num_consumers=num_consumers))
        # Use parallel interleave to read from consumers in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(consumers)
        ds = ds.interleave(lambda x: x,
                           cycle_length=num_consumers,
                           num_parallel_calls=num_consumers)
        ds = ds.take(1000)
        results = self.getDatasetOutput(ds, requires_initialization=True)

        for i in range(0, len(results), num_consumers):
            self.assertEqual(0, results[i] % num_consumers)
            # Check that each group of `num_consumers` results are consecutive.
            for offset in range(1, num_consumers):
                if i + offset < len(results):
                    self.assertEqual(results[i] + offset, results[i + offset])

    @combinations.generate(test_base.default_test_combinations())
    def testRoundRobinBucketizing(self):
        # Tests a common use case for round robin reads. At each step, all
        # consumers should get batches with the same bucket size.
        cluster = self.create_cluster(num_workers=4)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_elements = 100
        low_bucket_max = 30
        mid_bucket_max = 60
        bucket_boundaries = [low_bucket_max, mid_bucket_max]
        batch_size = 10
        num_consumer_hosts = 3
        replicas_per_consumer_host = 5
        num_consumers = num_consumer_hosts * replicas_per_consumer_host
        bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
        # Set up the dataset that will run on the tf.data workers.
        ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
        ds = ds.shuffle(num_elements)
        ds = ds.repeat()
        ds = ds.apply(
            grouping.bucket_by_sequence_length(lambda x: x,
                                               bucket_boundaries,
                                               bucket_batch_sizes,
                                               drop_remainder=True))
        ds = ds.apply(
            grouping.group_by_window(
                lambda x: math_ops.cast(x[1], dtypes.int64),
                lambda _, x: dataset_ops.Dataset.from_tensors(x),
                window_size=num_consumers))
        ds = ds.flat_map(lambda x: x)

        # Set up the per-consumer-host datasets. During each global step, we pull
        # `replicas_per_consumer_host` batches from each of these datasets.
        host_datasets = []
        for host_index in range(num_consumer_hosts):
            per_replica_datasets = []
            for i in range(replicas_per_consumer_host):
                consumer_index = host_index * replicas_per_consumer_host + i
                per_replica_datasets.append(
                    self.make_distributed_dataset(
                        ds,
                        cluster,
                        job_name="test",
                        consumer_index=consumer_index,
                        num_consumers=num_consumers))
            host_dataset = dataset_ops.Dataset.from_tensor_slices(
                per_replica_datasets)
            host_dataset = host_dataset.interleave(
                lambda x: x,
                cycle_length=len(per_replica_datasets),
                num_parallel_calls=len(per_replica_datasets),
                deterministic=True)
            host_datasets.append(host_dataset)

        # Use parallel interleave to read from host datasets in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
        ds = ds.interleave(lambda x: x,
                           block_length=replicas_per_consumer_host,
                           cycle_length=len(host_datasets),
                           num_parallel_calls=len(host_datasets),
                           deterministic=True)

        num_rounds = 10
        get_next = self.getNext(ds, requires_initialization=True)
        results = []
        for _ in range(num_rounds * num_consumers):
            results.append(self.evaluate(get_next()))

        def get_bucket(elem):
            bucket_ind = 0
            while bucket_ind < len(bucket_boundaries
                                   ) and elem >= bucket_boundaries[bucket_ind]:
                bucket_ind += 1
            return bucket_ind

        # Check that the batches for each step contain elements from the same
        # bucket.
        for i in range(0, len(results), num_consumers):
            batches = results[num_consumers * i:num_consumers * (i + 1)]
            bucket_inds = [get_bucket(batch[0]) for batch in batches]
            for bucket_ind in bucket_inds[1:]:
                self.assertEqual(bucket_inds[0], bucket_ind)

    @combinations.generate(test_base.v1_only_combinations())
    def testRoundRobinFiniteV1(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test",
                                           consumer_index=0,
                                           num_consumers=1)

        with self.assertRaisesRegex(
                errors.FailedPreconditionError,
                "Encountered end of sequence on a "
                "round-robin read iterator"):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(test_base.v2_only_combinations())
    def testRoundRobinFiniteV2(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test",
                                           consumer_index=0,
                                           num_consumers=1)

        with self.assertRaisesRegex(
                errors.FailedPreconditionError, "Round robin reads "
                "require that the input dataset has infinite "
                "cardinality, but the dataset has cardinality " +
                str(num_elements)):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(job_name=[None, "test"])))
    def testGcUnusedJob(self, job_name):
        cluster = self.create_cluster(num_workers=1,
                                      job_gc_check_interval_ms=50,
                                      job_gc_timeout_ms=20)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 job_name=job_name)
        it = iter(ds)
        self.assertEqual(next(it).numpy(), 0)
        self.assertEqual(cluster.num_tasks_on_worker(), 1)
        del it
        while cluster.num_tasks_on_worker() > 0:
            time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testDontGcUsedJob(self):
        cluster = self.create_cluster(num_workers=1,
                                      job_gc_check_interval_ms=50,
                                      job_gc_timeout_ms=20)
        num_elements = 10
        it1 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test1"))
        it2 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        self.assertEqual(2, cluster.num_tasks_on_worker())
        del it1
        del it2
        # Check that only the first job is gced. The second job will not be gced
        # because there is still an outstanding iterator for it.
        while cluster.num_tasks_on_worker() > 1:
            time.sleep(0.1)
        self.assertEqual(1, cluster.num_tasks_on_worker())

    @combinations.generate(test_base.eager_only_combinations())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = self.create_cluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = dataset_ops.Options()
            opts.experimental_deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        cluster = self.create_cluster(num_workers=3)
        ds = self.make_distributed_dataset(ds, cluster)
        next(iter(ds))

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(external_state_policy=[
                distribute_options.ExternalStatePolicy.IGNORE,
                distribute_options.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.eager_only_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochTensorSlices(self):
        cluster = self.create_cluster(num_workers=2)
        vals = [5, 1, 2, 4]
        ds = dataset_ops.Dataset.from_tensor_slices(vals)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, vals, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochInterleave(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochParallelInterleave(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]),
            num_parallel_calls=dataset_ops.AUTOTUNE)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochFlatMap(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochForeverRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_elements = 20
        elements_to_read = 1000
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        it = iter(ds)
        results = {}
        for _ in range(elements_to_read):
            val = next(it).numpy()
            if val not in results:
                results[val] = 0
            results[val] += 1
        for i in range(num_elements):
            self.assertGreater(results[i], elements_to_read / num_elements / 2)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochForeverRepeatFewElements(self):
        num_workers = 5
        cluster = self.create_cluster(num_workers=num_workers)
        # Less than the number of workers, so that some workers get zero elements on
        # the first repetition.
        num_elements = 1
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        it = iter(ds)
        for _ in range(100):
            self.assertEqual(next(it).numpy(), 0)

        # Stop all but one worker and check that we can still read.
        for i in range(num_workers - 1):
            cluster.workers[i]._stop()
        for _ in range(100):
            self.assertEqual(next(it).numpy(), 0)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochShuffleAndRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).shuffle(
            num_elements).repeat(num_repeats)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    def testDistributeFromInterleave(self):
        cluster = self.create_cluster(num_workers=1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(_):
            dataset = dataset_ops.Dataset.range(2)
            self.make_distributed_dataset(dataset, cluster)
            return dataset

        ds = ds.interleave(interleave_fn, cycle_length=2)
        self.assertDatasetProduces(ds, [0, 0, 1, 1])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpoch(self):
        cluster = self.create_cluster(num_workers=2)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "service must be a string"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=1))

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "service must not be empty"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=""))

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeInvalidProcessingMode(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError,
                                    "invalid is not a valid processing mode"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="invalid",
                                            service="grpc://localhost:5000"))

    @combinations.generate(test_base.eager_only_combinations())
    def testZipDifferentProcessingModesDatasets(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1, cluster, processing_mode="distributed_epoch")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertDatasetProduces(ds,
                                   list(
                                       zip(range(num_elements),
                                           range(num_elements))),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testZipDifferentProcessingModesDatasetsSharedJobName(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1,
            cluster,
            processing_mode="distributed_epoch",
            job_name="job_name")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs",
                                            job_name="job_name")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "but there is already an existing job"):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetId(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdMultipleComponents(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdWrongElementSpec(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdNotRegistered(self):
        cluster = self.create_cluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = self.create_cluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds, requires_initialization=True)
        self.assertEqual(0, self.evaluate(get_next()))
        # Without properly implemented cancellation, we will hang here while trying
        # to garbage collect the dataset iterator.

    @combinations.generate(test_base.eager_only_combinations())
    def testRegisterEquivalentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = self.create_cluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
        id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
        self.assertEqual(id_1.numpy(), id_2.numpy())

    @combinations.generate(test_base.eager_only_combinations())
    def testRegisterDifferentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(20)
        cluster = self.create_cluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
        id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
        self.assertNotEqual(id_1.numpy(), id_2.numpy())

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedEpochOnZippedDataset(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = self.create_cluster(num_workers=1)

        ds_3 = dataset_ops.Dataset.zip((ds_1, ds_2))
        ds_3 = self.make_distributed_dataset(
            ds_3, cluster, processing_mode="distributed_epoch")

        error_regex = "Cannot create a split provider for dataset " + \
            "of type ZipDataset"
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds_3, requires_initialization=True)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedEpochOnDistributedDataset(self):
        cluster_1 = self.create_cluster(num_workers=1)
        cluster_2 = self.create_cluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        numbers = [1 * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(numbers)
        ds = self.make_distributed_dataset(ds,
                                           cluster_1,
                                           processing_mode="parallel_epochs")
        ds = ds.map(lambda x: x + 1)
        ds = self.make_distributed_dataset(ds,
                                           cluster_2,
                                           processing_mode="distributed_epoch")

        error_regex = "Cannot create a split provider for dataset " + \
            "of type DataServiceDataset"
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testTwoLevelDistribute(self):
        cluster_1_size = 3
        cluster_1 = self.create_cluster(num_workers=cluster_1_size)
        cluster_2 = self.create_cluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        strings = ["a" * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(strings)
        ds = ds.shuffle(len(strings))
        ds = self.make_distributed_dataset(ds, cluster_1)
        # Large enough so that all strings of the same size are windowed together.
        window_size = cluster_1_size * size_repeats
        batch_size = size_repeats

        def key_func(x):
            return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

        ds = ds.apply(
            grouping.group_by_window(
                key_func=key_func,
                reduce_func=lambda _, x: x.batch(batch_size),
                window_size=window_size))
        ds = self.make_distributed_dataset(ds, cluster_2)

        it = iter(ds)
        for _ in range(num_sizes):
            element = next(it).numpy()
            for _ in range(1, cluster_1_size):
                self.assertAllEqual(next(it).numpy(), element)
        self.assertEmpty(list(it))

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations()))
    def testDistributeLargeGraph(self):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=NO_WORK_DIR,
                                      fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [tensor])
class CoordinatedReadTest(data_service_test_base.TestBase,
                          parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
  def testBasic(self, num_workers, num_consumers):
    cluster = data_service_test_base.TestCluster(num_workers=num_workers)
    # Round robin reads can cause slow cluster shutdown.
    data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
    ds = self.make_coordinated_read_dataset(cluster, num_consumers)
    ds = ds.take(100)
    results = self.getDatasetOutput(ds)
    self.checkCoordinatedReadGroups(results, num_consumers)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations()))
  def testConsumerRestart(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    # Round robin reads can cause slow cluster shutdown.
    data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
    num_consumers = 3
    ds = self.make_coordinated_read_dataset(cluster, num_consumers)
    ds = ds.take(20)
    self.getDatasetOutput(ds)
    ds2 = self.make_coordinated_read_dataset(cluster, num_consumers)
    ds2 = ds2.take(20)
    with self.assertRaisesRegex(errors.FailedPreconditionError,
                                "current round has already reached"):
      self.getDatasetOutput(ds2)

  @combinations.generate(test_base.default_test_combinations())
  def testBucketizing(self):
    # Tests a common use case for round robin reads. At each step, all
    # consumers should get batches with the same bucket size.
    cluster = data_service_test_base.TestCluster(num_workers=4)
    # Round robin reads can cause slow cluster shutdown.
    data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
    num_elements = 100
    low_bucket_max = 30
    mid_bucket_max = 60
    bucket_boundaries = [low_bucket_max, mid_bucket_max]
    batch_size = 10
    num_consumer_hosts = 3
    replicas_per_consumer_host = 5
    num_consumers = num_consumer_hosts * replicas_per_consumer_host
    bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
    # Set up the dataset that will run on the tf.data workers.
    ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
    ds = ds.shuffle(num_elements)
    ds = ds.repeat()
    ds = ds.apply(
        grouping.bucket_by_sequence_length(
            lambda x: x,
            bucket_boundaries,
            bucket_batch_sizes,
            drop_remainder=True))
    ds = ds.apply(
        grouping.group_by_window(
            lambda x: math_ops.cast(x[1], dtypes.int64),
            lambda _, x: dataset_ops.Dataset.from_tensors(x),
            window_size=num_consumers))
    ds = ds.flat_map(lambda x: x)

    # Set up the per-consumer-host datasets. During each global step, we pull
    # `replicas_per_consumer_host` batches from each of these datasets.
    host_datasets = []
    for host_index in range(num_consumer_hosts):
      per_replica_datasets = []
      for i in range(replicas_per_consumer_host):
        consumer_index = host_index * replicas_per_consumer_host + i
        per_replica_datasets.append(
            self.make_distributed_dataset(
                ds,
                cluster,
                job_name="test",
                consumer_index=consumer_index,
                num_consumers=num_consumers))
      host_dataset = dataset_ops.Dataset.from_tensor_slices(
          per_replica_datasets)
      host_dataset = host_dataset.interleave(
          lambda x: x,
          cycle_length=len(per_replica_datasets),
          num_parallel_calls=len(per_replica_datasets),
          deterministic=True)
      host_datasets.append(host_dataset)

    # Use parallel interleave to read from host datasets in parallel.
    ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
    ds = ds.interleave(
        lambda x: x,
        block_length=replicas_per_consumer_host,
        cycle_length=len(host_datasets),
        num_parallel_calls=len(host_datasets),
        deterministic=True)

    num_rounds = 4
    get_next = self.getNext(ds)
    results = []
    for i in range(num_rounds * num_consumers):
      results.append(self.evaluate(get_next()))

    def get_bucket(elem):
      bucket_ind = 0
      while bucket_ind < len(
          bucket_boundaries) and elem >= bucket_boundaries[bucket_ind]:
        bucket_ind += 1
      return bucket_ind

    # Check that the batches for each step contain elements from the same
    # bucket.
    for i in range(0, len(results), num_consumers):
      batches = results[num_consumers * i:num_consumers * (i + 1)]
      bucket_inds = [get_bucket(batch[0]) for batch in batches]
      for bucket_ind in bucket_inds[1:]:
        self.assertEqual(
            bucket_inds[0], bucket_ind,
            "Batches: {}, Buckets: {}".format(batches, bucket_inds))

  @combinations.generate(test_base.v1_only_combinations())
  def testFiniteV1(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = self.make_distributed_dataset(
        ds, cluster, job_name="test", consumer_index=0, num_consumers=1)

    with self.assertRaisesRegex(
        errors.FailedPreconditionError, "Encountered end of sequence on a "
        "round-robin read iterator"):
      self.getDatasetOutput(ds)

  @combinations.generate(test_base.v2_only_combinations())
  def testFiniteV2(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = self.make_distributed_dataset(
        ds, cluster, job_name="test", consumer_index=0, num_consumers=1)

    with self.assertRaisesRegex(
        errors.FailedPreconditionError, "Round robin reads "
        "require that the input dataset has infinite "
        "cardinality, but the dataset has cardinality " + str(num_elements)):
      self.getDatasetOutput(ds)
Exemplo n.º 5
0
class CsvDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def _setup_files(self, inputs, linebreak='\n', compression_type=None):
        filenames = []
        for i, file_rows in enumerate(inputs):
            fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
            contents = linebreak.join(file_rows).encode('utf-8')
            if compression_type is None:
                with open(fn, 'wb') as f:
                    f.write(contents)
            elif compression_type == 'GZIP':
                with gzip.GzipFile(fn, 'wb') as f:
                    f.write(contents)
            elif compression_type == 'ZLIB':
                contents = zlib.compress(contents)
                with open(fn, 'wb') as f:
                    f.write(contents)
            else:
                raise ValueError('Unsupported compression_type',
                                 compression_type)
            filenames.append(fn)
        return filenames

    def _make_test_datasets(self, inputs, **kwargs):
        # Test by comparing its output to what we could get with map->decode_csv
        filenames = self._setup_files(inputs)
        dataset_expected = core_readers.TextLineDataset(filenames)
        dataset_expected = dataset_expected.map(
            lambda l: parsing_ops.decode_csv(l, **kwargs))
        dataset_actual = readers.CsvDataset(filenames, **kwargs)
        return (dataset_actual, dataset_expected)

    def _test_by_comparison(self, inputs, **kwargs):
        """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
        dataset_actual, dataset_expected = self._make_test_datasets(
            inputs, **kwargs)
        self.assertDatasetsEqual(dataset_actual, dataset_expected)

    def _test_dataset(
            self,
            inputs,
            expected_output=None,
            expected_err_re=None,
            linebreak='\n',
            compression_type=None,  # Used for both setup and parsing
            **kwargs):
        """Checks that elements produced by CsvDataset match expected output."""
        # Convert str type because py3 tf strings are bytestrings
        filenames = self._setup_files(inputs, linebreak, compression_type)
        kwargs['compression_type'] = compression_type
        if expected_err_re is not None:
            # Verify that OpError is produced as expected
            with self.assertRaisesOpError(expected_err_re):
                dataset = readers.CsvDataset(filenames, **kwargs)
                self.getDatasetOutput(dataset)
        else:
            dataset = readers.CsvDataset(filenames, **kwargs)
            expected_output = [
                tuple(
                    v.encode('utf-8') if isinstance(v, str) else v for v in op)
                for op in expected_output
            ]
            self.assertDatasetProduces(dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_requiredFields(self):
        record_defaults = [[]] * 4
        inputs = [['1,2,3,4']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_int(self):
        record_defaults = [[0]] * 4
        inputs = [['1,2,3,4', '5,6,7,8']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_float(self):
        record_defaults = [[0.0]] * 4
        inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_string(self):
        record_defaults = [['']] * 4
        inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withEmptyFields(self):
        record_defaults = [[0]] * 4
        inputs = [[',,,', '1,1,1,', ',2,2,2']]
        self._test_dataset(inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errWithUnquotedQuotes(self):
        record_defaults = [['']] * 3
        inputs = [['1,2"3,4']]
        self._test_dataset(
            inputs,
            expected_err_re='Unquoted fields cannot have quotes inside',
            record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errWithUnescapedQuotes(self):
        record_defaults = [['']] * 3
        inputs = [['"a"b","c","d"']]
        self._test_dataset(
            inputs,
            expected_err_re=
            'Quote inside a string has to be escaped by another quote',
            record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
        record_defaults = [['']] * 3
        inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
        filenames = self._setup_files(inputs)
        dataset = readers.CsvDataset(filenames,
                                     record_defaults=record_defaults)
        dataset = dataset.apply(error_ops.ignore_errors())
        self.assertDatasetProduces(dataset, [(b'e', b'f', b'g')])

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
        record_defaults = [['']] * 3
        inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
        filenames = self._setup_files(inputs)
        dataset = readers.CsvDataset(filenames,
                                     record_defaults=record_defaults)
        dataset = dataset.apply(error_ops.ignore_errors())
        self.assertDatasetProduces(dataset, [(b'e', b'f', b'g')])

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
        record_defaults = [['']] * 3
        inputs = [['1,2"3,4']]
        self._test_by_comparison(inputs,
                                 record_defaults=record_defaults,
                                 use_quote_delim=False)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_mixedTypes(self):
        record_defaults = [
            constant_op.constant([], dtype=dtypes.int32),
            constant_op.constant([], dtype=dtypes.float32),
            constant_op.constant([], dtype=dtypes.string),
            constant_op.constant([], dtype=dtypes.float64)
        ]
        inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withUseQuoteDelimFalse(self):
        record_defaults = [['']] * 4
        inputs = [['1,2,"3,4"', '"5,6",7,8']]
        self._test_by_comparison(inputs,
                                 record_defaults=record_defaults,
                                 use_quote_delim=False)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withFieldDelim(self):
        record_defaults = [[0]] * 4
        inputs = [['1:2:3:4', '5:6:7:8']]
        self._test_by_comparison(inputs,
                                 record_defaults=record_defaults,
                                 field_delim=':')

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withNaValue(self):
        record_defaults = [[0]] * 4
        inputs = [['1,NA,3,4', 'NA,6,7,8']]
        self._test_by_comparison(inputs,
                                 record_defaults=record_defaults,
                                 na_value='NA')

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withSelectCols(self):
        record_defaults = [['']] * 2
        inputs = [['1,2,3,4', '"5","6","7","8"']]
        self._test_by_comparison(inputs,
                                 record_defaults=record_defaults,
                                 select_cols=[1, 2])

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withSelectColsTooHigh(self):
        record_defaults = [[0]] * 2
        inputs = [['1,2,3,4', '5,6,7,8']]
        self._test_dataset(
            inputs,
            expected_err_re='Expect 2 fields but have 1 in record',
            record_defaults=record_defaults,
            select_cols=[3, 4])

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withOneCol(self):
        record_defaults = [['NA']]
        inputs = [['0', '', '2']]
        self._test_dataset(inputs, [['0'], ['NA'], ['2']],
                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withMultipleFiles(self):
        record_defaults = [[0]] * 4
        inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withLeadingAndTrailingSpaces(self):
        record_defaults = [[0.0]] * 4
        inputs = [['0, 1, 2, 3']]
        expected = [[0.0, 1.0, 2.0, 3.0]]
        self._test_dataset(inputs, expected, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errorWithMissingDefault(self):
        record_defaults = [[]] * 2
        inputs = [['0,']]
        self._test_dataset(
            inputs,
            expected_err_re='Field 1 is required but missing in record!',
            record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errorWithFewerDefaultsThanFields(self):
        record_defaults = [[0.0]] * 2
        inputs = [['0,1,2,3']]
        self._test_dataset(
            inputs,
            expected_err_re='Expect 2 fields but have more in record',
            record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errorWithMoreDefaultsThanFields(self):
        record_defaults = [[0.0]] * 5
        inputs = [['0,1,2,3']]
        self._test_dataset(
            inputs,
            expected_err_re='Expect 5 fields but have 4 in record',
            record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withHeader(self):
        record_defaults = [[0]] * 2
        inputs = [['col1,col2', '1,2']]
        expected = [[1, 2]]
        self._test_dataset(
            inputs,
            expected,
            record_defaults=record_defaults,
            header=True,
        )

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withHeaderAndNoRecords(self):
        record_defaults = [[0]] * 2
        inputs = [['col1,col2']]
        expected = []
        self._test_dataset(
            inputs,
            expected,
            record_defaults=record_defaults,
            header=True,
        )

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errorWithHeaderEmptyFile(self):
        record_defaults = [[0]] * 2
        inputs = [[]]
        expected_err_re = "Can't read header of file"
        self._test_dataset(
            inputs,
            expected_err_re=expected_err_re,
            record_defaults=record_defaults,
            header=True,
        )

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withEmptyFile(self):
        record_defaults = [['']] * 2
        inputs = [['']]  # Empty file
        self._test_dataset(inputs,
                           expected_output=[],
                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errorWithEmptyRecord(self):
        record_defaults = [['']] * 2
        inputs = [['', '1,2']]  # First record is empty
        self._test_dataset(
            inputs,
            expected_err_re='Expect 2 fields but have 1 in record',
            record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withChainedOps(self):
        # Testing that one dataset can create multiple iterators fine.
        # `repeat` creates multiple iterators from the same C++ Dataset.
        record_defaults = [[0]] * 4
        inputs = [['1,,3,4', '5,6,,8']]
        ds_actual, ds_expected = self._make_test_datasets(
            inputs, record_defaults=record_defaults)
        self.assertDatasetsEqual(
            ds_actual.repeat(5).prefetch(1),
            ds_expected.repeat(5).prefetch(1))

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withTypeDefaults(self):
        # Testing using dtypes as record_defaults for required fields
        record_defaults = [dtypes.float32, [0.0]]
        inputs = [['1.0,2.0', '3.0,4.0']]
        self._test_dataset(
            inputs,
            [[1.0, 2.0], [3.0, 4.0]],
            record_defaults=record_defaults,
        )

    @combinations.generate(test_base.default_test_combinations())
    def testMakeCsvDataset_fieldOrder(self):
        data = [[
            '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
            '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
        ]]
        file_path = self._setup_files(data)

        ds = readers.make_csv_dataset(file_path,
                                      batch_size=1,
                                      shuffle=False,
                                      num_epochs=1)
        nxt = self.getNext(ds)

        result = list(self.evaluate(nxt()).values())

        self.assertEqual(result, sorted(result))

## The following tests exercise parsing logic for quoted fields

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withQuoted(self):
        record_defaults = [['']] * 4
        inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)

    def testCsvDataset_withOneColAndQuotes(self):
        record_defaults = [['']]
        inputs = [['"0"', '"1"', '"2"']]
        self._test_dataset(inputs, [['0'], ['1'], ['2']],
                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withNewLine(self):
        # In this case, we expect it to behave differently from
        # TextLineDataset->map(decode_csv) since that flow has bugs
        record_defaults = [['']] * 4
        inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
        expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
        self._test_dataset(inputs, expected, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withNewLineInUnselectedCol(self):
        record_defaults = [['']]
        inputs = [['1,"2\n3",4', '5,6,7']]
        self._test_dataset(inputs,
                           expected_output=[['1'], ['5']],
                           record_defaults=record_defaults,
                           select_cols=[0])

    @combinations.generate(test_base.v2_only_combinations())
    def testCsvDataset_withExcludeCol(self):
        record_defaults = [['']]
        inputs = [['1,2,3', '5,6,7']]
        self._test_dataset(inputs,
                           expected_output=[['1'], ['5']],
                           record_defaults=record_defaults,
                           exclude_cols=[1, 2])

    @combinations.generate(test_base.v2_only_combinations())
    def testCsvDataset_withSelectandExcludeCol(self):
        record_defaults = [['']]
        inputs = [['1,2,3', '5,6,7']]
        self._test_dataset(
            inputs,
            expected_err_re=
            'Either select_cols or exclude_cols should be empty',
            record_defaults=record_defaults,
            select_cols=[0],
            exclude_cols=[1, 2])

    @combinations.generate(test_base.v2_only_combinations())
    def testCsvDataset_withExcludeColandRecordDefaultsTooLow(self):
        record_defaults = [['']]
        inputs = [['1,2,3', '5,6,7']]
        self._test_dataset(
            inputs,
            expected_err_re='Expect 1 fields but have more in record',
            record_defaults=record_defaults,
            exclude_cols=[0])

    @combinations.generate(test_base.v2_only_combinations())
    def testCsvDataset_withExcludeColandRecordDefaultsTooHigh(self):
        record_defaults = [['']] * 3
        inputs = [['1,2,3', '5,6,7']]
        self._test_dataset(
            inputs,
            expected_err_re='Expect 3 fields but have 2 in record',
            record_defaults=record_defaults,
            exclude_cols=[0])

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withMultipleNewLines(self):
        # In this case, we expect it to behave differently from
        # TextLineDataset->map(decode_csv) since that flow has bugs
        record_defaults = [['']] * 4
        inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
        expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
        self._test_dataset(inputs, expected, record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_errorWithTerminateMidRecord(self):
        record_defaults = [['']] * 4
        inputs = [['a,b,c,"a']]
        self._test_dataset(
            inputs,
            expected_err_re=
            'Reached end of file without closing quoted field in record',
            record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withEscapedQuotes(self):
        record_defaults = [['']] * 4
        inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
        self._test_by_comparison(inputs, record_defaults=record_defaults)


## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
## and different types of line breaks

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withInvalidBufferSize(self):
        record_defaults = [['']] * 4
        inputs = [['a,b,c,d']]
        self._test_dataset(inputs,
                           expected_err_re='buffer_size should be positive',
                           record_defaults=record_defaults,
                           buffer_size=0)

    def _test_dataset_on_buffer_sizes(self,
                                      inputs,
                                      expected,
                                      linebreak,
                                      record_defaults,
                                      compression_type=None,
                                      num_sizes_to_test=20):
        # Testing reading with a range of buffer sizes that should all work.
        for i in list(range(1, 1 + num_sizes_to_test)) + [None]:
            self._test_dataset(inputs,
                               expected,
                               linebreak=linebreak,
                               compression_type=compression_type,
                               record_defaults=record_defaults,
                               buffer_size=i)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withLF(self):
        record_defaults = [['NA']] * 3
        inputs = [['abc,def,ghi', '0,1,2', ',,']]
        expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\n',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withCR(self):
        # Test that when the line separator is '\r', parsing works with all buffer
        # sizes
        record_defaults = [['NA']] * 3
        inputs = [['abc,def,ghi', '0,1,2', ',,']]
        expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\r',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withCRLF(self):
        # Test that when the line separator is '\r\n', parsing works with all buffer
        # sizes
        record_defaults = [['NA']] * 3
        inputs = [['abc,def,ghi', '0,1,2', ',,']]
        expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\r\n',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withBufferSizeAndQuoted(self):
        record_defaults = [['NA']] * 3
        inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
        expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
                    ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\n',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withCRAndQuoted(self):
        # Test that when the line separator is '\r', parsing works with all buffer
        # sizes
        record_defaults = [['NA']] * 3
        inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
        expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
                    ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\r',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withCRLFAndQuoted(self):
        # Test that when the line separator is '\r\n', parsing works with all buffer
        # sizes
        record_defaults = [['NA']] * 3
        inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
        expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
                    ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\r\n',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withGzipCompressionType(self):
        record_defaults = [['NA']] * 3
        inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
        expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
                    ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\r\n',
                                           compression_type='GZIP',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withZlibCompressionType(self):
        record_defaults = [['NA']] * 3
        inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
        expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
                    ['NA', 'NA', 'NA']]
        self._test_dataset_on_buffer_sizes(inputs,
                                           expected,
                                           linebreak='\r\n',
                                           compression_type='ZLIB',
                                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_withScalarDefaults(self):
        record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
        inputs = [[',,,', '1,1,1,', ',2,2,2']]
        self._test_dataset(inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
                           record_defaults=record_defaults)

    @combinations.generate(test_base.default_test_combinations())
    def testCsvDataset_with2DDefaults(self):
        record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
        inputs = [[',,,', '1,1,1,', ',2,2,2']]

        if context.executing_eagerly():
            err_spec = errors.InvalidArgumentError, (
                'Each record default should be at '
                'most rank 1')
        else:
            err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'

        with self.assertRaisesWithPredicateMatch(*err_spec):
            self._test_dataset(inputs,
                               [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
                               record_defaults=record_defaults)

    def testCsvDataset_immutableParams(self):
        inputs = [['a,b,c', '1,2,3', '4,5,6']]
        filenames = self._setup_files(inputs)
        select_cols = ['a', 'c']
        _ = readers.make_csv_dataset(filenames,
                                     batch_size=1,
                                     select_columns=select_cols)
        self.assertAllEqual(select_cols, ['a', 'c'])