def _weights_type_combinations():
  return combinations.combine(weights_type=["list", "tensor", "dataset"])
Example #2
0
def default_test_combinations():
    """Returns the default test combinations for tf.data tests."""
    return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"])
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
    def create_cluster(self, num_workers):
        """Creates a cluster of tf.data service servers.

    Args:
      num_workers: The number of workers in the cluster.

    Returns:
      A target for connecting to the service, e.g.
      "grpc+local://localhost:2000".
    """
        self._master = server_lib.MasterServer(PROTOCOL)
        master_address = self._master.target[len(PROTOCOL + "://"):]

        self._servers = []
        for _ in range(num_workers):
            self._servers.append(
                server_lib.WorkerServer(PROTOCOL,
                                        master_address=master_address))

        return self._master.target

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleEpochs(self):
        service = self.create_cluster(1)
        ds = dataset_ops.Dataset.range(3)
        ds = ds.apply(data_service_ops.distribute(service))
        for _ in range(10):
            token = data_service_ops.create_job(
                ds, processing_mode="parallel_epochs")
            it = data_service_ops.create_iterator(ds, token)
            self.assertEqual(list(range(3)), [t.numpy() for t in it])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeBasic(self):
        num_elements = 10
        service = self.create_cluster(1)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(data_service_ops.distribute(service))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        it = data_service_ops.create_iterator(ds, token)
        results = [t.numpy() for t in it]
        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testConcurrentEpoch(self):
        num_elements = 10
        num_datasets = 3
        service = self.create_cluster(1)
        iterators = []
        results = []
        for _ in range(num_datasets):
            ds = dataset_ops.Dataset.range(num_elements)
            ds = ds.apply(data_service_ops.distribute(service))
            token = data_service_ops.create_job(
                ds, processing_mode="parallel_epochs")
            it = data_service_ops.create_iterator(ds, token)
            iterators.append(it)
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = next(iterators[dataset_ind]).numpy()
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedEpoch(self):
        num_elements = 10
        num_iterators = 3
        service = self.create_cluster(1)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(data_service_ops.distribute(service))
        result = []
        iterators = []
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        for _ in range(num_iterators):
            iterators.append(data_service_ops.create_iterator(ds, token))

        # Alternate reading between the iterators.
        for _ in range(2):
            for it in iterators:
                result.append(next(it).numpy())

        # Drain the rest of the elements.
        for it in iterators:
            for elem in it:
                result.append(elem.numpy())

        self.assertCountEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.eager_only_combinations())
    def testMultiWorker(self):
        num_workers = 3
        num_elements = 10
        service = self.create_cluster(num_workers)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.apply(data_service_ops.distribute(service))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        results = [elem.numpy() for elem in iterator]
        self.assertCountEqual(num_workers * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        num_elements = 10
        service = self.create_cluster(num_workers)

        @def_function.function
        def f():
            ds = dataset_ops.Dataset.range(num_elements)
            ds = ds.apply(data_service_ops.distribute(service))
            token = data_service_ops.create_job(
                ds, processing_mode="parallel_epochs")
            it = data_service_ops.create_iterator(ds, token)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in it:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        service = self.create_cluster(3)
        ds = ds.apply(data_service_ops.distribute(service))
        token = data_service_ops.create_job(ds,
                                            processing_mode="parallel_epochs")
        iterator = data_service_ops.create_iterator(ds, token)
        next(iterator)

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(external_state_policy=[
                distribute_options.ExternalStatePolicy.IGNORE,
                distribute_options.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.eager_only_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.eager_only_combinations())
    def testNoDistributeCalls(self):
        ds = dataset_ops.Dataset.range(1)
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Dataset does not contain any distribute() transformations"):
            data_service_ops.create_job(ds, processing_mode="parallel_epochs")

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleDistributeCalls(self):
        service = self.create_cluster(1)
        ds1 = dataset_ops.Dataset.range(1)
        ds1 = ds1.apply(data_service_ops.distribute(service))
        ds2 = dataset_ops.Dataset.range(1)
        ds2 = ds2.apply(data_service_ops.distribute(service))
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "Datasets containing multiple calls to .distribute(...) "
                "are not supported"):
            data_service_ops.create_job(ds, processing_mode="parallel_epochs")

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeFromInterleave(self):
        service = self.create_cluster(1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(_):
            ds = dataset_ops.Dataset.range(2)
            ds = ds.apply(data_service_ops.distribute(service))
            return ds

        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                r"The `.distribute\(...\)` dataset "
                "transformation is not supported within tf.data functions"):
            ds = ds.interleave(interleave_fn, cycle_length=2)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "service must be a string"):
            ds = ds.apply(data_service_ops.distribute(service=1))

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "service must not be empty"):
            ds = ds.apply(data_service_ops.distribute(service=""))
Example #4
0
class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
    def _testFromGenerator(self, generator, elem_sequence, num_repeats,
                           requires_initialization):
        dataset = dataset_ops.Dataset.from_generator(
            generator,
            output_types=dtypes.int64).repeat(num_repeats).prefetch(5)
        self.assertDatasetProduces(
            dataset,
            elem_sequence * num_repeats,
            requires_initialization=requires_initialization,
            num_test_iterations=2)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_repeats=[1, 5],
                                 requires_initialization=[True, False])))
    def testFromGeneratorUsingFn(self, num_repeats, requires_initialization):
        def generator():
            for i in range(1, 100):
                yield [i] * i

        elem_sequence = list(generator())
        self._testFromGenerator(
            generator,
            elem_sequence,
            num_repeats=num_repeats,
            requires_initialization=requires_initialization)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_repeats=[1, 5],
                                 requires_initialization=[True, False])))
    def testFromGeneratorUsingList(self, num_repeats, requires_initialization):
        generator = lambda: [[i] * i for i in range(1, 100)]
        elem_sequence = list(generator())
        self._testFromGenerator(
            generator,
            elem_sequence,
            num_repeats=num_repeats,
            requires_initialization=requires_initialization)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_repeats=[1, 5],
                                 requires_initialization=[True, False])))
    def testFromGeneratorUsingNdarray(self, num_repeats,
                                      requires_initialization):
        generator = lambda: np.arange(100, dtype=np.int64)
        elem_sequence = list(generator())
        self._testFromGenerator(
            generator,
            elem_sequence,
            num_repeats=num_repeats,
            requires_initialization=requires_initialization)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_repeats=[1, 5],
                                 requires_initialization=[True, False])))
    def testFromGeneratorUsingGeneratorExpression(self, num_repeats,
                                                  requires_initialization):
        # NOTE(mrry): Generator *expressions* are not repeatable (or in general
        # reusable), because they eagerly evaluate the `for` expression as
        # `iter(range(1, 100))` and discard the means of reconstructing
        # `range(1, 100)`. Wrapping the generator expression in a `lambda` makes
        # it repeatable.
        generator = lambda: ([i] * i for i in range(1, 100))
        elem_sequence = list(generator())
        self._testFromGenerator(
            generator,
            elem_sequence,
            num_repeats=num_repeats,
            requires_initialization=requires_initialization)

    @combinations.generate(test_base.default_test_combinations())
    def testFromMultipleConcurrentGenerators(self):
        num_inner_repeats = 5
        num_outer_repeats = 100

        def generator():
            for i in range(1, 10):
                yield ([i] * i, [i, i**2, i**3])

        input_list = list(generator())

        # The interleave transformation is essentially a flat map that
        # draws from multiple input datasets concurrently (in a cyclic
        # fashion). By placing `Dataset.from_generator()` inside an
        # interleave, we test its behavior when multiple iterators are
        # active at the same time; by additionally prefetching inside the
        # interleave, we create the possibility of parallel (modulo GIL)
        # invocations to several iterators created by the same dataset.
        def interleave_fn(_):
            return (dataset_ops.Dataset.from_generator(
                generator,
                output_types=(dtypes.int64, dtypes.int64),
                output_shapes=([None],
                               [3])).repeat(num_inner_repeats).prefetch(5))

        dataset = dataset_ops.Dataset.range(num_outer_repeats).interleave(
            interleave_fn, cycle_length=10, block_length=len(input_list))
        get_next = self.getNext(dataset)
        for _ in range(num_inner_repeats * num_outer_repeats):
            for elem in input_list:
                val0, val1 = self.evaluate(get_next())
                self.assertAllEqual(elem[0], val0)
                self.assertAllEqual(elem[1], val1)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    # TODO(b/67868766): Reenable this when the source of flakiness is discovered.
    def _testFromGeneratorsRunningInParallel(self):
        num_parallel_iterators = 3

        # Define shared state that multiple iterator instances will access to
        # demonstrate their concurrent activity.
        lock = threading.Lock()
        condition = threading.Condition(lock)
        next_ticket = [0]  # GUARDED_BY(lock)

        def generator():
            # NOTE(mrry): We yield one element before the barrier, because
            # the current implementation of `Dataset.interleave()` must
            # fetch one element from each incoming dataset to start the
            # prefetching.
            yield 0

            # Define a barrier that `num_parallel_iterators` iterators must enter
            # before any can proceed. Demonstrates that multiple iterators may be
            # active at the same time.
            condition.acquire()
            ticket = next_ticket[0]
            next_ticket[0] += 1
            if ticket == num_parallel_iterators - 1:
                # The last iterator to join the barrier notifies the others.
                condition.notify_all()
            else:
                # Wait until the last iterator enters the barrier.
                while next_ticket[0] < num_parallel_iterators:
                    condition.wait()
            condition.release()

            yield 1

        # As in `testFromMultipleConcurrentGenerators()`, we use a combination of
        # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple
        # iterators to be active concurrently.
        def interleave_fn(_):
            return dataset_ops.Dataset.from_generator(
                generator, output_types=dtypes.int64,
                output_shapes=[]).prefetch(2)

        dataset = dataset_ops.Dataset.range(num_parallel_iterators).interleave(
            interleave_fn, cycle_length=num_parallel_iterators, block_length=1)
        get_next = self.getNext(dataset)

        for elem in [0, 1]:
            for _ in range(num_parallel_iterators):
                self.assertAllEqual(elem, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorImplicitConversion(self):
        def generator():
            yield [1]
            yield [2]
            yield [3]

        for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]:
            dataset = dataset_ops.Dataset.from_generator(generator,
                                                         output_types=dtype,
                                                         output_shapes=[1])
            get_next = self.getNext(dataset)

            for expected in [[1], [2], [3]]:
                next_val = self.evaluate(get_next())
                self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
                self.assertAllEqual(expected, next_val)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorString(self):
        def generator():
            yield "foo"
            yield b"bar"
            yield u"baz"

        dataset = dataset_ops.Dataset.from_generator(
            generator, output_types=dtypes.string, output_shapes=[])
        self.assertDatasetProduces(dataset,
                                   expected_output=[b"foo", b"bar", b"baz"])

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorTypeError(self):
        def generator():
            yield np.array([1, 2, 3], dtype=np.int64)
            yield np.array([4, 5, 6], dtype=np.int64)
            yield "ERROR"
            yield np.array([7, 8, 9], dtype=np.int64)

        dataset = dataset_ops.Dataset.from_generator(generator,
                                                     output_types=dtypes.int64,
                                                     output_shapes=[3])

        get_next = self.getNext(dataset)

        self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
        self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
        with self.assertRaisesOpError("The expected type was int64"):
            self.evaluate(get_next())
        self.assertAllEqual([7, 8, 9], self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorShapeError(self):
        def generator():
            yield np.array([1, 2, 3], dtype=np.int64)
            yield np.array([4, 5, 6], dtype=np.int64)
            yield np.array([7, 8, 9, 10], dtype=np.int64)
            yield np.array([11, 12, 13], dtype=np.int64)

        dataset = dataset_ops.Dataset.from_generator(generator,
                                                     output_types=dtypes.int64,
                                                     output_shapes=[3])
        get_next = self.getNext(dataset)

        self.assertAllEqual([1, 2, 3], self.evaluate(get_next()))
        self.assertAllEqual([4, 5, 6], self.evaluate(get_next()))
        with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
            self.evaluate(get_next())
        self.assertAllEqual([11, 12, 13], self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorStructureError(self):
        def generator():
            yield 1, 2
            yield 3, 4
            yield 5
            yield 6, 7, 8
            yield 9, 10

        dataset = dataset_ops.Dataset.from_generator(
            generator, output_types=(dtypes.int64, dtypes.int64))
        get_next = self.getNext(dataset)

        self.assertEqual((1, 2), self.evaluate(get_next()))
        self.assertEqual((3, 4), self.evaluate(get_next()))
        with self.assertRaisesOpError(
                r"The expected structure was \(tf\.int64, tf\.int64\)"):
            self.evaluate(get_next())
        with self.assertRaisesOpError(
                r"The expected structure was \(tf\.int64, tf\.int64\)"):
            self.evaluate(get_next())
        self.assertEqual((9, 10), self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorHeterogeneous(self):
        def generator():
            yield 1
            yield [2, 3]

        dataset = dataset_ops.Dataset.from_generator(generator,
                                                     output_types=dtypes.int64)
        self.assertDatasetProduces(dataset, expected_output=[1, [2, 3]])

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorStopShort(self):
        def generator():
            yield 0
            yield 1
            yield 2

        dataset = dataset_ops.Dataset.from_generator(generator,
                                                     output_types=dtypes.int64)
        get_next = self.getNext(dataset)
        self.assertAllEqual(0, self.evaluate(get_next()))
        self.assertAllEqual(1, self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorDestructorCalled(self):
        # Use an `Event` to signal that the generator has been deleted.
        event = threading.Event()

        class GeneratorWrapper(object):
            def __iter__(self):
                return self

            def next(self):
                return self.__next__()

            def __next__(self):
                return 42

            def __del__(self):
                event.set()

        dataset = dataset_ops.Dataset.from_generator(
            GeneratorWrapper, output_types=dtypes.int64).take(2)
        get_next = self.getNext(dataset)

        self.assertAllEqual(42, self.evaluate(get_next()))
        self.assertAllEqual(42, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        # Test that `GeneratorWrapper` object is destroyed when the
        # iterator terminates (and the generator iterator is deleted).
        self.assertTrue(event.is_set())

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorWithArgs(self):
        def flat_map_fn(elem):
            def generator_with_arg(n):
                for _ in range(n):
                    yield np.array(n, dtype=np.int64)

            return dataset_ops.Dataset.from_generator(
                generator_with_arg,
                output_types=dtypes.int64,
                output_shapes=(),
                args=(elem, ))

        dataset = dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
        self.assertDatasetProduces(
            dataset, expected_output=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4])

    @combinations.generate(test_base.default_test_combinations())
    def testFromGeneratorWithTwoArgs(self):
        def flat_map_fn(elem, message):
            def generator_with_arg(n, msg):
                for i in range(n):
                    yield i, msg

            return dataset_ops.Dataset.from_generator(
                generator_with_arg,
                output_types=(dtypes.int64, dtypes.string),
                output_shapes=((), ()),
                args=(elem, message))

        dataset = dataset_ops.Dataset.zip(
            (dataset_ops.Dataset.range(5),
             dataset_ops.Dataset.from_tensors("Hi!").repeat(None)
             )).flat_map(flat_map_fn)

        self.assertDatasetProduces(dataset,
                                   expected_output=[(0, b"Hi!"), (0, b"Hi!"),
                                                    (1, b"Hi!"), (0, b"Hi!"),
                                                    (1, b"Hi!"), (2, b"Hi!"),
                                                    (0, b"Hi!"), (1, b"Hi!"),
                                                    (2, b"Hi!"), (3, b"Hi!")])

    @combinations.generate(test_base.default_test_combinations())
    def testGeneratorDatasetFinalizeFunctionCalled(self):
        # NOTE(mrry): This test tests the internal `_GeneratorDataset`,
        # which affords more control over what the finalize function can do than
        # the `Dataset.from_generator()` wrapper.

        # Use an `Event` to signal that the generator has been deleted.
        event = threading.Event()

        def finalize_fn(_):
            def finalize_py_func():
                event.set()
                return 0

            return script_ops.py_func(finalize_py_func, [], [dtypes.int64],
                                      stateful=True)

        dummy = constant_op.constant(37)
        dataset = dataset_ops._GeneratorDataset(dummy, lambda x: x,
                                                lambda x: x,
                                                finalize_fn).take(2)
        get_next = self.getNext(dataset)

        self.assertAllEqual(37, self.evaluate(get_next()))
        self.assertAllEqual(37, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testSharedName(self):
        def generator():
            for _ in range(10):
                yield [20]

        dataset = dataset_ops.Dataset.from_generator(
            generator, output_types=(dtypes.int64))
        get_next = self.getNext(dataset,
                                requires_initialization=True,
                                shared_name="shared_dataset")

        self.assertAllEqual([20], self.evaluate(get_next()))
Example #5
0
class RangeTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testStop(self, output_type):
        stop = 5
        dataset = dataset_ops.Dataset.range(stop, output_type=output_type)
        expected_output = np.arange(stop, dtype=output_type.as_numpy_dtype)
        self.assertDatasetProduces(dataset, expected_output=expected_output)
        self.assertEqual(output_type,
                         dataset_ops.get_legacy_output_types(dataset))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testStartStop(self, output_type):
        start, stop = 2, 5
        dataset = dataset_ops.Dataset.range(start,
                                            stop,
                                            output_type=output_type)
        expected_output = np.arange(start,
                                    stop,
                                    dtype=output_type.as_numpy_dtype)
        self.assertDatasetProduces(dataset, expected_output=expected_output)
        self.assertEqual(output_type,
                         dataset_ops.get_legacy_output_types(dataset))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testStartStopStep(self, output_type):
        start, stop, step = 2, 10, 2
        dataset = dataset_ops.Dataset.range(start,
                                            stop,
                                            step,
                                            output_type=output_type)
        expected_output = np.arange(start,
                                    stop,
                                    step,
                                    dtype=output_type.as_numpy_dtype)
        self.assertDatasetProduces(dataset, expected_output=expected_output)
        self.assertEqual(output_type,
                         dataset_ops.get_legacy_output_types(dataset))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testZeroStep(self, output_type):
        start, stop, step = 2, 10, 0
        with self.assertRaises(errors.InvalidArgumentError):
            dataset = dataset_ops.Dataset.range(start,
                                                stop,
                                                step,
                                                output_type=output_type)
            self.evaluate(dataset._variant_tensor)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testNegativeStep(self, output_type):
        start, stop, step = 2, 10, -1
        dataset = dataset_ops.Dataset.range(start,
                                            stop,
                                            step,
                                            output_type=output_type)
        expected_output = np.arange(start,
                                    stop,
                                    step,
                                    dtype=output_type.as_numpy_dtype)
        self.assertDatasetProduces(dataset, expected_output=expected_output)
        self.assertEqual(output_type,
                         dataset_ops.get_legacy_output_types(dataset))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testStopLessThanStart(self, output_type):
        start, stop = 10, 2
        dataset = dataset_ops.Dataset.range(start,
                                            stop,
                                            output_type=output_type)
        expected_output = np.arange(start,
                                    stop,
                                    dtype=output_type.as_numpy_dtype)
        self.assertDatasetProduces(dataset, expected_output=expected_output)
        self.assertEqual(output_type,
                         dataset_ops.get_legacy_output_types(dataset))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testStopLessThanStartWithPositiveStep(self, output_type):
        start, stop, step = 10, 2, 2
        dataset = dataset_ops.Dataset.range(start,
                                            stop,
                                            step,
                                            output_type=output_type)
        expected_output = np.arange(start,
                                    stop,
                                    step,
                                    dtype=output_type.as_numpy_dtype)
        self.assertDatasetProduces(dataset, expected_output=expected_output)
        self.assertEqual(output_type,
                         dataset_ops.get_legacy_output_types(dataset))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(output_type=[
                dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
            ])))
    def testStopLessThanStartWithNegativeStep(self, output_type):
        start, stop, step = 10, 2, -1
        dataset = dataset_ops.Dataset.range(start,
                                            stop,
                                            step,
                                            output_type=output_type)
        expected_output = np.arange(start,
                                    stop,
                                    step,
                                    dtype=output_type.as_numpy_dtype)
        self.assertDatasetProduces(dataset, expected_output=expected_output)
        self.assertEqual(output_type,
                         dataset_ops.get_legacy_output_types(dataset))

    @combinations.generate(test_base.default_test_combinations())
    def testName(self):
        dataset = dataset_ops.Dataset.range(5, name="range")
        self.assertDatasetProduces(dataset, list(range(5)))
Example #6
0
class DynamicShardingTest(data_service_test_base.TestBase,
                          parameterized.TestCase):
    def _make_dynamic_sharding_dataset(self, dataset, cluster):
        return self.make_distributed_dataset(
            dataset,
            cluster,
            processing_mode=data_service_ops.ShardingPolicy.DYNAMIC,
            job_name="job_name")

    @combinations.generate(test_base.default_test_combinations())
    def testBasic(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        self.assertDatasetProduces(ds,
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testTensorSlices(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        vals = [5, 1, 2, 4]
        ds = dataset_ops.Dataset.from_tensor_slices(vals)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        self.assertDatasetProduces(ds, vals, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testInterleave(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testParallelInterleave(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.interleave(
            lambda x: dataset_ops.Dataset.from_tensor_slices([x]),
            num_parallel_calls=dataset_ops.AUTOTUNE)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testFlatMap(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)
        ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x]))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        self.assertDatasetProduces(ds, elements, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testGroupByWindow(self):
        # Verify that split providers are not propagated into iterators created for
        # the reduce datasets created by the reduce_fn in group_by_window.
        cluster = data_service_test_base.TestCluster(num_workers=2)
        elements = [1, 5, 0]
        ds = dataset_ops.Dataset.from_tensor_slices(elements)

        def reduce_fn(_, window):
            return dataset_ops.Dataset.zip(
                (window, dataset_ops.Dataset.range(100)))

        ds = ds.group_by_window(lambda x: 0, reduce_fn, window_size=3)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        # This will fail if the tensor_slices split provider ispropagated into the
        # `reduce_fn`, since the `zip` requires either 0 or 2 split providers.
        self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatBeforeDistribution(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatAfterDistribution(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        ds = ds.repeat(num_repeats)
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testForeverRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        num_elements = 20
        elements_to_read = 1000
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        get_next = self.getNext(ds)
        results = {}
        for _ in range(elements_to_read):
            val = self.evaluate(get_next())
            if val not in results:
                results[val] = 0
            results[val] += 1
        for i in range(num_elements):
            self.assertGreater(results[i], elements_to_read / num_elements / 2)

    @combinations.generate(test_base.default_test_combinations())
    def testForeverRepeatFewElements(self):
        num_workers = 5
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        # Less than the number of workers, so that some workers get zero elements on
        # the first repetition.
        num_elements = 1
        ds = dataset_ops.Dataset.range(num_elements).repeat()
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        get_next = self.getNext(ds)
        for _ in range(20):
            self.assertEqual(self.evaluate(get_next()), 0)

        # Stop all but one worker and check that we can still read.
        for i in range(num_workers - 1):
            cluster.workers[i].stop()
        for _ in range(20):
            self.assertEqual(self.evaluate(get_next()), 0)

    @combinations.generate(test_base.default_test_combinations())
    def testShuffleAndRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        num_repeats = 5
        num_elements = 20
        ds = dataset_ops.Dataset.range(num_elements).shuffle(
            num_elements).repeat(num_repeats)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        self.assertDatasetProduces(ds,
                                   num_repeats * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testZip(self):
        num_elements = 10
        cluster = data_service_test_base.TestCluster(num_workers=1)
        a = dataset_ops.Dataset.range(num_elements)

        ds = dataset_ops.Dataset.zip((a, a))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        self.assertDatasetProduces(
            ds, list(zip(range(num_elements), range(num_elements))))

    @combinations.generate(test_base.default_test_combinations())
    def testNestedZip(self):
        num_elements = 10
        cluster = data_service_test_base.TestCluster(num_workers=1)
        a = dataset_ops.Dataset.range(num_elements)

        ds = dataset_ops.Dataset.zip((a, a))
        ds = dataset_ops.Dataset.zip((a, a, ds, a))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        b = list(range(10))
        self.assertDatasetProduces(ds, list(zip(b, b, zip(b, b), b)))

    @combinations.generate(test_base.default_test_combinations())
    def testImbalancedZip(self):
        smaller_num_elements = 200
        larger_num_elements = 1000

        cluster = data_service_test_base.TestCluster(num_workers=1)
        a = dataset_ops.Dataset.range(smaller_num_elements)
        b = dataset_ops.Dataset.range(larger_num_elements)

        ds = dataset_ops.Dataset.zip((a, b))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        self.assertDatasetProduces(
            ds,
            list(zip(range(smaller_num_elements),
                     range(smaller_num_elements))))

    @combinations.generate(test_base.default_test_combinations())
    def testImbalancedZipMultiWorker(self):
        smaller_num_elements = 200
        larger_num_elements = 1000
        cluster = data_service_test_base.TestCluster(num_workers=3)
        a = dataset_ops.Dataset.range(smaller_num_elements)
        b = dataset_ops.Dataset.range(larger_num_elements)

        ds = dataset_ops.Dataset.zip((a, b))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        # Cannot assert specific elements because the range datasets are split
        # nondeterministically and may not line up.
        self.assertLen(self.getDatasetOutput(ds), smaller_num_elements)

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentRates(self):
        cluster = data_service_test_base.TestCluster(num_workers=3)
        a = dataset_ops.Dataset.range(100)
        b = dataset_ops.Dataset.range(100).filter(
            lambda x: math_ops.equal(x % 10, 0))

        ds = dataset_ops.Dataset.zip((a, b))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        self.assertLen(self.getDatasetOutput(ds), 10)

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentRepeats(self):
        cluster = data_service_test_base.TestCluster(num_workers=3)
        a = dataset_ops.Dataset.range(50)
        b = dataset_ops.Dataset.range(10).repeat(10)

        ds = dataset_ops.Dataset.zip((a, b))
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        self.assertLen(self.getDatasetOutput(ds), 50)

    @combinations.generate(test_base.default_test_combinations())
    def testSampleFromDatasets(self):
        cluster = data_service_test_base.TestCluster(num_workers=3)
        num_samples = 200
        weights = [.6, .3, .1]
        classes = len(weights)

        # Create a dataset that samples each integer in `[0, num_datasets)`
        # with probability given by `weights[i]`.
        ds = dataset_ops.Dataset.sample_from_datasets([
            dataset_ops.Dataset.from_tensors(i).repeat()
            for i in range(classes)
        ], weights)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        ds = ds.take(num_samples)

        freqs = np.zeros([classes])
        for v in self.getDatasetOutput(ds):
            freqs[v] += 1

        self.assertGreater(freqs[0], freqs[1])
        self.assertGreater(freqs[1], freqs[2])

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(num_workers=[1, 3])))
    def testChooseFromDatasets(self, num_workers):
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        words = [b"foo", b"bar", b"baz"]
        datasets = [
            dataset_ops.Dataset.from_tensors(w).repeat() for w in words
        ]
        choice_array = np.random.randint(3, size=(15, ), dtype=np.int64)
        choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
        ds = dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        expected = [words[i] for i in choice_array]
        if compat.forward_compatible(2022, 6, 6):
            expected *= num_workers

        assert_items_equal = (num_workers > 1)
        self.assertDatasetProduces(ds,
                                   expected,
                                   assert_items_equal=assert_items_equal)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testEnumerateReplicateOnSplit(self):
        if not compat.forward_compatible(2022, 6, 6):
            self.skipTest("Replicate on split is not yet available.")

        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers)
        ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]).repeat()
        ds = ds.enumerate()
        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        get_next = self.getNext(ds)

        counts = collections.defaultdict(int)
        while True:
            i, _ = self.evaluate(get_next())
            counts[i] += 1
            # Read until all workers have reached enumeration index 10.
            if counts[10] == num_workers:
                break

        for i in range(10):
            self.assertEqual(counts[i], num_workers)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(num_workers=[1, 3])))
    def testConcatenate(self, num_workers):
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        a = dataset_ops.Dataset.range(100)
        b = dataset_ops.Dataset.range(100, 200)
        ds = a.concatenate(b)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        assert_items_equal = (num_workers > 1)
        self.assertDatasetProduces(ds,
                                   list(range(200)),
                                   assert_items_equal=assert_items_equal)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(already_written=[True, False]))
    )
    def testSnapshot(self, already_written):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        ds = dataset_ops.Dataset.range(100)
        ds = ds.snapshot(self.get_temp_dir())
        if already_written:
            # Materialize the snapshot.
            self.getDatasetOutput(ds)

        ds = self._make_dynamic_sharding_dataset(ds, cluster)
        error_regex = "Splitting is not implemented for snapshot datasets"
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedDataset(self):
        cluster_1 = data_service_test_base.TestCluster(num_workers=1)
        cluster_2 = data_service_test_base.TestCluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        numbers = [1 * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(numbers)
        ds = self.make_distributed_dataset(
            ds, cluster_1, processing_mode=data_service_ops.ShardingPolicy.OFF)
        ds = ds.map(lambda x: x + 1)
        ds = self._make_dynamic_sharding_dataset(ds, cluster_2)

        error_regex = ("Cannot create split providers for dataset " +
                       "of type DataServiceDataset")
        with self.assertRaisesRegex(errors.UnimplementedError, error_regex):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributedEpoch(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)
        num_elements = 100
        ds = dataset_ops.Dataset.range(num_elements)
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           processing_mode="distributed_epoch")
        self.assertDatasetProduces(ds,
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testFlatMapWithRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=3)
        ds = dataset_ops.Dataset.range(5)

        def flat_map_fn(_):
            return dataset_ops.Dataset.from_tensor_slices(["a", "b",
                                                           "c"]).repeat(10)

        ds = ds.flat_map(flat_map_fn)
        ds = self._make_dynamic_sharding_dataset(ds, cluster)

        self.assertDatasetProduces(ds, [b"a", b"b", b"c"] * 50,
                                   assert_items_equal=True)
Example #7
0
class RemoteReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):

  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
    super(RemoteReplicateTest, self).__init__(methodName)
    self._cached_server1 = server_lib.Server.create_local_server()
    self._cached_server2 = server_lib.Server.create_local_server()
    self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
    self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
    self._device0 = "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME
    self._device1 = "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME
    self._device2 = "/job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME

  def setUp(self):
    super(RemoteReplicateTest, self).setUp()
    # Start the local server.
    local_port = pywrap_tensorflow.TF_PickUnusedPortOrDie()
    context.set_server_def(
        server_def=_get_server_def(
            JOB_NAME,
            local_server_port=local_port,
            remote_server_addresses=[
                self._cached_server1_target, self._cached_server2_target
            ],
            task_index=0))

  @combinations.generate(
      combinations.combine(tf_api_version=[2], mode=["eager"]))
  def testBasic(self):
    with ops.device(self._device0):
      dataset0 = dataset_ops.Dataset.range(100)
    replicated_ds = distribute.replicate(dataset0,
                                         [self._device1, self._device2])
    dataset1 = replicated_ds[self._device1]
    dataset2 = replicated_ds[self._device2]
    with ops.device(self._device0):
      self.assertDatasetProduces(dataset0, range(100))
    with ops.device(self._device1):
      self.assertDatasetProduces(dataset1, range(100))
    with ops.device(self._device2):
      self.assertDatasetProduces(dataset2, range(100))

  @combinations.generate(
      combinations.combine(tf_api_version=[2], mode=["eager"]))
  def testMap(self):
    with ops.device(self._device0):
      dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2)
    replicated_ds = distribute.replicate(dataset0,
                                         [self._device1, self._device2])
    dataset1 = replicated_ds[self._device1]
    dataset2 = replicated_ds[self._device2]
    with ops.device(self._device0):
      self.assertDatasetProduces(dataset0, range(0, 200, 2))
    with ops.device(self._device1):
      self.assertDatasetProduces(dataset1, range(0, 200, 2))
    with ops.device(self._device2):
      self.assertDatasetProduces(dataset2, range(0, 200, 2))

  @combinations.generate(
      combinations.combine(tf_api_version=[2], mode=["eager"]))
  def testVariableInput(self):
    with ops.device(self._device0):
      counter_var = variable_scope.get_variable(
          "counter", (), dtypes.int32, use_resource=True)
      dataset0 = dataset_ops.Dataset.range(100).map(
          lambda _: counter_var.assign_add(1))
    # We don't support stateful ops in functions as of now.
    with self.assertRaises(errors.FailedPreconditionError):
      replicated_ds = distribute.replicate(dataset0,
                                           [self._device1, self._device2])
      self.evaluate(replicated_ds[self._device1]._variant_tensor)

  @combinations.generate(
      combinations.combine(tf_api_version=[2], mode=["eager"]))
  def testAllowStatefulOp(self):
    with compat.forward_compatibility_horizon(2019, 9, 12):
      with ops.device(self._device0):
        dataset0 = dataset_ops.Dataset.range(100).map(
            lambda _: random_ops.random_uniform(  # pylint:disable=g-long-lambda
                [],
                minval=1,
                maxval=10,
                dtype=dtypes.float32))
        opt = dataset_ops.Options()
        opt.experimental_allow_stateful = True
        dataset0 = dataset0.with_options(opt)
      replicated_ds = distribute.replicate(dataset0,
                                           [self._device1, self._device2])
      dataset1 = replicated_ds[self._device1]
      dataset2 = replicated_ds[self._device2]

      with ops.device(self._device0):
        get_next0 = self.getNext(dataset0)
      with ops.device(self._device1):
        get_next1 = self.getNext(dataset1)
      with ops.device(self._device2):
        get_next2 = self.getNext(dataset2)

      for _ in range(100):
        get_next0()
        get_next1()
        get_next2()
class MakeDeterministicTest(test_base.DatasetTestBase, parameterized.TestCase):
    def _set_seed(self):
        # Set the seed, since in graph mode some non-random dataset ops call
        # tf.compat.v1.get_seed to copy the seed to a Defun. Calling get_seed raises
        # an error with determinism if no seed is set.
        # TODO(reedwm): Ensure such dataset ops do not raise an error when no seed
        # is set.
        random_seed.set_random_seed(1)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(use_function=[False, True],
                                 use_legacy_interleave=[False, True])))
    def test_stateful_ops_interleave(self, use_function,
                                     use_legacy_interleave):
        with test_util.deterministic_ops():

            v = variables.Variable(0.)

            def map_fn(x):
                v.assign_add(1.)
                return (x, v.read_value())

            def interleave_fn(x):
                del x
                return dataset_ops.Dataset.range(2).map(map_fn)

            if use_function:
                map_fn = def_function.function(map_fn)
                interleave_fn = def_function.function(interleave_fn)

            dataset = dataset_ops.Dataset.range(5)
            if use_legacy_interleave:
                dataset = dataset.apply(
                    interleave_ops.parallel_interleave(interleave_fn,
                                                       cycle_length=5))
            else:
                dataset = dataset.interleave(interleave_fn,
                                             cycle_length=5,
                                             num_parallel_calls=3)
            self.evaluate(variables.global_variables_initializer())
            expected_output = list(zip([0] * 5 + [1] * 5, range(1, 11)))
            self.assertDatasetProduces(dataset,
                                       expected_output=expected_output,
                                       requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(use_function=[False, True])))
    def test_stateful_ops_map(self, use_function):
        with test_util.deterministic_ops():

            v = variables.Variable(0.)

            def map_fn(x):
                v.assign_add(1.)
                return (x, v.read_value())

            if use_function:
                map_fn = def_function.function(map_fn)

            dataset = dataset_ops.Dataset.range(5)
            dataset = dataset.map(map_fn, num_parallel_calls=5)
            self.evaluate(variables.global_variables_initializer())
            expected_output = list(zip(range(0, 5), range(1, 6)))
            self.assertDatasetProduces(dataset,
                                       expected_output=expected_output,
                                       requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(use_function=[False, True])))
    def test_stateful_ops_batch(self, use_function):
        with test_util.deterministic_ops():

            v = variables.Variable(0.)

            def map_fn(x):
                return (x, v.read_value())

            if use_function:
                map_fn = def_function.function(map_fn)

            dataset = dataset_ops.Dataset.range(5)
            dataset = dataset.map(map_fn)
            dataset = dataset.apply(testing.assert_next(["Batch"]))
            dataset = dataset.batch(2, num_parallel_calls=2)
            self.evaluate(variables.global_variables_initializer())
            expected_output = [
                (np.array([0, 1]), np.array([0, 0])),
                (np.array([2, 3]), np.array([0, 0])),
                (np.array([4]), np.array([0])),
            ]
            self.assertDatasetProduces(dataset,
                                       expected_output=expected_output,
                                       requires_initialization=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(use_function=[False, True],
                                 use_legacy_map_and_batch=[False, True])))
    def test_stateful_ops_map_and_batch(self, use_function,
                                        use_legacy_map_and_batch):
        with test_util.deterministic_ops():

            v = variables.Variable(0.)

            def map_fn(x):
                v.assign_add(1.)
                return (x, v.read_value())

            if use_function:
                map_fn = def_function.function(map_fn)

            dataset = dataset_ops.Dataset.range(5)
            if use_legacy_map_and_batch:
                dataset = dataset.apply(
                    batching.map_and_batch(map_fn, 2, num_parallel_calls=5))
            else:
                dataset = dataset.map(map_fn, num_parallel_calls=5)
                dataset = dataset.batch(2)
            self.evaluate(variables.global_variables_initializer())
            expected_output = [
                (np.array([0, 1]), np.array([1, 2])),
                (np.array([2, 3]), np.array([3, 4])),
                (np.array([4]), np.array([5])),
            ]
            self.assertDatasetProduces(dataset,
                                       expected_output=expected_output,
                                       requires_initialization=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(use_function=[False, True],
                                 use_legacy_interleave=[False, True])))
    def test_no_stateful_ops_interleave(self, use_function,
                                        use_legacy_interleave):
        self._set_seed()
        with test_util.deterministic_ops():

            def interleave_fn(x):
                del x
                return dataset_ops.Dataset.range(2)

            if use_function:
                interleave_fn = def_function.function(interleave_fn)

            dataset = dataset_ops.Dataset.range(5)
            if use_legacy_interleave:
                dataset = dataset.apply(
                    testing.assert_next(["LegacyParallelInterleaveV2"]))
                dataset = dataset.apply(
                    interleave_ops.parallel_interleave(interleave_fn,
                                                       cycle_length=5))
            else:
                dataset = dataset.apply(
                    testing.assert_next(["ParallelInterleave"]))
                dataset = dataset.interleave(interleave_fn,
                                             cycle_length=5,
                                             num_parallel_calls=3)
            self.evaluate(variables.global_variables_initializer())
            self.assertDatasetProduces(dataset,
                                       expected_output=[0] * 5 + [1] * 5)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(use_function=[False, True])))
    def test_no_stateful_ops_map(self, use_function):
        self._set_seed()
        with test_util.deterministic_ops():

            def map_fn(x):
                return x + 1

            if use_function:
                map_fn = def_function.function(map_fn)

            dataset = dataset_ops.Dataset.range(5)
            dataset = dataset.apply(testing.assert_next(["ParallelMap"]))
            dataset = dataset.map(map_fn, num_parallel_calls=5)
            self.evaluate(variables.global_variables_initializer())
            expected_output = range(1, 6)
            self.assertDatasetProduces(dataset,
                                       expected_output=expected_output)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(use_function=[False, True],
                                 use_control_flow=[False, True])))
    def test_text_line_dataset(self, use_function, use_control_flow):
        self._set_seed()
        with test_util.deterministic_ops():

            def write_nums_to_file(filename, numbers):
                path = os.path.join(self.get_temp_dir(), filename)
                with open(path, "w") as f:
                    f.write("\n".join(str(n) for n in numbers))
                return path

            f1 = write_nums_to_file("f1", (1, 2, 3))
            f2 = write_nums_to_file("f2", (4, 5, 6))
            f3 = write_nums_to_file("f3", (7, 8, 9))

            if use_control_flow:

                def interleave_fn(filename):
                    # Test function that uses control flow. The True branch is never taken
                    concat = string_ops.string_join([filename, "abc"])
                    return control_flow_ops.cond(
                        math_ops.equal(filename, "abc"),
                        lambda: reader_ops.TextLineDataset(concat),
                        lambda: reader_ops.TextLineDataset(filename))
            else:

                def interleave_fn(filename):
                    return reader_ops.TextLineDataset(filename)

            if use_function:
                interleave_fn = def_function.function(interleave_fn)

            dataset = dataset_ops.Dataset.from_tensor_slices([f1, f2, f3])
            dataset = dataset.apply(testing.assert_next(["ParallelInterleave"
                                                         ]))
            dataset = dataset.interleave(interleave_fn,
                                         cycle_length=3,
                                         num_parallel_calls=3)

            self.assertDatasetProduces(
                dataset,
                expected_output=["1", "4", "7", "2", "5", "8", "3", "6", "9"])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(local_determinism=[None, True, False],
                                 global_determinism=[True, False])))
    def test_deterministic_attribute(self, local_determinism,
                                     global_determinism):
        self._set_seed()
        with test_util.deterministic_ops():

            def sleep(x):
                time.sleep(0.1)
                return x

            def map_function(x):
                if math_ops.equal(x, 0):
                    return script_ops.py_func(sleep, [x],
                                              x.dtype,
                                              stateful=False)
                else:
                    return x

            dataset = dataset_ops.Dataset.range(100)
            dataset = dataset.map(map_function,
                                  num_parallel_calls=2,
                                  deterministic=local_determinism)
            opts = options_lib.Options()
            opts.deterministic = global_determinism
            dataset = dataset.with_options(opts)

            self.assertDatasetProduces(dataset, expected_output=range(100))

    @combinations.generate(test_base.default_test_combinations())
    def test_rewrite_prefetch(self):
        with test_util.deterministic_ops():
            v = variables.Variable(-1, dtype=dtypes.int64)

            def map_fn(x):
                v.assign(x)
                return x

            dataset = dataset_ops.Dataset.range(5)
            dataset = dataset.map(map_fn)
            dataset = dataset.prefetch(5)
            self.evaluate(variables.global_variables_initializer())
            get_next = self.getNext(dataset, requires_initialization=True)
            self.assertEqual(self.evaluate(v), -1)
            self.assertEqual(self.evaluate(get_next()), 0)
            time.sleep(0.01)
            self.assertEqual(self.evaluate(v), 0)
            self.assertEqual(self.evaluate(get_next()), 1)
            time.sleep(0.01)
            self.assertEqual(self.evaluate(v), 1)

    @combinations.generate(test_base.default_test_combinations())
    def test_no_determinism(self):
        config.disable_op_determinism()
        v = variables.Variable(0.)

        def interleave_fn(x):
            del x
            v.assign(1.)
            return dataset_ops.Dataset.range(2)

        dataset = dataset_ops.Dataset.range(5)
        dataset = dataset.apply(testing.assert_next(["ParallelInterleave"]))
        dataset = dataset.interleave(interleave_fn,
                                     cycle_length=5,
                                     num_parallel_calls=3)
        self.evaluate(variables.global_variables_initializer())
        expected_output = [0] * 5 + [1] * 5
        self.assertDatasetProduces(dataset,
                                   expected_output=expected_output,
                                   requires_initialization=True)
Example #9
0
def _test_combinations():
    return combinations.combine(tf_api_version=[1], mode=["graph"])
Example #10
0
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherStop(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        results.append(next(iterator).numpy())
        cluster.stop_dispatcher()
        # After the dispatcher dies, the worker should continue providing the rest
        # of the dataset's elements.
        for _ in range(num_elements - 1):
            results.append(next(iterator).numpy())
        self.assertEqual(results, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBeforeReading(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        cluster.restart_dispatcher()

        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartDuringReading(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())
        cluster.restart_dispatcher()
        for elem in iterator:
            results.append(elem.numpy())

        self.assertEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherRestartBetweenIterations(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(100, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherManyRestarts(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements_start = 10
        num_elements_end = 15
        datasets = []
        for num_elements in range(num_elements_start, num_elements_end):
            datasets.append(
                self.make_distributed_range_dataset(num_elements, cluster))
            cluster.restart_dispatcher()
        for ds, num_elements in zip(
                datasets, range(num_elements_start, num_elements_end)):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherAndWorkerRestart(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)

        cluster.restart_dispatcher()
        cluster.restart_worker()
        self.assertDatasetProduces(ds, list(range(num_elements)))
        cluster.restart_dispatcher()
        cluster.restart_worker()
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testDispatcherAndMultiWorkerRestart(self):
        num_workers = 2
        cluster = self.create_cluster(num_workers=num_workers)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []

        cluster.restart_dispatcher()
        for worker_index in range(num_workers):
            cluster.restart_worker(worker_index=worker_index)
        for elem in iterator:
            results.append(elem.numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), results)
        cluster.restart_dispatcher()
        for worker_index in range(num_workers):
            cluster.restart_worker(worker_index=worker_index)
        for elem in iterator:
            results.append(elem.numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), results)

    @combinations.generate(test_base.eager_only_combinations())
    def testStartServersLate(self):
        # Test that the data service client performs retries instead of failing when
        # the dataset is created before the master and worker are started.
        try:
            import portpicker  # pylint: disable=g-import-not-at-top
            dispatcher_port = portpicker.pick_unused_port()
        except:
            raise self.skipTest(
                "Flakes in portpicker library do not represent "
                "TensorFlow errors.")
        cluster = self.create_cluster(num_workers=1,
                                      dispatcher_port=dispatcher_port,
                                      start=False)

        def start_servers():
            time.sleep(0.5)
            cluster.start_dispatcher()
            cluster.start_workers()

        start_servers_thread = threading.Thread(target=start_servers,
                                                daemon=True)
        start_servers_thread.start()

        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        results = [elem.numpy() for elem in ds]
        self.assertEqual(list(range(num_elements)), results)
        start_servers_thread.join()

    @combinations.generate(test_base.eager_only_combinations())
    def testAddWorkerMidJob(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        results = []
        # Read halfway through the dataset.
        for _ in range(num_elements // 2):
            results.append(next(iterator).numpy())

        cluster.add_worker()
        # Wait for the new worker to register with the dispatcher.
        while cluster.num_registered_workers() < 2:
            time.sleep(10 / 1000)  # 10ms

        for elem in iterator:
            results.append(elem.numpy())

        self.assertCountEqual(2 * list(range(num_elements)), results)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(use_same_port=[True, False]),
                           data_service_test_base.all_cluster_configurations())
    )
    def testRestartWorker(self, use_same_port, work_dir, fault_tolerant_mode):
        cluster = self.create_cluster(num_workers=1,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        iterator = iter(ds)
        # Read halfway through the dataset.
        midpoint = num_elements // 2
        for i in range(midpoint):
            self.assertEqual(i, next(iterator).numpy())

        # Stop the original worker and start a new one.
        cluster.restart_worker(use_same_port=use_same_port)

        # There may have been some elements prefetched from the first worker
        # before it was stopped.
        while True:
            val = next(iterator).numpy()
            if val == 0:
                break

        # The dataset starts over now that we read from the new worker.
        # TODO(b/157086991): Iterate until end of sequence when we support
        # detecting lost workers.
        for i in range(1, num_elements // 2):
            val = next(iterator).numpy()
            self.assertEqual(i, val)

    @combinations.generate(test_base.eager_only_combinations())
    def testChangeProcessingModeAfterRestart(self):
        cluster = self.create_cluster(num_workers=1)
        num_elements = 100
        range_dataset = dataset_ops.Dataset.range(num_elements)
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service=cluster.target,
                                        job_name="test"))
        iterator = iter(ds)
        for i in range(num_elements // 2):
            self.assertEqual(i, next(iterator).numpy())
        cluster.restart_dispatcher()
        ds = range_dataset.apply(
            data_service_ops.distribute(processing_mode="distributed_epoch",
                                        service=cluster.target,
                                        job_name="test"))
        with self.assertRaisesOpError(
                "already an existing job with that name "
                "using processing mode <parallel_epochs>"):
            next(iter(ds)).numpy()

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(work_dir=[TMP_WORK_DIR, NO_WORK_DIR])))
    def testDistributeLargeGraphThenRegisterWorker(self, work_dir):
        cluster = self.create_cluster(num_workers=0,
                                      work_dir=work_dir,
                                      fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        it = iter(ds)
        cluster.add_worker()
        self.assertAllEqual(next(it), tensor)
Example #11
0

class CollectiveOpsV1(object):
    all_reduce = _collective_ops.all_reduce
    all_gather = _collective_ops.all_gather


class CollectiveOpsV2(object):
    all_reduce = _collective_ops.all_reduce_v2
    all_gather = _collective_ops.all_gather_v2


@combinations.generate(
    combinations.combine(collective_ops=[
        combinations.NamedObject('v1', CollectiveOpsV1),
        combinations.NamedObject('v2', CollectiveOpsV2)
    ],
                         mode='eager'))
class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
    def setUp(self):
        _setup_context()
        super().setUp()

    def testReduce(self, collective_ops):
        @def_function.function
        def run_all_reduce_1cpu():
            with ops.device('/device:CPU:0'):
                in_value = constant_op.constant([1.])
                group_size = 1
                group_key = 1
                instance_key = 1
Example #12
0
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testBasic(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        self.assertEqual(
            [[8] if drop_remainder else [None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testCanHandleUnknownRank(self):
        dataset = dataset_ops.Dataset.from_tensors("xxx")
        # decode_image results in a tensor of completely unknown shape (i.e. unknown
        # rank)
        dataset = dataset.map(image_ops.decode_image)
        self.assertEqual([tensor_shape.TensorShape(None)],
                         _flat_shapes(dataset))
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        # Note that we are just testing the dataset shapes, not the actual output.
        self.assertEqual([tensor_shape.TensorShape(None)],
                         _flat_shapes(rebatched_dataset))

    @combinations.generate(test_base.default_test_combinations())
    def testCanHandleUnknownDims(self):
        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.batch(10, drop_remainder=False)
        dataset = dataset.batch(10, drop_remainder=False)
        self.assertEqual([[None, None]],
                         [ts.as_list() for ts in _flat_shapes(dataset)])
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        # Note that we are just testing the dataset shapes, not the actual output.
        self.assertEqual(
            [[None, None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    @combinations.generate(test_base.default_test_combinations())
    def testScalarInputError(self):
        dataset = dataset_ops.Dataset.range(1024)
        distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
        with self.assertRaisesRegexp(ValueError, ("You can fix the issue "
                                                  "by adding the `batch`")):
            distribute._RebatchDataset(dataset, num_replicas=4)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
        self.assertEqual(
            [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
        expected_output = []
        i = 0
        for _ in range(32):  # number of steps
            # first four minibatches have seven elements
            for _ in range(4):
                expected_output.append([k for k in range(i, i + 7)])
                i += 7
            # last minibatch has four elements
            expected_output.append([k for k in range(i, i + 4)])
            i += 4
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testBatchSizeNotDivisibleByNumReplicas2(self):
        dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
        # This will rebatch into sub-batches of size 4, since
        # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get
        # data.
        expected_output = [[k for k in range(i, i + 4)]
                           for i in range(0, 16, 4)]
        expected_output.extend([[]])  # Last replica gets an empty batch
        expected_output.extend([[k for k in range(i, i + 4)]
                                for i in range(16, 32, 4)])
        expected_output.extend([[]])  # Last replica gets an empty batch
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testTupleOutput(self):
        dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
            32)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        expected_output = [
            (
                [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testNestedDictionaryOutput(self):
        dataset = dataset_ops.Dataset.range(1024).map(lambda x: {
            "a": x,
            "b": {
                "c": x
            }
        }).batch(32)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        expected_output = [
            {
                "a": [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                "b": {
                    "c": [k for k in range(i, i + 8)]
                }
            } for i in range(0, 1024, 8)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testFinalPartialBatch(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1032).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        self.assertEqual(
            [[8] if drop_remainder else [None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        # if drop_remainder, the final partial batch is dropped, even though it
        # makes up a complete minibatch.
        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        if not drop_remainder:
            # The last partial batch of size 8 is split over 4 replicas
            expected_output.extend([[k for k in range(i, i + 2)]
                                    for i in range(1024, 1032, 2)])
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testFinalPartialBatchAfterRebatch(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(34).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        self.assertEqual(
            [[8] if drop_remainder else [None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 32, 8)]  # pylint: disable=g-complex-comprehension
        if not drop_remainder:
            # The last partial batch of size 2 is split over 4 replicas
            expected_output += [[32], [33], [], []]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testMultipleBatches(self):
        dataset = dataset_ops.Dataset.range(128).batch(4).batch(8)
        self.assertEqual([[None, None]],
                         [ts.as_list() for ts in _flat_shapes(dataset)])

        # Each element is a list of 8 elements where each element is a list of 4.
        expected_output = [
            [
                [j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                for j in range(i, i + 32, 4)
            ]  # generates 8 elements
            for i in range(0, 128, 32)
        ]
        self.assertDatasetProduces(dataset, expected_output)

        rebatched_dataset = distribute._RebatchDataset(dataset, 4)
        self.assertEqual(
            [[None, None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
        # Each element is a list of 2 elements where each element is a list of 4.
        expected_output = [
            [
                [j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                for j in range(i, i + 8, 4)
            ]  # generates 2 elements
            for i in range(0, 128, 8)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testRaggedTensorDataset(self):
        # Set up a dataset that produces ragged tensors with a static batch size.
        row_lengths = np.random.randint(8, size=128)
        values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32)
        dataset = dataset_ops.Dataset.from_tensor_slices(
            ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths))
        dataset = dataset.batch(32, drop_remainder=True)

        # The map changes the internal representation of the ragged tensor.
        # This test will fail if we don't normalize the tensor representation.
        dataset = dataset.map(lambda x: x)

        dataset = distribute._RebatchDataset(dataset, num_replicas=8)
        # After rebatching, batch size is now 4.
        expected_output = []
        value_index = 0
        for batch_row_lengths in row_lengths.reshape((-1, 4)):
            num_values = np.sum(batch_row_lengths)
            expected_output.append(
                ragged_tensor.RaggedTensor.from_row_lengths(
                    values[value_index:(value_index + num_values)],
                    batch_row_lengths))
            value_index += num_values
        self.assertDatasetProduces(dataset, expected_output)
Example #13
0
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testBasic(self):
        components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                      np.array([9.0, 10.0, 11.0, 12.0]))

        def dataset_fn(count=5, buffer_size=None, seed=0):
            repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(
                components).repeat(count))
            if buffer_size:
                shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed)

                self.assertEqual(
                    tuple([c.shape[1:] for c in components]),
                    dataset_ops.get_legacy_output_shapes(shuffle_dataset))
                return shuffle_dataset
            else:
                return repeat_dataset

        # First run without shuffling to collect the "ground truth".
        get_next = self.getNext(dataset_fn())
        unshuffled_elements = []
        for _ in range(20):
            unshuffled_elements.append(self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

        # Assert that the shuffled dataset has the same elements as the
        # "ground truth".
        get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
        shuffled_elements = []
        for _ in range(20):
            shuffled_elements.append(self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        self.assertAllEqual(sorted(unshuffled_elements),
                            sorted(shuffled_elements))

        # Assert that shuffling twice with the same seeds gives the same sequence.
        get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
        reshuffled_elements_same_seed = []
        for _ in range(20):
            reshuffled_elements_same_seed.append(self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)

        # Assert that shuffling twice with a different seed gives a different
        # permutation of the same elements.
        get_next = self.getNext(dataset_fn(buffer_size=100, seed=137))
        reshuffled_elements_different_seed = []
        for _ in range(20):
            reshuffled_elements_different_seed.append(self.evaluate(
                get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        self.assertNotEqual(shuffled_elements,
                            reshuffled_elements_different_seed)
        self.assertAllEqual(sorted(shuffled_elements),
                            sorted(reshuffled_elements_different_seed))

        # Assert that the shuffled dataset has the same elements as the
        # "ground truth" when the buffer size is smaller than the input
        # dataset.
        get_next = self.getNext(dataset_fn(buffer_size=2, seed=37))
        reshuffled_elements_small_buffer = []
        for _ in range(20):
            reshuffled_elements_small_buffer.append(self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        self.assertAllEqual(sorted(unshuffled_elements),
                            sorted(reshuffled_elements_small_buffer))

        # Test the case of shuffling an empty dataset.
        get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37))

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(combinations.combine(tf_api_version=1,
                                                mode="graph"))
    def testSeedZero(self):
        """Test for same behavior when the seed is a Python or Tensor zero."""
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.range(10).shuffle(10, seed=0))
        get_next = iterator.get_next()

        elems = []
        with self.cached_session() as sess:
            for _ in range(10):
                elems.append(sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

        seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
        iterator = dataset_ops.make_initializable_iterator(
            dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder))
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
            for elem in elems:
                self.assertEqual(elem, sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testDefaultArguments(self):
        components = [0, 1, 2, 3, 4]
        dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
            5).repeat()
        get_next = self.getNext(dataset)
        counts = collections.defaultdict(lambda: 0)
        for _ in range(10):
            for _ in range(5):
                counts[self.evaluate(get_next())] += 1

        for i in range(5):
            self.assertEqual(10, counts[i])

    @combinations.generate(
        combinations.times(
            test_base.graph_only_combinations(),
            combinations.combine(reshuffle=[True, False]),
            combinations.combine(graph_seed=38, op_seed=None) +
            combinations.combine(graph_seed=None, op_seed=42) +
            combinations.combine(graph_seed=38, op_seed=42)))
    def testShuffleSeed(self, reshuffle, graph_seed, op_seed):
        results = []
        for _ in range(2):
            with ops.Graph().as_default() as g:
                random_seed.set_random_seed(graph_seed)
                dataset = dataset_ops.Dataset.range(10).shuffle(
                    10, seed=op_seed,
                    reshuffle_each_iteration=reshuffle).repeat(3)
                iterator = dataset_ops.make_one_shot_iterator(dataset)
                next_element = iterator.get_next()

                run_results = []
                with self.session(graph=g) as sess:
                    for _ in range(30):
                        run_results.append(sess.run(next_element))
                    with self.assertRaises(errors.OutOfRangeError):
                        sess.run(next_element)
                results.append(run_results)

        self.assertAllEqual(results[0], results[1])

    # TODO(b/117581999): enable this test for eager-mode.
    @combinations.generate(
        combinations.times(
            test_base.graph_only_combinations(),
            combinations.combine(reshuffle=[True, False],
                                 initializable=[True, False])))
    def testMultipleIterators(self, reshuffle, initializable):
        with ops.Graph().as_default() as g:
            dataset = dataset_ops.Dataset.range(100).shuffle(
                10, reshuffle_each_iteration=reshuffle).repeat(3)

            if initializable:
                iterators = [
                    dataset_ops.make_initializable_iterator(dataset)
                    for _ in range(2)
                ]
            else:
                iterators = [
                    dataset_ops.make_one_shot_iterator(dataset)
                    for _ in range(2)
                ]

            results = []
            with self.session(graph=g) as sess:
                for iterator in iterators:
                    if initializable:
                        sess.run(iterator.initializer)
                    next_element = iterator.get_next()
                    run_results = []
                    for _ in range(300):
                        run_results.append(sess.run(next_element))
                    with self.assertRaises(errors.OutOfRangeError):
                        sess.run(next_element)

                    results.append(run_results)

                self.assertNotEqual(results[0], results[1])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(reshuffle=[True, False], seed=[None, 42])))
    def testReshuffleRepeatEpochs(self, reshuffle, seed):
        dataset = dataset_ops.Dataset.range(10).shuffle(
            10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2)
        next_element = self.getNext(dataset)

        first_epoch = []
        for _ in range(10):
            first_epoch.append(self.evaluate(next_element()))

        second_epoch = []
        for _ in range(10):
            second_epoch.append(self.evaluate(next_element()))

        self.assertEqual(first_epoch == second_epoch, not reshuffle)

    @combinations.generate(
        combinations.times(
            combinations.combine(tf_api_version=2, mode="eager"),
            combinations.combine(reshuffle=[True, False], seed=[None, 42])))
    def testReshuffleIterationEpochs(self, reshuffle, seed):
        dataset = dataset_ops.Dataset.range(10).shuffle(
            10, seed=seed, reshuffle_each_iteration=reshuffle)

        first_epoch = []
        for elem in dataset:
            first_epoch.append(elem.numpy())

        second_epoch = []
        for elem in dataset:
            second_epoch.append(elem.numpy())

        self.assertEqual(first_epoch == second_epoch, not reshuffle)

    @combinations.generate(combinations.combine(tf_api_version=2,
                                                mode="eager"))
    def testShuffleV2ResourceCapture(self):
        def make_dataset():
            ids = dataset_ops.Dataset.range(10)
            ids = ids.shuffle(1)

            def interleave_fn(dataset, _):
                return dataset

            dataset = dataset_ops.Dataset.range(1)
            dataset = dataset.interleave(functools.partial(interleave_fn, ids))
            return dataset

        results = []
        for elem in make_dataset():
            results.append(elem.numpy())

        self.assertAllEqual(results, range(10))

    @combinations.generate(
        combinations.times(
            test_base.eager_only_combinations(),
            combinations.combine(reshuffle=[True, False], seed=[None, 42])))
    def testReshuffleSeparateTransformations(self, reshuffle, seed):
        dataset = dataset_ops.Dataset.range(10)

        first_epoch = []
        for elem in dataset.shuffle(10,
                                    seed=seed,
                                    reshuffle_each_iteration=reshuffle):
            first_epoch.append(elem.numpy())

        second_epoch = []
        for elem in dataset.shuffle(10,
                                    seed=seed,
                                    reshuffle_each_iteration=reshuffle):
            second_epoch.append(elem.numpy())

        self.assertEqual(first_epoch != second_epoch, seed is None)

    @combinations.generate(combinations.combine(tf_api_version=2,
                                                mode="eager"))
    def testShuffleV2InFunction(self):
        counter_var = variables.Variable(0)

        @function.defun
        def consume():
            ds = dataset_ops.Dataset.range(10)
            ds = ds.shuffle(1)
            for _ in ds:
                counter_var.assign(counter_var + 1)

        consume()
        self.assertAllEqual(self.evaluate(counter_var), 10)

    @combinations.generate(test_base.default_test_combinations())
    def testEmptyDataset(self):
        dataset = dataset_ops.Dataset.from_tensors(1)

        def map_fn(x):
            with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
                return x

        dataset = dataset.map(map_fn)
        dataset = dataset.cache()
        dataset = dataset.shuffle(buffer_size=10).repeat()

        get_next = self.getNext(dataset)

        # First time around, we get an error for the failed assertion.
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(get_next())

        # Second time around, we get an EOF because the cached dataset is empty.
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(reshuffle=[True, False])))
    def testRerandomizeOnReplicate(self, reshuffle):
        random_seed.set_random_seed(None)
        # When no seeds are fixed, each instantiation of the shuffle dataset should
        # produce elements in a different order.
        num_elements = 100
        dataset = dataset_ops.Dataset.range(num_elements)
        dataset = dataset.shuffle(num_elements,
                                  reshuffle_each_iteration=reshuffle)

        shuffle_1 = self.getDatasetOutput(dataset)
        dataset = self.graphRoundTrip(dataset, allow_stateful=True)
        shuffle_2 = self.getDatasetOutput(dataset)

        self.assertCountEqual(shuffle_1, shuffle_2)
        self.assertNotEqual(shuffle_1, shuffle_2)
Example #14
0
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           data_service_test_base.all_cluster_configurations())
    )
    def testDistributeBasic(self, work_dir, fault_tolerant_mode):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=work_dir,
            fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testDistributeCompression(self, compression):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 compression=compression)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidCompression(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError,
                                    "Invalid compression argument"):
            self.make_distributed_range_dataset(10, cluster, compression="foo")

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeSparse(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        element = sparse_tensor.SparseTensor(indices=[[0]],
                                             values=constant_op.constant(
                                                 [0], dtype=dtypes.int32),
                                             dense_shape=[1])
        ds = dataset_ops.Dataset.from_tensors(element)
        ds = self.make_distributed_dataset(ds, cluster)
        results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
        self.assertAllEqual(results, [[0]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeRagged(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
        ds = ds.map(math_ops.range)
        ds = ds.apply(batching.dense_to_ragged_batch(2))
        ds = self.make_distributed_dataset(ds, cluster)
        results = [elem.to_tensor() for elem in ds]
        self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
        self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
        self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                init_source=["textfile", "keyvaluetensor", "dataset"])))
    def testDistributeLookupTable(self, init_source):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        initializer = self.lookupTableInitializer(init_source, [10, 11])
        table = lookup_ops.StaticHashTable(initializer, -1)
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.evaluate(lookup_ops.tables_initializer())
        self.assertDatasetProduces(ds, [10, 11, -1],
                                   requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(value_rank=[0, 1])))
    def testDistributeMutableHashTable(self, value_rank):
        def value(v):
            for _ in range(value_rank):
                v = [v, v]
            return v

        v1 = value(10)
        v2 = value(11)
        default_value = value(-1)

        cluster = data_service_test_base.TestCluster(num_workers=1)
        table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64,
                                            default_value)
        self.evaluate(table.insert([0, 1], [v1, v2]))
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [v1, v2, default_value],
                                   requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(shuffle_seed=[None, 10])))
    def testShuffleOrder(self, shuffle_seed):
        random_seed.set_random_seed(None)
        num_elements = 100
        cluster = data_service_test_base.TestCluster(num_workers=2)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.shuffle(num_elements, seed=shuffle_seed)
        ds = self.make_distributed_dataset(ds, cluster)
        output = self.getDatasetOutput(ds)

        # The output will be two sequences of range(num_elements)
        # non-deterministically interleaved together. If the orders of the elements
        # were the same, first_order and second_order computed below will be equal.
        first_order = {}
        second_order = {}
        for element in output:
            if element in first_order:
                second_order[element] = len(second_order)
            else:
                first_order[element] = len(first_order)
        if shuffle_seed is None:
            self.assertNotEqual(first_order, second_order)
        else:
            self.assertEqual(first_order, second_order)

    @combinations.generate(test_base.default_test_combinations())
    def testMultipleEpochs(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        for _ in range(10):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatedDataset(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_repetitions *
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testConcurrentEpoch(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_datasets = 3
        get_nexts = []
        results = []
        for _ in range(num_datasets):
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            get_nexts.append(self.getNext(ds))
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = self.evaluate(get_nexts[dataset_ind]())
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testMultiWorker(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testMaxOutstandingRequests(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 max_outstanding_requests=1)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10

        @def_function.function
        def f():
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in ds:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testEmptyJobNameDistribute(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError, "job_name must not be empty"):
            dataset_ops.Dataset.range(10).apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=cluster.dispatcher.target,
                                            job_name=""))

    @combinations.generate(test_base.default_test_combinations())
    def testEmptyJobNameFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset_ops.Dataset.range(10))
        with self.assertRaisesRegex(ValueError, "job_name must not be empty"):
            data_service_ops.from_dataset_id(dataset_id=dataset_id,
                                             processing_mode="parallel_epochs",
                                             service=cluster.dispatcher.target,
                                             job_name="")

    @combinations.generate(test_base.default_test_combinations())
    def testNonStringJobNameDistribute(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError, "job_name must be a string"):
            dataset_ops.Dataset.range(10).apply(
                data_service_ops.distribute(
                    processing_mode="parallel_epochs",
                    service=cluster.dispatcher.target,
                    job_name=constant_op.constant("foo")))

    @combinations.generate(test_base.default_test_combinations())
    def testNonStringJobNameFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher.target, dataset_ops.Dataset.range(10))
        with self.assertRaisesRegex(ValueError, "job_name must be a string"):
            data_service_ops.from_dataset_id(
                dataset_id=dataset_id,
                processing_mode="parallel_epochs",
                service=cluster.dispatcher.target,
                job_name=constant_op.constant("foo"))

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(self.evaluate(get_next_1()))
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.default_test_combinations())
    def testDifferentJobNames(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name1")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name2")
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultiIteration(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        # iteration 1
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, [])
        # iteration 2
        self.assertDatasetProduces(ds2, list(range(num_elements)))
        self.assertDatasetProduces(ds1, [])

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobNameRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        num_repetitions = 3
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds1 = ds1.repeat(num_repetitions)
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = ds2.repeat(num_repetitions)
        results = []
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_1()))
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(num_repetitions * list(range(num_elements)),
                              results)

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultipleEpochs(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset = self.make_distributed_range_dataset(10,
                                                      cluster,
                                                      job_name="job_name")

        num_epochs = 5
        for _ in range(num_epochs):
            get_next = self.getNext(dataset)
            self.assertEqual(self.getIteratorOutput(get_next), list(range(10)))

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(job_name=[None, "test"])))
    def testGcUnusedJob(self, job_name):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 job_name=job_name)
        it = iter(ds)
        self.assertEqual(next(it).numpy(), 0)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)
        del it
        while cluster.workers[0].num_tasks() > 0:
            time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testDontGcUsedJob(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 10
        it1 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test1"))
        it2 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        self.assertEqual(cluster.workers[0].num_tasks(), 2)
        del it1
        del it2
        # Check that only the first job is gced. The second job will not be gced
        # because there is still an outstanding iterator for it.
        while cluster.workers[0].num_tasks() > 1:
            time.sleep(0.1)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)

    @combinations.generate(test_base.eager_only_combinations())
    def testGcAndRecreate(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=3, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 1000
        # Repeatedly create and garbage-collect the same job.
        for _ in range(3):
            ds = self.make_distributed_range_dataset(num_elements,
                                                     cluster,
                                                     job_name="test")
            it = iter(ds)
            for _ in range(50):
                next(it)
            del it
            # Wait for the task to be garbage-collected on all workers.
            while cluster.num_tasks_on_workers() > 0:
                time.sleep(0.1)

    @combinations.generate(test_base.default_test_combinations())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = data_service_test_base.TestCluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = options_lib.Options()
            opts.deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = options_lib.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        cluster = data_service_test_base.TestCluster(num_workers=3)
        ds = self.make_distributed_dataset(ds, cluster)
        self.getDatasetOutput(ds)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(external_state_policy=[
                options_lib.ExternalStatePolicy.IGNORE,
                options_lib.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.default_test_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(options_lib.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeFromInterleave(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(x):
            dataset = dataset_ops.Dataset.range(10 * x, 10 * x + 2)
            dataset = self.make_distributed_dataset(dataset, cluster)
            return dataset

        ds = ds.interleave(interleave_fn, cycle_length=2)
        self.assertDatasetProduces(ds, [0, 10, 1, 11])

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "service must be a string"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=1))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "service must not be empty"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=""))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeExplicitProtocol(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, data_transfer_protocol="grpc")
        ds = dataset_ops.Dataset.range(10)
        ds = ds.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service="grpc://" +
                                        cluster.dispatcher_address()))
        self.assertDatasetProduces(ds, list(range(10)))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidProtocol(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(
                errors.NotFoundError,
                "No credentials factory has been registered for protocol grp"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service="grp://" +
                                            cluster.dispatcher_address()))
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeInvalidProcessingMode(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(
                ValueError,
                "should be a ShardingPolicy, `\"parallel_epochs\"`, or "
                "`\"distributed_epoch\"`. Got 'invalid'."):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="invalid",
                                            service="grpc://localhost:5000"))

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasets(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1, cluster, processing_mode="distributed_epoch")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertDatasetProduces(ds,
                                   list(
                                       zip(range(num_elements),
                                           range(num_elements))),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasetsSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1,
            cluster,
            processing_mode="distributed_epoch",
            job_name="job_name")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs",
                                            job_name="job_name")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "but there is already an existing job"):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = self.register_dataset(cluster.dispatcher_address(), ds)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdSharedJobs(self):
        cluster = data_service_test_base.TestCluster(num_workers=2)

        datasets = [
            dataset_ops.Dataset.range(20, output_type=dtypes.int32),
            dataset_ops.Dataset.from_tensor_slices(list(range(20, 40)))
        ]
        dataset_ids = []

        for ds in datasets:
            dataset_id = self.register_dataset(cluster.dispatcher_address(),
                                               ds)
            dataset_ids.append(dataset_id)

        # Read from both jobs in parallel, with 2 consumers for each job.
        data_service_datasets = []
        for _ in range(2):
            for dataset, dataset_id in zip(datasets, dataset_ids):
                ds = self.from_dataset_id("distributed_epoch",
                                          cluster,
                                          dataset_id,
                                          dataset.element_spec,
                                          job_name="shared_job")
                data_service_datasets.append(ds)
        ds = dataset_ops.Dataset.from_tensor_slices(data_service_datasets)
        ds = ds.interleave(lambda x: x,
                           cycle_length=len(data_service_datasets))

        self.assertDatasetProduces(ds,
                                   list(range(40)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testRegisteringDatasetAsTfFunction(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        register_func = def_function.function(self.register_dataset)
        dataset_id = register_func(
            (constant_op.constant("grpc"),
             constant_op.constant(cluster.dispatcher_address())), ds)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdMultipleComponents(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = self.register_dataset(cluster.dispatcher_address(), ds)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdWrongElementSpec(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = self.register_dataset(cluster.dispatcher_address(), ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, wrong_spec)

        if data_service_test_base.TRANSFER_PROTOCOL.value:
            with self.assertRaisesRegex(errors.InvalidArgumentError,
                                        "Data type mismatch at component 0"):
                self.evaluate(self.getNext(from_dataset_id_ds)())
        else:
            with self.assertRaisesRegex(errors.FailedPreconditionError,
                                        "Expected a tensor of type variant"):
                self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdNotRegistered(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = self.from_dataset_id("parallel_epochs", cluster,
                                                  dataset_id, element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = data_service_test_base.TestCluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds)
        self.assertEqual(0, self.evaluate(get_next()))
        # Without properly implemented cancellation, we will hang here while trying
        # to garbage collect the dataset iterator.

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterEquivalentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1)
        id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2)
        self.assertEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterDifferentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(20)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = self.register_dataset(cluster.dispatcher_address(), ds_1)
        id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2)
        self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testTwoLevelDistribute(self):
        cluster_1_size = 3
        cluster_1 = data_service_test_base.TestCluster(
            num_workers=cluster_1_size)
        cluster_2 = data_service_test_base.TestCluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        strings = ["a" * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(strings)
        ds = ds.shuffle(len(strings))
        ds = self.make_distributed_dataset(ds, cluster_1)
        # Large enough so that all strings of the same size are windowed together.
        window_size = cluster_1_size * size_repeats
        batch_size = size_repeats

        def key_func(x):
            return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

        ds = ds.apply(
            grouping.group_by_window(
                key_func=key_func,
                reduce_func=lambda _, x: x.batch(batch_size),
                window_size=window_size))
        ds = self.make_distributed_dataset(ds, cluster_2)

        get_next = self.getNext(ds)
        for _ in range(num_sizes):
            element = self.evaluate(get_next())
            for _ in range(1, cluster_1_size):
                self.assertAllEqual(self.evaluate(get_next()), element)
        self.assertEmpty(self.getIteratorOutput(get_next))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testDistributeLargeGraph(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [tensor])

    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           combinations.combine(use_resource=False)) +
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(use_resource=True)))
    def testVariables(self, use_resource):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        if not use_resource:
            with variable_scope.variable_scope("foo", use_resource=False):
                v = variables.VariableV1(10, dtype=dtypes.int64)
        else:
            v = variables.Variable(10, dtype=dtypes.int64)

        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(lambda x: x + v)
        ds = self.make_distributed_dataset(ds, cluster)
        self.evaluate(v.initializer)
        self.assertDatasetProduces(ds,
                                   list(range(10, 13)),
                                   requires_initialization=True)

    @combinations.generate(test_base.graph_only_combinations())
    def testElementSpecGraphMode(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        with self.assertRaisesRegex(
                ValueError,
                "In graph mode element_spec must be provided manually."):
            ds = data_service_ops.from_dataset_id("parallel_epochs",
                                                  cluster.dispatcher_address(),
                                                  dataset_id)

    @combinations.generate(test_base.eager_only_combinations())
    def testFromDatasetIdDoesntRequireElementSpec(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=NO_WORK_DIR,
            fault_tolerant_mode=False,
            data_transfer_protocol="grpc")
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)

        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        ds = data_service_ops.from_dataset_id("parallel_epochs",
                                              cluster.dispatcher_address(),
                                              dataset_id)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testElementSpecMixedMode(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)

        @def_function.function
        def get_dataset_id():
            return data_service_ops.register_dataset(
                cluster.dispatcher_address(), ds)

        dataset_id = get_dataset_id()
        dataset_id_val = tensor_util.constant_value(dataset_id)

        with self.assertRaisesRegex(
                ValueError, "Failed to fetch element spec for dataset id " +
                str(dataset_id_val) + " from tf.data service. If the "
                "dataset was registered in graph mode or inside a "
                "tf.function, the `element_spec` must be specified as "
                "an argument to `from_dataset_id`."):
            ds = data_service_ops.from_dataset_id("parallel_epochs",
                                                  cluster.dispatcher_address(),
                                                  dataset_id)

    @combinations.generate(test_base.default_test_combinations())
    def testNoShardingPolicy(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.OFF)
        self.assertDatasetProduces(dataset, list(range(20)))

    @combinations.generate(test_base.default_test_combinations())
    def testCardinality(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        dataset = self.make_distributed_range_dataset(10, cluster)
        self.assertEqual(self.evaluate(dataset.cardinality()),
                         dataset_ops.UNKNOWN)
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.graph_only_combinations())
    def testNoGradients(self):
        component = constant_op.constant([1.])
        side = constant_op.constant(0.)
        add = lambda x: x + side
        dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
        value = dataset_ops.make_one_shot_iterator(dataset).get_next()
        self.assertIsNone(gradients_impl.gradients(value, component)[0])
        self.assertIsNone(gradients_impl.gradients(value, side)[0])
        self.assertIsNone(
            gradients_impl.gradients(value, [component, side])[0])

    @combinations.generate(test_base.graph_only_combinations())
    def testCapturingStateInOneShotRaisesException(self):
        var = variables.Variable(37.0, name="myvar")
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [0.0, 1.0, 2.0]).map(lambda x: x + var))
        with self.assertRaisesRegex(
                ValueError,
                r"`Dataset.make_one_shot_iterator\(\)` does not support "
                "datasets that capture stateful objects.+myvar"):
            dataset_ops.make_one_shot_iterator(dataset)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIterator(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorCaptureByValue(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))
        tensor_components = tuple(
            [ops.convert_to_tensor(c) for c in components])

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(tensor_components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOneShotIteratorInsideContainer(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def within_container():
            def _map_fn(x, y, z):
                return math_ops.square(x), math_ops.square(y), math_ops.square(
                    z)

            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensor_slices(components).map(
                    _map_fn).repeat(14))
            return iterator.get_next()

        server = server_lib.Server.create_local_server()

        # Create two iterators within unique containers, and run them to
        # make sure that the resources aren't shared.
        #
        # The test below would fail if cname were the same across both
        # sessions.
        for j in range(2):
            with session.Session(server.target) as sess:
                cname = "iteration%d" % j
                with ops.container(cname):
                    get_next = within_container()

                for _ in range(14):
                    for i in range(7):
                        result = sess.run(get_next)
                        for component, result_component in zip(
                                components, result):
                            self.assertAllEqual(component[i]**2,
                                                result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorNonBlocking(self):
        dataset = dataset_ops.Dataset.from_tensors([1, 2,
                                                    3]).map(lambda x: x * x)
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        # Create a session with a single thread to ensure that the
        # one-shot iterator initializer does not deadlock.
        config = config_pb2.ConfigProto(inter_op_parallelism_threads=1,
                                        use_per_session_threads=True)
        with session.Session(config=config) as sess:
            self.assertAllEqual([1, 4, 9], sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

        # Test with multiple threads invoking the one-shot iterator concurrently.
        with session.Session(config=config) as sess:
            results = []

            def consumer_thread():
                try:
                    results.append(sess.run(next_element))
                except errors.OutOfRangeError:
                    results.append(None)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

            self.assertLen(results, num_threads)
            self.assertLen([None for r in results if r is None],
                           num_threads - 1)
            self.assertAllEqual([[1, 4, 9]],
                                [r for r in results if r is not None])

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorInitializerFails(self):
        # Define a dataset whose initialization will always fail.
        dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                sess.run(next_element)

            # Test that subsequent attempts to use the iterator also fail.
            with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                sess.run(next_element)

        with self.cached_session() as sess:

            def consumer_thread():
                with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                    sess.run(next_element)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

    @combinations.generate(test_base.graph_only_combinations())
    def testSimpleSharedResource(self):
        components = (np.array(1, dtype=np.int64),
                      np.array([1, 2, 3],
                               dtype=np.int64), np.array(37.0,
                                                         dtype=np.float64))

        server = server_lib.Server.create_local_server()

        # Create two non-overlapping sessions that share the same iterator
        # resource on the same server, and verify that an action of the
        # first session (initializing the iterator) is visible in the
        # second session.
        with ops.Graph().as_default():
            iterator = dataset_ops.make_initializable_iterator(
                dataset_ops.Dataset.from_tensors(components).map(
                    lambda x, y, z: (x, y, z)),
                shared_name="shared_iterator")
            init_op = iterator.initializer
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                sess.run(init_op)
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

                # Re-initialize the iterator in the first session.
                sess.run(init_op)

        with ops.Graph().as_default():
            # Re-define the iterator manually, without defining any of the
            # functions in this graph, to ensure that we are not
            # accidentally redefining functions with the same names in the
            # new graph.
            iterator = iterator_ops.Iterator.from_structure(
                shared_name="shared_iterator",
                output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
                output_shapes=([], [3], []))
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                # Use the iterator without re-initializing in the second session.
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testNotInitializedError(self):
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
        iterator = dataset_ops.make_initializable_iterator(
            dataset_ops.Dataset.from_tensors(components))
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegex(errors.FailedPreconditionError,
                                        "iterator has not been initialized"):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testReinitializableIterator(self):
        dataset_3 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([1, 2, 3]))
        dataset_4 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([4, 5, 6, 7]))
        iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset_3), [None])

        dataset_3_init_op = iterator.make_initializer(dataset_3)
        dataset_4_init_op = iterator.make_initializer(dataset_4)
        get_next = iterator.get_next()

        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_3),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_4),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(
            [None],
            dataset_ops.get_legacy_output_shapes(iterator).as_list())

        with self.cached_session() as sess:
            # The iterator is initially uninitialized.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(get_next)

            # Initialize with one dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Initialize with a different dataset.
            sess.run(dataset_4_init_op)
            self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Reinitialize with the first dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testReinitializableIteratorWithFunctions(self):
        def g():
            for i in range(10):
                yield i

        iterator = iterator_ops.Iterator.from_structure(dtypes.int64, [])
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            dataset_1 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_1))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

            dataset_2 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_2))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

    @combinations.generate(test_base.default_test_combinations())
    def testReinitializableIteratorStaticErrors(self):
        # Non-matching structure for types and shapes.
        with self.assertRaises(TypeError):
            iterator = iterator_ops.Iterator.from_structure(
                (dtypes.int64, dtypes.float64), [None])

        # Test validation of dataset argument.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64))

        # Incompatible structure.
        with self.assertRaises(ValueError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    ((constant_op.constant([1, 2, 3], dtype=dtypes.int64), ),
                     (constant_op.constant([4., 5., 6., 7.],
                                           dtype=dtypes.float64), ))))

        # Incompatible types.
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float32))))

        # Incompatible shapes.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64), ([None], []))
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float64))))

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandle(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleFuture(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleReuseTensorObject(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
        initializable_iterator = dataset_ops.make_initializable_iterator(
            dataset)
        structure_iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset))

        created_ops = len(ops.get_default_graph().get_operations())

        self.assertIs(one_shot_iterator.string_handle(),
                      one_shot_iterator.string_handle())
        self.assertIs(initializable_iterator.string_handle(),
                      initializable_iterator.string_handle())
        self.assertIs(structure_iterator.string_handle(),
                      structure_iterator.string_handle())

        # Assert that getting the (default) string handle creates no ops.
        self.assertEqual(created_ops,
                         len(ops.get_default_graph().get_operations()))

        # Specifying an explicit name will create a new op.
        handle_with_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo", handle_with_name.op.name)
        self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)

        handle_with_same_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo_1", handle_with_same_name.op.name)
        self.assertIsNot(handle_with_name, handle_with_same_name)

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleError(self):
        dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices(
            [1, 2, 3]).repeat())
        dataset_float_vector = (dataset_ops.Dataset.from_tensors(
            [1.0, 2.0, 3.0]))

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])

        feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [])
        feedable_int_vector = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [None])
        feedable_int_any = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32)

        with self.cached_session() as sess:
            handle_int_scalar = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_int_scalar).string_handle())
            handle_float_vector = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_float_vector).string_handle())

            self.assertEqual(
                1,
                sess.run(feedable_int_scalar.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            self.assertEqual(
                2,
                sess.run(feedable_int_any.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(feedable_int_vector.get_next(),
                             feed_dict={handle_placeholder:
                                        handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(
                        feedable_int_vector.get_next(),
                        feed_dict={handle_placeholder: handle_float_vector}))

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 3

        with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        @function.Defun(dtypes.string)
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            remote_op = functional_ops.remote_call(args=[iterator_3_handle],
                                                   Tout=[dtypes.int32],
                                                   f=_remote_fn,
                                                   target=target_placeholder)

        with self.session(config=worker_config) as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [1])
            # Fails when target is cpu:2 where the resource is not located.
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:2"
                         })
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:1"
                         })

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
        s1 = server_lib.Server.create_local_server()
        s2 = server_lib.Server.create_local_server()
        s3 = server_lib.Server.create_local_server()

        cluster_def = cluster_pb2.ClusterDef()
        workers = cluster_def.job.add()
        workers.name = "worker"
        workers.tasks[0] = s1.target[len("grpc://"):]
        workers.tasks[1] = s2.target[len("grpc://"):]
        client = cluster_def.job.add()
        client.name = "client"
        client.tasks[0] = s3.target[len("grpc://"):]
        config = config_pb2.ConfigProto(cluster_def=cluster_def)

        worker_devices = [
            "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
        ]
        itr_handles = []
        for device in worker_devices:
            with ops.device(device):
                src = dataset_ops.Dataset.from_tensor_slices([device])
                itr = dataset_ops.make_one_shot_iterator(src)
                itr_handles.append(itr.string_handle())

        targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
        handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)

        @function.Defun(dtypes.string)
        def loading_func(h):
            remote_itr = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(itr),
                dataset_ops.get_legacy_output_shapes(itr))
            return remote_itr.get_next()

        def map_fn(target, handle):
            return functional_ops.remote_call(args=[handle],
                                              Tout=[dtypes.string],
                                              f=loading_func,
                                              target=target)

        with ops.device("/job:client"):
            client_dataset = dataset_ops.Dataset.zip(
                (targets, handles)).map(map_fn)
            itr = dataset_ops.make_initializable_iterator(client_dataset)
            n = itr.get_next()

        with session.Session(s3.target, config=config) as sess:
            sess.run(itr.initializer)
            expected_values = worker_devices
            for expected in expected_values:
                self.assertEqual((compat.as_bytes(expected), ), sess.run(n))

            with self.assertRaises(errors.OutOfRangeError):
                sess.run(n)

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        def _encode_raw(byte_array):
            return bytes(bytearray(byte_array))

        @function.Defun(dtypes.uint8)
        def _remote_fn(h):
            handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            iterator_3_handle_uint8 = parsing_ops.decode_raw(
                input_bytes=iterator_3_handle, out_type=dtypes.uint8)
            remote_op = functional_ops.remote_call(
                args=[iterator_3_handle_uint8],
                Tout=[dtypes.int32],
                f=_remote_fn,
                target=target_placeholder)

        with self.cached_session() as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [1])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:0"
                         })

    @combinations.generate(test_base.graph_only_combinations())
    def testRepeatedGetNextWarning(self):
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.range(10))
        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            for _ in range(100):
                iterator.get_next()
        self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD,
                         len(w))
        for warning in w:
            self.assertIn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE,
                          str(warning.message))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=tensor_spec.TensorSpec(
                    [], dtypes.float32),
                expected_output_classes=ops.Tensor,
                expected_output_types=dtypes.float32,
                expected_output_shapes=[[]])))
    def testTensorIteratorStructure(self, expected_element_structure,
                                    expected_output_classes,
                                    expected_output_types,
                                    expected_output_shapes):
        tf_value_fn = lambda: constant_op.constant(37.0)
        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=sparse_tensor.SparseTensorSpec(
                    [1], dtypes.int32),
                expected_output_classes=sparse_tensor.SparseTensor,
                expected_output_types=dtypes.int32,
                expected_output_shapes=[[1]])))
    def testSparseTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return sparse_tensor.SparseTensor(indices=[[0]],
                                              values=constant_op.constant(
                                                  [0], dtype=dtypes.int32),
                                              dense_shape=[1])

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(expected_element_structure={
                "a":
                tensor_spec.TensorSpec([], dtypes.float32),
                "b": (tensor_spec.TensorSpec([1], dtypes.string),
                      tensor_spec.TensorSpec([], dtypes.string))
            },
                                 expected_output_classes={
                                     "a": ops.Tensor,
                                     "b": (ops.Tensor, ops.Tensor)
                                 },
                                 expected_output_types={
                                     "a": dtypes.float32,
                                     "b": (dtypes.string, dtypes.string)
                                 },
                                 expected_output_shapes={
                                     "a": [],
                                     "b": ([1], [])
                                 })))
    def testNestedTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorGetNextName(self):
        with ops.Graph().as_default():
            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensors(37.0))
            next_element = iterator.get_next(name="overridden_name")
            self.assertEqual("overridden_name", next_element.op.name)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2],
                             mode="eager",
                             execution_mode=[context.ASYNC, context.SYNC]))
    def testIteratorEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for foo in iterator:
                self.assertEqual(val, foo.numpy())
                val += 1

    @combinations.generate(test_base.eager_only_combinations())
    def testOwnedIteratorFunction(self):

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)

        @def_function.function
        def fn():
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for _ in range(10):
                queue.enqueue(next(iterator))

        fn()

        for i in range(10):
            self.assertEqual(queue.dequeue().numpy(), i)

    @combinations.generate(test_base.eager_only_combinations())
    def testOwnedIteratorFunctionError(self):
        # In this test we verify that a function that raises an error ends up
        # properly deallocating the iterator resource.

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
        queue.enqueue(0)

        def init_fn(n):
            return n

        def next_fn(_):
            ds = dataset_ops.Dataset.range(0)
            return next(iter(ds))

        def finalize_fn(n):
            queue.enqueue(0)
            return n

        @def_function.function
        def fn():
            output_signature = tensor_spec.TensorSpec((), dtypes.int64)
            dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn,
                                                    finalize_fn,
                                                    output_signature)
            iterator = iter(dataset)
            next(iterator)

        with self.assertRaises(errors.OutOfRangeError):
            fn()

        self.assertEqual(queue.size().numpy(), 2)

    @combinations.generate(test_base.eager_only_combinations())
    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(iterator):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in iterator:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(iter(dataset))), 10)
            self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
            self.assertEqual(trace_count[0], 1)

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedFunctionsIteratorResource(self):
        @def_function.function
        def sum_dataset(ds):
            it = iter(ds)

            @def_function.function
            def next_element(it):
                return next(it)

            total = 0
            for _ in range(10):
                total += next_element(it)
            return total

        ds = dataset_ops.Dataset.range(10)
        self.assertEqual(sum_dataset(ds).numpy(), 45)
        self.assertEqual(sum_dataset(ds).numpy(), 45)

    @combinations.generate(test_base.default_test_combinations())
    def testNestedAutomaticControlDependencies(self):
        counter_var = variables.Variable(0)

        def map_fn(x):
            counter_var.assign_add(1)
            return x

        def dataset_fn():
            return dataset_ops.Dataset.range(10).map(map_fn)

        @def_function.function
        def fn():
            it = iter(dataset_fn())
            for _ in range(10):
                _ = next(it)
            return counter_var

        self.evaluate(counter_var.initializer)
        self.assertEqual(self.evaluate(fn()), 10)
Example #16
0
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_parallel_calls=[None, 1, 2],
                                 num_parallel_batches=None) +
            combinations.combine(num_parallel_calls=None,
                                 num_parallel_batches=10)))
    def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
        """Test a dataset that maps a TF function across its input elements."""
        # The pipeline is TensorSliceDataset ->
        # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        def dataset_fn(batch_size, count):
            dataset = dataset_ops.Dataset.from_tensor_slices(
                components).repeat(count).apply(
                    batching.map_and_batch(
                        map_func=_map_fn,
                        batch_size=batch_size,
                        num_parallel_calls=num_parallel_calls,
                        num_parallel_batches=num_parallel_batches))
            return dataset

        # Batch of a finite input, where the batch_size divides the
        # total number of elements.
        dataset = dataset_fn(14, 28)
        get_next = self.getNext(dataset)
        self.assertEqual([[None] + list(c.shape[1:]) for c in components], [
            shape.as_list()
            for shape in dataset_ops.get_legacy_output_shapes(dataset)
        ])
        num_batches = (28 * 7) // 14
        for i in range(num_batches):
            result = self.evaluate(get_next())
            for component, result_component in zip(components, result):
                for j in range(14):
                    self.assertAllEqual(component[(i * 14 + j) % 7]**2,
                                        result_component[j])
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

        # Batch of a finite input, where the batch_size does not
        # divide the total number of elements.
        get_next = self.getNext(dataset_fn(8, 14))

        # We expect (num_batches - 1) full-sized batches.
        num_batches = int(math.ceil((14 * 7) / 8))
        for i in range(num_batches - 1):
            result = self.evaluate(get_next())
            for component, result_component in zip(components, result):
                for j in range(8):
                    self.assertAllEqual(component[(i * 8 + j) % 7]**2,
                                        result_component[j])

        result = self.evaluate(get_next())
        for component, result_component in zip(components, result):
            for j in range((14 * 7) % 8):
                self.assertAllEqual(
                    component[((num_batches - 1) * 8 + j) % 7]**2,
                    result_component[j])
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

        # Batch of an empty input should fail straight away.
        self.assertDatasetProduces(dataset_fn(8, 0), expected_output=[])

        # Empty batch should be an initialization time error.
        with self.assertRaises(errors.InvalidArgumentError):
            self.assertDatasetProduces(dataset_fn(0, 14), expected_output=[])

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testMapAndBatchPartialBatch(self, drop_remainder):
        dataset = (dataset_ops.Dataset.range(10).apply(
            batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]),
                                   batch_size=4,
                                   drop_remainder=drop_remainder)))

        if drop_remainder:
            self.assertEqual(
                [4, 1],
                dataset_ops.get_legacy_output_shapes(dataset).as_list())
        else:
            self.assertEqual(
                [None, 1],
                dataset_ops.get_legacy_output_shapes(dataset).as_list())
        expected_output = [[[0], [1], [4], [9]], [[16], [25], [36], [49]]]
        if not drop_remainder:
            expected_output.append([[64], [81]])
        self.assertDatasetProduces(dataset, expected_output=expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchYieldsPartialBatch(self):
        dataset = (dataset_ops.Dataset.range(10).apply(
            batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]),
                                   4)))

        self.assertEqual(
            [None, 1],
            dataset_ops.get_legacy_output_shapes(dataset).as_list())
        expected_output = [[[0], [1], [4], [9]], [[16], [25], [36], [49]],
                           [[64], [81]]]
        self.assertDatasetProduces(dataset, expected_output=expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchParallelGetNext(self):
        dataset = dataset_ops.Dataset.range(50000).apply(
            batching.map_and_batch(lambda x: x, batch_size=100))

        if context.executing_eagerly():
            iterator = iter(dataset)
            get_next = iterator._next_internal  # pylint: disable=protected-access
        else:
            iterator = dataset_ops.make_one_shot_iterator(dataset)
            get_next = iterator.get_next

        elements = []
        for _ in range(100):
            elements.append(get_next)

        for i in range(5):
            got = self.evaluate([element() for element in elements])
            got.sort(key=lambda x: x[0])
            expected = []
            for j in range(100):
                expected.append(
                    range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
            self.assertAllEqual(got, expected)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate([element() for element in elements])

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchParallelGetNextDropRemainder(self):
        dataset = dataset_ops.Dataset.range(49999).apply(
            batching.map_and_batch(lambda x: x,
                                   batch_size=100,
                                   drop_remainder=True))

        if context.executing_eagerly():
            iterator = iter(dataset)
            get_next = iterator._next_internal  # pylint: disable=protected-access
        else:
            iterator = dataset_ops.make_one_shot_iterator(dataset)
            get_next = iterator.get_next

        elements = []
        for _ in range(100):
            elements.append(get_next)

        for i in range(4):
            got = self.evaluate([element() for element in elements])
            got.sort(key=lambda x: x[0])
            expected = []
            for j in range(100):
                expected.append(
                    range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
            self.assertAllEqual(got, expected)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate([element() for element in elements])

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchSparse(self):
        def _sparse(i):
            return sparse_tensor.SparseTensorValue(indices=[[0]],
                                                   values=(i * [1]),
                                                   dense_shape=[1])

        dataset = dataset_ops.Dataset.range(10).apply(
            batching.map_and_batch(_sparse, 5))

        self.assertDatasetProduces(
            dataset,
            expected_output=[
                sparse_tensor.SparseTensorValue(
                    indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
                    values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
                    dense_shape=[5, 1]) for i in range(2)
            ])

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchFails(self):
        """Test a dataset that maps a TF function across its input elements."""

        with self.assertRaisesRegex(errors.InvalidArgumentError, "oops"):
            dataset = dataset_ops.Dataset.from_tensors(
                array_ops.check_numerics(
                    constant_op.constant(1.0) / constant_op.constant(0.0),
                    "oops"))
            dataset = dataset.apply(batching.map_and_batch(lambda x: x, 14))
            get_next = self.getNext(dataset, requires_initialization=True)
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchShapeMismatch(self):
        """Test a dataset that maps a TF function across its input elements."""
        def generator():
            yield [1]
            yield [2]
            yield [3]
            yield [[4, 5, 6]]

        dataset = dataset_ops.Dataset.from_generator(generator,
                                                     output_types=dtypes.int32)
        batch_size = 4
        dataset = dataset.apply(batching.map_and_batch(lambda x: x,
                                                       batch_size))
        self.assertDatasetProduces(
            dataset,
            expected_error=(errors.InvalidArgumentError,
                            "number of elements does not match"))

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchImplicitDispose(self):
        # Tests whether a map and batch dataset will be cleaned up correctly when
        # the pipeline does not run it until exhaustion.
        # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
        # MapAndBatchDataset(f=square_3, batch_size=100).
        components = (np.arange(1000),
                      np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
                      np.array(37.0) * np.arange(1000))

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
            1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
        dataset = dataset.prefetch(5)
        get_next = self.getNext(dataset)
        for _ in range(3):
            self.evaluate(get_next())

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(threshold=[0, 5, 10, 90, 95, 99])))
    def testMapAndBatchMapError(self, threshold):
        def raising_py_fn(i):
            if i >= threshold:
                raise StopIteration()
            else:
                return i

        dataset = dataset_ops.Dataset.range(100).apply(
            batching.map_and_batch(
                lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
                batch_size=10))

        get_next = self.getNext(dataset)
        for i in range(threshold // 10):
            self.assertAllEqual([i * 10 + j for j in range(10)],
                                self.evaluate(get_next()))
        for i in range(threshold // 10, 10):
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(element=False, dtype=dtypes.bool) +
            combinations.combine(
                element=-42,
                dtype=[dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64
                       ]) + combinations.combine(
                           element=42, dtype=[dtypes.uint8, dtypes.uint16]) +
            combinations.combine(
                element=42.0,
                dtype=[dtypes.float16, dtypes.float32, dtypes.float64]) +
            combinations.combine(element=b"hello", dtype=[dtypes.string])))
    def testMapAndBatchTypes(self, element, dtype):
        def gen():
            yield element

        dataset = dataset_ops.Dataset.from_generator(
            gen, dtype).repeat(100).apply(
                batching.map_and_batch(lambda x: x, batch_size=10))

        get_next = self.getNext(dataset)
        for _ in range(10):
            self.assertAllEqual([element for _ in range(10)],
                                self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testShortCircuitIdentity(self):
        map_fn = lambda x: x
        dataset = self.structuredDataset(None).repeat().apply(
            batching.map_and_batch(map_fn, batch_size=10))
        get_next = self.getNext(dataset)
        expected = map_fn(
            self.evaluate(self.structuredElement(None, shape=[10])))
        self.assertAllEqual(expected, self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testShortCircuitReplicate(self):
        map_fn = lambda x: (x, x)
        dataset = self.structuredDataset(None).repeat().apply(
            batching.map_and_batch(map_fn, batch_size=10))
        get_next = self.getNext(dataset)
        expected = map_fn(
            self.evaluate(self.structuredElement(None, shape=[10])))
        self.assertAllEqual(expected, self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testShortCircuitSwap(self):
        map_fn = lambda x, y: (y, x)
        dataset = self.structuredDataset((None, None)).repeat().apply(
            batching.map_and_batch(map_fn, batch_size=10))
        get_next = self.getNext(dataset)
        expected = map_fn(
            *self.evaluate(self.structuredElement((None, None), shape=[10])))
        self.assertAllEqual(expected, self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testShortCircuitProject(self):
        map_fn = lambda x, y: x
        dataset = self.structuredDataset((None, None)).repeat().apply(
            batching.map_and_batch(map_fn, batch_size=10))
        get_next = self.getNext(dataset)
        expected = map_fn(
            *self.evaluate(self.structuredElement((None, None), shape=[10])))
        self.assertAllEqual(expected, self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testShortCircuitCapturedInput(self):
        captured_t = variables.Variable(42)
        dataset = self.structuredDataset(None).repeat().apply(
            batching.map_and_batch(lambda x: captured_t, batch_size=10))
        self.evaluate(variables.global_variables_initializer())
        get_next = self.getNext(dataset, requires_initialization=True)
        self.assertAllEqual([42] * 10, self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndBatchControlFlow(self):
        def map_fn(x):
            previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2
            control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
            return_value = control_flow_ops.cond(x < 50, lambda: x + 1,
                                                 lambda: x * x)
            control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value
            return return_value

        dataset = dataset_ops.Dataset.range(100).apply(
            batching.map_and_batch(map_fn, batch_size=10))
        get_next = self.getNext(dataset)
        for i in range(10):
            if i < 5:
                self.assertAllEqual([i * 10 + j + 1 for j in range(10)],
                                    self.evaluate(get_next()))
            else:
                self.assertAllEqual([((i * 10) + j) * ((i * 10) + j)
                                     for j in range(10)],
                                    self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
def keras_model_type_combinations():
    return combinations.combine(model_type=KERAS_MODEL_TYPES)
Example #18
0
class ParallelInterleaveTest(test_base.DatasetTestBase,
                             parameterized.TestCase):
    def setUp(self):

        self.error = None
        self.repeat_count = 2

        # Set up threading events used to sequence when items are produced that
        # are subsequently interleaved. These events allow us to deterministically
        # simulate slowdowns and force sloppiness.
        self.read_coordination_events = {}
        self.write_coordination_events = {}
        # input values [4, 5, 6] are the common case for the tests; set defaults
        for i in range(4, 7):
            self.read_coordination_events[i] = threading.Semaphore(0)
            self.write_coordination_events[i] = threading.Event()

    def dataset_fn(self, input_values, cycle_length, block_length, sloppy,
                   buffer_output_elements, prefetch_input_elements):
        def map_py_fn(x):
            self.write_coordination_events[x].wait()
            self.write_coordination_events[x].clear()
            self.read_coordination_events[x].release()
            if self.error:
                err = self.error
                self.error = None
                raise err  # pylint: disable=raising-bad-type
            return x * x

        def map_fn(x):
            return script_ops.py_func(map_py_fn, [x], x.dtype)

        def interleave_fn(x):
            dataset = dataset_ops.Dataset.from_tensors(x)
            dataset = dataset.repeat(x)
            return dataset.map(map_fn)

        return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
            self.repeat_count).apply(
                interleave_ops.parallel_interleave(interleave_fn, cycle_length,
                                                   block_length, sloppy,
                                                   buffer_output_elements,
                                                   prefetch_input_elements))

    def _interleave(self, lists, cycle_length, block_length):
        """Python implementation of interleave used for testing."""
        num_open = 0

        # `all_iterators` acts as a queue of iterators over each element of `lists`.
        all_iterators = [iter(l) for l in lists]

        # `open_iterators` are the iterators whose elements are currently being
        # interleaved.
        open_iterators = []
        for i in range(cycle_length):
            if all_iterators:
                open_iterators.append(all_iterators.pop(0))
                num_open += 1
            else:
                open_iterators.append(None)

        while num_open or all_iterators:
            for i in range(cycle_length):
                if open_iterators[i] is None:
                    if all_iterators:
                        open_iterators[i] = all_iterators.pop(0)
                        num_open += 1
                    else:
                        continue
                for _ in range(block_length):
                    try:
                        yield next(open_iterators[i])
                    except StopIteration:
                        open_iterators[i] = None
                        num_open -= 1
                        break

    @combinations.generate(
        combinations.times(
            combinations.combine(
                input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [
                    6, 6, 6, 6, 6, 6
                ], [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]],
                expected_elements=[[
                    4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 5,
                    5, 5, 5, 5, 6, 6, 6, 6, 6, 6
                ]],
                cycle_length=1,
                block_length=1) +
            combinations.combine(
                input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [
                    6, 6, 6, 6, 6, 6
                ], [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]],
                expected_elements=[[
                    4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
                    6, 5, 6, 5, 6, 5, 6, 5, 6, 6
                ]],
                cycle_length=2,
                block_length=1) + combinations.combine(
                    input_lists=[[[4] * 4, [5] * 5, [6] * 6] * 2],
                    expected_elements=[[
                        4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6,
                        6, 5, 5, 6, 6, 5, 5, 6, 6, 5, 6, 6
                    ]],
                    cycle_length=2,
                    block_length=2) +
            combinations.combine(
                input_lists=[[[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6],
                              [4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6]]],
                expected_elements=[[
                    4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
                ]],
                cycle_length=2,
                block_length=1)))
    def testPythonImplementation(self, input_lists, expected_elements,
                                 cycle_length, block_length):
        for index, (expected, produced) in enumerate(
                zip_longest(
                    expected_elements,
                    self._interleave(input_lists, cycle_length,
                                     block_length))):
            self.assertEqual(
                expected, produced,
                "Values differ at %s. %s != %s" % (index, expected, produced))

    def _clear_coordination_events(self):
        for i in range(4, 7):
            self.read_coordination_events[i] = threading.Semaphore(0)
            self.write_coordination_events[i].clear()

    def _allow_all_map_threads(self):
        for i in range(4, 7):
            self.write_coordination_events[i].set()

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sloppy=[False, True],
                                 prefetch_input_elements=[0, 1])))
    def testSingleThreaded(self, sloppy, prefetch_input_elements):
        # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
        # `Dataset.flat_map()` and is single-threaded. No synchronization required.
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=1,
                            block_length=1,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=prefetch_input_elements))
        for expected_element in self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                                 self.repeat_count, 1, 1):
            self.write_coordination_events[expected_element].set()
            self.assertEqual(expected_element * expected_element,
                             self.evaluate(next_element()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(test_base.default_test_combinations())
    def testSingleThreadedRagged(self):
        # Tests a sequence with wildly different elements per iterator.
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([3, 7, 4]),
                            cycle_length=2,
                            block_length=1,
                            sloppy=False,
                            buffer_output_elements=1,
                            prefetch_input_elements=1))

        # Add coordination values for 3 and 7
        self.read_coordination_events[3] = threading.Semaphore(0)
        self.write_coordination_events[3] = threading.Event()
        self.read_coordination_events[7] = threading.Semaphore(0)
        self.write_coordination_events[7] = threading.Event()

        for expected_element in self._interleave([[3] * 3, [7] * 7, [4] * 4] *
                                                 self.repeat_count, 2, 1):
            self.write_coordination_events[expected_element].set()
            output = self.evaluate(next_element())
            self.assertEqual(expected_element * expected_element, output)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def testTwoThreadsNoContention(self, sloppy):
        # num_threads > 1.
        # Explicit coordination should result in `Dataset.interleave()` behavior
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        done_first_event = False
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=2,
                            block_length=1,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=1))
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                 self.repeat_count, 2, 1)):
            self.write_coordination_events[expected_element].set()
            if done_first_event:  # First event starts the worker threads.
                self.read_coordination_events[expected_element].acquire()
            actual_element = self.evaluate(next_element())
            if not done_first_event:
                self.read_coordination_events[expected_element].acquire()
                done_first_event = True
            self.assertEqual(
                expected_element * expected_element, actual_element,
                "At index %s: %s expected, got: %s" %
                (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def testTwoThreadsNoContentionWithRaces(self, sloppy):
        """Tests where all the workers race in producing elements.

    Note: this is in contrast with the previous test which carefully sequences
    the execution of the map functions.

    Args:
      sloppy: Whether to be sloppy or not.
    """
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        done_first_event = False
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=2,
                            block_length=1,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=1))
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                 self.repeat_count, 2, 1)):
            if done_first_event:  # First event starts the worker threads.
                self._allow_all_map_threads()
                self.read_coordination_events[expected_element].acquire()
            else:
                self.write_coordination_events[expected_element].set()
            time.sleep(
                0.5)  # Sleep to consistently "avoid" the race condition.
            actual_element = self.evaluate(next_element())
            if not done_first_event:
                done_first_event = True
                self.assertTrue(
                    self.read_coordination_events[expected_element].acquire(
                        False))
            self.assertEqual(
                expected_element * expected_element, actual_element,
                "At index %s: %s expected, got: %s" %
                (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def testTwoThreadsNoContentionBlockLength(self, sloppy):
        # num_threads > 1.
        # Explicit coordination should result in `Dataset.interleave()` behavior
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        done_first_event = False
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=2,
                            block_length=2,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=1))
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                 self.repeat_count, 2, 2)):
            self.write_coordination_events[expected_element].set()
            if done_first_event:  # First event starts the worker threads.
                self.read_coordination_events[expected_element].acquire()
            actual_element = self.evaluate(next_element())
            if not done_first_event:
                done_first_event = True
                self.read_coordination_events[expected_element].acquire()
            self.assertEqual(
                expected_element * expected_element, actual_element,
                "At index %s: %s expected, got: %s" %
                (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy):
        """Tests where all the workers race in producing elements.

    Note: this is in contrast with the previous test which carefully sequences
    the execution of the map functions.


    Args:
      sloppy: Whether to be sloppy or not.
    """
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        done_first_event = False
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=2,
                            block_length=2,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=1))
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                 self.repeat_count, 2, 2)):
            if done_first_event:  # First event starts the worker threads.
                self._allow_all_map_threads()
                self.read_coordination_events[expected_element].acquire()
            else:
                self.write_coordination_events[expected_element].set()
            time.sleep(
                0.5)  # Sleep to consistently "avoid" the race condition.
            actual_element = self.evaluate(next_element())
            if not done_first_event:
                done_first_event = True
                self.assertTrue(
                    self.read_coordination_events[expected_element].acquire(
                        False))
            self.assertEqual(
                expected_element * expected_element, actual_element,
                "At index %s: %s expected, got: %s" %
                (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def testEmptyInput(self, sloppy):
        # Empty input.
        self._clear_coordination_events()
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([]),
                            cycle_length=2,
                            block_length=3,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=0))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def _testNonEmptyInputIntoEmptyOutputs(self, sloppy):
        # Non-empty input leading to empty output.
        self._clear_coordination_events()
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([0, 0, 0]),
                            cycle_length=2,
                            block_length=3,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=0))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sloppy=[False, True],
                                 prefetch_input_elements=[1, 0])))
    def testPartiallyEmptyOutputs(self, sloppy, prefetch_input_elements):
        race_indices = {2, 8,
                        14}  # Sequence points when sloppy mode has race conds
        # Mixture of non-empty and empty interleaved datasets.
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        done_first_event = False
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 0, 6]),
                            cycle_length=2,
                            block_length=1,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=prefetch_input_elements))
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2,
                                 1)):
            self.write_coordination_events[expected_element].set()
            # First event starts the worker threads. Additionally, when running the
            # sloppy case with prefetch_input_elements=0, we get stuck if we wait
            # for the read coordination event for certain event orderings in the
            # presence of finishing iterators.
            if done_first_event and not (sloppy and (i in race_indices)):
                self.read_coordination_events[expected_element].acquire()
            actual_element = self.evaluate(next_element())
            if not done_first_event or (sloppy and (i in race_indices)):
                done_first_event = True
                self.read_coordination_events[expected_element].acquire()
            self.assertEqual(
                expected_element * expected_element, actual_element,
                "At index %s: %s expected, got: %s" %
                (i, expected_element, actual_element))

    @combinations.generate(test_base.default_test_combinations())
    def testDelayedOutputSloppy(self):
        # Explicitly control the sequence of events to ensure we correctly avoid
        # head-of-line blocking.
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=2,
                            block_length=1,
                            sloppy=True,
                            buffer_output_elements=1,
                            prefetch_input_elements=0))

        mis_ordering = [
            4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6,
            6, 6, 5, 5, 5, 5, 6, 6
        ]
        for element in mis_ordering:
            self.write_coordination_events[element].set()
            self.assertEqual(element * element, self.evaluate(next_element()))
            self.assertTrue(
                self.read_coordination_events[element].acquire(False))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(test_base.default_test_combinations())
    def testBlockLengthWithContentionSloppy(self):
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        done_first_event = False
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=2,
                            block_length=1,
                            sloppy=True,
                            buffer_output_elements=1,
                            prefetch_input_elements=1))
        # Test against a generating sequence that differs from the uncontended
        # case, in order to prove sloppy correctness.
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                 self.repeat_count,
                                 cycle_length=2,
                                 block_length=3)):
            self.write_coordination_events[expected_element].set()
            if done_first_event:  # First event starts the worker threads.
                self.read_coordination_events[expected_element].acquire()
            actual_element = self.evaluate(next_element())
            if not done_first_event:
                self.read_coordination_events[expected_element].acquire()
                done_first_event = True
            self.assertEqual(
                expected_element * expected_element, actual_element,
                "At index %s: %s expected, got: %s" %
                (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def testEarlyExit(self, sloppy):
        # Exiting without consuming all input should not block
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=3,
                            block_length=2,
                            sloppy=sloppy,
                            buffer_output_elements=1,
                            prefetch_input_elements=0))
        for i in range(4, 7):
            self.write_coordination_events[i].set()
        elem = self.evaluate(next_element())  # Start all workers
        # Allow the one successful worker to progress beyond the py_func again.
        elem = int(math.sqrt(elem))
        self.write_coordination_events[elem].set()
        self.read_coordination_events[elem].acquire()
        # Allow the prefetch to succeed
        for i in range(4, 7):
            self.read_coordination_events[i].acquire()
            self.write_coordination_events[i].set()

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(sloppy=[False, True])))
    def testTooManyReaders(self, sloppy=False):
        def interleave_fn(x):
            dataset = dataset_ops.Dataset.from_tensors(x)
            dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
            return dataset

        dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
        dataset = dataset.repeat(self.repeat_count)
        dataset = dataset.apply(
            interleave_ops.parallel_interleave(interleave_fn,
                                               cycle_length=16,
                                               block_length=2,
                                               sloppy=sloppy))
        get_next = self.getNext(dataset)
        output_values = []
        for _ in range(30):
            output_values.append(self.evaluate(get_next()))

        expected_values = self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                           self.repeat_count, 1, 2)
        self.assertCountEqual(output_values, expected_values)

    @combinations.generate(test_base.default_test_combinations())
    def testSparse(self):
        def _map_fn(i):
            return sparse_tensor.SparseTensor(indices=[[0, 0], [1, 1]],
                                              values=(i * [1, -1]),
                                              dense_shape=[2, 2])

        def _interleave_fn(x):
            return dataset_ops.Dataset.from_tensor_slices(
                sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))

        dataset = dataset_ops.Dataset.range(10).map(_map_fn).apply(
            interleave_ops.parallel_interleave(_interleave_fn, cycle_length=1))
        get_next = self.getNext(dataset)

        for i in range(10):
            for j in range(2):
                expected = [i, 0] if j % 2 == 0 else [0, -i]
                self.assertAllEqual(expected, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testErrorsInOutputFn(self):
        self.skipTest("b/131722904")
        self._clear_coordination_events()
        next_element = self.getNext(
            self.dataset_fn(input_values=np.int64([4, 5, 6]),
                            cycle_length=2,
                            block_length=1,
                            sloppy=False,
                            buffer_output_elements=1,
                            prefetch_input_elements=0))

        except_on_element_indices = set([3])

        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5] * 5, [6] * 6] *
                                 self.repeat_count, 2, 1)):
            if i in except_on_element_indices:
                self.error = ValueError()
                self.write_coordination_events[expected_element].set()
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(next_element())
            else:
                self.write_coordination_events[expected_element].set()
                actual_element = self.evaluate(next_element())
                self.assertEqual(
                    expected_element * expected_element, actual_element,
                    "At index %s: %s expected, got: %s" %
                    (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(test_base.default_test_combinations())
    def testErrorsInInputFn(self):
        def map_py_fn(x):
            if x == 5:
                raise ValueError()
            return x

        def map_fn(x):
            return script_ops.py_func(map_py_fn, [x], x.dtype)

        def interleave_fn(x):
            dataset = dataset_ops.Dataset.from_tensors(x)
            dataset = dataset.repeat(x)
            return dataset

        def dataset_fn(input_values, cycle_length, block_length, sloppy,
                       buffer_output_elements, prefetch_input_elements):
            return dataset_ops.Dataset.from_tensor_slices(input_values).map(
                map_fn).repeat(self.repeat_count).apply(
                    interleave_ops.parallel_interleave(
                        interleave_fn, cycle_length, block_length, sloppy,
                        buffer_output_elements, prefetch_input_elements))

        next_element = self.getNext(
            dataset_fn(input_values=np.int64([4, 5, 6]),
                       cycle_length=2,
                       block_length=1,
                       sloppy=False,
                       buffer_output_elements=1,
                       prefetch_input_elements=0))
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count,
                                 2, 1)):
            if expected_element == 5:
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(next_element())
            else:
                actual_element = self.evaluate(next_element())
                self.assertEqual(
                    expected_element, actual_element,
                    "At index %s: %s expected, got: %s" %
                    (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(test_base.default_test_combinations())
    def testErrorsInInterleaveFn(self):
        def map_py_fn(x):
            if x == 5:
                raise ValueError()
            return x

        def interleave_fn(x):
            dataset = dataset_ops.Dataset.from_tensors(x)
            y = script_ops.py_func(map_py_fn, [x], x.dtype)
            dataset = dataset.repeat(y)
            return dataset

        def dataset_fn(input_values, cycle_length, block_length, sloppy,
                       buffer_output_elements, prefetch_input_elements):
            return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
                self.repeat_count).apply(
                    interleave_ops.parallel_interleave(
                        interleave_fn, cycle_length, block_length, sloppy,
                        buffer_output_elements, prefetch_input_elements))

        next_element = self.getNext(
            dataset_fn(input_values=np.int64([4, 5, 6]),
                       cycle_length=2,
                       block_length=1,
                       sloppy=False,
                       buffer_output_elements=1,
                       prefetch_input_elements=0))
        for i, expected_element in enumerate(
                self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count,
                                 2, 1)):
            if expected_element == 5:
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(next_element())
            else:
                actual_element = self.evaluate(next_element())
                self.assertEqual(
                    expected_element, actual_element,
                    "At index %s: %s expected, got: %s" %
                    (i, expected_element, actual_element))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(test_base.default_test_combinations())
    def testShutdownRace(self):
        dataset = dataset_ops.Dataset.range(20)
        map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
        dataset = dataset.apply(
            interleave_ops.parallel_interleave(map_fn,
                                               cycle_length=3,
                                               sloppy=False,
                                               buffer_output_elements=1,
                                               prefetch_input_elements=0))
        dataset = dataset.batch(32)

        results = []
        for _ in range(2):
            elements = []
            next_element = self.getNext(dataset)
            try:
                while True:
                    elements.extend(self.evaluate(next_element()))
            except errors.OutOfRangeError:
                pass
            results.append(elements)
        self.assertAllEqual(results[0], results[1])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sloppy=[None, True, False],
                                 global_determinism=[True, False])))
    def testDeterminismConfiguration(self, sloppy, global_determinism):
        if sloppy is None:
            expect_determinism = global_determinism
        else:
            expect_determinism = not sloppy
        elements = list(range(1000))

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            dataset = dataset_ops.Dataset.from_tensor_slices(elements)
            dataset = dataset.apply(
                interleave_ops.parallel_interleave(interleave_fn,
                                                   cycle_length=10,
                                                   sloppy=sloppy))

            opts = dataset_ops.Options()
            opts.experimental_deterministic = global_determinism
            dataset = dataset.with_options(opts)
            return dataset

        self.checkDeterminism(dataset_fn, expect_determinism, elements)
class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(initial_known=[True, False])))
    def testDistribution(self, initial_known):
        classes = np.random.randint(5, size=(10000, ))  # Uniformly sampled
        target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
        initial_dist = [0.2] * 5 if initial_known else None
        classes = math_ops.cast(classes,
                                dtypes.int64)  # needed for Windows build.
        dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
            200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()

        get_next = self.getNext(
            dataset.rejection_resample(target_dist=target_dist,
                                       initial_dist=initial_dist,
                                       class_func=lambda c, _: c,
                                       seed=27))

        returned = []
        while len(returned) < 2000:
            returned.append(self.evaluate(get_next()))

        returned_classes, returned_classes_and_data = zip(*returned)
        _, returned_data = zip(*returned_classes_and_data)
        self.assertAllEqual(
            [compat.as_bytes(str(c)) for c in returned_classes], returned_data)
        total_returned = len(returned_classes)
        class_counts = np.array([
            len([True for v in returned_classes if v == c]) for c in range(5)
        ])
        returned_dist = class_counts / total_returned
        self.assertAllClose(target_dist, returned_dist, atol=1e-2)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(only_initial_dist=[True, False])))
    def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
        init_dist = [0.5, 0.5]
        target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
        num_classes = len(init_dist)
        # We don't need many samples to test that this works.
        num_samples = 100
        data_np = np.random.choice(num_classes, num_samples, p=init_dist)

        dataset = dataset_ops.Dataset.from_tensor_slices(data_np)

        # Reshape distribution.
        dataset = dataset.rejection_resample(class_func=lambda x: x,
                                             target_dist=target_dist,
                                             initial_dist=init_dist)

        get_next = self.getNext(dataset)

        returned = []
        with self.assertRaises(errors.OutOfRangeError):
            while True:
                returned.append(self.evaluate(get_next()))

    @combinations.generate(test_base.default_test_combinations())
    def testRandomClasses(self):
        init_dist = [0.25, 0.25, 0.25, 0.25]
        target_dist = [0.0, 0.0, 0.0, 1.0]
        num_classes = len(init_dist)
        # We don't need many samples to test a dirac-delta target distribution.
        num_samples = 100
        data_np = np.random.choice(num_classes, num_samples, p=init_dist)

        dataset = dataset_ops.Dataset.from_tensor_slices(data_np)

        # Apply a random mapping that preserves the data distribution.
        def _remap_fn(_):
            return math_ops.cast(
                random_ops.random_uniform([1]) * num_classes, dtypes.int32)[0]

        dataset = dataset.map(_remap_fn)

        # Reshape distribution.
        dataset = dataset.rejection_resample(class_func=lambda x: x,
                                             target_dist=target_dist,
                                             initial_dist=init_dist)

        get_next = self.getNext(dataset)

        returned = []
        with self.assertRaises(errors.OutOfRangeError):
            while True:
                returned.append(self.evaluate(get_next()))

        classes, _ = zip(*returned)
        bincount = np.bincount(np.array(classes),
                               minlength=num_classes).astype(
                                   np.float32) / len(classes)

        self.assertAllClose(target_dist, bincount, atol=1e-2)

    @combinations.generate(test_base.default_test_combinations())
    def testExhaustion(self):
        init_dist = [0.5, 0.5]
        target_dist = [0.9, 0.1]
        dataset = dataset_ops.Dataset.range(10000)
        dataset = dataset.rejection_resample(class_func=lambda x: x % 2,
                                             target_dist=target_dist,
                                             initial_dist=init_dist)

        get_next = self.getNext(dataset)
        returned = []
        with self.assertRaises(errors.OutOfRangeError):
            while True:
                returned.append(self.evaluate(get_next()))

        classes, _ = zip(*returned)
        bincount = np.bincount(np.array(classes),
                               minlength=len(init_dist)).astype(
                                   np.float32) / len(classes)

        self.assertAllClose(target_dist, bincount, atol=1e-2)

    @parameterized.parameters(
        ("float32", "float64"),
        ("float64", "float32"),
        ("float64", "float64"),
        ("float64", None),
    )
    def testOtherDtypes(self, target_dtype, init_dtype):
        target_dist = np.array([0.5, 0.5], dtype=target_dtype)

        if init_dtype is None:
            init_dist = None
        else:
            init_dist = np.array([0.5, 0.5], dtype=init_dtype)

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.rejection_resample(class_func=lambda x: x % 2,
                                             target_dist=target_dist,
                                             initial_dist=init_dist)
        get_next = self.getNext(dataset)
        self.evaluate(get_next())
Example #20
0
def all_cluster_configurations():
    with_work_dir = combinations.combine(work_dir=TMP_WORK_DIR,
                                         fault_tolerant_mode=[True, False])
    without_work_dir = combinations.combine(work_dir=NO_WORK_DIR,
                                            fault_tolerant_mode=False)
    return with_work_dir + without_work_dir
Example #21
0
class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):

  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
    super(LocalReplicateTest, self).__init__(methodName)
    self._device0 = "/device:CPU:0"
    self._device1 = "/device:CPU:1"
    self._device2 = "/device:CPU:2"

  @combinations.generate(
      combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
  def testBasic(self):
    with ops.device(self._device0):
      dataset0 = dataset_ops.Dataset.range(100)
    replicated_ds = distribute.replicate(dataset0,
                                         [self._device1, self._device2])
    dataset1 = replicated_ds[self._device1]
    dataset2 = replicated_ds[self._device2]

    with ops.device(self._device0):
      self.assertDatasetProduces(dataset0, range(100))
    with ops.device(self._device1):
      self.assertDatasetProduces(dataset1, range(100))
    with ops.device(self._device2):
      self.assertDatasetProduces(dataset2, range(100))

  @combinations.generate(
      combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
  def testVariableInput(self):
    with ops.device(self._device0):
      counter_var = variable_scope.get_variable(
          "counter", (), dtypes.int32, use_resource=True)
      dataset0 = dataset_ops.Dataset.range(100).map(
          lambda _: counter_var.assign_add(1))
    # We don't support stateful ops in functions as of now.
    with self.assertRaises(errors.FailedPreconditionError):
      replicated_ds = distribute.replicate(dataset0,
                                           [self._device1, self._device2])
      self.evaluate(replicated_ds[self._device1]._variant_tensor)

  @combinations.generate(
      combinations.combine(tf_api_version=[1], mode=["graph", "eager"]))
  def testAllowStatefulOp(self):
    with compat.forward_compatibility_horizon(2019, 9, 12):
      with ops.device(self._device0):
        dataset0 = dataset_ops.Dataset.range(100).map(
            lambda _: random_ops.random_uniform(  # pylint:disable=g-long-lambda
                [],
                minval=1,
                maxval=10,
                dtype=dtypes.float32))
        opt = dataset_ops.Options()
        opt.experimental_allow_stateful = True
        dataset0 = dataset0.with_options(opt)
      replicated_ds = distribute.replicate(dataset0,
                                           [self._device1, self._device2])
      dataset1 = replicated_ds[self._device1]
      dataset2 = replicated_ds[self._device2]

      with ops.device(self._device0):
        get_next0 = self.getNext(dataset0)
      with ops.device(self._device1):
        get_next1 = self.getNext(dataset1)
      with ops.device(self._device2):
        get_next2 = self.getNext(dataset2)

      for _ in range(100):
        get_next0()
        get_next1()
        get_next2()
Example #22
0
 def reduce_fn(x, y):
   name, dataset_fn, expected_output = y
   return x + combinations.combine(
       dataset_fn=combinations.NamedObject(name, dataset_fn),
       expected_output=[expected_output])
Example #23
0
class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(count=[32, 34],
                                 padded_shapes=[[None], [25]],
                                 drop_remainder=[True, False])))
    def testPaddedBatchDataset(self, count, padded_shapes, drop_remainder):
        seq_lens = np.random.randint(20, size=(count, )).astype(np.int32)
        batch_size = 4
        dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
            lambda x: array_ops.fill([x], x)).padded_batch(
                batch_size=batch_size,
                drop_remainder=drop_remainder,
                padded_shapes=padded_shapes)

        num_full_batches = len(seq_lens) // batch_size
        get_next = self.getNext(dataset)
        for i in range(num_full_batches):
            result = self.evaluate(get_next())
            padded_len = padded_shapes[0]
            if padded_len is None or padded_len == -1:
                padded_len = np.max(result) if result.size > 0 else 0
            self.assertEqual((batch_size, padded_len), result.shape)
            for j in range(batch_size):
                seq_len = seq_lens[(i * batch_size) + j]
                self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
                self.assertAllEqual(result[j, seq_len:],
                                    [0] * (padded_len - seq_len))

        if not drop_remainder and len(seq_lens) % batch_size > 0:
            result = self.evaluate(get_next())
            padded_len = padded_shapes[0]
            if padded_len is None or padded_len == -1:
                padded_len = np.max(result) if result.size > 0 else 0
            self.assertEqual((len(seq_lens) % batch_size, padded_len),
                             result.shape)
            for j in range(len(seq_lens) % batch_size):
                seq_len = seq_lens[num_full_batches * batch_size + j]
                self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
                self.assertAllEqual(result[j, seq_len:],
                                    [0] * (padded_len - seq_len))

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShortPadding(self):
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [6, 5, 5, 5,
             5]).map(lambda x: array_ops.fill([x], x)).padded_batch(
                 batch_size=4, padded_shapes=[5]))
        self.assertDatasetProduces(dataset,
                                   expected_error=(errors.DataLossError, ''))

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchEmptyTensors(self):
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [0, 0, 0, 0]).map(lambda x: array_ops.fill([x], x)).padded_batch(
                batch_size=4, padded_shapes=[-1]))
        self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(padding_values=[(-1, '<end>', {
                'structure': ''
            }), (-1, '<end>', None)])))
    def testPaddedBatchDatasetNonDefaultPadding(self, padding_values):
        def fill_tuple(x):
            filled = array_ops.fill([x], x)
            return (filled, string_ops.as_string(filled), {
                'structure': string_ops.as_string(filled)
            })

        random_seq_lens = np.random.randint(20, size=(32, )).astype(np.int32)
        dataset = (dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map(
            fill_tuple).padded_batch(4,
                                     padded_shapes=([-1], [-1], {
                                         'structure': [-1]
                                     }),
                                     padding_values=padding_values))

        get_next = self.getNext(dataset)
        for i in range(8):
            result = self.evaluate(get_next())
            padded_len = np.max(result[0])
            self.assertEqual((4, padded_len), result[0].shape)
            self.assertEqual((4, padded_len), result[1].shape)
            self.assertEqual((4, padded_len), result[2]['structure'].shape)
            for j in range(4):
                seq_len = random_seq_lens[(i * 4) + j]
                self.assertAllEqual(result[0][j, :seq_len],
                                    [seq_len] * seq_len)
                self.assertAllEqual(result[0][j, seq_len:],
                                    [-1] * (padded_len - seq_len))
                self.assertAllEqual(result[1][j, :seq_len],
                                    [compat.as_bytes(str(seq_len))] * seq_len)
                self.assertAllEqual(result[1][j, seq_len:],
                                    [b'<end>'] * (padded_len - seq_len))
                self.assertAllEqual(result[2]['structure'][j, :seq_len],
                                    [compat.as_bytes(str(seq_len))] * seq_len)
                self.assertAllEqual(result[2]['structure'][j, seq_len:],
                                    [b''] * (padded_len - seq_len))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchDatasetUnicode(self):
        # See GitHub issue 16149
        def generator():
            data = [[u'Простой', u'тест', u'юникода'],
                    [u'никогда', u'не', u'бывает', u'простым']]

            for seq in data:
                yield seq, [0, 1, 2, 3]

        dataset = dataset_ops.Dataset.from_generator(
            generator, (dtypes.string, dtypes.int32),
            (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None
                                                                         ])))
        padded_dataset = dataset.padded_batch(2,
                                              padded_shapes=([None], [None]),
                                              padding_values=('', 0))
        next_element = self.getNext(padded_dataset)
        self.evaluate(next_element())

    @combinations.generate(test_base.graph_only_combinations())
    def testPaddedBatchDatasetShapeSpecifications(self):
        int_placeholder = array_ops.placeholder(dtypes.int32)
        float_placeholder = array_ops.placeholder(dtypes.float32)
        string_placeholder = array_ops.placeholder(dtypes.string)
        input_dataset = dataset_ops.Dataset.from_tensors(
            (int_placeholder, float_placeholder, string_placeholder))

        # Test different ways of specifying the `padded_shapes` argument.
        dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
            32,
            padded_shapes=(tensor_shape.TensorShape([None]),
                           tensor_shape.TensorShape([None, None]),
                           tensor_shape.TensorShape([37])))
        dynamic_padding_from_lists = input_dataset.padded_batch(
            32, padded_shapes=([None], [None, None], [37]))
        dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
            32, padded_shapes=([-1], [-1, -1], [37]))
        dynamic_padding_from_tensors = input_dataset.padded_batch(
            32,
            padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
                           constant_op.constant([-1, -1], dtype=dtypes.int64),
                           constant_op.constant([37], dtype=dtypes.int64)))

        for dataset in [
                dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
                dynamic_padding_from_lists_with_minus_one,
                dynamic_padding_from_tensors
        ]:
            dataset_output_shapes = dataset_ops.get_legacy_output_shapes(
                dataset)
            self.assertEqual([None, None], dataset_output_shapes[0].as_list())
            self.assertEqual([None, None, None],
                             dataset_output_shapes[1].as_list())
            self.assertEqual([None, 37], dataset_output_shapes[2].as_list())

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchSparseError(self):
        def _map_fn(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0]],
                                                   values=(i * [1]),
                                                   dense_shape=[1, 1]), i

        with self.assertRaises(TypeError):
            _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)

    @combinations.generate(test_base.default_test_combinations())
    def testPaddedBatchShapeError(self):
        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \(1,\) is not compatible with the '
                r'corresponding input component shape \(\).'):
            _ = dataset_ops.Dataset.range(10).padded_batch(5,
                                                           padded_shapes=[1])

        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \(1,\) is not compatible with the '
                r'corresponding input component shape \(3,\).'):
            _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
                5, padded_shapes=[1])

        with self.assertRaisesRegexp(
                ValueError, r'Padded shape .* must be a 1-D tensor '
                r'of tf.int64 values, but its shape was \(2, 2\).'):
            _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
                5, padded_shapes=[[1, 1], [1, 1]])

        with self.assertRaisesRegexp(
                TypeError, r'Padded shape .* must be a 1-D tensor '
                r'of tf.int64 values, but its element type was float32.'):
            _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
                5, padded_shapes=constant_op.constant([1.5, 2., 3.]))

        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \(1,\) is not compatible with the '
                r'corresponding input component shape \(\).'):
            shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
            _ = dataset_ops.Dataset.range(10).padded_batch(
                5, padded_shapes=shape_as_tensor)

    @combinations.generate(test_base.graph_only_combinations())
    def testPaddedBatchShapeErrorPlaceholder(self):
        with self.assertRaisesRegexp(
                ValueError,
                r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
                r'corresponding input component shape \(\).'):
            shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2])
            _ = dataset_ops.Dataset.range(10).padded_batch(
                5, padded_shapes=shape_as_tensor)
Example #24
0
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  @combinations.generate(test_base.default_test_combinations())
  def testOptimizationStatefulFunction(self):
    dataset = dataset_ops.Dataset.range(
        10).map(lambda _: random_ops.random_uniform([])).batch(10)
    options = dataset_ops.Options()
    options.experimental_optimization.apply_default_optimizations = False
    dataset = dataset.with_options(options)
    get_next = self.getNext(dataset)
    self.evaluate(get_next())

  # TODO(b/123902160)
  @combinations.generate(test_base.graph_only_combinations())
  def testOptimizationLargeInputFromTensor(self):
    input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
    dataset = dataset_ops.Dataset.from_tensors(input_t)
    options = dataset_ops.Options()
    options.experimental_optimization.apply_default_optimizations = False
    dataset = dataset.with_options(options)
    iterator = dataset_ops.make_initializable_iterator(dataset)
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
      self.evaluate(get_next)

  # TODO(b/123902160)
  @combinations.generate(test_base.graph_only_combinations())
  def testOptimizationLargeInputFromTensorSlices(self):
    input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
    dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
    options = dataset_ops.Options()
    options.experimental_optimization.apply_default_optimizations = False
    dataset = dataset.with_options(options)
    iterator = dataset_ops.make_initializable_iterator(dataset)
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
      self.evaluate(get_next)

  @combinations.generate(test_base.default_test_combinations())
  def testOptimizationNestedDataset(self):

    def flat_map_fn(_):
      dataset = dataset_ops.Dataset.from_tensors(0)
      dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
      dataset = dataset.skip(0)  # Should be removed by noop elimination
      dataset = dataset.cache()
      return dataset

    dataset = dataset_ops.Dataset.range(1)
    dataset = dataset.flat_map(flat_map_fn)
    options = dataset_ops.Options()
    options.experimental_optimization.apply_default_optimizations = False
    options.experimental_optimization.noop_elimination = True
    dataset = dataset.with_options(options)
    self.assertDatasetProduces(dataset, expected_output=[0])

  @combinations.generate(test_base.default_test_combinations())
  def testOptimizationNestedDatasetWithModifiedRetval(self):

    def flat_map_fn(_):
      dataset = dataset_ops.Dataset.from_tensors(0)
      dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
      # Should be fused by map and batch fusion
      dataset = dataset.map(lambda x: x)
      dataset = dataset.batch(1)
      return dataset

    dataset = dataset_ops.Dataset.range(1)
    dataset = dataset.flat_map(flat_map_fn)

    options = dataset_ops.Options()
    options.experimental_optimization.apply_default_optimizations = False
    options.experimental_optimization.map_and_batch_fusion = True
    dataset = dataset.with_options(options)
    self.assertDatasetProduces(dataset, expected_output=[[0]])

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          _disable_intra_op_parallelism_test_combinations(),
          combinations.combine(apply_autotune=[None, True, False])))
  def testOptimizationDisableIntraOpParallelism(self, dataset_fn,
                                                expected_output,
                                                apply_autotune):
    dataset = dataset_fn()
    dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"]))
    if apply_autotune is not None:
      options = dataset_ops.Options()
      options.experimental_optimization.autotune = apply_autotune
      dataset = dataset.with_options(options)

    self.assertDatasetProduces(dataset, expected_output=expected_output)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(autotune=False, autotune_buffers=False) +
          combinations.combine(autotune=True, autotune_buffers=False) +
          combinations.combine(autotune=True, autotune_buffers=True),
          combinations.combine(first_buffer_sizes=[(1, -1, -1, 4),
                                                   (2, -1, 3, -1),
                                                   (2, 1, -1, -1)]),
          combinations.combine(second_buffer_sizes=[(1, -1, -1, 4),
                                                    (2, -1, 3, -1),
                                                    (2, 1, -1, -1)]))
  )
  def testOptimizationAutotuneBuffers(self, autotune, autotune_buffers,
                                      first_buffer_sizes, second_buffer_sizes):
    dataset = dataset_ops.Dataset.range(10)
    for buffer_size in first_buffer_sizes:
      dataset = dataset.prefetch(buffer_size=buffer_size)
    dataset = dataset.map(lambda x: x + 1)
    for buffer_size in second_buffer_sizes:
      dataset = dataset.prefetch(buffer_size=buffer_size)
    options = dataset_ops.Options()
    options.experimental_optimization.autotune = autotune
    options.experimental_optimization.autotune_buffers = autotune_buffers
    dataset = dataset.with_options(options)
    self.assertDatasetProduces(dataset, expected_output=list(range(1, 11)))

  @combinations.generate(test_base.default_test_combinations())
  def testOptimizationThreadPoolDataset(self):
    dataset = dataset_ops.Dataset.range(10).batch(10)

    dataset = threadpool.override_threadpool(
        dataset,
        threadpool.PrivateThreadPool(
            2, display_name="private_thread_pool_%d" % 2))

    options = dataset_ops.Options()
    options.experimental_optimization.apply_default_optimizations = False
    dataset = dataset.with_options(options)
    self.assertDatasetProduces(
        dataset,
        expected_output=[list(range(10))],
        requires_initialization=True)

  # Reference variables are not supported in eager mode.
  @combinations.generate(
      combinations.times(test_base.graph_only_combinations(),
                         _captured_refvar_test_combinations()))
  def testOptimizationWithCapturedRefVar(self, dataset_fn):
    """Tests that default optimizations are disabled with ref variables."""
    variable = variable_scope.get_variable(
        "v", initializer=0, use_resource=False)
    assign_op = variable.assign_add(1)

    # Check that warning is logged.
    warnings.simplefilter("always")
    with warnings.catch_warnings(record=True) as w:
      unoptimized_dataset = dataset_fn(variable)

      options = dataset_ops.Options()
      options.experimental_optimization.apply_default_optimizations = False
      options.experimental_optimization.noop_elimination = True
      options.experimental_optimization.map_and_batch_fusion = True
      optimized_dataset = unoptimized_dataset.with_options(options)
      optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)

    self.assertGreaterEqual(len(w), 1)
    graph_rewrites = options._graph_rewrites()
    expected = (
        "tf.data graph rewrites are not compatible with "
        "tf.Variable. The following rewrites will be disabled: %s."
        " To enable rewrites, use resource variables instead by "
        "calling `tf.enable_resource_variables()` at the start of the "
        "program." %
        (", ".join(graph_rewrites.enabled + graph_rewrites.default)))
    self.assertTrue(any(expected in str(warning) for warning in w))

    # Check that outputs are the same in the optimized and unoptimized cases,
    # when the variable value is changing.
    unoptimized_it = dataset_ops.make_initializable_iterator(
        unoptimized_dataset)
    with ops.control_dependencies([assign_op]):
      unoptimized_output = unoptimized_it.get_next()
      optimized_output = optimized_it.get_next()

    self.evaluate(variable.initializer)
    self.evaluate((unoptimized_it.initializer, optimized_it.initializer))
    while True:
      try:
        unoptimized, optimized = self.evaluate((unoptimized_output,
                                                optimized_output))
        self.assertEqual(unoptimized, optimized)
      except errors.OutOfRangeError:
        break

  @combinations.generate(test_base.default_test_combinations())
  def testOptimizationDefault(self):
    """Tests the optimization settings by default."""
    options = dataset_ops.Options()
    expected_optimizations_enabled = []
    expected_optimizations_disabled = []
    expected_optimizations_default = [
        "map_and_batch_fusion",
        "noop_elimination",
        "shuffle_and_repeat_fusion",
    ]
    graph_rewrites = options._graph_rewrites()
    self.assertEqual(set(graph_rewrites.enabled),
                     set(expected_optimizations_enabled))
    self.assertEqual(set(graph_rewrites.disabled),
                     set(expected_optimizations_disabled))
    self.assertEqual(set(graph_rewrites.default),
                     set(expected_optimizations_default))

    options.experimental_optimization.apply_default_optimizations = True
    graph_rewrites = options._graph_rewrites()
    self.assertEqual(set(graph_rewrites.enabled),
                     set(expected_optimizations_enabled))
    self.assertEqual(set(graph_rewrites.disabled),
                     set(expected_optimizations_disabled))
    self.assertEqual(set(graph_rewrites.default),
                     set(expected_optimizations_default))

    options.experimental_optimization.apply_default_optimizations = False
    expected_optimizations_default = []
    graph_rewrites = options._graph_rewrites()
    self.assertEqual(set(graph_rewrites.enabled),
                     set(expected_optimizations_enabled))
    self.assertEqual(set(graph_rewrites.disabled),
                     set(expected_optimizations_disabled))
    self.assertEqual(set(graph_rewrites.default),
                     set(expected_optimizations_default))

  @combinations.generate(test_base.default_test_combinations())
  def testOptimizationEnabled(self):
    """Tests the optimization settings by enabling all."""
    options = dataset_ops.Options()
    options.experimental_optimization.filter_fusion = True
    options.experimental_optimization.filter_with_random_uniform_fusion = True
    options.experimental_optimization.hoist_random_uniform = True
    options.experimental_optimization.map_and_batch_fusion = True
    options.experimental_optimization.map_and_filter_fusion = True
    options.experimental_optimization.map_parallelization = True
    options.experimental_optimization.map_fusion = True
    options.experimental_optimization.noop_elimination = True
    options.experimental_optimization.parallel_batch = True
    options.experimental_optimization.shuffle_and_repeat_fusion = True
    options.experimental_optimization.map_vectorization.enabled = True
    options.experimental_optimization.autotune_buffers = True
    options.experimental_deterministic = False
    options.experimental_stats.latency_all_edges = True
    options.experimental_slack = True

    expected_optimizations_enabled = [
        "filter_fusion",
        "filter_with_random_uniform_fusion",
        "hoist_random_uniform",
        "map_and_batch_fusion",
        "map_and_filter_fusion",
        "map_parallelization",
        "map_fusion",
        "noop_elimination",
        "parallel_batch",
        "shuffle_and_repeat_fusion",
        "map_vectorization",
        "autotune_buffer_sizes",
        "make_sloppy",
        "latency_all_edges",
        "slack",
        "disable_prefetch_legacy_autotune",
    ]
    expected_optimizations_disabled = []
    expected_optimizations_default = []
    graph_rewrites = options._graph_rewrites()
    self.assertEqual(set(graph_rewrites.enabled),
                     set(expected_optimizations_enabled))
    self.assertEqual(set(graph_rewrites.disabled),
                     set(expected_optimizations_disabled))
    self.assertEqual(set(graph_rewrites.default),
                     set(expected_optimizations_default))

  @combinations.generate(test_base.default_test_combinations())
  def testOptimizationDisabled(self):
    """Tests the optimization settings by disabling all."""
    options = dataset_ops.Options()
    options.experimental_optimization.filter_fusion = False
    options.experimental_optimization.filter_with_random_uniform_fusion = False
    options.experimental_optimization.hoist_random_uniform = False
    options.experimental_optimization.map_and_batch_fusion = False
    options.experimental_optimization.map_and_filter_fusion = False
    options.experimental_optimization.map_parallelization = False
    options.experimental_optimization.map_fusion = False
    options.experimental_optimization.noop_elimination = False
    options.experimental_optimization.parallel_batch = False
    options.experimental_optimization.shuffle_and_repeat_fusion = False
    options.experimental_optimization.map_vectorization.enabled = False
    options.experimental_optimization.autotune = False
    options.experimental_deterministic = True
    options.experimental_stats.latency_all_edges = False
    options.experimental_slack = False

    expected_optimizations_enabled = []
    expected_optimizations_disabled = [
        "filter_fusion",
        "filter_with_random_uniform_fusion",
        "hoist_random_uniform",
        "map_and_batch_fusion",
        "map_and_filter_fusion",
        "map_parallelization",
        "map_fusion",
        "noop_elimination",
        "parallel_batch",
        "shuffle_and_repeat_fusion",
        "map_vectorization",
        "autotune_buffer_sizes",
        "make_sloppy",
        "latency_all_edges",
        "slack",
        "disable_prefetch_legacy_autotune",
    ]
    expected_optimizations_default = []
    graph_rewrites = options._graph_rewrites()
    self.assertEqual(set(graph_rewrites.enabled),
                     set(expected_optimizations_enabled))
    self.assertEqual(set(graph_rewrites.disabled),
                     set(expected_optimizations_disabled))
    self.assertEqual(set(graph_rewrites.default),
                     set(expected_optimizations_default))

  @combinations.generate(test_base.default_test_combinations())
  def testAutotuningDefaults(self):
    options = dataset_ops.Options()

    # Check defaults
    autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings()
    self.assertTrue(autotune)
    self.assertEqual(algorithm,
                     optimization_options._AutotuneAlgorithm.HILL_CLIMB)
    self.assertEqual(cpu_budget, 0)
    self.assertEqual(ram_budget, 0)

  @combinations.generate(test_base.default_test_combinations())
  def testAutotuningSettings(self):
    options = dataset_ops.Options()
    options.experimental_optimization.autotune_cpu_budget = 1000
    options.experimental_optimization.autotune_ram_budget = 999999999
    options.experimental_optimization.autotune_buffers = True
    self.assertIn("autotune_buffer_sizes", options._graph_rewrites().enabled)
    self.assertIn("disable_prefetch_legacy_autotune",
                  options._graph_rewrites().enabled)

    autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings()
    self.assertTrue(autotune)
    self.assertEqual(algorithm,
                     optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT)
    self.assertEqual(cpu_budget, 1000)
    self.assertEqual(ram_budget, 999999999)
class BucketBySequenceLengthTest(test_base.DatasetTestBase,
                                 parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_no_padding=[True, False])))
    def testBucketDropReminder(self, param_no_padding):

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        lengths = [8, 13, 25, 35]

        n_bucket_elements = [28, 7, 6, 5]
        n_expected_batches = 5

        # Expected sequence lengths of the individual batches.
        expected_lengths = []

        # Expected sum of all batches with an equal sequence length.
        # <seq-length>: <expected-total-sum>
        expected_sums = {}

        # Expected batch sizes of batches depending on the sequence length.
        # <seq-length>: [batch1_size, ..., batchN_size]
        expected_batch_sizes = {}

        for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
                                                       n_bucket_elements):
            # Calculate the expected sum across all batches of a specific sequence
            # length.
            expected_sums[length] = \
                (bucket_elements - bucket_elements % batch_size) * length
            # Calculate the expected occurrence of individual batch sizes.
            expected_batch_sizes[length] = \
                [batch_size] * (bucket_elements // batch_size)
            # Calculate the expected occurrence of individual sequence lengths.
            expected_lengths.extend([length] * (bucket_elements // batch_size))

        def build_dataset(sparse):
            def _generator():
                # Produce 1 batch for each bucket
                elements = []
                for bucket_elements, length in zip(n_bucket_elements, lengths):
                    # Using only full sequences (opposed to the strategy employed in
                    # `testBucket`) makes checking the sum a lot easier.
                    record_len = length
                    for _ in range(bucket_elements):
                        elements.append([1] * record_len)
                random.shuffle(elements)
                for el in elements:
                    yield (_format_record(el, sparse), )

            dataset = dataset_ops.Dataset.from_generator(
                _generator, (_get_record_type(sparse), ),
                (_get_record_shape(sparse), ))
            if sparse:
                dataset = dataset.map(lambda x: (_to_sparse_tensor(x), ))
            return dataset

        def _test_bucket_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.bucket_by_sequence_length(
                element_length_func=_element_length_fn,
                bucket_boundaries=boundaries,
                bucket_batch_sizes=batch_sizes,
                no_padding=no_padding,
                drop_remainder=True)

            get_next = self.getNext(dataset)
            batches = []
            for _ in range(n_expected_batches):
                batch, = self.evaluate(get_next())
                batches.append(batch)

            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

            generated_lengths = []

            # <seq-length>: <total-sum>
            generated_sums = {}

            # <seq-length>: [<batch_size>, ...]
            generated_batch_sizes = {}

            for length, batch_size, bucket_elements in zip(
                    lengths, batch_sizes, n_bucket_elements):
                # Initialize the sum across all batches.
                generated_sums[length] = 0
                # Initialize the individual batch sizes.
                generated_batch_sizes[length] = []

            for batch in batches:
                shape = batch.dense_shape if no_padding else batch.shape
                length = shape[1]
                generated_lengths.append(length)

                batch_size = shape[0]
                generated_batch_sizes[length].append(batch_size)

                batch_sum = batch.values.sum() if no_padding else batch.sum()
                generated_sums[length] += batch_sum

            for l in lengths:
                # Make sure the sum of the batch contents is correct for the individual
                # sequence lengths.
                self.assertEqual(
                    generated_sums[l], expected_sums[l],
                    "Tensor sums did not match! "
                    "expected: {}, generated: {}".format(
                        expected_sums, generated_sums))

                # Make sure the individual batch sizes are generated as expected.
                self.assertEqual(
                    sorted(generated_batch_sizes[l]),
                    sorted(expected_batch_sizes[l]),
                    "Batch-sizes did not match! "
                    "expected: {}, generated: {}".format(
                        sorted(expected_batch_sizes[l]),
                        sorted(generated_batch_sizes[l])))

            # Make sure the generated sequence lengths appear as often as expected.
            self.assertEqual(
                sorted(generated_lengths), sorted(expected_lengths),
                "The generated sequence lengths did not match! "
                "expected: {}, generated: {}".format(
                    sorted(expected_lengths), sorted(generated_lengths)))

        _test_bucket_by_padding(param_no_padding)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_no_padding=[True, False])))
    def testBucket(self, param_no_padding):

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        lengths = [8, 13, 25, 35]

        def build_dataset(sparse):
            def _generator():
                # Produce 1 batch for each bucket
                elements = []
                for batch_size, length in zip(batch_sizes, lengths):
                    record_len = length - 1
                    for _ in range(batch_size):
                        elements.append([1] * record_len)
                        record_len = length
                random.shuffle(elements)
                for el in elements:
                    yield (_format_record(el, sparse), )

            dataset = dataset_ops.Dataset.from_generator(
                _generator, (_get_record_type(sparse), ),
                (_get_record_shape(sparse), ))
            if sparse:
                dataset = dataset.map(lambda x: (_to_sparse_tensor(x), ))
            return dataset

        def _test_bucket_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.bucket_by_sequence_length(
                element_length_func=_element_length_fn,
                bucket_boundaries=boundaries,
                bucket_batch_sizes=batch_sizes,
                no_padding=no_padding)
            get_next = self.getNext(dataset)
            batches = []
            for _ in range(4):
                batch, = self.evaluate(get_next())
                batches.append(batch)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

            batch_sizes_val = []
            lengths_val = []
            for batch in batches:
                shape = batch.dense_shape if no_padding else batch.shape
                batch_size = shape[0]
                length = shape[1]
                batch_sizes_val.append(batch_size)
                lengths_val.append(length)
                if not context.executing_eagerly():
                    sum_check = batch.values.sum(
                    ) if no_padding else batch.sum()
                    self.assertEqual(sum_check, batch_size * length - 1)
            self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
            self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
            self.assertEqual(sorted(lengths), sorted(lengths_val))

        _test_bucket_by_padding(param_no_padding)

    @combinations.generate(test_base.default_test_combinations())
    def testPadToBoundary(self):

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        lengths = [8, 13, 25]

        def element_gen():
            # Produce 1 batch for each bucket
            elements = []
            for batch_size, length in zip(batch_sizes[:-1], lengths):
                for _ in range(batch_size):
                    elements.append([1] * length)
            random.shuffle(elements)
            for el in elements:
                yield (el, )
            for _ in range(batch_sizes[-1]):
                el = [1] * (boundaries[-1] + 5)
                yield (el, )

        element_len = lambda el: array_ops.shape(el)[0]
        dataset = dataset_ops.Dataset.from_generator(element_gen,
                                                     (dtypes.int64, ),
                                                     ([None], ))
        dataset = dataset.bucket_by_sequence_length(
            element_length_func=element_len,
            bucket_boundaries=boundaries,
            bucket_batch_sizes=batch_sizes,
            pad_to_bucket_boundary=True)
        get_next = self.getNext(dataset)

        batches = []
        for _ in range(3):
            batch, = self.evaluate(get_next())
            batches.append(batch)
        with self.assertRaisesOpError("bucket_boundaries"):
            self.evaluate(get_next())

        batch_sizes_val = []
        lengths_val = []
        for batch in batches:
            batch_size = batch.shape[0]
            length = batch.shape[1]
            batch_sizes_val.append(batch_size)
            lengths_val.append(length)
        batch_sizes = batch_sizes[:-1]
        self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
        self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
        self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
                         sorted(lengths_val))

    @combinations.generate(test_base.default_test_combinations())
    def testPadToBoundaryNoExtraneousPadding(self):

        boundaries = [3, 7, 11]
        batch_sizes = [2, 2, 2, 2]
        lengths = range(1, 11)

        def element_gen():
            for length in lengths:
                yield ([1] * length, )

        element_len = lambda element: array_ops.shape(element)[0]
        dataset = dataset_ops.Dataset.from_generator(element_gen,
                                                     (dtypes.int64, ),
                                                     ([None], ))
        dataset = dataset.bucket_by_sequence_length(
            element_length_func=element_len,
            bucket_boundaries=boundaries,
            bucket_batch_sizes=batch_sizes,
            pad_to_bucket_boundary=True)
        get_next = self.getNext(dataset)

        batches = []
        for _ in range(5):
            batch, = self.evaluate(get_next())
            batches.append(batch)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

        self.assertAllEqual(batches[0], [[1, 0], [1, 1]])
        self.assertAllEqual(batches[1],
                            [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]])
        self.assertAllEqual(batches[2],
                            [[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1]])
        self.assertAllEqual(
            batches[3],
            [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
        self.assertAllEqual(
            batches[4],
            [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_no_padding=[True, False])))
    def testTupleElements(self, param_no_padding):
        def build_dataset(sparse):
            def _generator():
                text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
                label = [1, 2, 1, 2]
                for x, y in zip(text, label):
                    yield (_format_record(x, sparse), y)

            dataset = dataset_ops.Dataset.from_generator(
                generator=_generator,
                output_types=(_get_record_type(sparse), dtypes.int32),
                output_shapes=(_get_record_shape(sparse),
                               tensor_shape.TensorShape([])))
            if sparse:
                dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
            return dataset

        def _test_tuple_elements_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.bucket_by_sequence_length(
                element_length_func=_element_length_fn,
                bucket_batch_sizes=[2, 2, 2],
                bucket_boundaries=[0, 8],
                no_padding=no_padding)
            shapes = dataset_ops.get_legacy_output_shapes(dataset)
            self.assertEqual([None, None], shapes[0].as_list())
            self.assertEqual([None], shapes[1].as_list())

        _test_tuple_elements_by_padding(param_no_padding)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_drop_remainder=[True, False])))
    def testBucketSparse(self, param_drop_remainder):  # pylint: disable=g-doc-args
        """Tests bucketing of sparse tensors (case where `no_padding` == True).

    Test runs on following dataset:
      [
        [0],
        [0, 1],
        [0, 1, 2]
        ...
        [0, ..., max_len - 1]
      ]
    Sequences are bucketed by length and batched with
      `batch_size` < `bucket_size`.
    """

        min_len = 0
        max_len = 100
        batch_size = 7
        bucket_size = 10

        def _build_dataset():
            input_data = [range(i + 1) for i in range(min_len, max_len)]

            def generator_fn():
                for record in input_data:
                    yield _format_record(record, sparse=True)

            dataset = dataset_ops.Dataset.from_generator(
                generator=generator_fn,
                output_types=_get_record_type(sparse=True))
            dataset = dataset.map(_to_sparse_tensor)
            return dataset

        def _compute_expected_batches(drop_remainder):
            """Computes expected batch outputs and stores in a set."""
            all_expected_sparse_tensors = set()
            for bucket_start_len in range(min_len, max_len, bucket_size):
                if drop_remainder:
                    batch_offsets = [0]
                else:
                    batch_offsets = range(0, bucket_size, batch_size)

                for batch_offset in batch_offsets:
                    batch_start_len = bucket_start_len + batch_offset
                    batch_end_len = min(batch_start_len + batch_size,
                                        bucket_start_len + bucket_size)
                    expected_indices = []
                    expected_values = []
                    for length in range(batch_start_len, batch_end_len):
                        for val in range(length + 1):
                            expected_indices.append(
                                (length - batch_start_len, val))
                            expected_values.append(val)
                    expected_sprs_tensor = (tuple(expected_indices),
                                            tuple(expected_values))
                    all_expected_sparse_tensors.add(expected_sprs_tensor)
            return all_expected_sparse_tensors

        def _compute_batches(dataset):
            """Computes actual batch outputs of dataset and stores in a set."""
            batch = self.getNext(dataset)
            all_sparse_tensors = set()
            with self.assertRaises(errors.OutOfRangeError):
                while True:
                    output = self.evaluate(batch())
                    sprs_tensor = (tuple([
                        tuple(idx) for idx in output.indices
                    ]), tuple(output.values))
                    all_sparse_tensors.add(sprs_tensor)

            return all_sparse_tensors

        dataset = _build_dataset()
        boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
        dataset = dataset.bucket_by_sequence_length(
            element_length_func=_element_length_fn,
            bucket_boundaries=boundaries,
            bucket_batch_sizes=[batch_size] * (len(boundaries) + 1),
            no_padding=True,
            drop_remainder=param_drop_remainder)
        batches = _compute_batches(dataset)
        expected_batches = _compute_expected_batches(param_drop_remainder)
        self.assertEqual(batches, expected_batches)

    @combinations.generate(test_base.default_test_combinations())
    def testCardinality(self):

        boundaries = [3, 7, 11]
        batch_sizes = [2, 2, 2, 2]
        lengths = range(1, 11)

        def element_gen():
            for length in lengths:
                yield ([1] * length, )

        element_len = lambda element: array_ops.shape(element)[0]
        dataset = dataset_ops.Dataset.from_generator(element_gen,
                                                     (dtypes.int64, ),
                                                     ([None], )).repeat()
        dataset = dataset.bucket_by_sequence_length(
            element_length_func=element_len,
            bucket_boundaries=boundaries,
            bucket_batch_sizes=batch_sizes,
            pad_to_bucket_boundary=True)
        self.assertEqual(self.evaluate(dataset.cardinality()),
                         dataset_ops.INFINITE)
Example #26
0
 def reduce_fn(x, y):
   name, dataset_fn = y
   return x + combinations.combine(
       dataset_fn=combinations.NamedObject(name, dataset_fn))
Example #27
0
class ReplicateClusterTest(test_base.DatasetTestBase, parameterized.TestCase):
    def setUp(self):
        super(ReplicateClusterTest, self).setUp()
        # Start the local server.
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 2
        worker, _ = test_util.create_local_cluster(3,
                                                   0,
                                                   worker_config=worker_config)
        self._device0 = "/job:worker/replica:0/task:0/device:CPU:0"
        self._device1 = "/job:worker/replica:0/task:1/device:CPU:0"
        self._device2 = "/job:worker/replica:0/task:2/device:CPU:0"
        self._target = worker[0].target

    @combinations.generate(
        combinations.combine(tf_api_version=[1], mode=["graph"]))
    def testBasic(self):
        with ops.device(self._device0):
            dataset0 = dataset_ops.Dataset.range(100)
        replicated_ds = distribute.replicate(dataset0,
                                             [self._device1, self._device2])
        dataset1 = replicated_ds[self._device1]
        dataset2 = replicated_ds[self._device2]
        with ops.device(self._device0):
            get_next = self.getNext(dataset0)
        with ops.device(self._device1):
            get_next1 = self.getNext(dataset1)
        with ops.device(self._device2):
            get_next2 = self.getNext(dataset2)

        with session.Session(self._target) as sess:
            for i in range(100):
                self.assertEqual(i, sess.run(get_next()))
                self.assertEqual(i, sess.run(get_next1()))
                self.assertEqual(i, sess.run(get_next2()))

    @combinations.generate(
        combinations.combine(tf_api_version=[1], mode=["graph"]))
    def testMap(self):
        with ops.device(self._device0):
            dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2)
        replicated_ds = distribute.replicate(dataset0,
                                             [self._device1, self._device2])
        dataset1 = replicated_ds[self._device1]
        dataset2 = replicated_ds[self._device2]
        with ops.device(self._device0):
            get_next = self.getNext(dataset0)
        with ops.device(self._device1):
            get_next1 = self.getNext(dataset1)
        with ops.device(self._device2):
            get_next2 = self.getNext(dataset2)

        with session.Session(self._target) as sess:
            for i in range(100):
                self.assertEqual(i * 2, sess.run(get_next()))
                self.assertEqual(i * 2, sess.run(get_next1()))
                self.assertEqual(i * 2, sess.run(get_next2()))

    @combinations.generate(
        combinations.combine(tf_api_version=[1], mode=["graph"]))
    def testVariableInput(self):
        with ops.device(self._device0):
            counter_var = variable_scope.get_variable("counter", (),
                                                      dtypes.int32,
                                                      use_resource=True)
            dataset0 = dataset_ops.Dataset.range(100).map(
                lambda _: counter_var.assign_add(1))
        replicated_ds = distribute.replicate(dataset0,
                                             [self._device1, self._device2])
        dataset1 = replicated_ds[self._device1]
        with ops.device(self._device1):
            it1 = dataset_ops.make_initializable_iterator(dataset1)
        # We don't support stateful ops across processes in functions as of now.
        with session.Session(self._target) as sess:
            with self.assertRaises(errors.OpError):
                sess.run(it1.initializer)
Example #28
0
class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_elements=[14, 15], window_size=[2]) +
            combinations.combine(num_elements=[100], window_size=[3])))
    def testTakeWhileDataset(self, num_elements, window_size):
        def _predicate_func(elem):
            return array_ops.shape(elem)[0] > (window_size - 1)

        dataset = dataset_ops.Dataset.range(num_elements).batch(window_size)
        dataset = dataset.take_while(predicate=_predicate_func).flat_map(
            dataset_ops.Dataset.from_tensor_slices)

        expected_num_elements = int(num_elements / window_size) * window_size
        self.assertDatasetProduces(dataset, np.arange(expected_num_elements))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_elements=[10], upper_bound=[2]) +
            combinations.combine(num_elements=[16], upper_bound=[7]) +
            combinations.combine(num_elements=[100], upper_bound=[99]) +
            combinations.combine(num_elements=[100], upper_bound=[101]) +
            combinations.combine(num_elements=[0], upper_bound=[1])))
    def testTakeWhileDatasetRange(self, num_elements, upper_bound):
        dataset = dataset_ops.Dataset.range(num_elements).take_while(
            lambda x: x < upper_bound)

        self.assertDatasetProduces(dataset,
                                   np.arange(min(num_elements, upper_bound)))

    @combinations.generate(test_base.default_test_combinations())
    def testTakeWhileDatasetString(self):
        def not_equal(string):
            return lambda x: math_ops.not_equal(x, constant_op.constant(string)
                                                )

        string = ["this", "is", "the", "test", "for", "strings"]
        dataset = dataset_ops.Dataset.from_tensor_slices(string).take_while(
            predicate=not_equal("test"))

        next_element = self.getNext(dataset)
        self.assertEqual(b"this", self.evaluate(next_element()))
        self.assertEqual(b"is", self.evaluate(next_element()))
        self.assertEqual(b"the", self.evaluate(next_element()))

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(size=[5], index=[3]) +
            combinations.combine(size=[10], index=[0]) +
            combinations.combine(size=[100], index=[5]) +
            combinations.combine(size=[8], index=[7])))
    def testTakewhileDatasetShortCircuit(self, size, index):
        def _predicate_func(data_elem):
            return data_elem

        boolean_array = [True] * size
        boolean_array[index] = False
        dataset = dataset_ops.Dataset.from_tensor_slices(
            boolean_array).take_while(predicate=_predicate_func)

        next_element = self.getNext(dataset)

        for _ in range(index):
            self.assertTrue(self.evaluate(next_element()))

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_element())

    @combinations.generate(test_base.default_test_combinations())
    def testTakeWhileDatasetWithRepeat(self):
        dataset = dataset_ops.Dataset.range(10).take_while(
            predicate=lambda x: x < 2).repeat(5)
        self.assertDatasetProduces(dataset, np.tile([0, 1], 5))

    @combinations.generate(test_base.default_test_combinations())
    def testTakeWhileDatasetStops(self):
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.take_while(
            lambda x: math_ops.logical_not(math_ops.equal(x, 5)))
        self.assertDatasetProduces(dataset, range(5))
Example #29
0
class ClusterCombinationTest(test.TestCase, parameterized.TestCase):
  # For this test we need to use `framework.test_combinations` because our
  # `generate` eats the cluster parameters.
  #
  # Note that we don't have a standalone combination for ClusterParameters, so
  # we should use GPUCombination which contains it.

  @framework_combinations.generate(  # pylint: disable=redundant-keyword-arg
      framework_combinations.combine(distribution=[
          combinations.NamedDistribution(
              "HasClusterParams", lambda: None, has_chief=True, num_workers=2),
      ]),
      test_combinations=(combinations.ClusterCombination(),))
  def testClusterParams(self, distribution, has_chief, num_workers):
    self.assertTrue(has_chief)
    self.assertEqual(num_workers, 2)

  @framework_combinations.generate(  # pylint: disable=redundant-keyword-arg
      framework_combinations.combine(distribution=[
          combinations.NamedDistribution("NoClusterParams", lambda: None),
      ]),
      test_combinations=(combinations.ClusterCombination(),))
  def testClusterParamsHasDefault(self, distribution, has_chief, num_workers):
    self.assertFalse(has_chief)
    self.assertEqual(num_workers, 1)

  @framework_combinations.generate(  # pylint: disable=redundant-keyword-arg
      framework_combinations.combine(v=1),
      test_combinations=(combinations.ClusterCombination(),))
  def testClusterParamsNoStrategy(self, v, has_chief, num_workers):
    self.assertFalse(has_chief)
    self.assertEqual(num_workers, 1)

  @framework_combinations.generate(  # pylint: disable=redundant-keyword-arg
      framework_combinations.combine(distribution=[
          combinations.NamedDistribution(
              "WithClusterParams", lambda: None, has_chief=True, num_workers=2),
          combinations.NamedDistribution("WithoutClusterParams", lambda: None),
      ]),
      test_combinations=(combinations.ClusterCombination(),))
  def testClusterParamsAreOptional(self, distribution):
    # If combinations library doesn't raise an exception, the test is passed.
    pass

  @framework_combinations.generate(  # pylint: disable=redundant-keyword-arg
      framework_combinations.combine(
          ds1=combinations.NamedDistribution(
              "Strategy1", lambda: None, has_chief=True, num_workers=0),
          ds2=combinations.NamedDistribution(
              "Strategy2", lambda: None, has_chief=False, num_workers=1),
          ds3=combinations.NamedDistribution(
              "Strategy3", lambda: None, has_chief=True, num_workers=0),
      ),
      test_combinations=(combinations.ClusterCombination(),))
  def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3):
    # If combinations library doesn't raise an exception, the test is passed.
    pass

  @combinations.generate(combinations.combine(num_workers=2,))
  def testUseWithoutStrategy(self):
    # There's no perfect way to check if the test runs in a subprocess. We
    # approximate by checking the presence of TF_CONFIG, which is normally not
    # set to the main process.
    self.assertNotEqual(os.getenv("TF_CONFIG"), "")
Example #30
0
class RaggedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(make_dataset=[
                _make_scalar_ds,
                _make_vector_ds,
                _make_matrix_ds1,
                _make_matrix_ds2,
                _make_ragged_ds,
                _make_5dtensor_ds,
                _make_dict_ds,
                _make_tuple_ds,
                _make_matrix_ds_fully_defined,
            ],
                                 nrows=[0, 20, 23],
                                 batch_size=[4],
                                 drop_remainder=[True, False])))
    def testBasic(self, make_dataset, nrows, batch_size, drop_remainder):
        dataset = make_dataset(nrows)

        # Get the unbatched rows (so we can check expected values).
        get_next = self.getNext(dataset)
        rows = [
            nest.map_structure(_to_list, self.evaluate(get_next()))
            for _ in range(nrows)
        ]

        # Batch the dataset, and check that batches match slices from `rows`.
        batched_dataset = dataset.apply(
            batching.dense_to_ragged_batch(batch_size, drop_remainder))
        get_next = self.getNext(batched_dataset)
        for start_row in range(0, nrows, batch_size):
            end_row = start_row + batch_size
            if end_row > nrows and drop_remainder:
                break
            end_row = min(end_row, nrows)
            result = self.evaluate(get_next())

            # Use nest for potentially nested datasets.
            nest.map_structure_up_to(
                result, lambda a, *b: self.assertAllEqual(a, list(b)), result,
                *rows[start_row:end_row])

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testWithStructuredElements(self):
        nrows = 20
        batch_size = 4

        def make_structure(x):
            return {
                'dense':
                array_ops.fill([x], x),
                'ragged':
                ragged_concat_ops.stack(
                    [array_ops.stack([x]),
                     array_ops.stack([x, x])]),
                'sparse':
                sparse_tensor.SparseTensor([[x]], [x], [100])
            }

        dataset = dataset_ops.Dataset.from_tensor_slices(np.arange(nrows))
        dataset = dataset.map(make_structure)
        dataset = dataset.apply(batching.dense_to_ragged_batch(batch_size))
        get_next = self.getNext(dataset)

        for i in range(0, nrows, batch_size):
            result = self.evaluate(get_next())
            rows = range(i, i + batch_size)
            self.assertAllEqual(result['dense'], [[r] * r for r in rows])
            self.assertAllEqual(result['ragged'],
                                [[[r], [r, r]] for r in rows])
            self.assertAllEqual(result['sparse'].indices,
                                list(enumerate(rows)))
            self.assertAllEqual(result['sparse'].values, rows)
            self.assertAllEqual(result['sparse'].dense_shape, [4, 100])