Beispiel #1
0
class EnablingTF2Behavior(test.TestCase, parameterized.TestCase):
    def __init__(self, methodName):
        super().__init__(methodName)
        self._set_default_seed = False

    @combinations.generate(test_base.v1_only_combinations())
    def test_tf1_enable_tf2_behaviour(self):
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

    @combinations.generate(test_base.v1_only_combinations())
    def test_tf1_disable_tf2_behaviour(self):
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

    @combinations.generate(test_base.v2_only_combinations())
    def test_tf2_enable_tf2_behaviour(self):
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

    @combinations.generate(test_base.v2_only_combinations())
    def test_tf2_disable_tf2_behaviour(self):
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())

        v2_compat.disable_v2_behavior()
        self.assertFalse(tf2.enabled())
        self.assertFalse(_pywrap_tf2.is_enabled())

        v2_compat.enable_v2_behavior()
        self.assertTrue(tf2.enabled())
        self.assertTrue(_pywrap_tf2.is_enabled())
Beispiel #2
0
class FromSparseTensorSlicesCheckpointTest(
        checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
    def _build_sparse_tensor_slice_dataset(self, slices):
        # pylint: disable=g-complex-comprehension
        indices = np.array([[i, j] for i in range(len(slices))
                            for j in range(len(slices[i]))],
                           dtype=np.int64)
        values = np.array([val for s in slices for val in s], dtype=np.float64)
        # pylint: enable=g-complex-comprehension
        dense_shape = np.array(
            [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
        sparse_components = sparse_tensor.SparseTensor(indices, values,
                                                       dense_shape)
        return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)

    @combinations.generate(
        combinations.times(test_base.v1_only_combinations(),
                           checkpoint_test_base.default_test_combinations()))
    def test(self, verify_fn):
        slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]

        verify_fn(self,
                  lambda: self._build_sparse_tensor_slice_dataset(slices),
                  num_outputs=9,
                  sparse_tensors=True)
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           data_service_test_base.all_cluster_configurations())
    )
    def testDistributeBasic(self, work_dir, fault_tolerant_mode):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 10
        ds = self.make_distributed_range_dataset(10, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeSparse(self):
        cluster = self.create_cluster(num_workers=1)
        element = sparse_tensor.SparseTensor(indices=[[0]],
                                             values=constant_op.constant(
                                                 [0], dtype=dtypes.int32),
                                             dense_shape=[1])
        ds = dataset_ops.Dataset.from_tensors(element)
        ds = self.make_distributed_dataset(ds, cluster)
        results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
        self.assertAllEqual(results, [[0]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeRagged(self):
        cluster = self.create_cluster(num_workers=1)
        ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
        ds = ds.map(math_ops.range)
        ds = ds.apply(batching.dense_to_ragged_batch(2))
        ds = self.make_distributed_dataset(ds, cluster)
        results = [elem.to_tensor() for elem in ds]
        self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
        self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
        self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDifferentShuffleOrders(self):
        random_seed.set_random_seed(None)
        num_elements = 100
        cluster = self.create_cluster(num_workers=2)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.shuffle(num_elements)
        ds = self.make_distributed_dataset(ds, cluster)
        output = [elem.numpy() for elem in ds]

        # The output will be two sequences of range(num_elements)
        # non-deterministically interleaved together. If the orders of the elements
        # were the same, first_order and second_order computed below will be equal.
        first_order = {}
        second_order = {}
        for element in output:
            if element in first_order:
                second_order[element] = len(second_order)
            else:
                first_order[element] = len(first_order)
        self.assertNotEqual(first_order, second_order)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleEpochs(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        for _ in range(10):
            self.assertEqual(list(range(num_elements)),
                             [elem.numpy() for elem in ds])

    @combinations.generate(test_base.eager_only_combinations())
    def testRepeatedDataset(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_repetitions *
                                   list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testConcurrentEpoch(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_datasets = 3
        iterators = []
        results = []
        for _ in range(num_datasets):
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            iterators.append(iter(ds))
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = next(iterators[dataset_ind]).numpy()
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedEpoch(self):
        self.skipTest("Not yet implemented")
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        num_iterators = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        result = []
        iterators = []
        for _ in range(num_iterators):
            iterators.append(iter(ds))

        # Alternate reading between the iterators.
        for _ in range(2):
            for it in iterators:
                result.append(next(it).numpy())

        # Drain the rest of the elements.
        for it in iterators:
            for elem in it:
                result.append(elem.numpy())

        self.assertCountEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultiWorker(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertCountEqual(num_workers * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testMaxOutstandingRequests(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 max_outstanding_requests=1)
        self.assertCountEqual(num_workers * list(range(num_elements)),
                              self.getDatasetOutput(ds))

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 10

        @def_function.function
        def f():
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in ds:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobName(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        iter1 = iter(ds1)
        iter2 = iter(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(next(iter1).numpy())
            results.append(next(iter2).numpy())
        for elem in iter1:
            results.append(elem.numpy())
        for elem in iter2:
            results.append(elem.numpy())
        self.assertCountEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDifferentJobNames(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name1")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name2")
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultiIteration(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        # iteration 1
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, [])
        # iteration 2
        self.assertDatasetProduces(ds2, list(range(num_elements)))
        self.assertDatasetProduces(ds1, [])

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameRepeat(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        num_repetitions = 3
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds1 = ds1.repeat(num_repetitions)
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = ds2.repeat(num_repetitions)
        results = []
        iter1 = iter(ds1)
        iter2 = iter(ds2)
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(next(iter1).numpy())
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(next(iter2).numpy())
        for elem in iter1:
            results.append(elem.numpy())
        for elem in iter2:
            results.append(elem.numpy())
        self.assertCountEqual(num_repetitions * list(range(num_elements)),
                              results)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
    def testRoundRobin(self, num_workers, num_consumers):
        cluster = self.create_cluster(num_workers=num_workers)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        ds = dataset_ops.Dataset.range(10000000)
        ds = ds.repeat()
        consumers = []
        for consumer_index in range(num_consumers):
            consumers.append(
                self.make_distributed_dataset(ds,
                                              cluster,
                                              job_name="test",
                                              consumer_index=consumer_index,
                                              num_consumers=num_consumers))
        # Use parallel interleave to read from consumers in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(consumers)
        ds = ds.interleave(lambda x: x,
                           cycle_length=num_consumers,
                           num_parallel_calls=num_consumers)
        ds = ds.take(1000)
        results = self.getDatasetOutput(ds, requires_initialization=True)

        for i in range(0, len(results), num_consumers):
            self.assertEqual(0, results[i] % num_consumers)
            # Check that each group of `num_consumers` results are consecutive.
            for offset in range(1, num_consumers):
                if i + offset < len(results):
                    self.assertEqual(results[i] + offset, results[i + offset])

    @combinations.generate(test_base.default_test_combinations())
    def testRoundRobinBucketizing(self):
        # Tests a common use case for round robin reads. At each step, all
        # consumers should get batches with the same bucket size.
        cluster = self.create_cluster(num_workers=4)
        # Round robin reads can cause slow cluster shutdown.
        data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
        num_elements = 100
        low_bucket_max = 30
        mid_bucket_max = 60
        bucket_boundaries = [low_bucket_max, mid_bucket_max]
        batch_size = 10
        num_consumer_hosts = 3
        replicas_per_consumer_host = 5
        num_consumers = num_consumer_hosts * replicas_per_consumer_host
        bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
        # Set up the dataset that will run on the tf.data workers.
        ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
        ds = ds.shuffle(num_elements)
        ds = ds.repeat()
        ds = ds.apply(
            grouping.bucket_by_sequence_length(lambda x: x,
                                               bucket_boundaries,
                                               bucket_batch_sizes,
                                               drop_remainder=True))
        ds = ds.apply(
            grouping.group_by_window(
                lambda x: math_ops.cast(x[1], dtypes.int64),
                lambda _, x: dataset_ops.Dataset.from_tensors(x),
                window_size=num_consumers))
        ds = ds.flat_map(lambda x: x)

        # Set up the per-consumer-host datasets. During each global step, we pull
        # `replicas_per_consumer_host` batches from each of these datasets.
        host_datasets = []
        for host_index in range(num_consumer_hosts):
            per_replica_datasets = []
            for i in range(replicas_per_consumer_host):
                consumer_index = host_index * replicas_per_consumer_host + i
                per_replica_datasets.append(
                    self.make_distributed_dataset(
                        ds,
                        cluster,
                        job_name="test",
                        consumer_index=consumer_index,
                        num_consumers=num_consumers))
            host_dataset = dataset_ops.Dataset.from_tensor_slices(
                per_replica_datasets)
            host_dataset = host_dataset.interleave(
                lambda x: x,
                cycle_length=len(per_replica_datasets),
                num_parallel_calls=len(per_replica_datasets),
                deterministic=True)
            host_datasets.append(host_dataset)

        # Use parallel interleave to read from host datasets in parallel.
        ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
        ds = ds.interleave(lambda x: x,
                           block_length=replicas_per_consumer_host,
                           cycle_length=len(host_datasets),
                           num_parallel_calls=len(host_datasets),
                           deterministic=True)

        num_rounds = 10
        get_next = self.getNext(ds, requires_initialization=True)
        results = []
        for _ in range(num_rounds * num_consumers):
            results.append(self.evaluate(get_next()))

        def get_bucket(elem):
            bucket_ind = 0
            while bucket_ind < len(bucket_boundaries
                                   ) and elem >= bucket_boundaries[bucket_ind]:
                bucket_ind += 1
            return bucket_ind

        # Check that the batches for each step contain elements from the same
        # bucket.
        for i in range(0, len(results), num_consumers):
            batches = results[num_consumers * i:num_consumers * (i + 1)]
            bucket_inds = [get_bucket(batch[0]) for batch in batches]
            for bucket_ind in bucket_inds[1:]:
                self.assertEqual(bucket_inds[0], bucket_ind)

    @combinations.generate(test_base.v1_only_combinations())
    def testRoundRobinFiniteV1(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test",
                                           consumer_index=0,
                                           num_consumers=1)

        with self.assertRaisesRegex(
                errors.FailedPreconditionError,
                "Encountered end of sequence on a "
                "round-robin read iterator"):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(test_base.v2_only_combinations())
    def testRoundRobinFiniteV2(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test",
                                           consumer_index=0,
                                           num_consumers=1)

        with self.assertRaisesRegex(
                errors.FailedPreconditionError, "Round robin reads "
                "require that the input dataset has infinite "
                "cardinality, but the dataset has cardinality " +
                str(num_elements)):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(job_name=[None, "test"])))
    def testGcUnusedJob(self, job_name):
        cluster = self.create_cluster(num_workers=1,
                                      job_gc_check_interval_ms=50,
                                      job_gc_timeout_ms=20)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 job_name=job_name)
        it = iter(ds)
        self.assertEqual(next(it).numpy(), 0)
        self.assertEqual(cluster.num_tasks_on_worker(), 1)
        del it
        while cluster.num_tasks_on_worker() > 0:
            time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testDontGcUsedJob(self):
        cluster = self.create_cluster(num_workers=1,
                                      job_gc_check_interval_ms=50,
                                      job_gc_timeout_ms=20)
        num_elements = 10
        it1 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test1"))
        it2 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        self.assertEqual(2, cluster.num_tasks_on_worker())
        del it1
        del it2
        # Check that only the first job is gced. The second job will not be gced
        # because there is still an outstanding iterator for it.
        while cluster.num_tasks_on_worker() > 1:
            time.sleep(0.1)
        self.assertEqual(1, cluster.num_tasks_on_worker())

    @combinations.generate(test_base.eager_only_combinations())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = self.create_cluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = dataset_ops.Options()
            opts.experimental_deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        cluster = self.create_cluster(num_workers=3)
        ds = self.make_distributed_dataset(ds, cluster)
        next(iter(ds))

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(external_state_policy=[
                distribute_options.ExternalStatePolicy.IGNORE,
                distribute_options.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.eager_only_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochTensorSlices(self):
        cluster = self.create_cluster(num_workers=2)
        vals = [5, 1, 2, 4]
        ds = dataset_ops.Dataset.from_tensor_slices(vals)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, vals, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochInterleave(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochParallelInterleave(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]),
            num_parallel_calls=dataset_ops.AUTOTUNE)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochFlatMap(self):
        cluster = self.create_cluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochForeverRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_elements = 20
        elements_to_read = 1000
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        it = iter(ds)
        results = {}
        for _ in range(elements_to_read):
            val = next(it).numpy()
            if val not in results:
                results[val] = 0
            results[val] += 1
        for i in range(num_elements):
            self.assertGreater(results[i], elements_to_read / num_elements / 2)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochForeverRepeatFewElements(self):
        num_workers = 5
        cluster = self.create_cluster(num_workers=num_workers)
        # Less than the number of workers, so that some workers get zero elements on
        # the first repetition.
        num_elements = 1
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        it = iter(ds)
        for _ in range(100):
            self.assertEqual(next(it).numpy(), 0)

        # Stop all but one worker and check that we can still read.
        for i in range(num_workers - 1):
            cluster.workers[i]._stop()
        for _ in range(100):
            self.assertEqual(next(it).numpy(), 0)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpochShuffleAndRepeat(self):
        cluster = self.create_cluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).shuffle(
            num_elements).repeat(num_repeats)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    def testDistributeFromInterleave(self):
        cluster = self.create_cluster(num_workers=1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(_):
            dataset = dataset_ops.Dataset.range(2)
            self.make_distributed_dataset(dataset, cluster)
            return dataset

        ds = ds.interleave(interleave_fn, cycle_length=2)
        self.assertDatasetProduces(ds, [0, 0, 1, 1])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeDistributedEpoch(self):
        cluster = self.create_cluster(num_workers=2)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "service must be a string"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=1))

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "service must not be empty"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=""))

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeInvalidProcessingMode(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError,
                                    "invalid is not a valid processing mode"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="invalid",
                                            service="grpc://localhost:5000"))

    @combinations.generate(test_base.eager_only_combinations())
    def testZipDifferentProcessingModesDatasets(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1, cluster, processing_mode="distributed_epoch")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertDatasetProduces(ds,
                                   list(
                                       zip(range(num_elements),
                                           range(num_elements))),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testZipDifferentProcessingModesDatasetsSharedJobName(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1,
            cluster,
            processing_mode="distributed_epoch",
            job_name="job_name")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs",
                                            job_name="job_name")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "but there is already an existing job"):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetId(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdMultipleComponents(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdWrongElementSpec(self):
        cluster = self.create_cluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(cluster.target, ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdNotRegistered(self):
        cluster = self.create_cluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.target, dataset_id, element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = self.create_cluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds, requires_initialization=True)
        self.assertEqual(0, self.evaluate(get_next()))
        # Without properly implemented cancellation, we will hang here while trying
        # to garbage collect the dataset iterator.

    @combinations.generate(test_base.eager_only_combinations())
    def testRegisterEquivalentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = self.create_cluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
        id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
        self.assertEqual(id_1.numpy(), id_2.numpy())

    @combinations.generate(test_base.eager_only_combinations())
    def testRegisterDifferentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(20)
        cluster = self.create_cluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.target, ds_1)
        id_2 = data_service_ops.register_dataset(cluster.target, ds_2)
        self.assertNotEqual(id_1.numpy(), id_2.numpy())

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedEpochOnZippedDataset(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = self.create_cluster(num_workers=1)

        ds_3 = dataset_ops.Dataset.zip((ds_1, ds_2))
        ds_3 = self.make_distributed_dataset(
            ds_3, cluster, processing_mode="distributed_epoch")

        error_regex = "Cannot create a split provider for dataset " + \
            "of type ZipDataset"
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds_3, requires_initialization=True)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedEpochOnDistributedDataset(self):
        cluster_1 = self.create_cluster(num_workers=1)
        cluster_2 = self.create_cluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        numbers = [1 * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(numbers)
        ds = self.make_distributed_dataset(ds,
                                           cluster_1,
                                           processing_mode="parallel_epochs")
        ds = ds.map(lambda x: x + 1)
        ds = self.make_distributed_dataset(ds,
                                           cluster_2,
                                           processing_mode="distributed_epoch")

        error_regex = "Cannot create a split provider for dataset " + \
            "of type DataServiceDataset"
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds, requires_initialization=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testTwoLevelDistribute(self):
        cluster_1_size = 3
        cluster_1 = self.create_cluster(num_workers=cluster_1_size)
        cluster_2 = self.create_cluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        strings = ["a" * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(strings)
        ds = ds.shuffle(len(strings))
        ds = self.make_distributed_dataset(ds, cluster_1)
        # Large enough so that all strings of the same size are windowed together.
        window_size = cluster_1_size * size_repeats
        batch_size = size_repeats

        def key_func(x):
            return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

        ds = ds.apply(
            grouping.group_by_window(
                key_func=key_func,
                reduce_func=lambda _, x: x.batch(batch_size),
                window_size=window_size))
        ds = self.make_distributed_dataset(ds, cluster_2)

        it = iter(ds)
        for _ in range(num_sizes):
            element = next(it).numpy()
            for _ in range(1, cluster_1_size):
                self.assertAllEqual(next(it).numpy(), element)
        self.assertEmpty(list(it))

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations()))
    def testDistributeLargeGraph(self):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=NO_WORK_DIR,
                                      fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [tensor])
class CoordinatedReadTest(data_service_test_base.TestBase,
                          parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
  def testBasic(self, num_workers, num_consumers):
    cluster = data_service_test_base.TestCluster(num_workers=num_workers)
    # Round robin reads can cause slow cluster shutdown.
    data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
    ds = self.make_coordinated_read_dataset(cluster, num_consumers)
    ds = ds.take(100)
    results = self.getDatasetOutput(ds)
    self.checkCoordinatedReadGroups(results, num_consumers)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations()))
  def testConsumerRestart(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    # Round robin reads can cause slow cluster shutdown.
    data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
    num_consumers = 3
    ds = self.make_coordinated_read_dataset(cluster, num_consumers)
    ds = ds.take(20)
    self.getDatasetOutput(ds)
    ds2 = self.make_coordinated_read_dataset(cluster, num_consumers)
    ds2 = ds2.take(20)
    with self.assertRaisesRegex(errors.FailedPreconditionError,
                                "current round has already reached"):
      self.getDatasetOutput(ds2)

  @combinations.generate(test_base.default_test_combinations())
  def testBucketizing(self):
    # Tests a common use case for round robin reads. At each step, all
    # consumers should get batches with the same bucket size.
    cluster = data_service_test_base.TestCluster(num_workers=4)
    # Round robin reads can cause slow cluster shutdown.
    data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
    num_elements = 100
    low_bucket_max = 30
    mid_bucket_max = 60
    bucket_boundaries = [low_bucket_max, mid_bucket_max]
    batch_size = 10
    num_consumer_hosts = 3
    replicas_per_consumer_host = 5
    num_consumers = num_consumer_hosts * replicas_per_consumer_host
    bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
    # Set up the dataset that will run on the tf.data workers.
    ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
    ds = ds.shuffle(num_elements)
    ds = ds.repeat()
    ds = ds.apply(
        grouping.bucket_by_sequence_length(
            lambda x: x,
            bucket_boundaries,
            bucket_batch_sizes,
            drop_remainder=True))
    ds = ds.apply(
        grouping.group_by_window(
            lambda x: math_ops.cast(x[1], dtypes.int64),
            lambda _, x: dataset_ops.Dataset.from_tensors(x),
            window_size=num_consumers))
    ds = ds.flat_map(lambda x: x)

    # Set up the per-consumer-host datasets. During each global step, we pull
    # `replicas_per_consumer_host` batches from each of these datasets.
    host_datasets = []
    for host_index in range(num_consumer_hosts):
      per_replica_datasets = []
      for i in range(replicas_per_consumer_host):
        consumer_index = host_index * replicas_per_consumer_host + i
        per_replica_datasets.append(
            self.make_distributed_dataset(
                ds,
                cluster,
                job_name="test",
                consumer_index=consumer_index,
                num_consumers=num_consumers))
      host_dataset = dataset_ops.Dataset.from_tensor_slices(
          per_replica_datasets)
      host_dataset = host_dataset.interleave(
          lambda x: x,
          cycle_length=len(per_replica_datasets),
          num_parallel_calls=len(per_replica_datasets),
          deterministic=True)
      host_datasets.append(host_dataset)

    # Use parallel interleave to read from host datasets in parallel.
    ds = dataset_ops.Dataset.from_tensor_slices(host_datasets)
    ds = ds.interleave(
        lambda x: x,
        block_length=replicas_per_consumer_host,
        cycle_length=len(host_datasets),
        num_parallel_calls=len(host_datasets),
        deterministic=True)

    num_rounds = 4
    get_next = self.getNext(ds)
    results = []
    for i in range(num_rounds * num_consumers):
      results.append(self.evaluate(get_next()))

    def get_bucket(elem):
      bucket_ind = 0
      while bucket_ind < len(
          bucket_boundaries) and elem >= bucket_boundaries[bucket_ind]:
        bucket_ind += 1
      return bucket_ind

    # Check that the batches for each step contain elements from the same
    # bucket.
    for i in range(0, len(results), num_consumers):
      batches = results[num_consumers * i:num_consumers * (i + 1)]
      bucket_inds = [get_bucket(batch[0]) for batch in batches]
      for bucket_ind in bucket_inds[1:]:
        self.assertEqual(
            bucket_inds[0], bucket_ind,
            "Batches: {}, Buckets: {}".format(batches, bucket_inds))

  @combinations.generate(test_base.v1_only_combinations())
  def testFiniteV1(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = self.make_distributed_dataset(
        ds, cluster, job_name="test", consumer_index=0, num_consumers=1)

    with self.assertRaisesRegex(
        errors.FailedPreconditionError, "Encountered end of sequence on a "
        "round-robin read iterator"):
      self.getDatasetOutput(ds)

  @combinations.generate(test_base.v2_only_combinations())
  def testFiniteV2(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    num_elements = 100
    ds = dataset_ops.Dataset.range(num_elements)
    ds = self.make_distributed_dataset(
        ds, cluster, job_name="test", consumer_index=0, num_consumers=1)

    with self.assertRaisesRegex(
        errors.FailedPreconditionError, "Round robin reads "
        "require that the input dataset has infinite "
        "cardinality, but the dataset has cardinality " + str(num_elements)):
      self.getDatasetOutput(ds)
Beispiel #5
0
class MultiDeviceIteratorTest(test_base.DatasetTestBase,
                              parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.v1_only_combinations(),
                           combinations.combine(num_inits=[0, 1, 42])))
    def testInitOnly(self, num_inits):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            for _ in range(num_inits):
                self.evaluate(multi_device_iterator.initializer)

    @combinations.generate(test_base.v1_only_combinations())
    def testBasic(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)

    @combinations.generate(test_base.v1_only_combinations())
    def testOneOnSameDevice(self):
        with ops.device("/cpu:0"):
            dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:0", "/cpu:1"])

        config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)

    @combinations.generate(test_base.v1_only_combinations())
    def testRepeatDevices(self):
        with ops.device("/cpu:0"):
            dataset = dataset_ops.Dataset.range(20)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 20, 4):
                elements = multi_device_iterator.get_next()
                elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
                self.assertEqual(i + 2, self.evaluate(elem_on_3))
                self.assertEqual(i + 3, self.evaluate(elem_on_4))
            with self.assertRaises(errors.OutOfRangeError):
                elements = multi_device_iterator.get_next()
                elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)
                self.evaluate(elem_on_3)
                self.evaluate(elem_on_4)

    @combinations.generate(test_base.v1_only_combinations())
    def testNotFullyDivisible(self):
        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 8, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            elem_on_1 = multi_device_iterator.get_next("/cpu:1")
            self.assertEqual(8, self.evaluate(elem_on_1))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)

    @combinations.generate(test_base.v1_only_combinations())
    def testGetNextAsOptional(self):
        if context.executing_eagerly():
            return

        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])
        elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
        elem_on_1_has_value_t = elem_on_1.has_value()
        elem_on_1_t = elem_on_1.get_value()
        elem_on_2_has_value_t = elem_on_2.has_value()
        elem_on_2_t = elem_on_2.get_value()

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 8, 2):
                elem_on_1_has_value, elem_on_1_value = sess.run(
                    [elem_on_1_has_value_t, elem_on_1_t])
                self.assertTrue(elem_on_1_has_value)
                self.assertEqual(i, elem_on_1_value)
                elem_on_2_has_value, elem_on_2_value = sess.run(
                    [elem_on_2_has_value_t, elem_on_2_t])
                self.assertTrue(elem_on_2_has_value)
                self.assertEqual(i + 1, elem_on_2_value)
            elem_on_1_has_value, elem_on_1_value = sess.run(
                [elem_on_1_has_value_t, elem_on_1_t])
            self.assertTrue(elem_on_1_has_value)
            self.assertEqual(8, elem_on_1_value)
            self.assertFalse(self.evaluate(elem_on_1_has_value_t))
            self.assertFalse(self.evaluate(elem_on_2_has_value_t))
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(elem_on_1_t)
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(elem_on_2_t)

    @combinations.generate(test_base.v1_only_combinations())
    def testUneven(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1 = multi_device_iterator.get_next("/cpu:1")
                self.assertEqual(i, self.evaluate(elem_on_1))
            for i in range(0, 10, 2):
                elem_on_2 = multi_device_iterator.get_next("/cpu:2")
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)

    @combinations.generate(test_base.v1_only_combinations())
    def testMultipleInitializationsGraph(self):
        if context.executing_eagerly():
            return

        with ops.device("/cpu:0"):
            epoch = array_ops.placeholder(dtypes.int64, shape=[])
            dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
            dataset2 = dataset_ops.Dataset.range(1000)
            dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()
        init_op = multi_device_iterator.initializer

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        pool = config.session_inter_op_thread_pool.add()
        pool.num_threads = 2
        with session.Session(config=config) as sess:
            for i in range(1000):
                sess.run(init_op, feed_dict={epoch: i})
                self.assertEqual([(i, 0), (i, 1)],
                                 self.evaluate([elem_on_1, elem_on_2]))

    @combinations.generate(test_base.v1_only_combinations())
    def testMultipleInitializationsEager(self):
        if not context.executing_eagerly():
            return

        with ops.device("/cpu:0"):
            dataset1 = dataset_ops.Dataset.range(1000)
            dataset2 = dataset_ops.Dataset.range(1000)
            dataset = dataset_ops.Dataset.zip((dataset1, dataset2))

        for _ in range(5):
            multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
                dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual([(0, 0), (1, 1)],
                             self.evaluate([elem_on_1, elem_on_2]))

    @combinations.generate(test_base.v1_only_combinations())
    def testBasicGpu(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/gpu:0"])

        config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)

    @combinations.generate(test_base.v1_only_combinations())
    def testUnevenGpu(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)

        config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1 = multi_device_iterator.get_next("/cpu:1")
                self.assertEqual(i, self.evaluate(elem_on_1))
            for i in range(0, 10, 2):
                elem_on_2 = multi_device_iterator.get_next("/gpu:0")
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)

    @combinations.generate(test_base.v1_only_combinations())
    def testGetNextAsOptionalGpu(self):
        if not test_util.is_gpu_available() or context.executing_eagerly():
            self.skipTest("No GPU available")

        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/gpu:0"])
        elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
        elem_on_1_has_value_t = elem_on_1.has_value()
        elem_on_1_t = elem_on_1.get_value()
        elem_on_2_has_value_t = elem_on_2.has_value()
        elem_on_2_t = elem_on_2.get_value()

        config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
        with self.test_session(config=config) as sess:
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 8, 2):
                elem_on_1_has_value, elem_on_1_value = sess.run(
                    [elem_on_1_has_value_t, elem_on_1_t])
                self.assertTrue(elem_on_1_has_value)
                self.assertEqual(i, elem_on_1_value)
                elem_on_2_has_value, elem_on_2_value = sess.run(
                    [elem_on_2_has_value_t, elem_on_2_t])
                self.assertTrue(elem_on_2_has_value)
                self.assertEqual(i + 1, elem_on_2_value)
            elem_on_1_has_value, elem_on_1_value = sess.run(
                [elem_on_1_has_value_t, elem_on_1_t])
            self.assertTrue(elem_on_1_has_value)
            self.assertEqual(8, elem_on_1_value)
            self.assertFalse(self.evaluate(elem_on_1_has_value_t))
            self.assertFalse(self.evaluate(elem_on_2_has_value_t))
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(elem_on_1_t)
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(elem_on_2_t)

    @combinations.generate(test_base.v1_only_combinations())
    def testOptimization(self):
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
        dataset = dataset.skip(0)  # this should be optimized away
        dataset = dataset.cache()

        options = dataset_ops.Options()
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)

        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)
Beispiel #6
0
class CheckpointInputPipelineHookTest(test.TestCase, parameterized.TestCase):
    @staticmethod
    def _model_fn(features, labels, mode, config):
        del labels
        del mode
        del config
        global_step = training_util.get_or_create_global_step()
        update_global_step_op = global_step.assign_add(1)
        latest_feature = variables.VariableV1(0,
                                              name='latest_feature',
                                              dtype=dtypes.int64)
        store_latest_feature_op = latest_feature.assign(features)
        ops.add_to_collection('my_vars', global_step)
        ops.add_to_collection('my_vars', latest_feature)
        return model_fn.EstimatorSpec(mode='train',
                                      train_op=control_flow_ops.group([
                                          update_global_step_op,
                                          store_latest_feature_op
                                      ]),
                                      loss=constant_op.constant(2.0))

    def _read_vars(self, model_dir):
        """Returns (global_step, latest_feature)."""
        with ops.Graph().as_default() as g:
            ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
            meta_filename = ckpt_path + '.meta'
            saver_lib.import_meta_graph(meta_filename)
            saver = saver_lib.Saver()
            with self.session(graph=g) as sess:
                saver.restore(sess, ckpt_path)
                return sess.run(ops.get_collection('my_vars'))

    def _build_iterator_saver_hook(self, est):
        return iterator_ops.CheckpointInputPipelineHook(est)

    @combinations.generate(test_base.v1_only_combinations())
    def testReturnDatasetFromInputFn(self):
        def _input_fn():
            return dataset_ops.Dataset.range(10)

        est = estimator.Estimator(model_fn=self._model_fn)

        est.train(_input_fn,
                  steps=2,
                  hooks=[self._build_iterator_saver_hook(est)])
        self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
        est.train(_input_fn,
                  steps=2,
                  hooks=[self._build_iterator_saver_hook(est)])
        self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))

    @combinations.generate(test_base.v1_only_combinations())
    def testBuildIteratorInInputFn(self):
        def _input_fn():
            ds = dataset_ops.Dataset.range(10)
            iterator = ds.make_one_shot_iterator()
            return iterator.get_next()

        est = estimator.Estimator(model_fn=self._model_fn)

        est.train(_input_fn,
                  steps=2,
                  hooks=[self._build_iterator_saver_hook(est)])
        self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
        est.train(_input_fn,
                  steps=2,
                  hooks=[self._build_iterator_saver_hook(est)])
        self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))

    @combinations.generate(test_base.v1_only_combinations())
    def testDoNotRestore(self):
        def _input_fn():
            return dataset_ops.Dataset.range(10)

        est = estimator.Estimator(model_fn=self._model_fn)

        est.train(_input_fn,
                  steps=2,
                  hooks=[self._build_iterator_saver_hook(est)])
        self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
        est.train(_input_fn,
                  steps=2,
                  hooks=[self._build_iterator_saver_hook(est)])
        self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
        # Hook not provided, input pipeline was not restored.
        est.train(_input_fn, steps=2)
        self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))

    @combinations.generate(test_base.v1_only_combinations())
    def testRaiseErrorIfNoIterator(self):
        def _input_fn():
            return constant_op.constant(1, dtype=dtypes.int64)

        est = estimator.Estimator(model_fn=self._model_fn)

        with self.assertRaises(ValueError):
            est.train(_input_fn,
                      steps=2,
                      hooks=[self._build_iterator_saver_hook(est)])