Пример #1
0
class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testOptionsDefault(self):
        ds = dataset_ops.Dataset.range(0)
        self.assertEqual(dataset_ops.Options(), ds.options())

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsOnce(self):
        options = dataset_ops.Options()
        ds = dataset_ops.Dataset.range(0).with_options(options).cache()
        self.assertEqual(options, ds.options())

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsTwiceSame(self):
        options = dataset_ops.Options()
        options.experimental_optimization.autotune = True
        ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
            options)
        self.assertEqual(options, ds.options())

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsTwiceDifferentOptions(self):
        options1 = dataset_ops.Options()
        options1.experimental_optimization.autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_deterministic = False
        ds = dataset_ops.Dataset.range(0)
        ds = ds.with_options(options1)
        ds = ds.with_options(options2)
        self.assertTrue(ds.options().experimental_optimization.autotune)
        # Explicitly check that flag is False since assertFalse allows None
        self.assertIs(ds.options().experimental_deterministic, False)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsTwiceSameOption(self):
        if sys.version_info >= (3, 8) and platform.system() == "Windows":
            # TODO(b/165013260): Fix this
            self.skipTest(
                "Test is currently broken on Windows with Python 3.8")
        options1 = dataset_ops.Options()
        options1.experimental_optimization.autotune = False
        options2 = dataset_ops.Options()
        options2.experimental_optimization.autotune = True
        ds = dataset_ops.Dataset.range(0)
        ds = ds.with_options(options1)
        ds = ds.with_options(options2)
        self.assertTrue(ds.options().experimental_optimization.autotune)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsMergeOptionsFromMultipleInputs(self):
        options1 = dataset_ops.Options()
        options1.experimental_optimization.autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_deterministic = True
        ds1 = dataset_ops.Dataset.range(0).with_options(options1)
        ds2 = dataset_ops.Dataset.range(0).with_options(options2)
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertTrue(ds.options().experimental_optimization.autotune)
        self.assertTrue(ds.options().experimental_deterministic)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsHaveDefaults(self):
        options1 = dataset_ops.Options()
        options2 = dataset_ops.Options()
        self.assertIsNot(options1.experimental_optimization,
                         options2.experimental_optimization)
        self.assertIsNot(options1.experimental_stats,
                         options2.experimental_stats)
        self.assertIsNot(options1.experimental_threading,
                         options2.experimental_threading)
        self.assertEqual(options1.experimental_optimization,
                         optimization_options.OptimizationOptions())
        self.assertEqual(options1.experimental_stats,
                         stats_options.StatsOptions())
        self.assertEqual(options1.experimental_threading,
                         threading_options.ThreadingOptions())

    @combinations.generate(test_base.default_test_combinations())
    def testMutatingOptionsRaiseValueError(self):
        ds = dataset_ops.Dataset.range(0)
        options1 = dataset_ops.Options()
        options1.experimental_slack = True
        options2 = dataset_ops.Options()
        options2.experimental_optimization.autotune = True
        ds = ds.with_options(options1)
        ds = ds.map(lambda x: 2 * x)
        ds = ds.with_options(options2)
        with self.assertRaises(ValueError):
            dataset_options = ds.options()
            dataset_options.experimental_deterministic = True

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedDataset(self):
        ds = dataset_ops.Dataset.from_tensors(0)
        result = ds

        for _ in range(999):
            result = result.concatenate(ds)
        self.assertDatasetProduces(result, [0] * 1000)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsProtoRoundTrip(self):
        options = dataset_ops.Options()
        options.experimental_deterministic = True
        options.experimental_external_state_policy = (
            distribute_options.ExternalStatePolicy.FAIL)
        options.experimental_distribute.auto_shard_policy = (
            distribute_options.AutoShardPolicy.DATA)
        options.experimental_distribute.num_devices = 1000
        options.experimental_optimization.apply_default_optimizations = True
        options.experimental_optimization.autotune = True
        options.experimental_optimization.autotune_buffers = True
        options.experimental_optimization.autotune_cpu_budget = 10
        options.experimental_optimization.autotune_ram_budget = 20
        options.experimental_optimization.filter_fusion = True
        options.experimental_optimization.filter_with_random_uniform_fusion = True
        options.experimental_optimization.hoist_random_uniform = True
        options.experimental_optimization.map_and_batch_fusion = True
        options.experimental_optimization.map_and_filter_fusion = True
        options.experimental_optimization.map_fusion = True
        options.experimental_optimization.map_parallelization = True
        options.experimental_optimization.map_vectorization.enabled = True
        options.experimental_optimization.map_vectorization.use_choose_fastest = (
            True)
        options.experimental_optimization.noop_elimination = True
        options.experimental_optimization.parallel_batch = True
        options.experimental_optimization.reorder_data_discarding_ops = True
        options.experimental_optimization.shuffle_and_repeat_fusion = True
        options.experimental_slack = True
        options.experimental_threading.max_intra_op_parallelism = 30
        options.experimental_threading.private_threadpool_size = 40
        pb = options._to_proto()
        result = dataset_ops.Options()
        result._from_proto(pb)
        self.assertEqual(options, result)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsProtoDefaultValuesRoundTrip(self):
        options = dataset_ops.Options()
        pb = options._to_proto()
        result = dataset_ops.Options()
        result._from_proto(pb)
        self.assertEqual(options, result)

    @combinations.generate(test_base.default_test_combinations())
    def testProtoOptionsDefaultValuesRoundTrip(self):
        pb = dataset_options_pb2.Options()
        options = dataset_ops.Options()
        options._from_proto(pb)
        result = options._to_proto()
        expected_pb = dataset_options_pb2.Options()
        expected_pb.distribute_options.CopyFrom(
            dataset_options_pb2.DistributeOptions())
        expected_pb.optimization_options.CopyFrom(
            dataset_options_pb2.OptimizationOptions())
        expected_pb.optimization_options.map_vectorization.CopyFrom(
            dataset_options_pb2.MapVectorization())
        expected_pb.threading_options.CopyFrom(
            dataset_options_pb2.ThreadingOptions())
        self.assertProtoEquals(expected_pb, result)
Пример #2
0
class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                skip=[0, 5, 10], take=[1], error=[None], error_msg=[None]) +
            combinations.combine(skip=[100],
                                 take=[1],
                                 error=[errors.InvalidArgumentError],
                                 error_msg=["Dataset was empty."]) +
            combinations.combine(
                skip=[0],
                take=[2],
                error=[errors.InvalidArgumentError],
                error_msg=["Dataset had more than one element."])))
    def testGetSingleElement(self, skip, take, error=None, error_msg=None):
        def make_sparse(x):
            x_1d = array_ops.reshape(x, [1])
            x_2d = array_ops.reshape(x, [1, 1])
            return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)

        dataset = dataset_ops.Dataset.range(100).skip(skip).map(
            lambda x: (x * x, make_sparse(x))).take(take)
        if error is None:
            dense_val, sparse_val = self.evaluate(
                get_single_element.get_single_element(dataset))
            self.assertEqual(skip * skip, dense_val)
            self.assertAllEqual([[skip]], sparse_val.indices)
            self.assertAllEqual([skip], sparse_val.values)
            self.assertAllEqual([skip], sparse_val.dense_shape)
        else:
            with self.assertRaisesRegexp(error, error_msg):
                self.evaluate(get_single_element.get_single_element(dataset))

    @combinations.generate(test_base.default_test_combinations())
    def testWindow(self):
        """Test that `get_single_element()` can consume a nested dataset."""
        def flat_map_func(ds):
            batched = ds.batch(2)
            element = get_single_element.get_single_element(batched)
            return dataset_ops.Dataset.from_tensors(element)

        dataset = dataset_ops.Dataset.range(10).window(2).flat_map(
            flat_map_func)
        self.assertDatasetProduces(dataset,
                                   [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])

    @combinations.generate(test_base.default_test_combinations())
    def testSideEffect(self):
        counter_var = variables.Variable(0)

        def increment_fn(x):
            counter_var.assign_add(1)
            return x

        def dataset_fn():
            return dataset_ops.Dataset.range(1).map(increment_fn)

        @function.defun
        def fn():
            _ = get_single_element.get_single_element(dataset_fn())
            return "hello"

        self.evaluate(counter_var.initializer)
        self.assertEqual(self.evaluate(fn()), b"hello")
        self.assertEqual(self.evaluate(counter_var), 1)

    @combinations.generate(test_base.default_test_combinations())
    def testAutomaticControlDependencies(self):
        counter_var = variables.Variable(1)

        def increment_fn(x):
            counter_var.assign(counter_var + 1)
            return x

        def multiply_fn(x):
            counter_var.assign(counter_var * 2)
            return x

        def dataset1_fn():
            return dataset_ops.Dataset.range(1).map(increment_fn)

        def dataset2_fn():
            return dataset_ops.Dataset.range(1).map(multiply_fn)

        @function.defun
        def fn():
            _ = get_single_element.get_single_element(dataset1_fn())
            _ = get_single_element.get_single_element(dataset2_fn())
            return "hello"

        self.evaluate(counter_var.initializer)
        self.assertEqual(self.evaluate(fn()), b"hello")
        self.assertEqual(self.evaluate(counter_var), 4)
Пример #3
0
class LocalTaskGarbageCollectTest(data_service_test_base.TestBase,
                                  parameterized.TestCase):
    """Tests garbage collecting unused local worker tasks.

  The user typically creates an iterator in each epoch. This should delete the
  previous iterator and releases the resources of it.
  """
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochs(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_epochs, num_steps = 5, 5
        dataset = self._make_distributed_infinite_range_dataset(cluster)
        for _ in range(num_epochs):
            # For each iteration, the previous iterator is garbage collected.
            get_next = self.getNext(dataset)
            for i in range(num_steps):
                self.assertEqual(self.evaluate(get_next()), i)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochsSharedJob(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_epochs, num_steps = 5, 5
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        for _ in range(num_epochs):
            # For each iteration, the previous iterator is garbage collected.
            get_next = self.getNext(dataset)
            for i in range(num_steps):
                self.assertEqual(self.evaluate(get_next()), i)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_remote_workers=[0, 3],
                                 job_name=[None, "shared_job_name"])))
    def testRepeatDistributedDataset(self, num_remote_workers, job_name):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        dataset = self.make_distributed_range_dataset(10,
                                                      cluster,
                                                      job_name=job_name,
                                                      target_workers="LOCAL")
        dataset = dataset.repeat(3)
        self.assertDatasetProduces(dataset, list(range(10)) * 3)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testReadFromDeletedTask(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Re-creating the dataset resets the iterator index, so the second iterator
        # reads from the same task as the first, which has been deleted.
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "which has been deleted."):
            _ = self.evaluate(get_next())

    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testReadFromDeletedTask_GraphMode(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        with self.session() as sess:
            get_next = self.getNext(dataset)
            for i in range(num_steps):
                self.assertEqual(sess.run(get_next()), i)

        # Re-creating the dataset resets the iterator index, so the second iterator
        # reads from the same task as the first, which has been deleted.
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "which has been deleted."):
            with self.session() as sess:
                get_next = self.getNext(dataset)
                sess.run(get_next())

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochs_WorkerRestart(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")

        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Verifies the worker re-creates the task after the iterator is deleted and
        # the worker restarts.
        del get_next
        cluster.restart_local_workers()

        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(num_remote_workers=[0, 3])))
    def testMultipleEpochs_DispatcherRestart(self, num_remote_workers):
        num_local_workers = 1
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)

        num_steps = 10
        dataset = self._make_distributed_infinite_range_dataset(
            cluster, job_name="shared_job_name")
        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

        # Verifies the worker re-creates the task after the iterator is deleted and
        # the dispatcher restarts.
        del get_next
        cluster.restart_dispatcher()

        get_next = self.getNext(dataset)
        for i in range(num_steps):
            self.assertEqual(self.evaluate(get_next()), i)

    def _make_distributed_infinite_range_dataset(self, cluster, job_name=None):
        dataset = dataset_ops.Dataset.range(1000000).repeat()
        return self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            job_name=job_name,
            processing_mode=ShardingPolicy.OFF,
            target_workers="LOCAL")
Пример #4
0
class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):

    # TODO(b/121264236)
    @combinations.generate(
        combinations.combine(tf_api_version=[1], mode=["graph"]))
    def testPrefetchWithSlackOption(self):
        """Determines slack_period based on num devices attached to iterator."""
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.prefetch(1)
        options = dataset_ops.Options()
        options.experimental_slack = True
        dataset = dataset.with_options(options)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])
        dataset = multi_device_iterator._dataset  # pylint: disable=protected-access
        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testPrefetchWithSlackOptionWithoutIterator(self):
        """Defaults to slack period of 1 without iterator."""
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.prefetch(1)
        options = dataset_ops.Options()
        options.experimental_slack = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, range(10))

    @combinations.generate(test_base.default_test_combinations())
    def testWithPassthroughDataset(self):
        """Should still work with a passthrough dataset after prefetch()."""
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.prefetch(1)
        dataset = dataset.map(lambda x: x + 1)
        options = dataset_ops.Options()
        options.experimental_slack = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, range(1, 11))

    @combinations.generate(test_base.default_test_combinations())
    def testNoErrorWithoutPrefetch(self):
        """The rewrite should not fail if there is no prefetch() in the pipeline."""
        dataset = dataset_ops.Dataset.range(10)
        options = dataset_ops.Options()
        options.experimental_slack = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, range(10))

    @combinations.generate(test_base.default_test_combinations())
    def testNoErrorWithInvalidDataset(self):
        """With a nested dataset op after prefetch, the rewrite should fail."""
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.prefetch(1)
        dataset = dataset.flat_map(dataset_ops.Dataset.from_tensors)
        options = dataset_ops.Options()
        options.experimental_slack = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, range(10))
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testBasic(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        self.assertEqual(
            [[8] if drop_remainder else [None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testCanHandleUnknownRank(self):
        dataset = dataset_ops.Dataset.from_tensors("xxx")
        # decode_image results in a tensor of completely unknown shape (i.e. unknown
        # rank)
        dataset = dataset.map(image_ops.decode_image)
        self.assertEqual([tensor_shape.TensorShape(None)],
                         _flat_shapes(dataset))
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        # Note that we are just testing the dataset shapes, not the actual output.
        self.assertEqual([tensor_shape.TensorShape(None)],
                         _flat_shapes(rebatched_dataset))

    @combinations.generate(test_base.default_test_combinations())
    def testCanHandleUnknownDims(self):
        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.batch(10, drop_remainder=False)
        dataset = dataset.batch(10, drop_remainder=False)
        self.assertEqual([[None, None]],
                         [ts.as_list() for ts in _flat_shapes(dataset)])
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        # Note that we are just testing the dataset shapes, not the actual output.
        self.assertEqual(
            [[None, None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    @combinations.generate(test_base.default_test_combinations())
    def testScalarInputError(self):
        dataset = dataset_ops.Dataset.range(1024)
        distribute._RebatchDataset(dataset.batch(4), num_replicas=4)
        with self.assertRaisesRegex(ValueError, ("You can fix the issue "
                                                 "by adding the `batch`")):
            distribute._RebatchDataset(dataset, num_replicas=4)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testBatchNotDivisibleByNumReplicas(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
        self.assertEqual(
            [[None]], [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
        expected_output = []
        i = 0
        for _ in range(32):  # number of steps
            # first four minibatches have seven elements
            for _ in range(4):
                expected_output.append([k for k in range(i, i + 7)])
                i += 7
            # last minibatch has four elements
            expected_output.append([k for k in range(i, i + 4)])
            i += 4
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testBatchSizeNotDivisibleByNumReplicas2(self):
        dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
        # This will rebatch into sub-batches of size 4, since
        # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get
        # data.
        expected_output = [[k for k in range(i, i + 4)]
                           for i in range(0, 16, 4)]
        expected_output.extend([[]])  # Last replica gets an empty batch
        expected_output.extend([[k for k in range(i, i + 4)]
                                for i in range(16, 32, 4)])
        expected_output.extend([[]])  # Last replica gets an empty batch
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testTupleOutput(self):
        dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
            32)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        expected_output = [
            (
                [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testNestedDictionaryOutput(self):
        dataset = dataset_ops.Dataset.range(1024).map(lambda x: {
            "a": x,
            "b": {
                "c": x
            }
        }).batch(32)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        expected_output = [
            {
                "a": [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                "b": {
                    "c": [k for k in range(i, i + 8)]
                }
            } for i in range(0, 1024, 8)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testFinalPartialBatch(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1032).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        self.assertEqual(
            [[8] if drop_remainder else [None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        # if drop_remainder, the final partial batch is dropped, even though it
        # makes up a complete minibatch.
        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        if not drop_remainder:
            # The last partial batch of size 8 is split over 4 replicas
            expected_output.extend([[k for k in range(i, i + 2)]
                                    for i in range(1024, 1032, 2)])
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(drop_remainder=[True, False])))
    def testFinalPartialBatchAfterRebatch(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(34).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        self.assertEqual(
            [[8] if drop_remainder else [None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 32, 8)]  # pylint: disable=g-complex-comprehension
        if not drop_remainder:
            # The last partial batch of size 2 is split over 4 replicas
            expected_output += [[32], [33], [], []]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testMultipleBatches(self):
        dataset = dataset_ops.Dataset.range(128).batch(4).batch(8)
        self.assertEqual([[None, None]],
                         [ts.as_list() for ts in _flat_shapes(dataset)])

        # Each element is a list of 8 elements where each element is a list of 4.
        expected_output = [
            [
                [j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                for j in range(i, i + 32, 4)
            ]  # generates 8 elements
            for i in range(0, 128, 32)
        ]
        self.assertDatasetProduces(dataset, expected_output)

        rebatched_dataset = distribute._RebatchDataset(dataset, 4)
        self.assertEqual(
            [[None, None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
        # Each element is a list of 2 elements where each element is a list of 4.
        expected_output = [
            [
                [j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                for j in range(i, i + 8, 4)
            ]  # generates 2 elements
            for i in range(0, 128, 8)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testRaggedTensorDataset(self):
        # Set up a dataset that produces ragged tensors with a static batch size.
        row_lengths = np.random.randint(8, size=128)
        values = np.random.normal(size=np.sum(row_lengths)).astype(np.float32)
        dataset = dataset_ops.Dataset.from_tensor_slices(
            ragged_tensor.RaggedTensor.from_row_lengths(values, row_lengths))
        dataset = dataset.batch(32, drop_remainder=True)

        # The map changes the internal representation of the ragged tensor.
        # This test will fail if we don't normalize the tensor representation.
        dataset = dataset.map(lambda x: x)

        dataset = distribute._RebatchDataset(dataset, num_replicas=8)
        # After rebatching, batch size is now 4.
        expected_output = []
        value_index = 0
        for batch_row_lengths in row_lengths.reshape((-1, 4)):
            num_values = np.sum(batch_row_lengths)
            expected_output.append(
                ragged_tensor.RaggedTensor.from_row_lengths(
                    values[value_index:(value_index + num_values)],
                    batch_row_lengths))
            value_index += num_values
        self.assertDatasetProduces(dataset, expected_output)

    @combinations.generate(test_base.default_test_combinations())
    def testNoOutputShapes(self):
        # Some datasets, e.g. datasets with None tensors, have components without
        # output shapes. Test that this doesn't break rebatching shape inference
        # logic.
        dataset = dataset_ops.Dataset.range(4)
        dataset = dataset.map(lambda x: (x, None))
        dataset = dataset.batch(4, drop_remainder=True)
        _ = distribute._RebatchDataset(dataset, num_replicas=2)
Пример #6
0
class AutoShardWithRebatchDatasetTest(
    reader_dataset_ops_test_base.TFRecordDatasetTestBase,
    parameterized.TestCase):

  def _setUpFiles(self, num_files, num_records_per_file):
    self._num_files = num_files
    self._num_records = num_records_per_file
    self.test_filenames = self._createFiles()

  @combinations.generate(test_base.default_test_combinations())
  def testFileShardingWithLegacyRebatch(self):
    # Tests that RebatchDatasetV1 is a passthrough op.
    self._setUpFiles(num_files=5, num_records_per_file=10)
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.apply(
        testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)
    expected = [[self._record(3, i)] for i in range(10)]
    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(test_base.default_test_combinations())
  def testFileShardingWithRebatch(self):
    # Tests that RebatchDatasetV2 is a passthrough op.
    self._setUpFiles(num_files=3, num_records_per_file=5)
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.apply(
        testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 2])
    dataset = distribute._AutoShardDataset(dataset, 3, 1)
    expected = [[self._record(1, 0), self._record(1, 1)], [self._record(1, 2)],
                [self._record(1, 3), self._record(1, 4)]]
    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.times(
              combinations.combine(sharding_policy=[
                  distribute_options.AutoShardPolicy.DATA,
                  distribute_options.AutoShardPolicy.AUTO
              ]), combinations.combine(with_prefetch=[True, False]))))
  def testUseLegacyRebatchWithDataSharding(self, sharding_policy,
                                           with_prefetch):
    # This test simulates a distributed environment with 3 workers, each with
    # 1 replica.
    dataset = dataset_ops.Dataset.range(8)
    dataset = dataset.batch(4)
    options = dataset_ops.Options()
    options.experimental_distribute.auto_shard_policy = sharding_policy
    dataset = dataset.with_options(options)
    # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to
    # RebatchDataset(V1) for correctness reasons. This will modify the output
    # of the dataset.
    worker_a_dataset = distribute._RebatchDataset(
        dataset, batch_sizes=[2, 1, 1])
    if with_prefetch:
      worker_a_dataset = worker_a_dataset.prefetch(1)
    worker_a_dataset = distribute._AutoShardDataset(
        worker_a_dataset, 3, 0, num_replicas=3)
    expected = [[0, 1], [4, 5]]
    self.assertDatasetProduces(worker_a_dataset, expected)

    worker_b_dataset = distribute._RebatchDataset(
        dataset, batch_sizes=[1, 1, 2])
    if with_prefetch:
      worker_b_dataset = worker_b_dataset.prefetch(1)
    worker_b_dataset = distribute._AutoShardDataset(
        worker_b_dataset, 3, 1, num_replicas=3)
    expected = [[2, 3], [6, 7]]
    self.assertDatasetProduces(worker_b_dataset, expected)

    worker_c_dataset = distribute._RebatchDataset(
        dataset, batch_sizes=[1, 2, 1])
    if with_prefetch:
      worker_c_dataset = worker_c_dataset.prefetch(1)
    worker_c_dataset = distribute._AutoShardDataset(
        worker_c_dataset, 3, 2, num_replicas=3)
    expected = [[], []]
    self.assertDatasetProduces(worker_c_dataset, expected)
class ShuffleDatasetSerializationTest(
        dataset_serialization_test_base.DatasetSerializationTestBase,
        parameterized.TestCase):
    def _build_shuffle_dataset(
        self,
        range_limit=10,
        num_repeats=5,
        buffer_size=5,
        seed=None,
        reshuffle_each_iteration=None,
    ):
        return dataset_ops.Dataset.range(range_limit).shuffle(
            buffer_size,
            seed=seed,
            reshuffle_each_iteration=reshuffle_each_iteration).repeat(
                num_repeats)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(reshuffle_each_iteration=[True, False],
                                 buffer_size=[1, 3, 5, 8, 10])))
    def testShuffleCore(self, reshuffle_each_iteration, buffer_size):
        seed = 55
        range_limit = 5
        num_repeats = 2
        num_outputs = range_limit * num_repeats
        # pylint: disable=g-long-lambda
        self.run_core_tests(
            lambda: self._build_shuffle_dataset(range_limit=range_limit,
                                                num_repeats=num_repeats,
                                                buffer_size=buffer_size,
                                                seed=seed,
                                                reshuffle_each_iteration=
                                                reshuffle_each_iteration),
            num_outputs)

    @combinations.generate(
        combinations.combine(tf_api_version=1,
                             mode=["graph"],
                             reshuffle_each_iteration=[True, False],
                             buffer_size=[1, 3, 5, 8, 10]))
    def testMultipleIterators(self, reshuffle_each_iteration, buffer_size):
        range_limit = 5
        num_repeats = 2
        num_outputs = range_limit * num_repeats

        def ds_fn():
            # pylint: disable=cell-var-from-loop
            return self._build_shuffle_dataset(
                range_limit=range_limit,
                num_repeats=num_repeats,
                buffer_size=buffer_size,
                seed=None,  # Iterator seeds are generated non-deterministically.
                reshuffle_each_iteration=reshuffle_each_iteration)
            # pylint: enable=cell-var-from-loop

        with ops.Graph().as_default() as g:
            ds = ds_fn()
            iterators = [
                ds.make_one_shot_iterator(),
                ds.make_one_shot_iterator()
            ]
            get_next_ops = [it.get_next() for it in iterators]
            saveables = [
                contrib_iterator_ops.make_saveable_from_iterator(it)
                for it in iterators
            ]
            for saveable in saveables:
                ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
            saver = saver_lib.Saver(allow_empty=True)
            with self.session(graph=g) as sess:
                self._save(sess, saver)
                expected = [
                    self.evaluate(get_next_ops) for _ in range(num_outputs)
                ]
                self._restore(saver, sess)
                actual = [
                    self.evaluate(get_next_ops) for _ in range(num_outputs)
                ]
                self.match(expected, actual)
Пример #8
0
class FileCacheTest(test_base.DatasetTestBase, parameterized.TestCase):

  def setUp(self):
    super(FileCacheTest, self).setUp()
    self.tmp_dir = tempfile.mkdtemp()
    self.cache_prefix = path.join(self.tmp_dir, "cache")

  def tearDown(self):
    if self.tmp_dir:
      shutil.rmtree(self.tmp_dir, ignore_errors=True)
    super(FileCacheTest, self).tearDown()

  @combinations.generate(test_base.default_test_combinations())
  def testCacheDatasetPassthrough(self):
    components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                  np.array([9.0, 10.0, 11.0, 12.0]))

    def dataset_fn(count=5, filename=None):
      repeat_dataset = (
          dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
      if filename:
        return repeat_dataset.cache(filename)
      else:
        return repeat_dataset

    self.assertEqual(
        tuple([c.shape[1:] for c in components]),
        dataset_ops.get_legacy_output_shapes(dataset_fn()))

    get_next = self.getNext(dataset_fn())

    # First run without caching to collect the "ground truth".
    elements = []
    for _ in range(20):
      elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

    # Assert that the cached dataset has the same elements as the
    # "ground truth".
    get_next = self.getNext(dataset_fn(filename=self.cache_prefix))
    cached_elements = []
    for _ in range(20):
      cached_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertAllEqual(elements, cached_elements)

    # Re-initialize with an empty upstream (to throw errors.OutOfRangeError
    # if we didn't use the cache).
    get_next = self.getNext(dataset_fn(count=0, filename=self.cache_prefix))
    replayed_elements = []
    for _ in range(20):
      replayed_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertEqual(cached_elements, replayed_elements)

    # Re-initialize with an empty upstream and a missing cache file (should
    # throw errors.OutOfRangeError immediately).
    get_next = self.getNext(
        dataset_fn(count=0, filename=self.cache_prefix + "nonsense"))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(test_base.default_test_combinations())
  def testConcurrentWriters(self):
    components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                  np.array([9.0, 10.0, 11.0, 12.0]))

    cache_dataset1 = (
        dataset_ops.Dataset.from_tensor_slices(components).cache(
            self.cache_prefix))
    cache_dataset2 = (
        dataset_ops.Dataset.from_tensor_slices(components).cache(
            self.cache_prefix))

    get_next1 = self.getNext(cache_dataset1)
    get_next2 = self.getNext(cache_dataset2)

    self.evaluate(get_next1())  # this should succeed

    with self.assertRaises(errors.AlreadyExistsError):
      self.evaluate(get_next2())

    self.evaluate(get_next1())  # this should continue to succeed

  @combinations.generate(test_base.default_test_combinations())
  def testConcurrentReaders(self):
    components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                  np.array([9.0, 10.0, 11.0, 12.0]))

    cache_dataset1 = (
        dataset_ops.Dataset.from_tensor_slices(components).cache(
            self.cache_prefix))
    cache_dataset2 = (
        dataset_ops.Dataset.from_tensor_slices(components).cache(
            self.cache_prefix))

    get_next1 = self.getNext(cache_dataset1)
    get_next2 = self.getNext(cache_dataset2)

    elements = []
    for _ in range(4):
      elements.append(self.evaluate(get_next1()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next1())

    # Re-initialize
    get_next1 = self.getNext(cache_dataset1, requires_initialization=True)
    get_next2 = self.getNext(cache_dataset2, requires_initialization=True)

    # Reading concurrently should succeed.
    elements_itr1 = []
    elements_itr2 = []
    elements_itr2.append(self.evaluate(get_next2()))
    elements_itr1.append(self.evaluate(get_next1()))
    elements_itr2.append(self.evaluate(get_next2()))
    elements_itr1.append(self.evaluate(get_next1()))
    # Intentionally reversing the order
    elements_itr1.append(self.evaluate(get_next1()))
    elements_itr2.append(self.evaluate(get_next2()))
    elements_itr1.append(self.evaluate(get_next1()))
    elements_itr2.append(self.evaluate(get_next2()))

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next2())

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next1())

    self.assertAllEqual(elements, elements_itr1)
    self.assertAllEqual(elements, elements_itr2)

  @combinations.generate(test_base.default_test_combinations())
  def testReadingPastEndOfSequence(self):
    dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix)
    dataset = dataset.map(lambda a: a).batch(4).repeat(2)
    expected_output = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] * 2
    self.assertDatasetProduces(dataset, expected_output)

  @combinations.generate(test_base.default_test_combinations())
  def testCleaningUpCacheFiles(self):

    def do_test(i):
      dataset = dataset_ops.Dataset.range(10).cache(self.cache_prefix)
      get_next = self.getNext(dataset)
      for _ in range(i):
        try:
          self.evaluate(get_next())
        except errors.OutOfRangeError:
          break

    if not context.executing_eagerly():
      self.skipTest(
          "Test requires eager mode for iterators to be deconstructed")

    for i in [0, 3, 10, 12, 15]:
      do_test(i)
class MultiDeviceIteratorTest(test_base.DatasetTestBase,
                              parameterized.TestCase):
    def setUp(self):
        super(MultiDeviceIteratorTest, self).setUp()
        self._devices = self.configureDevicesForMultiDeviceTest(3)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(num_inits=[0, 1, 42])))
    def testInitOnly(self, num_inits):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        for _ in range(num_inits):
            self.evaluate(multi_device_iterator.initializer)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(max_buffer_size=[0, 1, 10],
                                 prefetch_buffer_size=[0, 1, 10])))
    def testBasic(self, prefetch_buffer_size, max_buffer_size):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]],
            max_buffer_size=max_buffer_size,
            prefetch_buffer_size=prefetch_buffer_size)

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testOneOnSameDevice(self):
        dataset = dataset_ops.Dataset.range(12)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[0], self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 12, 3):
            elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_0))
            self.assertEqual(i + 1, self.evaluate(elem_on_1))
            self.assertEqual(i + 2, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_0)
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatDevices(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[1]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elements = multi_device_iterator.get_next()
            elem_on_1, elem_on_2 = elements
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elements = multi_device_iterator.get_next()
            elem_on_1, elem_on_2 = elements
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testNotFullyDivisible(self):
        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 8, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        elem_on_1 = multi_device_iterator.get_next(self._devices[1])
        self.assertEqual(8, self.evaluate(elem_on_1))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.default_test_combinations())
    def testGetNextAsOptional(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
            has_elem_1, get_elem_1 = self.evaluate(
                [elem_on_1.has_value(),
                 elem_on_1.get_value()])
            has_elem_2, get_elem_2 = self.evaluate(
                [elem_on_2.has_value(),
                 elem_on_2.get_value()])
            self.assertTrue(has_elem_1)
            self.assertEqual(i, get_elem_1)
            self.assertTrue(has_elem_2)
            self.assertEqual(i + 1, get_elem_2)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
        has_elem_1 = elem_on_1.has_value()
        has_elem_2 = elem_on_2.has_value()
        self.assertFalse(self.evaluate(has_elem_1))
        self.assertFalse(self.evaluate(has_elem_2))
        with self.assertRaises(errors.InvalidArgumentError):
            elem_1 = elem_on_1.get_value()
            self.evaluate(elem_1)
        with self.assertRaises(errors.InvalidArgumentError):
            elem_2 = elem_on_2.get_value()
            self.evaluate(elem_2)

    @combinations.generate(test_base.default_test_combinations())
    def testUneven(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]], max_buffer_size=4)

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1 = multi_device_iterator.get_next(self._devices[1])
            self.assertEqual(i, self.evaluate(elem_on_1))
        for i in range(0, 10, 2):
            elem_on_2 = multi_device_iterator.get_next(self._devices[2])
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)

    @combinations.generate(test_base.graph_only_combinations())
    def testMultipleInitializationsGraph(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]],
            prefetch_buffer_size=4)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()

        for _ in range(5):
            self.evaluate(multi_device_iterator.initializer)
            self.assertEqual([(0, 0), (1, 1)],
                             self.evaluate([elem_on_1, elem_on_2]))

    @combinations.generate(test_base.eager_only_combinations())
    def testMultipleInitializationsEager(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))

        for _ in range(5):
            multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
                dataset, [self._devices[1], self._devices[2]],
                prefetch_buffer_size=4)
            self.evaluate(multi_device_iterator.initializer)
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual([(0, 0), (1, 1)],
                             self.evaluate([elem_on_1, elem_on_2]))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimization(self):
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
        dataset = dataset.skip(0)  # this should be optimized away
        dataset = dataset.cache()

        options = options_lib.Options()
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)

        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)
Пример #10
0
class MemoryCacheTest(test_base.DatasetTestBase, parameterized.TestCase):

  @combinations.generate(test_base.default_test_combinations())
  def testCacheDatasetPassthrough(self):
    with ops.device("cpu:0"):
      repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
      dataset = dataset_ops.Dataset.range(3).flat_map(
          lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count))

      cached_dataset = dataset.cache().repeat(2)
      uncached_dataset = dataset.repeat(2)

      self.evaluate(repeat_count.initializer)
      # Needs to be initializable to capture the variable.
      cached_next = self.getNext(cached_dataset, requires_initialization=True)
      uncached_next = self.getNext(
          uncached_dataset, requires_initialization=True)
      for i in range(3):
        for _ in range(10):
          self.assertEqual(self.evaluate(cached_next()), i)
          self.assertEqual(self.evaluate(uncached_next()), i)

      self.evaluate(repeat_count.assign(0))

      # The uncached iterator should now be empty.
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(uncached_next())

      # The cached iterator replays from cache.
      for i in range(3):
        for _ in range(10):
          self.assertEqual(self.evaluate(cached_next()), i)

      # The cached iterator should now be empty.
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(cached_next())

  @combinations.generate(test_base.default_test_combinations())
  def testEmptyCacheReading(self):
    components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
                  np.array([9.0, 10.0, 11.0, 12.0]))

    repeat_dataset = (
        dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
    cache_dataset = repeat_dataset.cache()

    self.assertDatasetProduces(cache_dataset, expected_output=[])

  @combinations.generate(test_base.default_test_combinations())
  def testConcurrentReaders(self):

    dataset_fn = lambda: dataset_ops.Dataset.range(5).cache()
    d1 = dataset_fn().map(lambda x: x + 1)
    d2 = dataset_fn().map(lambda x: x + 6)

    get_next1 = self.getNext(d1)

    self.assertEqual(1, self.evaluate(get_next1()))
    self.assertEqual(2, self.evaluate(get_next1()))
    self.assertEqual(3, self.evaluate(get_next1()))

    get_next2 = self.getNext(d2)

    self.assertEqual(6, self.evaluate(get_next2()))
    self.assertEqual(7, self.evaluate(get_next2()))
    self.assertEqual(4, self.evaluate(get_next1()))  # interleave execution
    self.assertEqual([8, 5],
                     [self.evaluate(get_next2()),
                      self.evaluate(get_next1())])
    self.assertEqual(9, self.evaluate(get_next2()))
    self.assertEqual(10, self.evaluate(get_next2()))

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next2())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next1())

  @combinations.generate(test_base.default_test_combinations())
  def testCacheTakeRepeat(self):
    dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)

    expected_output = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
    self.assertDatasetProduces(dataset, expected_output=expected_output)

  @combinations.generate(test_base.default_test_combinations())
  def testCacheRepeatEpochs(self):
    counter = variables.Variable(0)
    self.evaluate(counter.initializer)

    def increment_fn(x):
      counter.assign_add(1)
      return x

    dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache().repeat(2)
    get_next = self.getNext(dataset, requires_initialization=True)

    # first epoch
    for i in range(10):
      self.assertEqual(i, self.evaluate(counter))
      self.assertEqual(i, self.evaluate(get_next()))
    # second epoch
    for i in range(10):
      self.assertEqual(10, self.evaluate(counter))
      self.assertEqual(i, self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testCacheIterationEpochs(self):
    counter = variables.Variable(0)
    self.evaluate(counter.initializer)

    def increment_fn(x):
      counter.assign_add(1)
      return x

    dataset = dataset_ops.Dataset.range(10).map(increment_fn).cache()

    # first epoch
    i = 0
    for elem in dataset:
      self.assertEqual(i, self.evaluate(elem))
      i += 1
      self.assertEqual(i, self.evaluate(counter))

    # second epoch
    i = 0
    for elem in dataset:
      self.assertEqual(10, self.evaluate(counter))
      self.assertEqual(i, self.evaluate(elem))
      i += 1

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testCacheV2ResourceCapture(self):

    def make_dataset():
      ids = dataset_ops.Dataset.range(10)
      ids = ids.cache()

      def interleave_fn(dataset, _):
        return dataset

      dataset = dataset_ops.Dataset.range(1)
      dataset = dataset.interleave(functools.partial(interleave_fn, ids))
      return dataset

    results = []
    for elem in make_dataset():
      results.append(elem.numpy())

    self.assertAllEqual(results, range(10))

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testCacheV2ConcurrentIterators(self):

    dataset = dataset_ops.Dataset.range(10).cache()

    it1 = iter(dataset)
    it2 = iter(dataset)

    for i in range(10):
      self.assertEqual(next(it1), i)
      self.assertEqual(next(it2), i)

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testCacheKnownCardinality(self):

    # Check that a dataset which produces random permutation of range(10) ends
    # up being cached when we read all of its element but do not reach EOF.
    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.shuffle(10, reshuffle_each_iteration=True).cache()

    it = iter(dataset)

    results = []
    for _ in range(10):
      results.append(next(it))

    it = iter(dataset)
    for i in range(10):
      self.assertEqual(next(it), results[i])

  @combinations.generate(test_base.eager_only_combinations())
  def testCheckpointFinishedCache(self):
    num_elements = 10
    ds = dataset_ops.Dataset.range(num_elements)
    ds = ds.cache()

    iterator = iter(ds)
    for i in range(num_elements):
      self.assertEqual(next(iterator).numpy(), i)
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, self.get_temp_dir(), max_to_keep=1)
    manager.save()
    manager.restore_or_initialize()
    with self.assertRaises(StopIteration):
      next(iterator)

  @combinations.generate(test_base.eager_only_combinations())
  def testCheckpointLargeCache(self):
    # Tensor of size 100M
    dataset = dataset_ops.Dataset.from_tensors(
        array_ops.ones((25, 1000, 1000), dtype=dtypes.float32))
    # Repeat 25 times to exceed the 2G proto limit
    dataset = dataset.repeat(25)
    dataset = dataset.cache()

    # Iterate to fill the cache.
    iterator = iter(dataset)
    for _ in range(23):
      next(iterator)
    ckpt = trackable_utils.Checkpoint(iterator=iterator)
    manager = checkpoint_management.CheckpointManager(
        ckpt, self.get_temp_dir(), max_to_keep=1)
    manager.save()
Пример #11
0
class CacheCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                          parameterized.TestCase):

  def setUp(self):
    self.range_size = 10
    self.num_repeats = 3
    self.num_outputs = self.range_size * self.num_repeats
    self.cache_file_prefix = "test"

  def make_dataset_fn(self, is_memory):
    if is_memory:
      filename = ""
    else:
      filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)

    def ds_fn():
      return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(
          self.num_repeats)

    return ds_fn

  def expected_outputs(self):
    return list(range(self.range_size)) * self.num_repeats

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testCheckpointBeforeOneEpoch(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Generate 5 entries from iterator and save checkpoint.
    outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
    self.assertSequenceEqual(outputs, range(5))

    # Restore from checkpoint and produce the rest of the elements from the
    # iterator.
    outputs.extend(
        self.gen_outputs(
            ds_fn, [],
            self.num_outputs - 5,
            ckpt_saved=True,
            verify_exhausted=False))
    self.assertSequenceEqual(outputs, self.expected_outputs())

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Generate 8 entries from iterator but save checkpoint after producing 5.
    outputs = self.gen_outputs(
        ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
    self.assertSequenceEqual(outputs, range(8))

    outputs = outputs[:5]
    outputs.extend(
        self.gen_outputs(
            ds_fn, [],
            self.num_outputs - 5,
            ckpt_saved=True,
            verify_exhausted=False))
    self.assertSequenceEqual(outputs, self.expected_outputs())

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testCheckpointAfterOneEpoch(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Generate 15 entries from iterator and save checkpoint.
    outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
    self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))

    # Restore from checkpoint and produce the rest of the elements from the
    # iterator.
    outputs.extend(
        self.gen_outputs(
            ds_fn, [],
            self.num_outputs - 15,
            ckpt_saved=True,
            verify_exhausted=False))
    self.assertSequenceEqual(outputs, self.expected_outputs())

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Generate 18 entries from iterator but save checkpoint after producing 15.
    outputs = self.gen_outputs(
        ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)
    self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))

    outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
        ds_fn, [],
        self.num_outputs - 15,
        ckpt_saved=True,
        verify_exhausted=False)
    self.assertSequenceEqual(outputs, list(range(10)) * 3)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Generate 13 entries from iterator but save checkpoint after producing 5.
    outputs = self.gen_outputs(
        ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)
    self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))

    # Since we ran for more than one epoch, the cache was completely written.
    # The ckpt was saved when the iterator was in cache-write mode. Test that
    # the iterator falls back to read mode after restoring if the cache has
    # been completely written.

    outputs = list(range(5)) + self.gen_outputs(
        ds_fn, [],
        self.num_outputs - 5,
        ckpt_saved=True,
        verify_exhausted=False)
    self.assertSequenceEqual(outputs, list(range(10)) * 3)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testCheckpointUnusedWriterIterator(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Checkpoint before get_next is called even once.
    outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)
    self.assertSequenceEqual(outputs, [])

    outputs = self.gen_outputs(
        ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)
    self.assertSequenceEqual(outputs, list(range(10)) * 3)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testCheckpointUnusedMidwayWriterIterator(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Produce 5 elements and checkpoint.
    outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
    self.assertSequenceEqual(outputs, range(5))

    # Restore from checkpoint, then produce no elements and checkpoint.
    outputs.extend(
        self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
    self.assertSequenceEqual(outputs, range(5))

    # Restore from checkpoint and produce rest of the elements.
    outputs.extend(
        self.gen_outputs(
            ds_fn, [],
            self.num_outputs - 5,
            ckpt_saved=True,
            verify_exhausted=False))
    self.assertSequenceEqual(outputs, list(range(10)) * 3)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testUnusedCheckpointError(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Produce 5 elements and save ckpt.
    outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
    self.assertSequenceEqual(outputs, range(5))

    if is_memory:
      outputs = self.gen_outputs(
          ds_fn, [], self.num_outputs, verify_exhausted=False)
      self.assertSequenceEqual(outputs, self.expected_outputs())
    else:
      # Since the complete cache has not been written, a new iterator which does
      # not restore the checkpoint will throw an error since there is a partial
      # cache shard.
      with self.assertRaises(errors.AlreadyExistsError):
        outputs = self.gen_outputs(
            ds_fn, [], self.num_outputs, verify_exhausted=False)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(is_memory=[True, False])))
  def testIgnoreCheckpointIfCacheWritten(self, is_memory):
    ds_fn = self.make_dataset_fn(is_memory)

    # Produce 15 elements and save ckpt. This will write the complete cache.
    outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
    self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))

    # Build the iterator again but do not restore from ckpt. Since the cache
    # has already been written we should be able to use it.
    outputs = self.gen_outputs(
        ds_fn, [], self.num_outputs, verify_exhausted=False)
    self.assertSequenceEqual(outputs, list(range(10)) * 3)
Пример #12
0
class OptimizationTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationStatefulFunction(self):
        dataset = dataset_ops.Dataset.range(10).map(
            lambda _: random_ops.random_uniform([])).batch(10)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        get_next = self.getNext(dataset)
        self.evaluate(get_next())

    # TODO(b/123354468)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensor(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
        dataset = dataset_ops.Dataset.from_tensors(input_t)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    # TODO(b/123354468)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensorSlices(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
        dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op,
                     {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDataset(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
            dataset = dataset.skip(0)  # Should be removed by noop elimination
            dataset = dataset.cache()
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)
        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[0])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDatasetWithModifiedRetval(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
            # Should be fused by map and batch fusion
            dataset = dataset.map(lambda x: x)
            dataset = dataset.batch(1)
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)

        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.map_and_batch_fusion = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[[0]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(autotune=[True, False, None]),
            combinations.combine(map_parallelization=[True, False, None])))
    def testOptimizationMapParallelization(self, autotune,
                                           map_parallelization):
        dataset = dataset_ops.Dataset.range(5)
        if autotune is not False and map_parallelization is not False:  # pylint: disable=g-bool-id-comparison
            dataset = dataset.apply(testing.assert_next(["ParallelMap"]))
        else:
            dataset = dataset.apply(testing.assert_next(["Map"]))
        dataset = dataset.map(lambda x: x + 1)

        options = options_lib.Options()
        if autotune is not None:
            options.autotune.enabled = autotune
        if map_parallelization is not None:
            options.experimental_optimization.map_parallelization = (
                map_parallelization)
        dataset = dataset.with_options(options)

        self.assertDatasetProduces(dataset, expected_output=list(range(1, 6)))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(existing_prefetch=[True, False]),
            combinations.combine(autotune=[True, False]),
            combinations.combine(set_env=[True, False])))
    def testOptimizationInjectPrefetch(self, existing_prefetch, autotune,
                                       set_env):
        if set_env:
            os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "inject_prefetch"
            os.environ["TF_JOB_NAME"] = "test_job"

        dataset = dataset_ops.Dataset.range(5)
        dataset = dataset.map(lambda x: x + 1,
                              num_parallel_calls=dataset_ops.AUTOTUNE)
        if existing_prefetch:
            dataset = dataset.prefetch(1)
        if autotune and set_env and not existing_prefetch:
            dataset = dataset.apply(testing.assert_next(["Prefetch", "Root"]))
        else:
            dataset = dataset.apply(testing.assert_next(["Root"]))

        options = options_lib.Options()
        options.autotune.enabled = autotune
        dataset = dataset.with_options(options)

        self.assertDatasetProduces(dataset, expected_output=list(range(1, 6)))

        if set_env:
            del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
            del os.environ["TF_JOB_NAME"]

    # Reference variables are not supported in eager mode.
    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           _captured_refvar_test_combinations()))
    def testOptimizationWithCapturedRefVar(self, dataset_fn):
        """Tests that default optimizations are disabled with ref variables."""
        variable = variable_scope.get_variable("v",
                                               initializer=0,
                                               use_resource=False)
        assign_op = variable.assign_add(1)
        unoptimized_dataset = dataset_fn(variable)

        options = options_lib.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.noop_elimination = True
        options.experimental_optimization.map_and_batch_fusion = True
        optimized_dataset = unoptimized_dataset.with_options(options)
        optimized_it = dataset_ops.make_initializable_iterator(
            optimized_dataset)

        # Check that outputs are the same in the optimized and unoptimized cases,
        # when the variable value is changing.
        unoptimized_it = dataset_ops.make_initializable_iterator(
            unoptimized_dataset)
        with ops.control_dependencies([assign_op]):
            unoptimized_output = unoptimized_it.get_next()
            optimized_output = optimized_it.get_next()

        self.evaluate(variable.initializer)
        self.evaluate((unoptimized_it.initializer, optimized_it.initializer))
        while True:
            try:
                unoptimized, optimized = self.evaluate(
                    (unoptimized_output, optimized_output))
                self.assertEqual(unoptimized, optimized)
            except errors.OutOfRangeError:
                break
Пример #13
0
class MapAndFilterFusionTest(test_base.DatasetTestBase,
                             parameterized.TestCase):
    def _testDataset(self, dataset, function, predicate):
        expected_output = []
        for x in range(10):
            r = function(x)
            if isinstance(r, tuple):
                b = predicate(*r)  # Pass tuple as multiple arguments.
            else:
                b = predicate(r)
            if self.evaluate(b):
                expected_output.append(r)
        self.assertDatasetProduces(dataset, expected_output=expected_output)

    def _testMapAndFilterFusion(self, function, predicate):
        dataset = dataset_ops.Dataset.range(10).apply(
            testing.assert_next(["Map", "Filter",
                                 "Map"])).map(function).filter(predicate)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.map_and_filter_fusion = True
        dataset = dataset.with_options(options)
        self._testDataset(dataset, function, predicate)

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndFilterFusionScalar(self):
        identity = lambda x: x
        increment = lambda x: x + 1
        minus_five = lambda x: x - 5

        def increment_and_square(x):
            y = x + 1
            return y * y

        functions = [identity, increment, minus_five, increment_and_square]

        take_all = lambda x: constant_op.constant(True)
        is_zero = lambda x: math_ops.equal(x, 0)
        is_odd = lambda x: math_ops.equal(x % 2, 0)
        greater = lambda x: math_ops.greater(x + 5, 0)
        predicates = [take_all, is_zero, is_odd, greater]

        for function in functions:
            for predicate in predicates:
                self._testMapAndFilterFusion(function, predicate)

    @combinations.generate(test_base.default_test_combinations())
    def testMapAndFilterFusionTuple(self):
        replicate = lambda x: (x, x)
        with_two = lambda x: (x, 2)
        functions = [replicate, with_two]
        take_all = lambda x, y: constant_op.constant(True)
        is_zero = lambda x, y: math_ops.equal(
            x * math_ops.cast(y, dtypes.int64), 0)
        predicates = [take_all, is_zero]

        for function in functions:
            for predicate in predicates:
                self._testMapAndFilterFusion(function, predicate)

    @combinations.generate(test_base.default_test_combinations())
    def testCapturedInputs(self):
        a = constant_op.constant(3, dtype=dtypes.int64)
        b = constant_op.constant(4, dtype=dtypes.int64)
        some_tensor = math_ops.mul(a, b)
        function = lambda x: x * x

        def predicate(y):
            return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)

        # We are currently not supporting functions with captured inputs.
        dataset = dataset_ops.Dataset.range(10).apply(
            testing.assert_next(["Map",
                                 "Filter"])).map(function).filter(predicate)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.map_and_filter_fusion = True
        dataset = dataset.with_options(options)
        self._testDataset(dataset, function, predicate)
Пример #14
0
class TFRecordWriterTest(test_base.DatasetTestBase, parameterized.TestCase):
    def setUp(self):
        super(TFRecordWriterTest, self).setUp()
        self._num_records = 8

    def writer_fn(self, filename, compression_type=""):
        input_dataset = readers.TFRecordDataset([filename], compression_type)
        return writers.TFRecordWriter(self._outputFilename(),
                                      compression_type).write(input_dataset)

    def _record(self, i):
        return compat.as_bytes("Record %d" % (i))

    def _createFile(self, options=None):
        filename = self._inputFilename()
        writer = python_io.TFRecordWriter(filename, options)
        for i in range(self._num_records):
            writer.write(self._record(i))
        writer.close()
        return filename

    def _inputFilename(self):
        return os.path.join(self.get_temp_dir(), "tf_record.in.txt")

    def _outputFilename(self):
        return os.path.join(self.get_temp_dir(), "tf_record.out.txt")

    @combinations.generate(test_base.default_test_combinations())
    def testWrite(self):
        self.evaluate(self.writer_fn(self._createFile()))
        for i, r in enumerate(
                tf_record.tf_record_iterator(self._outputFilename())):
            self.assertAllEqual(self._record(i), r)

    @combinations.generate(test_base.default_test_combinations())
    def testWriteZLIB(self):
        options = tf_record.TFRecordOptions(
            tf_record.TFRecordCompressionType.ZLIB)
        self.evaluate(
            self.writer_fn(self._createFile(options), compression_type="ZLIB"))
        for i, r in enumerate(
                tf_record.tf_record_iterator(self._outputFilename(),
                                             options=options)):
            self.assertAllEqual(self._record(i), r)

    @combinations.generate(test_base.default_test_combinations())
    def testWriteGZIP(self):
        options = tf_record.TFRecordOptions(
            tf_record.TFRecordCompressionType.GZIP)
        self.evaluate(
            self.writer_fn(self._createFile(options), compression_type="GZIP"))
        for i, r in enumerate(
                tf_record.tf_record_iterator(self._outputFilename(),
                                             options=options)):
            self.assertAllEqual(self._record(i), r)

    @combinations.generate(test_base.default_test_combinations())
    def testFailDataset(self):
        with self.assertRaises(TypeError):
            writers.TFRecordWriter(self._outputFilename(), "").write("whoops")

    @combinations.generate(test_base.default_test_combinations())
    def testFailDType(self):
        input_dataset = dataset_ops.Dataset.from_tensors(10)
        with self.assertRaises(TypeError):
            writers.TFRecordWriter(self._outputFilename(),
                                   "").write(input_dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testFailShape(self):
        input_dataset = dataset_ops.Dataset.from_tensors([["hello"],
                                                          ["world"]])
        with self.assertRaises(TypeError):
            writers.TFRecordWriter(self._outputFilename(),
                                   "").write(input_dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testSideEffect(self):
        def writer_fn():
            input_dataset = readers.TFRecordDataset(self._createFile())
            return writers.TFRecordWriter(
                self._outputFilename()).write(input_dataset)

        @function.defun
        def fn():
            _ = writer_fn()
            return "hello"

        self.assertEqual(self.evaluate(fn()), b"hello")
        for i, r in enumerate(
                tf_record.tf_record_iterator(self._outputFilename())):
            self.assertAllEqual(self._record(i), r)

    @combinations.generate(test_base.default_test_combinations())
    def testShard(self):
        filename = self._createFile()
        dataset = readers.TFRecordDataset([filename])

        def reduce_func(key, dataset):
            shard_filename = string_ops.string_join(
                [filename, string_ops.as_string(key)])
            writer = writers.TFRecordWriter(shard_filename)
            writer.write(dataset.map(lambda _, x: x))
            return dataset_ops.Dataset.from_tensors(shard_filename)

        dataset = dataset.enumerate()
        dataset = dataset.apply(
            grouping.group_by_window(lambda i, _: i % 2, reduce_func,
                                     dtypes.int64.max))

        get_next = self.getNext(dataset)
        for i in range(2):
            shard_filename = (filename + str(i)).encode()
            self.assertEqual(self.evaluate(get_next()), shard_filename)
            for j, r in enumerate(
                    tf_record.tf_record_iterator(shard_filename)):
                self.assertAllEqual(self._record(i + 2 * j), r)
Пример #15
0
class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):

  @combinations.generate(test_base.default_test_combinations())
  def testBasic(self):
    components = (
        np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
        np.array([9.0, 10.0, 11.0, 12.0])
    )

    def dataset_fn(count=5, buffer_size=None, seed=0):
      repeat_dataset = (
          dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
      if buffer_size:
        shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed)

        self.assertEqual(
            tuple([c.shape[1:] for c in components]),
            dataset_ops.get_legacy_output_shapes(shuffle_dataset))
        return shuffle_dataset
      else:
        return repeat_dataset

    # First run without shuffling to collect the "ground truth".
    get_next = self.getNext(dataset_fn())
    unshuffled_elements = []
    for _ in range(20):
      unshuffled_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

    # Assert that the shuffled dataset has the same elements as the
    # "ground truth".
    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
    shuffled_elements = []
    for _ in range(20):
      shuffled_elements.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements))

    # Assert that shuffling twice with the same seeds gives the same sequence.
    get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
    reshuffled_elements_same_seed = []
    for _ in range(20):
      reshuffled_elements_same_seed.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)

    # Assert that shuffling twice with a different seed gives a different
    # permutation of the same elements.
    get_next = self.getNext(dataset_fn(buffer_size=100, seed=137))
    reshuffled_elements_different_seed = []
    for _ in range(20):
      reshuffled_elements_different_seed.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
    self.assertAllEqual(
        sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))

    # Assert that the shuffled dataset has the same elements as the
    # "ground truth" when the buffer size is smaller than the input
    # dataset.
    get_next = self.getNext(dataset_fn(buffer_size=2, seed=37))
    reshuffled_elements_small_buffer = []
    for _ in range(20):
      reshuffled_elements_small_buffer.append(self.evaluate(get_next()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    self.assertAllEqual(
        sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))

    # Test the case of shuffling an empty dataset.
    get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37))

    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
  def testSeedZero(self):
    """Test for same behavior when the seed is a Python or Tensor zero."""
    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.range(10).shuffle(10, seed=0))
    get_next = iterator.get_next()

    elems = []
    with self.cached_session() as sess:
      for _ in range(10):
        elems.append(sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

    seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
    iterator = dataset_ops.make_initializable_iterator(
        dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder))
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
      for elem in elems:
        self.assertEqual(elem, sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  @combinations.generate(test_base.default_test_combinations())
  def testDefaultArguments(self):
    components = [0, 1, 2, 3, 4]
    dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
        5).repeat()
    get_next = self.getNext(dataset)
    counts = collections.defaultdict(lambda: 0)
    for _ in range(10):
      for _ in range(5):
        counts[self.evaluate(get_next())] += 1

    for i in range(5):
      self.assertEqual(10, counts[i])

  @combinations.generate(
      combinations.times(
          test_base.graph_only_combinations(),
          combinations.combine(reshuffle=[True, False]),
          combinations.combine(graph_seed=38, op_seed=None) +
          combinations.combine(graph_seed=None, op_seed=42) +
          combinations.combine(graph_seed=38, op_seed=42)))
  def testShuffleSeed(self, reshuffle, graph_seed, op_seed):
    results = []
    for _ in range(2):
      with ops.Graph().as_default() as g:
        random_seed.set_random_seed(graph_seed)
        dataset = dataset_ops.Dataset.range(10).shuffle(
            10, seed=op_seed, reshuffle_each_iteration=reshuffle).repeat(3)
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        run_results = []
        with self.session(graph=g) as sess:
          for _ in range(30):
            run_results.append(sess.run(next_element))
          with self.assertRaises(errors.OutOfRangeError):
            sess.run(next_element)
        results.append(run_results)

    self.assertAllEqual(results[0], results[1])

  # TODO(b/117581999): enable this test for eager-mode.
  @combinations.generate(
      combinations.times(
          test_base.graph_only_combinations(),
          combinations.combine(
              reshuffle=[True, False], initializable=[True, False])))
  def testMultipleIterators(self, reshuffle, initializable):
    with ops.Graph().as_default() as g:
      dataset = dataset_ops.Dataset.range(100).shuffle(
          10, reshuffle_each_iteration=reshuffle).repeat(3)

      if initializable:
        iterators = [dataset_ops.make_initializable_iterator(dataset)
                     for _ in range(2)]
      else:
        iterators = [dataset_ops.make_one_shot_iterator(dataset)
                     for _ in range(2)]

      results = []
      with self.session(graph=g) as sess:
        for iterator in iterators:
          if initializable:
            sess.run(iterator.initializer)
          next_element = iterator.get_next()
          run_results = []
          for _ in range(300):
            run_results.append(sess.run(next_element))
          with self.assertRaises(errors.OutOfRangeError):
            sess.run(next_element)

          results.append(run_results)

        self.assertNotEqual(results[0], results[1])

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
  def testReshuffleRepeatEpochs(self, reshuffle, seed):
    dataset = dataset_ops.Dataset.range(10).shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle).repeat(2)
    next_element = self.getNext(dataset)

    first_epoch = []
    for _ in range(10):
      first_epoch.append(self.evaluate(next_element()))

    second_epoch = []
    for _ in range(10):
      second_epoch.append(self.evaluate(next_element()))

    self.assertEqual(first_epoch == second_epoch, not reshuffle)

  @combinations.generate(
      combinations.times(
          combinations.combine(tf_api_version=2, mode="eager"),
          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
  def testReshuffleIterationEpochs(self, reshuffle, seed):
    dataset = dataset_ops.Dataset.range(10).shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle)

    first_epoch = []
    for elem in dataset:
      first_epoch.append(elem.numpy())

    second_epoch = []
    for elem in dataset:
      second_epoch.append(elem.numpy())

    self.assertEqual(first_epoch == second_epoch, not reshuffle)

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testShuffleV2ResourceCapture(self):

    def make_dataset():
      ids = dataset_ops.Dataset.range(10)
      ids = ids.shuffle(1)

      def interleave_fn(dataset, _):
        return dataset

      dataset = dataset_ops.Dataset.range(1)
      dataset = dataset.interleave(functools.partial(interleave_fn, ids))
      return dataset

    results = []
    for elem in make_dataset():
      results.append(elem.numpy())

    self.assertAllEqual(results, range(10))

  @combinations.generate(
      combinations.times(
          test_base.eager_only_combinations(),
          combinations.combine(reshuffle=[True, False], seed=[None, 42])))
  def testReshuffleSeparateTransformations(self, reshuffle, seed):
    dataset = dataset_ops.Dataset.range(10)

    first_epoch = []
    for elem in dataset.shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle):
      first_epoch.append(elem.numpy())

    second_epoch = []
    for elem in dataset.shuffle(
        10, seed=seed, reshuffle_each_iteration=reshuffle):
      second_epoch.append(elem.numpy())

    self.assertEqual(first_epoch != second_epoch, seed is None)

  @combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
  def testShuffleV2InFunction(self):
    counter_var = variables.Variable(0)

    @function.defun
    def consume():
      ds = dataset_ops.Dataset.range(10)
      ds = ds.shuffle(1)
      for _ in ds:
        counter_var.assign(counter_var + 1)

    consume()
    self.assertAllEqual(self.evaluate(counter_var), 10)

  @combinations.generate(test_base.default_test_combinations())
  def testEmptyDataset(self):
    dataset = dataset_ops.Dataset.from_tensors(1)

    def map_fn(x):
      with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
        return x

    dataset = dataset.map(map_fn)
    dataset = dataset.cache()
    dataset = dataset.shuffle(buffer_size=10).repeat()

    get_next = self.getNext(dataset)

    # First time around, we get an error for the failed assertion.
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(get_next())

    # Second time around, we get an EOF because the cached dataset is empty.
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
Пример #16
0
class BucketBySequenceLengthTest(test_base.DatasetTestBase,
                                 parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_no_padding=[True, False])))
    def testBucketDropReminder(self, param_no_padding):

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        lengths = [8, 13, 25, 35]

        n_bucket_elements = [28, 7, 6, 5]
        n_expected_batches = 5

        # Expected sequence lengths of the individual batches.
        expected_lengths = []

        # Expected sum of all batches with an equal sequence length.
        # <seq-length>: <expected-total-sum>
        expected_sums = {}

        # Expected batch sizes of batches depending on the sequence length.
        # <seq-length>: [batch1_size, ..., batchN_size]
        expected_batch_sizes = {}

        for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
                                                       n_bucket_elements):
            # Calculate the expected sum across all batches of a specific sequence length.
            expected_sums[length] = \
                (bucket_elements - bucket_elements % batch_size) * length
            # Calculate the expected occurrence of individual batch sizes.
            expected_batch_sizes[length] = \
                [batch_size] * (bucket_elements // batch_size)
            # Calculate the expected occurrence of individual sequence lengths.
            expected_lengths.extend([length] * (bucket_elements // batch_size))

        def build_dataset(sparse):
            def _generator():
                # Produce 1 batch for each bucket
                elements = []
                for bucket_elements, length in zip(n_bucket_elements, lengths):
                    # Using only full sequences (opposed to the strategy employed in `testBucket`) makes
                    # checking the sum a lot easier.
                    record_len = length
                    for _ in range(bucket_elements):
                        elements.append([1] * record_len)
                random.shuffle(elements)
                for el in elements:
                    yield (_format_record(el, sparse), )

            dataset = dataset_ops.Dataset.from_generator(
                _generator, (_get_record_type(sparse), ),
                (_get_record_shape(sparse), ))
            if sparse:
                dataset = dataset.map(lambda x: (_to_sparse_tensor(x), ))
            return dataset

        def _test_bucket_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.apply(
                grouping.bucket_by_sequence_length(_element_length_fn,
                                                   boundaries,
                                                   batch_sizes,
                                                   no_padding=no_padding,
                                                   drop_remainder=True))

            get_next = self.getNext(dataset)
            batches = []
            for _ in range(n_expected_batches):
                batch, = self.evaluate(get_next())
                batches.append(batch)

            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

            generated_lengths = []

            # <seq-length>: <total-sum>
            generated_sums = {}

            # <seq-length>: [<batch_size>, ...]
            generated_batch_sizes = {}

            for length, batch_size, bucket_elements in zip(
                    lengths, batch_sizes, n_bucket_elements):
                # Initialize the sum across all batches.
                generated_sums[length] = 0
                # Initialize the individual batch sizes.
                generated_batch_sizes[length] = []

            for batch in batches:
                shape = batch.dense_shape if no_padding else batch.shape
                length = shape[1]
                generated_lengths.append(length)

                batch_size = shape[0]
                generated_batch_sizes[length].append(batch_size)

                batch_sum = batch.values.sum() if no_padding else batch.sum()
                generated_sums[length] += batch_sum

            for l in lengths:
                # Make sure the sum of the batch contents is correct for the individual sequence lengths.
                self.assertEqual(
                    generated_sums[l], expected_sums[l],
                    "Tensor sums did not match! "
                    "expected: {}, generated: {}".format(
                        expected_sums, generated_sums))

                # Make sure the individual batch sizes are generated as expected.
                self.assertEqual(
                    sorted(generated_batch_sizes[l]),
                    sorted(expected_batch_sizes[l]),
                    "Batch-sizes did not match! "
                    "expected: {}, generated: {}".format(
                        sorted(expected_batch_sizes[l]),
                        sorted(generated_batch_sizes[l])))

            # Make sure the generated sequence lengths appear as often as expected.
            self.assertEqual(
                sorted(generated_lengths), sorted(expected_lengths),
                "The generated sequence lengths did not match! "
                "expected: {}, generated: {}".format(
                    sorted(expected_lengths), sorted(generated_lengths)))

        _test_bucket_by_padding(param_no_padding)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_no_padding=[True, False])))
    def testBucket(self, param_no_padding):

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        lengths = [8, 13, 25, 35]

        def build_dataset(sparse):
            def _generator():
                # Produce 1 batch for each bucket
                elements = []
                for batch_size, length in zip(batch_sizes, lengths):
                    record_len = length - 1
                    for _ in range(batch_size):
                        elements.append([1] * record_len)
                        record_len = length
                random.shuffle(elements)
                for el in elements:
                    yield (_format_record(el, sparse), )

            dataset = dataset_ops.Dataset.from_generator(
                _generator, (_get_record_type(sparse), ),
                (_get_record_shape(sparse), ))
            if sparse:
                dataset = dataset.map(lambda x: (_to_sparse_tensor(x), ))
            return dataset

        def _test_bucket_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.apply(
                grouping.bucket_by_sequence_length(_element_length_fn,
                                                   boundaries,
                                                   batch_sizes,
                                                   no_padding=no_padding))
            get_next = self.getNext(dataset)
            batches = []
            for _ in range(4):
                batch, = self.evaluate(get_next())
                batches.append(batch)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(get_next())

            batch_sizes_val = []
            lengths_val = []
            for batch in batches:
                shape = batch.dense_shape if no_padding else batch.shape
                batch_size = shape[0]
                length = shape[1]
                batch_sizes_val.append(batch_size)
                lengths_val.append(length)
                if not context.executing_eagerly():
                    sum_check = batch.values.sum(
                    ) if no_padding else batch.sum()
                    self.assertEqual(sum_check, batch_size * length - 1)
            self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
            self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
            self.assertEqual(sorted(lengths), sorted(lengths_val))

        _test_bucket_by_padding(param_no_padding)

    def testPadToBoundary(self):

        boundaries = [10, 20, 30]
        batch_sizes = [10, 8, 4, 2]
        lengths = [8, 13, 25]

        def element_gen():
            # Produce 1 batch for each bucket
            elements = []
            for batch_size, length in zip(batch_sizes[:-1], lengths):
                for _ in range(batch_size):
                    elements.append([1] * length)
            random.shuffle(elements)
            for el in elements:
                yield (el, )
            for _ in range(batch_sizes[-1]):
                el = [1] * (boundaries[-1] + 5)
                yield (el, )

        element_len = lambda el: array_ops.shape(el)[0]
        dataset = dataset_ops.Dataset.from_generator(
            element_gen, (dtypes.int64, ), ([None], )).apply(
                grouping.bucket_by_sequence_length(
                    element_len,
                    boundaries,
                    batch_sizes,
                    pad_to_bucket_boundary=True))
        get_next = self.getNext(dataset)

        batches = []
        for _ in range(3):
            batch, = self.evaluate(get_next())
            batches.append(batch)
        with self.assertRaisesOpError("bucket_boundaries"):
            self.evaluate(get_next())

        batch_sizes_val = []
        lengths_val = []
        for batch in batches:
            batch_size = batch.shape[0]
            length = batch.shape[1]
            batch_sizes_val.append(batch_size)
            lengths_val.append(length)
        batch_sizes = batch_sizes[:-1]
        self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
        self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
        self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
                         sorted(lengths_val))

    def testPadToBoundaryNoExtraneousPadding(self):

        boundaries = [3, 7, 11]
        batch_sizes = [2, 2, 2, 2]
        lengths = range(1, 11)

        def element_gen():
            for length in lengths:
                yield ([1] * length, )

        element_len = lambda element: array_ops.shape(element)[0]
        dataset = dataset_ops.Dataset.from_generator(
            element_gen, (dtypes.int64, ), ([None], )).apply(
                grouping.bucket_by_sequence_length(
                    element_len,
                    boundaries,
                    batch_sizes,
                    pad_to_bucket_boundary=True))
        get_next = self.getNext(dataset)

        batches = []
        for _ in range(5):
            batch, = self.evaluate(get_next())
            batches.append(batch)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

        self.assertAllEqual(batches[0], [[1, 0], [1, 1]])
        self.assertAllEqual(batches[1],
                            [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]])
        self.assertAllEqual(batches[2],
                            [[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1]])
        self.assertAllEqual(
            batches[3],
            [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
        self.assertAllEqual(
            batches[4],
            [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_no_padding=[True, False])))
    def testTupleElements(self, param_no_padding):
        def build_dataset(sparse):
            def _generator():
                text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
                label = [1, 2, 1, 2]
                for x, y in zip(text, label):
                    yield (_format_record(x, sparse), y)

            dataset = dataset_ops.Dataset.from_generator(
                generator=_generator,
                output_types=(_get_record_type(sparse), dtypes.int32),
                output_shapes=(_get_record_shape(sparse),
                               tensor_shape.TensorShape([])))
            if sparse:
                dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
            return dataset

        def _test_tuple_elements_by_padding(no_padding):
            dataset = build_dataset(sparse=no_padding)
            dataset = dataset.apply(
                grouping.bucket_by_sequence_length(
                    element_length_func=_element_length_fn,
                    bucket_batch_sizes=[2, 2, 2],
                    bucket_boundaries=[0, 8],
                    no_padding=no_padding))
            shapes = dataset_ops.get_legacy_output_shapes(dataset)
            self.assertEqual([None, None], shapes[0].as_list())
            self.assertEqual([None], shapes[1].as_list())

        _test_tuple_elements_by_padding(param_no_padding)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(param_drop_remainder=[True, False])))
    def testBucketSparse(self, param_drop_remainder):  # pylint: disable=g-doc-args
        """Tests bucketing of sparse tensors (case where `no_padding` == True).

    Test runs on following dataset:
      [
        [0],
        [0, 1],
        [0, 1, 2]
        ...
        [0, ..., max_len - 1]
      ]
    Sequences are bucketed by length and batched with
      `batch_size` < `bucket_size`.
    """

        min_len = 0
        max_len = 100
        batch_size = 7
        bucket_size = 10

        def _build_dataset():
            input_data = [range(i + 1) for i in range(min_len, max_len)]

            def generator_fn():
                for record in input_data:
                    yield _format_record(record, sparse=True)

            dataset = dataset_ops.Dataset.from_generator(
                generator=generator_fn,
                output_types=_get_record_type(sparse=True))
            dataset = dataset.map(_to_sparse_tensor)
            return dataset

        def _compute_expected_batches(drop_remainder):
            """Computes expected batch outputs and stores in a set."""
            all_expected_sparse_tensors = set()
            for bucket_start_len in range(min_len, max_len, bucket_size):
                if drop_remainder:
                    batch_offsets = [0]
                else:
                    batch_offsets = range(0, bucket_size, batch_size)

                for batch_offset in batch_offsets:
                    batch_start_len = bucket_start_len + batch_offset
                    batch_end_len = min(batch_start_len + batch_size,
                                        bucket_start_len + bucket_size)
                    expected_indices = []
                    expected_values = []
                    for length in range(batch_start_len, batch_end_len):
                        for val in range(length + 1):
                            expected_indices.append(
                                (length - batch_start_len, val))
                            expected_values.append(val)
                    expected_sprs_tensor = (tuple(expected_indices),
                                            tuple(expected_values))
                    all_expected_sparse_tensors.add(expected_sprs_tensor)
            return all_expected_sparse_tensors

        def _compute_batches(dataset):
            """Computes actual batch outputs of dataset and stores in a set."""
            batch = self.getNext(dataset)
            all_sparse_tensors = set()
            with self.assertRaises(errors.OutOfRangeError):
                while True:
                    output = self.evaluate(batch())
                    sprs_tensor = (tuple([
                        tuple(idx) for idx in output.indices
                    ]), tuple(output.values))
                    all_sparse_tensors.add(sprs_tensor)

            return all_sparse_tensors

        dataset = _build_dataset()
        boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
        dataset = dataset.apply(
            grouping.bucket_by_sequence_length(
                _element_length_fn,
                boundaries, [batch_size] * (len(boundaries) + 1),
                no_padding=True,
                drop_remainder=param_drop_remainder))
        batches = _compute_batches(dataset)
        expected_batches = _compute_expected_batches(param_drop_remainder)
        self.assertEqual(batches, expected_batches)
Пример #17
0
class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
                           parameterized.TestCase):

  def setUp(self):
    super(AutoShardDatasetTest, self).setUp()
    self._num_files = 10
    self._num_records = 10
    self.test_filenames = self._createFiles()

  def getAllDatasetElements(self, dataset):
    actual = []
    next_fn = self.getNext(dataset)
    while True:
      try:
        actual.append(self.evaluate(next_fn()))
      except errors.OutOfRangeError:
        break
    return actual

  def assertDatasetProducesWithShuffle(self, dataset, expected, batch,
                                       num_examples, shuffle):
    if shuffle:
      actual = []
      next_fn = self.getNext(dataset)
      for _ in range(num_examples):
        elem = self.evaluate(next_fn())
        if isinstance(elem, tuple):
          actual.extend(elem)
        else:
          actual.extend(elem.tolist())

      self.assertCountEqual(actual, expected)
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_fn())
    else:
      self.assertDatasetProduces(dataset, list(chunk(expected, batch)))

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(shuffle=[True, False])))
  def testFlatMapReaderPipeline(self, shuffle):
    dataset = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=shuffle)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (3, 8)
        for r in range(0, 10)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)

  @combinations.generate(
      combinations.times(test_base.default_test_combinations(),
                         combinations.combine(batch_size=[1, 3, 10])))
  def testDatasetOfReaderDatasetsPipeline(self, batch_size):
    # This tests a scenario where a list_files main return multiple files
    # due to the glob containing wildcards.
    def batch(iterator, n):
      l = len(iterator)
      for i in range(0, l, n):
        yield iterator[i:min(i + n, l)]

    datasets = []
    for files in batch(self.test_filenames, batch_size):
      datasets.append(
          dataset_ops.Dataset.list_files(files, shuffle=False).map(
              core_readers.TFRecordDataset))
    dataset = dataset_ops.Dataset.from_tensor_slices(datasets)
    dataset = dataset.flat_map(lambda x: x)

    # Simulate additional ops in between flat_map and interleave. This should be
    # a no-op since if ShardDataset is placed right after flat_map, we will only
    # have two datasets left at this point.
    dataset = dataset.prefetch(1)
    dataset = dataset.prefetch(1)

    dataset = dataset.interleave(
        lambda x: x, cycle_length=1, num_parallel_calls=1)

    dataset = distribute._AutoShardDataset(dataset, 5, 0)
    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (0, 5)
        for r in range(0, 10)
    ]

    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(test_base.default_test_combinations())
  def testZipReaderPipeline(self):
    dataset1 = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=False)
    dataset1 = dataset1.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset2 = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=False)
    dataset2 = dataset2.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))

    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f))  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]

    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(shuffle=[True, False])))
  def testConcatenateReaderPipeline(self, shuffle):
    dataset1 = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=shuffle)
    dataset1 = dataset1.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset1 = dataset1.batch(5)
    dataset2 = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=shuffle)
    dataset2 = dataset2.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset2 = dataset2.batch(5)

    dataset = dataset1.concatenate(dataset2)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    expected += expected
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 8, shuffle)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(shuffle=[True, False])))
  def testPipelineWithMap(self, shuffle):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)

  @combinations.generate(test_base.default_test_combinations())
  def testDirectFilenameTFRecordReaderPipeline(self):
    dataset = core_readers.TFRecordDataset(self.test_filenames)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (0, 5)
        for r in range(0, 10)
    ]
    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(shuffle=[True, False])))
  def testValidPipelineWithRangeDataset(self, shuffle):
    dataset = dataset_ops.Dataset.range(self._num_files)
    dataset = dataset.map(lambda n: string_ops.string_join(  # pylint:disable=g-long-lambda
        [self.get_temp_dir(),
         string_ops.string_format("/tf_record.{}.txt", [n])]))
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(params=[(1, 0, 10, 10), (2, 1, 20, 5),
                                       (10, 1, 1, 10)])))
  def testStandardReaderPipeline(self, params):
    num_epochs, index, batch_size, parallel_reads = params
    dataset = readers.make_tf_record_dataset(
        file_pattern=self.test_filenames,
        num_epochs=num_epochs,
        batch_size=batch_size,
        parser_fn=None,
        num_parallel_reads=parallel_reads,
        drop_final_batch=True,
        shuffle=False)
    dataset = distribute._AutoShardDataset(dataset, 2, index)
    outputs = self.getNext(dataset)
    self._verify_records(
        outputs,
        batch_size=batch_size,
        file_index=[i for i in range(index, self._num_records, 2)],
        num_epochs=num_epochs,
        interleave_cycle_length=parallel_reads,
        drop_final_batch=True,
        use_parser_fn=None)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(outputs())

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(shuffle=[True, False])))
  def testSampleResNetPipeline(self, shuffle):
    dataset = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=shuffle)
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(sharding_policy=[
              distribute_options.AutoShardPolicy.DATA,
              distribute_options.AutoShardPolicy.AUTO
          ])))
  def testShardByDataBeforePrefetch(self, sharding_policy):
    dataset = dataset_ops.Dataset.range(4)
    dataset = dataset.apply(testing.assert_next(["Shard", "Prefetch"]))
    dataset = dataset.prefetch(1)
    options = dataset_ops.Options()
    options.experimental_distribute.auto_shard_policy = sharding_policy
    dataset = dataset.with_options(options)
    dataset = distribute._AutoShardDataset(dataset, 2, 0)
    self.assertDatasetProduces(dataset, [0, 2])

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.times(combinations.combine(
              sharding_policy=[distribute_options.AutoShardPolicy.DATA,
                               distribute_options.AutoShardPolicy.FILE]),
                             combinations.combine(shuffle=[True, False]))))
  def testReplicateAndShardProduceDisjointData(self, shuffle, sharding_policy):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames,
                                             shuffle=shuffle)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)

    graph_def = dataset._as_serialized_graph(
        strip_device_assignment=True,
        external_state_policy=distribute_options.ExternalStatePolicy.WARN)

    options = dataset_ops.Options()
    options.experimental_distribute.auto_shard_policy = sharding_policy

    ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0",
                                    dataset.element_spec)
    ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0",
                                    dataset.element_spec)

    ds1 = ds1.with_options(options)
    ds2 = ds2.with_options(options)

    ds1 = distribute._AutoShardDataset(ds1, 2, 0)
    ds2 = distribute._AutoShardDataset(ds2, 2, 1)

    elems1 = set(self.getAllDatasetElements(ds1))
    elems2 = set(self.getAllDatasetElements(ds2))

    self.assertEmpty(elems1.intersection(elems2))

  @combinations.generate(test_base.default_test_combinations())
  def testWorkersGreaterThanNumFilesWithDataSharding(self):
    options = dataset_ops.Options()
    options.experimental_distribute.auto_shard_policy = (
        distribute_options.AutoShardPolicy.DATA)

    dataset = core_readers._TFRecordDataset(self.test_filenames)
    dataset = dataset.with_options(options)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by
    # individual elements, we should be able to get some data from all files.
    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in (0, 5)
    ]
    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(test_base.default_test_combinations())
  def testAutoshardPolicyOff(self):
    options = dataset_ops.Options()
    options.experimental_distribute.auto_shard_policy = (
        distribute_options.AutoShardPolicy.OFF)

    dataset = core_readers._TFRecordDataset(self.test_filenames)
    dataset = dataset.with_options(options)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    # Should return every record in every file since autosharding is turned off.
    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 10)
    ]
    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(test_base.default_test_combinations())
  def testFileShardingWithoutReaderDatasetOp(self):
    options = dataset_ops.Options()
    options.experimental_distribute.auto_shard_policy = (
        distribute_options.AutoShardPolicy.FILE)

    dataset = dataset_ops.Dataset.range(1024)
    dataset = dataset.with_options(options)

    # We are specifying that we want a file sharding policy, and this pipeline
    # doesn't start with file reading, so we should error out.
    with self.assertRaises(errors.NotFoundError):
      dataset = distribute._AutoShardDataset(dataset, 10, 0)
      self.evaluate(self.getNext(dataset)())

  @combinations.generate(test_base.default_test_combinations())
  def testWorkersGreaterThanNumFiles(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames)
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 500, 499)
    self.assertDatasetProduces(dataset, [])

  @combinations.generate(test_base.default_test_combinations())
  def testTFRecordReaderWithDirectFileNames(self):
    # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
    # a flat_map automatically.
    dataset = core_readers._TFRecordDataset(self.test_filenames)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in (0, 5)
    ]
    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(test_base.default_test_combinations())
  def testTFRecordReaderWithDirectFileNamesAndShapes(self):
    # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
    # a flat_map automatically.
    dataset = core_readers._TFRecordDataset(self.test_filenames)

    # BatchDataset contains `output_types` and `output_shapes`
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 2, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 5)
    ]
    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))

  @combinations.generate(test_base.default_test_combinations())
  def testShardOutOfRange(self):
    dataset = dataset_ops.Dataset.range(5)
    with self.assertRaises(errors.InvalidArgumentError):
      dataset = distribute._AutoShardDataset(dataset, 10, 0)
      self.evaluate(self.getNext(dataset)())

  @combinations.generate(test_base.default_test_combinations())
  def testShardOutOfRangeEmptyDataset(self):
    dataset = dataset_ops.Dataset.range(0)
    with self.assertRaises(errors.OutOfRangeError):
      dataset = distribute._AutoShardDataset(dataset, 10, 0)
      self.evaluate(self.getNext(dataset)())

  @combinations.generate(test_base.default_test_combinations())
  def testNoReaderPipelines(self):
    dataset = dataset_ops.Dataset.range(1024)
    dataset = distribute._AutoShardDataset(dataset, 2, 0)
    self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0])

  @combinations.generate(test_base.default_test_combinations())
  def testUnknownOpInPipelineStillShardsAtTheEnd(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.apply(unique.unique())

    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in (0, 5)
    ]
    self.assertDatasetProduces(dataset, expected)

  @combinations.generate(test_base.default_test_combinations())
  def testInvalidWorkerIndex(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)

    with self.assertRaises(errors.InvalidArgumentError):
      dataset = distribute._AutoShardDataset(dataset, 2, 2)
      self.evaluate(self.getNext(dataset)())

  @combinations.generate(test_base.default_test_combinations())
  def testAssertCardinality(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = dataset.apply(cardinality.assert_cardinality(42))
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (0, 5)
        for r in range(0, 10)
    ]
    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))

  @combinations.generate(test_base.default_test_combinations())
  def testMaxIntraOpParallelism(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = dataset_ops._MaxIntraOpParallelismDataset(dataset, 1)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (0, 5)
        for r in range(0, 10)
    ]
    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))

  @combinations.generate(test_base.default_test_combinations())
  def testPrivateThreadpool(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (0, 5)
        for r in range(0, 10)
    ]
    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))

  @combinations.generate(test_base.default_test_combinations())
  def testMakeBatchedFeaturesDataset(self):
    files = 2
    records_per_file = 5

    def make_record(file_index):
      example = example_pb2.Example(
          features=feature_pb2.Features(
              feature={
                  "file":
                      feature_pb2.Feature(
                          int64_list=feature_pb2.Int64List(value=[file_index])),
              }))
      return example.SerializeToString()

    filenames = []
    for file_index in range(files):
      filename = os.path.join(self.get_temp_dir(),
                              "tf_record.%d.txt" % file_index)
      filenames.append(filename)
      writer = python_io.TFRecordWriter(filename)
      for _ in range(records_per_file):
        writer.write(make_record(file_index))
      writer.close()

    dataset = readers.make_batched_features_dataset(
        file_pattern=filenames,
        batch_size=records_per_file,
        features={
            "file": parsing_ops.FixedLenFeature([], dtypes.int64),
        },
        reader=core_readers.TFRecordDataset,
        num_epochs=1)
    # We should shard at the file level, so that all records come from file 0.
    dataset = distribute._AutoShardDataset(dataset, 2, 0)
    dataset = dataset.unbatch()
    output = self.getDatasetOutput(dataset)
    files = [elem["file"] for elem in output]
    self.assertEqual(files, [0] * records_per_file)
Пример #18
0
class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                input_values=[[4, 5, 6]],
                cycle_length=1,
                block_length=1,
                expected_elements=[[
                    4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 5,
                    5, 5, 5, 5, 6, 6, 6, 6, 6, 6
                ]]) + combinations.combine(
                    input_values=[[4, 5, 6]],
                    cycle_length=2,
                    block_length=1,
                    expected_elements=[[
                        4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4,
                        6, 5, 6, 5, 6, 5, 6, 5, 6, 5, 6, 6
                    ]]) + combinations.combine(
                        input_values=[[4, 5, 6]],
                        cycle_length=2,
                        block_length=3,
                        expected_elements=[[
                            4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6,
                            6, 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6
                        ]]) + combinations.combine(
                            input_values=[[4, 5, 6]],
                            cycle_length=7,
                            block_length=2,
                            expected_elements=[[
                                4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5,
                                6, 6, 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6
                            ]]) +
            combinations.combine(input_values=[[4, 0, 6]],
                                 cycle_length=2,
                                 block_length=1,
                                 expected_elements=[[
                                     4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4,
                                     6, 6, 6, 6, 6, 6
                                 ]])))
    def testPythonImplementation(self, input_values, cycle_length,
                                 block_length, expected_elements):
        input_lists = _repeat(input_values, 2)

        for expected, produced in zip(
                expected_elements,
                _interleave(input_lists, cycle_length, block_length)):
            self.assertEqual(expected, produced)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(input_values=[np.int64([4, 5, 6])],
                                 cycle_length=1,
                                 block_length=3,
                                 num_parallel_calls=[None, 1]) +
            combinations.combine(input_values=[np.int64([4, 5, 6])],
                                 cycle_length=2,
                                 block_length=[1, 3],
                                 num_parallel_calls=[None, 1, 2]) +
            combinations.combine(input_values=[np.int64([4, 5, 6])],
                                 cycle_length=7,
                                 block_length=2,
                                 num_parallel_calls=[None, 1, 3, 5, 7]) +
            combinations.combine(input_values=[np.int64([4, 5, 6, 7])],
                                 cycle_length=dataset_ops.AUTOTUNE,
                                 block_length=3,
                                 num_parallel_calls=[None, 1]) +
            combinations.combine(
                input_values=[np.int64([]), np.int64([0, 0, 0])],
                cycle_length=2,
                block_length=3,
                num_parallel_calls=[None]) +
            combinations.combine(input_values=[np.int64([4, 0, 6])],
                                 cycle_length=2,
                                 block_length=3,
                                 num_parallel_calls=[None, 1, 2])))
    def testInterleaveDataset(self, input_values, cycle_length, block_length,
                              num_parallel_calls):
        count = 2
        dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
            count).interleave(
                lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
                cycle_length, block_length, num_parallel_calls)
        expected_output = [
            element for element in _interleave(_repeat(input_values, count),
                                               cycle_length, block_length)
        ]
        self.assertDatasetProduces(dataset, expected_output)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                input_values=[np.float32([1., np.nan, 2., np.nan, 3.])],
                cycle_length=1,
                block_length=3,
                num_parallel_calls=[None, 1]) + combinations.combine(
                    input_values=[np.float32([1., np.nan, 2., np.nan, 3.])],
                    cycle_length=2,
                    block_length=[1, 3],
                    num_parallel_calls=[None, 1, 2]) +
            combinations.combine(
                input_values=[np.float32([1., np.nan, 2., np.nan, 3.])],
                cycle_length=7,
                block_length=2,
                num_parallel_calls=[None, 1, 3, 5, 7])))
    def testInterleaveDatasetError(self, input_values, cycle_length,
                                   block_length, num_parallel_calls):
        dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
            lambda x: array_ops.check_numerics(x, "message")).interleave(
                dataset_ops.Dataset.from_tensors, cycle_length, block_length,
                num_parallel_calls)
        get_next = self.getNext(dataset)

        for value in input_values:
            if np.isnan(value):
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(get_next())
            else:
                self.assertEqual(value, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testInterleaveSparse(self):
        def _map_fn(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                                   values=(i * [1, -1]),
                                                   dense_shape=[2, 2])

        def _interleave_fn(x):
            return dataset_ops.Dataset.from_tensor_slices(
                sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))

        dataset = dataset_ops.Dataset.range(10).map(_map_fn).interleave(
            _interleave_fn, cycle_length=1)
        get_next = self.getNext(dataset)
        for i in range(10):
            for j in range(2):
                expected = [i, 0] if j % 2 == 0 else [0, -i]
                self.assertAllEqual(expected, self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(input_values=[np.int64([4, 5, 6])],
                                 cycle_length=1,
                                 block_length=3,
                                 num_parallel_calls=1) +
            combinations.combine(input_values=[np.int64([4, 5, 6])],
                                 cycle_length=2,
                                 block_length=[1, 3],
                                 num_parallel_calls=[1, 2]) +
            combinations.combine(input_values=[np.int64([4, 5, 6])],
                                 cycle_length=7,
                                 block_length=2,
                                 num_parallel_calls=[1, 3, 5, 7]) +
            combinations.combine(input_values=[np.int64([4, 5, 6, 7])],
                                 cycle_length=dataset_ops.AUTOTUNE,
                                 block_length=3,
                                 num_parallel_calls=1) +
            combinations.combine(input_values=[np.int64([4, 0, 6])],
                                 cycle_length=2,
                                 block_length=3,
                                 num_parallel_calls=[1, 2])))
    def testSloppyInterleaveDataset(self, input_values, cycle_length,
                                    block_length, num_parallel_calls):
        count = 2
        dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
            count).interleave(
                lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
                cycle_length, block_length, num_parallel_calls)
        options = dataset_ops.Options()
        options.experimental_deterministic = False
        dataset = dataset.with_options(options)
        expected_output = [
            element for element in _interleave(_repeat(input_values, count),
                                               cycle_length, block_length)
        ]
        get_next = self.getNext(dataset)
        actual_output = []
        for _ in range(len(expected_output)):
            actual_output.append(self.evaluate(get_next()))
        self.assertAllEqual(expected_output.sort(), actual_output.sort())

    @combinations.generate(test_base.default_test_combinations())
    def testInterleaveMap(self):
        dataset = dataset_ops.Dataset.range(100)

        def interleave_fn(x):
            dataset = dataset_ops.Dataset.from_tensors(x)
            return dataset.map(lambda x: x + x)

        dataset = dataset.interleave(interleave_fn, cycle_length=5)
        dataset = dataset.interleave(interleave_fn, cycle_length=5)

        self.assertDatasetProduces(dataset, [4 * x for x in range(100)])

    @combinations.generate(test_base.default_test_combinations())
    def testParallelInterleaveCached(self):
        dataset = dataset_ops.Dataset.range(5)
        dataset = dataset.cache(os.path.join(self.get_temp_dir(), "cache_dir"))

        def interleave_fn(x):
            return dataset_ops.Dataset.from_tensors(x)

        dataset = dataset.interleave(interleave_fn,
                                     cycle_length=2,
                                     num_parallel_calls=2)
        self.assertDatasetProduces(dataset, list(range(5)))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(local_determinism=[None, True, False],
                                 global_determinism=[True, False])))
    def testDeterminismConfiguration(self, local_determinism,
                                     global_determinism):
        expect_determinism = local_determinism or (local_determinism is None
                                                   and global_determinism)
        elements = list(range(1000))

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            dataset = dataset_ops.Dataset.from_tensor_slices(elements)
            dataset = dataset.interleave(interleave_fn,
                                         cycle_length=10,
                                         num_parallel_calls=10,
                                         deterministic=local_determinism)
            opts = dataset_ops.Options()
            opts.experimental_deterministic = global_determinism
            dataset = dataset.with_options(opts)
            return dataset

        self.checkDeterminism(dataset_fn, expect_determinism, elements)
Пример #19
0
class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(
              count=20,
              size=[10, 14, 17],
              shift=[7, 14],
              stride=[1, 2, 6],
              drop_remainder=[True, False]) + combinations.combine(
                  count=[0, 1],
                  size=10,
                  shift=4,
                  stride=1,
                  drop_remainder=[True, False])))
  def testWindowDataset(self, count, size, shift, stride, drop_remainder):
    """Tests a dataset that slides a window its input elements."""
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    def _flat_map_fn(x, y, z):
      return dataset_ops.Dataset.zip((x.batch(batch_size=size),
                                      y.batch(batch_size=size),
                                      z.batch(batch_size=size)))

    dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
        _map_fn).repeat(count).window(
            size=size,
            shift=shift,
            stride=stride,
            drop_remainder=drop_remainder).flat_map(_flat_map_fn)
    get_next = self.getNext(dataset)

    self.assertEqual([[None] + list(c.shape[1:]) for c in components],
                     [ts.as_list() for ts in nest.flatten(
                         dataset_ops.get_legacy_output_shapes(dataset))])

    num_full_batches = max(0,
                           (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
    for i in range(num_full_batches):
      result = self.evaluate(get_next())
      for component, result_component in zip(components, result):
        for j in range(size):
          self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
                              result_component[j])
    if not drop_remainder:
      num_partial_batches = (count * 7) // shift + (
          (count * 7) % shift > 0) - num_full_batches
      for i in range(num_partial_batches):
        result = self.evaluate(get_next())
        for component, result_component in zip(components, result):
          remaining = (count * 7) - ((num_full_batches + i) * shift)
          num_elements = remaining // stride + ((remaining % stride) > 0)
          for j in range(num_elements):
            self.assertAllEqual(
                component[((num_full_batches + i) * shift + j * stride) % 7]**2,
                result_component[j])
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())

  @combinations.generate(
      combinations.times(
          test_base.default_test_combinations(),
          combinations.combine(count=20, size=0, shift=3, stride=1) +
          combinations.combine(count=20, size=3, shift=0, stride=1) +
          combinations.combine(count=20, size=3, shift=3, stride=0)))
  def testWindowDatasetInvalid(self, count, size, shift, stride):
    with self.assertRaises(errors.InvalidArgumentError):
      ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window(
          size=size, shift=shift,
          stride=stride).flat_map(lambda x: x.batch(batch_size=size))
      self.evaluate(ds._variant_tensor)

  @combinations.generate(test_base.default_test_combinations())
  def testWindowDifferentNestedStructures(self):
    ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2)
    self.getNext(ds)
    ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2)
    self.getNext(ds)

  @combinations.generate(test_base.default_test_combinations())
  def testWindowSparse(self):

    def _sparse(i):
      return sparse_tensor.SparseTensorValue(
          indices=[[0]], values=(i * [1]), dense_shape=[1])

    dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
        size=5, shift=3,
        drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5))

    num_batches = (10 - 5) // 3 + 1
    expected_output = [
        sparse_tensor.SparseTensorValue(
            indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
            values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
            dense_shape=[5, 1]) for i in range(num_batches)
    ]
    self.assertDatasetProduces(dataset, expected_output=expected_output)

  @combinations.generate(test_base.default_test_combinations())
  def testWindowSparseWithDifferentDenseShapes(self):

    def _sparse(i):
      return sparse_tensor.SparseTensorValue(
          indices=array_ops.expand_dims(
              math_ops.range(i, dtype=dtypes.int64), 1),
          values=array_ops.fill([math_ops.cast(i, dtypes.int32)], i),
          dense_shape=[i])

    dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
        size=5, shift=3,
        drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5))

    expected_output = []
    num_batches = (10 - 5) // 3 + 1
    for i in range(num_batches):
      expected_indices = []
      expected_values = []
      for j in range(5):
        for k in range(i * 3 + j):
          expected_indices.append([j, k])
          expected_values.append(i * 3 + j)
      expected_output.append(
          sparse_tensor.SparseTensorValue(
              indices=expected_indices,
              values=expected_values,
              dense_shape=[5, i * 3 + 5 - 1]))
    self.assertDatasetProduces(dataset, expected_output=expected_output)

  @combinations.generate(test_base.default_test_combinations())
  def testNestedWindowSparse(self):

    def _sparse(i):
      return sparse_tensor.SparseTensorValue(
          indices=[[0]], values=(i * [1]), dense_shape=[1])

    dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
        size=4, shift=2,
        drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
            size=3, shift=1,
            drop_remainder=True).flat_map(lambda x: x.batch(batch_size=3))

    expected_output = [
        sparse_tensor.SparseTensorValue(
            indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
                     [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
                     [2, 2, 0], [2, 3, 0]],
            values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
            dense_shape=[3, 4, 1]),
        sparse_tensor.SparseTensorValue(
            indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
                     [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
                     [2, 2, 0], [2, 3, 0]],
            values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
            dense_shape=[3, 4, 1])
    ]
    self.assertDatasetProduces(dataset, expected_output=expected_output)

  @combinations.generate(test_base.default_test_combinations())
  def testWindowShapeError(self):

    def generator():
      yield [1.0, 2.0, 3.0]
      yield [4.0, 5.0, 6.0]
      yield [7.0, 8.0, 9.0, 10.0]

    dataset = dataset_ops.Dataset.from_generator(
        generator, dtypes.float32, output_shapes=[None]).window(
            size=3, shift=1).flat_map(lambda x: x.batch(batch_size=3))
    self.assertDatasetProduces(
        dataset,
        expected_error=(
            errors.InvalidArgumentError,
            r"Cannot batch tensors with different shapes in component 0. "
            r"First element had shape \[3\] and element 2 had shape \[4\]."))

  @combinations.generate(test_base.default_test_combinations())
  def testWindowIgnoreErrors(self):
    input_values = np.float32([1., np.nan, 2., np.nan, 3.])
    dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
        lambda x: array_ops.check_numerics(x, "message")).window(
            size=2, shift=2, stride=2,
            drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
    self.assertDatasetProduces(
        dataset, expected_output=[np.float32([1., 2.]),
                                  np.float32([2., 3.])])

  @combinations.generate(test_base.default_test_combinations())
  def testNestedOutput(self):
    if not context.executing_eagerly():
      self.skipTest("self.evaluate() does not work with a dataset")
    dataset = dataset_ops.Dataset.range(100)
    dataset = dataset_ops.Dataset.zip((dataset, dataset)).window(10)
    for i, nested_dataset in enumerate(dataset):
      x, y = nested_dataset
      self.assertDatasetProduces(x, range(i*10, (i+1)*10))
      self.assertDatasetProduces(y, range(i*10, (i+1)*10))
Пример #20
0
class LegacySnapshotDatasetTest(
        reader_dataset_ops_test_base.TFRecordDatasetTestBase,
        parameterized.TestCase):
    def setUp(self):
        super(LegacySnapshotDatasetTest, self).setUp()
        self.removeTFRecords()
        tmpdir = self.get_temp_dir()
        tmpdir = os.path.join(tmpdir, "snapshot")
        os.mkdir(tmpdir)
        self.snapshot_dir = tmpdir

    def tearDown(self):
        super(LegacySnapshotDatasetTest, self).tearDown()
        shutil.rmtree(self.snapshot_dir)

    def removeTFRecords(self):
        for filename in self.test_filenames:
            os.remove(filename)
        self.test_filenames = []

    def setUpTFRecord(self, num_files=10, num_records=10):
        self._num_files = num_files
        self._num_records = num_records
        self.test_filenames = self._createFiles()

    def makeSnapshotDirectory(self):
        return self.snapshot_dir

    def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
                                        num_runs_per_fp, num_snapshot_files):
        dirlist_raw = os.listdir(directory)
        dirlist = []

        # Ignore the graphdef pbtxts we write for debugging purposes.
        for i in range(len(dirlist_raw)):
            if not dirlist_raw[i].endswith("-graph.pbtxt"):
                dirlist.append(dirlist_raw[i])

        self.assertLen(dirlist, num_fingerprints)

        for i in range(num_fingerprints):
            fingerprint_dir = os.path.join(directory, dirlist[i])
            fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
            self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1)
            self.assertEqual(fingerprint_dir_list[num_runs_per_fp],
                             "snapshot.metadata")

            for j in range(num_runs_per_fp):
                run_dir = os.path.join(fingerprint_dir,
                                       fingerprint_dir_list[j])
                run_dirlist = sorted(os.listdir(run_dir))
                self.assertLen(run_dirlist, num_snapshot_files)

                file_counter = 0
                for filename in run_dirlist:
                    self.assertEqual(filename, "%08d.snapshot" % file_counter)
                    file_counter += 1

    @combinations.generate(test_base.default_test_combinations())
    def testWriteDifferentPipelinesInOneDirectory(self):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(1000)))

        dataset = dataset_ops.Dataset.range(1001)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(1001)))

        self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotMultipleSimultaneous(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir))
        next1 = self.getNext(dataset1)

        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
        next2 = self.getNext(dataset2)

        for i in range(0, 1000):
            self.assertEqual(i, self.evaluate(next1()))
            self.assertEqual(i, self.evaluate(next2()))

        # we check that only one copy of the metadata has been written, and the
        # one that lost the race would be in passthrough mode.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testGetNextCreatesDir(self):
        tmpdir = self.snapshot_dir

        # We create two iterators but call getNext on only one.
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.legacy_snapshot(tmpdir))
        next1 = self.getNext(dataset1)

        dataset2 = dataset_ops.Dataset.range(1001)
        dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
        _ = self.getNext(dataset2)

        for _ in range(1000):
            self.evaluate(next1())

        # We check that only one directory is created.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
                snapshot.COMPRESSION_SNAPPY
            ])))
    def testWriteSnapshotSimpleSuccessful(self, compression):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset, list(range(1000)))

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
                snapshot.COMPRESSION_SNAPPY
            ])))
    def testWriteSnapshotRepeatAfterwards(self, compression):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
                snapshot.COMPRESSION_SNAPPY
            ])))
    def testWriteSnapshotMixTypes(self, compression):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)

        def map_fn(x):
            return (x, string_ops.as_string(x), string_ops.as_string(2 * x),
                    2 * x)

        dataset = dataset.map(map_fn)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        dataset = dataset.repeat(10)

        expected = []
        for i in range(10):
            expected.append((i, str(i), str(2 * i), 2 * i))
        self.assertDatasetProduces(dataset, expected * 10)

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testSpecifySnapshotNameWriteAndRead(self):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     snapshot_name="my_custom_snapshot"))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
        self.assertTrue(
            os.path.exists(os.path.join(tmpdir, "custom-my_custom_snapshot")))
        self.assertTrue(
            os.path.exists(
                os.path.join(tmpdir, "custom-my_custom_snapshot", "custom")))

    @combinations.generate(test_base.default_test_combinations())
    def testForcePassthroughMode(self):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, mode="passthrough"))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0)

    @combinations.generate(test_base.default_test_combinations())
    def testForceWriteMode(self):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir, mode="write"))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        # We will end up writing 10 different runs.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testForceReadMode(self):
        tmpdir = self.snapshot_dir

        # We write a copy of the snapshot first.
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     mode="write",
                                     snapshot_name="my_custom_snapshot"))
        self.assertDatasetProduces(dataset, list(range(10)))

        # We move the run to a new name.
        shutil.move(os.path.join(tmpdir, "custom-my_custom_snapshot"),
                    os.path.join(tmpdir, "custom-my_custom_snapshot_2"))

        # Even though the snapshot.metadata is pointing to the old run that no
        # longer exists after we moved, we force it to read from the run we specify.
        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     mode="read",
                                     snapshot_name="my_custom_snapshot_2"))
        self.assertDatasetProduces(dataset, list(range(10)))

        # We should still have one snapshot and one run.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testForceReadNonexistentSnapshot(self):
        tmpdir = self.snapshot_dir
        dataset = dataset_ops.Dataset.range(10)
        with self.assertRaises(errors.NotFoundError):
            dataset = dataset.apply(
                snapshot.legacy_snapshot(tmpdir, mode="read"))
            get_next = self.getNext(dataset)
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testForceReadNonexistentNamedSnapshot(self):
        tmpdir = self.snapshot_dir
        dataset = dataset_ops.Dataset.range(10)
        with self.assertRaises(errors.NotFoundError):
            dataset = dataset.apply(
                snapshot.legacy_snapshot(
                    tmpdir,
                    mode="read",
                    snapshot_name="my_nonexistent_snapshot"))
            get_next = self.getNext(dataset)
            self.evaluate(get_next())

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
                snapshot.COMPRESSION_SNAPPY
            ])))
    def testReadSnapshotBackAfterWrite(self, compression):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset2, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testReadShuffledSnapshotAfterWrite(self):
        self.setUpTFRecord(num_files=10, num_records=50)
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 50)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, shard_size_bytes=100))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=100,
                                     shuffle_on_read=True))
        next2 = self.getNext(dataset2)

        res1 = self.evaluate(next2())
        res2 = self.evaluate(next2())
        res3 = self.evaluate(next2())
        res4 = self.evaluate(next2())
        res5 = self.evaluate(next2())

        # make sure that we don't read the file back in the same order.
        self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5])

        # make sure all the elements are still there
        dataset3 = core_readers._TFRecordDataset(filenames)
        dataset3 = dataset3.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=100,
                                     shuffle_on_read=True))
        self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testReadShuffledSnapshotWithSeedAfterWrite(self):
        self.setUpTFRecord(num_files=10, num_records=50)
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 50)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=10,
                                     shuffle_on_read=True,
                                     shuffle_seed=123456))
        next2 = self.getNext(dataset2)

        dataset3 = core_readers._TFRecordDataset(filenames)
        dataset3 = dataset3.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=10,
                                     shuffle_on_read=True,
                                     shuffle_seed=123456))
        next3 = self.getNext(dataset3)

        # make sure that the items are read back in the same order for both datasets
        for _ in range(500):
            res2 = self.evaluate(next2())
            res3 = self.evaluate(next3())
            self.assertEqual(res2, res3)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
                snapshot.COMPRESSION_SNAPPY
            ])))
    def testReadSnapshotParallelAfterWrite(self, compression):
        self.setUpTFRecord(10, 4000)
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 4000)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=1024 * 1024,
                                     num_reader_threads=2,
                                     reader_buffer_size=10,
                                     compression=compression))
        self.assertDatasetProduces(dataset, expected, assert_items_equal=True)

        # remove the original files and try to read the data back only from
        # snapshot.
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=1024 * 1024,
                                     num_reader_threads=2,
                                     reader_buffer_size=10,
                                     compression=compression))
        self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)

    # Not testing Snappy here because Snappy reads currently require a lot of
    # memory.
    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.times(
                combinations.combine(compression=[
                    snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP
                ]),
                combinations.combine(threads=2, size=[1, 2]) +
                combinations.combine(threads=8, size=[1, 4, 8]))))
    def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
                                                    size):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     compression=compression,
                                     num_writer_threads=threads,
                                     writer_buffer_size=size))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from
        # snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testSameFingerprintWithDifferentInitializationOrder(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(0, 100)
        dataset2 = dataset_ops.Dataset.range(100, 200)
        dataset3 = dataset_ops.Dataset.range(200, 300)

        dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(300)))

        dataset4 = dataset_ops.Dataset.range(200, 300)
        dataset5 = dataset_ops.Dataset.range(100, 200)
        dataset6 = dataset_ops.Dataset.range(0, 100)

        dataset = dataset6.concatenate(dataset5).concatenate(dataset4)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(300)))

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testExpiredSnapshotRewrite(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     pending_snapshot_expiry_seconds=1))
        next1 = self.getNext(dataset1)

        # Don't finish reading dataset1, so it is never finalized
        for _ in range(500):
            self.evaluate(next1())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

        time.sleep(2)

        # Creating dataset2 after we run through dataset1 due to eager mode, where
        # the snapshot state is determined immediately upon dataset creation. We
        # only want to determine the snapshot state for dataset2 after the first
        # snapshot has expired.
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     pending_snapshot_expiry_seconds=1))
        next2 = self.getNext(dataset2)

        for _ in range(500):
            self.evaluate(next2())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testSnapshotArgsCreateNewSnapshot(self):
        tmpdir = self.snapshot_dir

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(
            snapshot.legacy_snapshot(tmpdir, shard_size_bytes=10000))
        next1 = self.getNext(dataset1)

        for _ in range(1000):
            self.evaluate(next1())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

        # Create second snapshot with a different shard_size_bytes
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset1.apply(
            snapshot.legacy_snapshot(tmpdir, shard_size_bytes=20000))
        next2 = self.getNext(dataset2)

        for _ in range(1000):
            self.evaluate(next2())
        self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP,
                snapshot.COMPRESSION_SNAPPY
            ])))
    def testSpecifyShardSize(self, compression):
        tmpdir = self.snapshot_dir

        dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
        dataset = dataset.map(
            lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
        dataset = dataset.repeat(10)
        dataset = dataset.apply(
            snapshot.legacy_snapshot(tmpdir,
                                     shard_size_bytes=10 * 1024 * 1024,
                                     compression=compression))
        next_fn = self.getNext(dataset)

        for _ in range(10):
            self.evaluate(next_fn())

        num_files = 1
        if compression == snapshot.COMPRESSION_NONE:
            num_files = 3
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files)

    @combinations.generate(test_base.default_test_combinations())
    def testAdditionalOperationsAfterReadBack(self):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.snapshot_dir
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(snapshot.legacy_snapshot(tmpdir))
        self.assertDatasetProduces(dataset2, expected)

        expected_after = [
            b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        dataset3 = core_readers._TFRecordDataset(filenames)
        dataset3 = dataset3.apply(snapshot.legacy_snapshot(tmpdir))
        dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000))
        self.assertDatasetProduces(dataset3, expected_after)
Пример #21
0
class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testOptionsDefault(self):
        ds = dataset_ops.Dataset.range(0)
        self.assertEqual(dataset_ops.Options(), ds.options())

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsOnce(self):
        options = dataset_ops.Options()
        ds = dataset_ops.Dataset.range(0).with_options(options).cache()
        self.assertEqual(options, ds.options())

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsTwiceSame(self):
        options = dataset_ops.Options()
        options.experimental_optimization.autotune = True
        ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
            options)
        self.assertEqual(options, ds.options())

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsTwiceDifferent(self):
        options1 = dataset_ops.Options()
        options1.experimental_optimization.autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_deterministic = False
        ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
            options2)
        self.assertTrue(ds.options().experimental_optimization.autotune)
        # Explicitly check that flag is False since assertFalse allows None
        self.assertIs(ds.options().experimental_deterministic, False)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsTwiceDifferentError(self):
        options1 = dataset_ops.Options()
        options1.experimental_optimization.autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_optimization.autotune = False
        with self.assertRaisesRegexp(ValueError,
                                     "Cannot merge incompatible values"):
            dataset_ops.Dataset.range(0).with_options(options1).with_options(
                options2)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsMergeOptionsFromMultipleInputs(self):
        options1 = dataset_ops.Options()
        options1.experimental_optimization.autotune = True
        options2 = dataset_ops.Options()
        options2.experimental_deterministic = True
        ds = dataset_ops.Dataset.zip(
            (dataset_ops.Dataset.range(0).with_options(options1),
             dataset_ops.Dataset.range(0).with_options(options2)))
        self.assertTrue(ds.options().experimental_optimization.autotune)
        self.assertTrue(ds.options().experimental_deterministic)

    @combinations.generate(test_base.default_test_combinations())
    def testOptionsHaveDefaults(self):
        options1 = dataset_ops.Options()
        options2 = dataset_ops.Options()
        self.assertIsNot(options1.experimental_optimization,
                         options2.experimental_optimization)
        self.assertIsNot(options1.experimental_stats,
                         options2.experimental_stats)
        self.assertIsNot(options1.experimental_threading,
                         options2.experimental_threading)
        self.assertEqual(options1.experimental_optimization,
                         optimization_options.OptimizationOptions())
        self.assertEqual(options1.experimental_stats,
                         stats_options.StatsOptions())
        self.assertEqual(options1.experimental_threading,
                         threading_options.ThreadingOptions())
Пример #22
0
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
                          parameterized.TestCase):
    def setUp(self):
        super(SnapshotDatasetTest, self).setUp()
        tmpdir = self.get_temp_dir()
        tmpdir = os.path.join(tmpdir, "snapshot")
        os.mkdir(tmpdir)
        self._snapshot_dir = tmpdir

    def tearDown(self):
        super(SnapshotDatasetTest, self).tearDown()
        shutil.rmtree(self._snapshot_dir)

    def createTFRecords(self, num_files=10, num_records=100):
        self._num_files = num_files
        self._num_records = num_records
        self._test_filenames = self._createFiles()

    def removeTFRecords(self):
        for filename in self._test_filenames:
            os.remove(filename)
        self._test_filenames = []
        self._num_files = None
        self._num_records = None

    def assertDatasetProducesSet(self, dataset, expected):
        actual = []
        next_fn = self.getNext(dataset)
        for _ in range(len(expected)):
            elem = self.evaluate(next_fn())
            actual.append(elem)
        self.assertCountEqual(actual, expected)
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(next_fn())

    def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
                                        num_runs_per_fingerprint,
                                        num_snapshot_shards_per_run):
        dirlist_raw = os.listdir(directory)
        dirlist = []

        # Ignore the graphdef pbtxts we write for debugging purposes.
        for i in range(len(dirlist_raw)):
            if not dirlist_raw[i].endswith("-graph.pbtxt"):
                dirlist.append(dirlist_raw[i])

        self.assertLen(dirlist, num_fingerprints)

        for i in range(num_fingerprints):
            fingerprint_dir = os.path.join(directory, dirlist[i])
            fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
            self.assertLen(fingerprint_dir_list, num_runs_per_fingerprint + 1)
            self.assertEqual(fingerprint_dir_list[num_runs_per_fingerprint],
                             "snapshot.metadata")

            for j in range(num_runs_per_fingerprint):
                run_dir = os.path.join(fingerprint_dir,
                                       fingerprint_dir_list[j])
                run_dirlist = sorted(os.listdir(run_dir))
                self.assertLen(run_dirlist, num_snapshot_shards_per_run)

                file_counter = 0
                for filename in run_dirlist:
                    self.assertEqual(filename, "%08d.shard" % file_counter)
                    file_counter += 1

    @combinations.generate(test_base.default_test_combinations())
    def testCreateSnapshotDataset(self):
        dataset = dataset_ops.Dataset.from_tensors([1, 2, 3])
        dataset.apply(snapshot.snapshot(self._snapshot_dir))

    @combinations.generate(test_base.default_test_combinations())
    def testReadSnapshotDatasetDefault(self):
        self.createTFRecords()
        filenames = self._test_filenames
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 100)
        ]

        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset, expected)
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

        self.removeTFRecords()
        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset2, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testReadSnapshotDatasetAutoWriteSnappyRead(self):
        self.createTFRecords()
        filenames = self._test_filenames
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 100)
        ]

        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(self._snapshot_dir, compression="AUTO"))
        self.assertDatasetProduces(dataset, expected)

        self.removeTFRecords()
        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(self._snapshot_dir, compression="SNAPPY"))
        self.assertDatasetProduces(dataset2, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testReadSnapshotDatasetCustomShardFn(self):
        self.createTFRecords()
        filenames = self._test_filenames
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 100)
        ]

        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(self._snapshot_dir,
                              shard_func=lambda _: np.int64(0)))
        self.assertDatasetProduces(dataset, expected)
        self.assertSnapshotDirectoryContains(self._snapshot_dir,
                                             num_fingerprints=1,
                                             num_runs_per_fingerprint=1,
                                             num_snapshot_shards_per_run=1)

        self.removeTFRecords()
        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(self._snapshot_dir, shard_func=lambda _: 0))
        self.assertDatasetProduces(dataset2, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testReadSnapshotDatasetCustomReaderFn(self):
        self.createTFRecords()
        filenames = self._test_filenames
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 100)
        ]

        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(
                self._snapshot_dir,
                reader_func=(
                    lambda ds: ds.interleave(  # pylint:disable=g-long-lambda
                        lambda x: x,
                        cycle_length=4,
                        num_parallel_calls=4))))
        self.assertDatasetProduces(dataset, expected)
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

        self.removeTFRecords()
        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(
                self._snapshot_dir,
                reader_func=(
                    lambda ds: ds.interleave(  # pylint:disable=g-long-lambda
                        lambda x: x,
                        cycle_length=4,
                        num_parallel_calls=4))))
        self.assertDatasetProducesSet(dataset2, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testSnapshotDatasetInvalidShardFn(self):
        dataset = dataset_ops.Dataset.range(1000)
        with self.assertRaises(TypeError):
            dataset = dataset.apply(
                snapshot.snapshot(self._snapshot_dir,
                                  shard_func=lambda _: "invalid_fn"))
            next_fn = self.getNext(dataset)
            self.evaluate(next_fn())

    @combinations.generate(test_base.default_test_combinations())
    def testSnapshotDatasetInvalidReaderFn(self):
        dataset = dataset_ops.Dataset.range(1000)
        with self.assertRaises(TypeError):
            dataset = dataset.apply(
                snapshot.snapshot(self._snapshot_dir,
                                  reader_func=lambda x: x + 1))
            next_fn = self.getNext(dataset)
            self.evaluate(next_fn())

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotDatasetSimple(self):
        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset, list(range(1000)))
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotDatasetMultipleFingerprints(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset1, list(range(1000)))

        dataset2 = dataset_ops.Dataset.range(2000)
        dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset2, list(range(2000)))

        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=2,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotDatasetSameFingerprintMultipleCompleteRuns(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset1, list(range(1000)))
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset2, list(range(1000)))

        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotDatasetSameFingerprintIncompleteRunRestart(self):
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.snapshot(self._snapshot_dir))
        next1 = self.getNext(dataset1)
        for i in range(500):
            self.assertEqual(i, self.evaluate(next1()))

        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(snapshot.snapshot(self._snapshot_dir))
        next2 = self.getNext(dataset2)
        for i in range(500):
            self.assertEqual(i, self.evaluate(next2()))

        for i in range(500, 1000):
            self.assertEqual(i, self.evaluate(next1()))
            self.assertEqual(i, self.evaluate(next2()))

        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=2,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotCustomShardFunction(self):
        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.enumerate()
        dataset = dataset.apply(
            snapshot.snapshot(self._snapshot_dir,
                              shard_func=lambda i, _: i % 2))
        dataset = dataset.map(lambda _, elem: elem)
        self.assertDatasetProduces(dataset, list(range(1000)))
        self.assertSnapshotDirectoryContains(self._snapshot_dir,
                                             num_fingerprints=1,
                                             num_runs_per_fingerprint=1,
                                             num_snapshot_shards_per_run=2)

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotDatasetWithTuples(self):
        dataset1 = dataset_ops.Dataset.range(0, 1000)
        dataset2 = dataset_ops.Dataset.range(1000, 2000)
        dataset3 = dataset_ops.Dataset.range(2000, 3000)
        dataset4 = dataset_ops.Dataset.range(3000, 4000)

        dataset = dataset_ops.Dataset.zip(
            (dataset1, dataset2, dataset3, dataset4))
        dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))

        expected = list(
            zip(range(0, 1000), range(1000, 2000), range(2000, 3000),
                range(3000, 4000)))
        self.assertDatasetProduces(dataset, expected)
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotShuffleSameFingerprint(self):
        def make_dataset():
            dataset = dataset_ops.Dataset.range(1000)
            dataset = dataset.shuffle(1000)
            dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
            return dataset

        dataset1 = make_dataset()
        self.assertDatasetProducesSet(dataset1, list(range(1000)))
        dataset2 = make_dataset()
        self.assertDatasetProducesSet(dataset2, list(range(1000)))
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

    @combinations.generate(test_base.default_test_combinations())
    def testReadUsingFlatMap(self):
        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProduces(dataset, list(range(1000)))
        flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(
            lambda x: x)
        self.assertDatasetProduces(flat_map, list(range(1000)))
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())

    @combinations.generate(test_base.default_test_combinations())
    def testReadOptimizableUsingFlatMap(self):
        if context.context().use_tfrt:
            self.skipTest("b/177260096: Flaky test.")
        dataset = dataset_ops.Dataset.range(100)
        # Will be optimized into ShuffleAndRepeat.
        dataset = dataset.shuffle(10)
        dataset = dataset.repeat(2)
        dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
        self.assertDatasetProducesSet(dataset, 2 * list(range(100)))
        flat_map = dataset_ops.Dataset.from_tensors(dataset).flat_map(
            lambda x: x)
        self.assertDatasetProducesSet(flat_map, 2 * list(range(100)))
        self.assertSnapshotDirectoryContains(
            self._snapshot_dir,
            num_fingerprints=1,
            num_runs_per_fingerprint=1,
            num_snapshot_shards_per_run=multiprocessing.cpu_count())
Пример #23
0
class DataServiceOpsTest(data_service_test_base.TestBase,
                         parameterized.TestCase):
    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           data_service_test_base.all_cluster_configurations())
    )
    def testDistributeBasic(self, work_dir, fault_tolerant_mode):
        cluster = data_service_test_base.TestCluster(
            num_workers=1,
            work_dir=work_dir,
            fault_tolerant_mode=fault_tolerant_mode)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(compression=[None, "AUTO"])))
    def testDistributeCompression(self, compression):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 compression=compression)
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidCompression(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        with self.assertRaisesRegex(ValueError,
                                    "Invalid compression argument"):
            self.make_distributed_range_dataset(10, cluster, compression="foo")

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeSparse(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        element = sparse_tensor.SparseTensor(indices=[[0]],
                                             values=constant_op.constant(
                                                 [0], dtype=dtypes.int32),
                                             dense_shape=[1])
        ds = dataset_ops.Dataset.from_tensors(element)
        ds = self.make_distributed_dataset(ds, cluster)
        results = [sparse_ops.sparse_tensor_to_dense(elem) for elem in ds]
        self.assertAllEqual(results, [[0]])

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeRagged(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.from_tensor_slices([1, 5, 3, 2, 8])
        ds = ds.map(math_ops.range)
        ds = ds.apply(batching.dense_to_ragged_batch(2))
        ds = self.make_distributed_dataset(ds, cluster)
        results = [elem.to_tensor() for elem in ds]
        self.assertAllEqual(results[0], [[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]])
        self.assertAllEqual(results[1], [[0, 1, 2], [0, 1, 0]])
        self.assertAllEqual(results[2], [[0, 1, 2, 3, 4, 5, 6, 7]])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                init_source=["textfile", "keyvaluetensor", "dataset"])))
    def testDistributeLookupTable(self, init_source):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        initializer = self.lookupTableInitializer(init_source, [10, 11])
        table = lookup_ops.StaticHashTable(initializer, -1)
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.evaluate(lookup_ops.tables_initializer())
        self.assertDatasetProduces(ds, [10, 11, -1],
                                   requires_initialization=True)

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           combinations.combine(value_rank=[0, 1])))
    def testDistributeMutableHashTable(self, value_rank):
        def value(v):
            for _ in range(value_rank):
                v = [v, v]
            return v

        v1 = value(10)
        v2 = value(11)
        default_value = value(-1)

        cluster = data_service_test_base.TestCluster(num_workers=1)
        table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.int64,
                                            default_value)
        self.evaluate(table.insert([0, 1], [v1, v2]))
        ds = dataset_ops.Dataset.range(3)
        ds = ds.map(table.lookup)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [v1, v2, default_value],
                                   requires_initialization=True)

    @combinations.generate(test_base.default_test_combinations())
    def testDifferentShuffleOrders(self):
        random_seed.set_random_seed(None)
        num_elements = 100
        cluster = data_service_test_base.TestCluster(num_workers=2)
        ds = dataset_ops.Dataset.range(num_elements)
        ds = ds.shuffle(num_elements)
        ds = self.make_distributed_dataset(ds, cluster)
        output = self.getDatasetOutput(ds)

        # The output will be two sequences of range(num_elements)
        # non-deterministically interleaved together. If the orders of the elements
        # were the same, first_order and second_order computed below will be equal.
        first_order = {}
        second_order = {}
        for element in output:
            if element in first_order:
                second_order[element] = len(second_order)
            else:
                first_order[element] = len(first_order)
        self.assertNotEqual(first_order, second_order)

    @combinations.generate(test_base.default_test_combinations())
    def testMultipleEpochs(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 3
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        for _ in range(10):
            self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testRepeatedDataset(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_repetitions *
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testConcurrentEpoch(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        num_datasets = 3
        get_nexts = []
        results = []
        for _ in range(num_datasets):
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            get_nexts.append(self.getNext(ds))
            results.append([])

        for _ in range(num_elements):
            for dataset_ind in range(num_datasets):
                result = self.evaluate(get_nexts[dataset_ind]())
                results[dataset_ind].append(result)
        for result in results:
            self.assertEqual(list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testMultiWorker(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements, cluster)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testMaxOutstandingRequests(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 max_outstanding_requests=1)
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.eager_only_combinations())
    def testInsideFunction(self):
        num_workers = 3
        cluster = data_service_test_base.TestCluster(num_workers=num_workers)
        num_elements = 10

        @def_function.function
        def f():
            ds = self.make_distributed_range_dataset(num_elements, cluster)
            result = tensor_array_ops.TensorArray(dtypes.int64,
                                                  size=num_workers *
                                                  num_elements,
                                                  dynamic_size=True)
            i = 0
            for elem in ds:
                result = result.write(i, elem)
                i += 1
            return result.stack()

        result = list(f().numpy())
        self.assertCountEqual(num_workers * list(range(num_elements)), result)

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 1000

        def make_ds():
            return dataset_ops.Dataset.range(num_elements).shuffle(
                num_elements)

        ds1 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        ds2 = self.make_distributed_dataset(make_ds(),
                                            cluster,
                                            job_name="job_name")
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        results = []
        for _ in range(num_elements // 5):
            results.append(self.evaluate(get_next_1()))
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(list(range(num_elements)), results)

    @combinations.generate(test_base.default_test_combinations())
    def testDifferentJobNames(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name1")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name2")
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, list(range(num_elements)))

    @combinations.generate(test_base.eager_only_combinations())
    def testSharedJobNameMultiIteration(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 10
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        # iteration 1
        self.assertDatasetProduces(ds1, list(range(num_elements)))
        self.assertDatasetProduces(ds2, [])
        # iteration 2
        self.assertDatasetProduces(ds2, list(range(num_elements)))
        self.assertDatasetProduces(ds1, [])

    @combinations.generate(test_base.default_test_combinations())
    def testSharedJobNameRepeat(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        num_repetitions = 3
        ds1 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds1 = ds1.repeat(num_repetitions)
        ds2 = self.make_distributed_range_dataset(num_elements,
                                                  cluster,
                                                  job_name="job_name")
        ds2 = ds2.repeat(num_repetitions)
        results = []
        get_next_1 = self.getNext(ds1)
        get_next_2 = self.getNext(ds2)
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_1()))
        for _ in range((num_elements * num_repetitions) // 5):
            results.append(self.evaluate(get_next_2()))
        results += self.getIteratorOutput(get_next_1)
        results += self.getIteratorOutput(get_next_2)
        self.assertCountEqual(num_repetitions * list(range(num_elements)),
                              results)

    @combinations.generate(
        combinations.times(test_base.eager_only_combinations(),
                           combinations.combine(job_name=[None, "test"])))
    def testGcUnusedJob(self, job_name):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 100
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 job_name=job_name)
        it = iter(ds)
        self.assertEqual(next(it).numpy(), 0)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)
        del it
        while cluster.workers[0].num_tasks() > 0:
            time.sleep(0.1)

    @combinations.generate(test_base.eager_only_combinations())
    def testDontGcUsedJob(self):
        cluster = data_service_test_base.TestCluster(
            num_workers=1, job_gc_check_interval_ms=50, job_gc_timeout_ms=20)
        num_elements = 10
        it1 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test1"))
        it2 = iter(
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        it3 = iter(  # this iterator keeps the task alive. pylint: disable=unused-variable
            self.make_distributed_range_dataset(num_elements,
                                                cluster,
                                                job_name="test2"))
        self.assertEqual(cluster.workers[0].num_tasks(), 2)
        del it1
        del it2
        # Check that only the first job is gced. The second job will not be gced
        # because there is still an outstanding iterator for it.
        while cluster.workers[0].num_tasks() > 1:
            time.sleep(0.1)
        self.assertEqual(cluster.workers[0].num_tasks(), 1)

    @combinations.generate(test_base.default_test_combinations())
    def testApplyDeterminismOption(self):
        elements = list(range(10))
        cluster = data_service_test_base.TestCluster(num_workers=1)

        def dataset_fn(delay_ms):
            def interleave_fn(x):
                ds = dataset_ops.Dataset.from_tensors(x)
                if math_ops.equal(x, 0):
                    ds = ds.apply(testing.sleep(delay_ms * 1000))
                else:
                    ds = ds.apply(testing.sleep(0))
                return ds

            ds = dataset_ops.Dataset.from_tensor_slices(elements)
            ds = ds.interleave(interleave_fn,
                               cycle_length=10,
                               num_parallel_calls=10)
            opts = dataset_ops.Options()
            opts.experimental_deterministic = False
            ds = ds.with_options(opts)
            ds = self.make_distributed_dataset(ds, cluster)
            return ds

        self.checkDeterminism(dataset_fn=dataset_fn,
                              expect_determinism=False,
                              expected_elements=elements)

    def run_stateful(self, external_state_policy):
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements).map(
            lambda _: random_ops.random_uniform(()))

        options = dataset_ops.Options()
        options.experimental_external_state_policy = external_state_policy
        ds = ds.with_options(options)

        cluster = data_service_test_base.TestCluster(num_workers=3)
        ds = self.make_distributed_dataset(ds, cluster)
        self.getDatasetOutput(ds)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(external_state_policy=[
                distribute_options.ExternalStatePolicy.IGNORE,
                distribute_options.ExternalStatePolicy.WARN
            ])))
    def testStatefulNoError(self, external_state_policy):
        self.run_stateful(external_state_policy)

    @combinations.generate(test_base.default_test_combinations())
    def testStatefulError(self):
        with self.assertRaises(errors.FailedPreconditionError):
            self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeFromInterleave(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(2)

        def interleave_fn(_):
            dataset = dataset_ops.Dataset.range(2)
            self.make_distributed_dataset(dataset, cluster)
            return dataset

        ds = ds.interleave(interleave_fn, cycle_length=2)
        self.assertDatasetProduces(ds, [0, 0, 1, 1])

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeNonStringAddresses(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError, "service must be a string"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=1))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeEmptyAddress(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               "service must not be empty"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service=""))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeExplicitProtocol(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(10)
        ds = ds.apply(
            data_service_ops.distribute(processing_mode="parallel_epochs",
                                        service="grpc://" +
                                        cluster.dispatcher_address()))
        self.assertDatasetProduces(ds, list(range(10)))

    @combinations.generate(test_base.default_test_combinations())
    def testDistributeInvalidProtocol(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(
                errors.NotFoundError,
                "No credentials factory has been registered for protocol grp"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="parallel_epochs",
                                            service="grp://" +
                                            cluster.dispatcher_address()))
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.eager_only_combinations())
    def testDistributeInvalidProcessingMode(self):
        ds = dataset_ops.Dataset.range(10)
        with self.assertRaisesRegex(ValueError,
                                    "invalid is not a valid processing mode"):
            ds = ds.apply(
                data_service_ops.distribute(processing_mode="invalid",
                                            service="grpc://localhost:5000"))

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasets(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1, cluster, processing_mode="distributed_epoch")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        self.assertDatasetProduces(ds,
                                   list(
                                       zip(range(num_elements),
                                           range(num_elements))),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testZipDifferentProcessingModesDatasetsSharedJobName(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)
        num_elements = 100
        ds1 = dataset_ops.Dataset.range(num_elements)
        ds1 = self.make_distributed_dataset(
            ds1,
            cluster,
            processing_mode="distributed_epoch",
            job_name="job_name")
        ds2 = dataset_ops.Dataset.range(num_elements)
        ds2 = self.make_distributed_dataset(ds2,
                                            cluster,
                                            processing_mode="parallel_epochs",
                                            job_name="job_name")
        ds = dataset_ops.Dataset.zip((ds1, ds2))
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "but there is already an existing job"):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetId(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            ds.element_spec)
        self.assertDatasetProduces(from_dataset_id_ds,
                                   list(range(num_elements)))

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdMultipleComponents(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        ds = dataset_ops.Dataset.zip({"a": (ds, ds), "b": ds})
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            ds.element_spec)
        output = self.getDatasetOutput(from_dataset_id_ds)
        for i in range(num_elements):
            self.assertEqual(i, output[i]["a"][0])
            self.assertEqual(i, output[i]["a"][1])
            self.assertEqual(i, output[i]["b"])

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdWrongElementSpec(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        dataset_id = data_service_ops.register_dataset(
            cluster.dispatcher_address(), ds)
        wrong_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            wrong_spec)
        with self.assertRaisesRegex(errors.FailedPreconditionError,
                                    "Expected a tensor of type variant"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testFromDatasetIdNotRegistered(self):
        cluster = data_service_test_base.TestCluster(num_workers=1)

        dataset_id = 0
        element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
        from_dataset_id_ds = data_service_ops.from_dataset_id(
            "parallel_epochs", cluster.dispatcher_address(), dataset_id,
            element_spec)
        with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
            self.evaluate(self.getNext(from_dataset_id_ds)())

    @combinations.generate(test_base.default_test_combinations())
    def testCancellation(self):
        self.skipTest("b/162521601")
        sleep_microseconds = int(1e6) * 1000

        cluster = data_service_test_base.TestCluster(num_workers=1)
        # Create a dataset which produces the first element quickly, and the second
        # element slowly. Fetching the first element triggers prefetching of the
        # second element, which we should be able to cancel.
        slow = dataset_ops.Dataset.range(1)
        slow = slow.apply(testing.sleep(sleep_microseconds))
        ds = dataset_ops.Dataset.range(1).concatenate(slow)
        ds = self.make_distributed_dataset(ds, cluster)
        ds = ds.prefetch(1)
        get_next = self.getNext(ds)
        self.assertEqual(0, self.evaluate(get_next()))
        # Without properly implemented cancellation, we will hang here while trying
        # to garbage collect the dataset iterator.

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterEquivalentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(10)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_1)
        id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_2)
        self.assertEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testRegisterDifferentDatasets(self):
        ds_1 = dataset_ops.Dataset.range(10)
        ds_2 = dataset_ops.Dataset.range(20)
        cluster = data_service_test_base.TestCluster(num_workers=1)
        id_1 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_1)
        id_2 = data_service_ops.register_dataset(cluster.dispatcher_address(),
                                                 ds_2)
        self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))

    @combinations.generate(test_base.default_test_combinations())
    def testTwoLevelDistribute(self):
        cluster_1_size = 3
        cluster_1 = data_service_test_base.TestCluster(
            num_workers=cluster_1_size)
        cluster_2 = data_service_test_base.TestCluster(num_workers=1)
        num_sizes = 10
        size_repeats = 5
        strings = ["a" * i for i in range(num_sizes)] * size_repeats
        ds = dataset_ops.Dataset.from_tensor_slices(strings)
        ds = ds.shuffle(len(strings))
        ds = self.make_distributed_dataset(ds, cluster_1)
        # Large enough so that all strings of the same size are windowed together.
        window_size = cluster_1_size * size_repeats
        batch_size = size_repeats

        def key_func(x):
            return math_ops.cast(string_ops.string_length_v2(x), dtypes.int64)

        ds = ds.apply(
            grouping.group_by_window(
                key_func=key_func,
                reduce_func=lambda _, x: x.batch(batch_size),
                window_size=window_size))
        ds = self.make_distributed_dataset(ds, cluster_2)

        get_next = self.getNext(ds)
        for _ in range(num_sizes):
            element = self.evaluate(get_next())
            for _ in range(1, cluster_1_size):
                self.assertAllEqual(self.evaluate(get_next()), element)
        self.assertEmpty(self.getIteratorOutput(get_next))

    @combinations.generate(
        combinations.times(test_base.default_test_combinations()))
    def testDistributeLargeGraph(self):
        cluster = data_service_test_base.TestCluster(num_workers=1,
                                                     work_dir=NO_WORK_DIR,
                                                     fault_tolerant_mode=False)
        # Larger than default OSS grpc message size limit of 4MB.
        tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32)
        ds = dataset_ops.Dataset.from_tensors(tensor)
        ds = self.make_distributed_dataset(ds, cluster)
        self.assertDatasetProduces(ds, [tensor])
class DenseToSparseBatchTest(test_base.DatasetTestBase,
                             parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testBasic(self):
        components = np.random.randint(12, size=(100, )).astype(np.int32)
        dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.fill([x], x)).apply(
                batching.dense_to_sparse_batch(4, [12]))
        get_next = self.getNext(dataset)

        for start in range(0, len(components), 4):
            results = self.evaluate(get_next())
            self.assertAllEqual(
                [[i, j] for i, c in enumerate(components[start:start + 4])
                 for j in range(c)], results.indices)
            self.assertAllEqual(
                [c for c in components[start:start + 4] for _ in range(c)],
                results.values)
            self.assertAllEqual([min(4,
                                     len(components) - start), 12],
                                results.dense_shape)

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testWithUnknownShape(self):
        components = np.random.randint(5, size=(40, )).astype(np.int32)
        dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.fill([x, x], x)).apply(
                batching.dense_to_sparse_batch(4, [5, None]))

        get_next = self.getNext(dataset)

        for start in range(0, len(components), 4):
            results = self.evaluate(get_next())
            self.assertAllEqual(
                [[i, j, z] for i, c in enumerate(components[start:start + 4])
                 for j in range(c) for z in range(c)], results.indices)
            self.assertAllEqual([
                c for c in components[start:start + 4] for _ in range(c)
                for _ in range(c)
            ], results.values)
            self.assertAllEqual([
                min(4,
                    len(components) - start), 5,
                np.max(components[start:start + 4])
            ], results.dense_shape)

        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())

    @combinations.generate(test_base.default_test_combinations())
    def testWithInvalidShape(self):
        input_tensor = array_ops.constant([[1]])
        with self.assertRaisesRegex(ValueError, "Dimension -2 must be >= 0"):
            dataset_ops.Dataset.from_tensors(input_tensor).apply(
                batching.dense_to_sparse_batch(4, [-2]))

    @combinations.generate(test_base.default_test_combinations())
    def testShapeErrors(self):
        def dataset_fn(input_tensor):
            return dataset_ops.Dataset.from_tensors(input_tensor).apply(
                batching.dense_to_sparse_batch(4, [12]))

        # Initialize with an input tensor of incompatible rank.
        get_next = self.getNext(dataset_fn([[1]]))
        with self.assertRaisesRegex(errors.InvalidArgumentError,
                                    "incompatible with the row shape"):
            self.evaluate(get_next())

        # Initialize with an input tensor that is larger than `row_shape`.
        get_next = self.getNext(dataset_fn(np.int32(range(13))))
        with self.assertRaisesRegex(errors.DataLossError,
                                    "larger than the row shape"):
            self.evaluate(get_next())
Пример #25
0
class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
                          parameterized.TestCase):
    def setUp(self):
        super(SnapshotDatasetTest, self).setUp()
        self.removeTFRecords()

    def removeTFRecords(self):
        for filename in self.test_filenames:
            os.remove(filename)
        self.test_filenames = []

    def setUpTFRecord(self, num_files=10, num_records=10):
        self._num_files = num_files
        self._num_records = num_records
        self.test_filenames = self._createFiles()

    def makeSnapshotDirectory(self):
        tmpdir = self.get_temp_dir()
        tmpdir = os.path.join(tmpdir, "snapshot")
        os.mkdir(tmpdir)
        return tmpdir

    def assertSnapshotDirectoryContains(self, directory, num_fingerprints,
                                        num_runs_per_fp, num_snapshot_files):
        dirlist = os.listdir(directory)
        self.assertLen(dirlist, num_fingerprints)

        for i in range(num_fingerprints):
            fingerprint_dir = os.path.join(directory, dirlist[i])
            fingerprint_dir_list = sorted(os.listdir(fingerprint_dir))
            self.assertLen(fingerprint_dir_list, num_runs_per_fp + 1)
            self.assertEqual(fingerprint_dir_list[num_runs_per_fp],
                             "snapshot.metadata")

            for j in range(num_runs_per_fp):
                run_dir = os.path.join(fingerprint_dir,
                                       fingerprint_dir_list[j])
                run_dirlist = sorted(os.listdir(run_dir))
                self.assertLen(run_dirlist, num_snapshot_files)

                file_counter = 0
                for filename in run_dirlist:
                    self.assertEqual(filename, "%08d.snapshot" % file_counter)
                    file_counter += 1

    @combinations.generate(test_base.default_test_combinations())
    def testWriteDifferentPipelinesInOneDirectory(self):
        tmpdir = self.makeSnapshotDirectory()

        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(1000)))

        dataset = dataset_ops.Dataset.range(1001)
        dataset = dataset.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(1001)))

        self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotMultipleSimultaneous(self):
        tmpdir = self.makeSnapshotDirectory()

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.snapshot(tmpdir))
        next1 = self.getNext(dataset1)

        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
        next2 = self.getNext(dataset2)

        for _ in range(1000):
            self.evaluate(next1())
            self.evaluate(next2())

        # we check that only one copy of the metadata has been written, and the
        # one that lost the race would be in passthrough mode.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testGetNextCreatesDir(self):
        tmpdir = self.makeSnapshotDirectory()

        # We create two iterators but call getNext on only one.
        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(snapshot.snapshot(tmpdir))
        next1 = self.getNext(dataset1)

        dataset2 = dataset_ops.Dataset.range(1001)
        dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
        _ = self.getNext(dataset2)

        for _ in range(1000):
            self.evaluate(next1())

        # We check that only one directory is created.
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP
            ])))
    def testWriteSnapshotSimpleSuccessful(self, compression):
        tmpdir = self.makeSnapshotDirectory()

        dataset = dataset_ops.Dataset.range(1000)
        dataset = dataset.apply(
            snapshot.snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset, list(range(1000)))

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testWriteSnapshotRepeatAfterwards(self):
        tmpdir = self.makeSnapshotDirectory()

        dataset = dataset_ops.Dataset.range(10)
        dataset = dataset.apply(snapshot.snapshot(tmpdir))
        dataset = dataset.repeat(10)
        self.assertDatasetProduces(dataset, list(range(10)) * 10)

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(compression=[
                snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP
            ])))
    def testReadSnapshotBackAfterWrite(self, compression):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.makeSnapshotDirectory()
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset2, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testReadShuffledSnapshotAfterWrite(self):
        self.setUpTFRecord(num_files=10, num_records=50)
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 50)
        ]

        tmpdir = self.makeSnapshotDirectory()
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(snapshot.snapshot(tmpdir, shard_size_bytes=10))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(tmpdir, shuffle_on_read=True))
        next2 = self.getNext(dataset2)

        res1 = self.evaluate(next2())
        res2 = self.evaluate(next2())
        res3 = self.evaluate(next2())
        res4 = self.evaluate(next2())
        res5 = self.evaluate(next2())

        # make sure that we don't read the file back in the same order.
        self.assertNotEqual([res1, res2, res3, res4, res5], expected[0:5])

        # make sure all the elements are still there
        dataset3 = core_readers._TFRecordDataset(filenames)
        dataset3 = dataset3.apply(
            snapshot.snapshot(tmpdir, shuffle_on_read=True))
        self.assertDatasetProduces(dataset3, expected, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testReadSnapshotParallelAfterWrite(self):
        self.setUpTFRecord(10, 4000)
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 4000)
        ]

        tmpdir = self.makeSnapshotDirectory()
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(tmpdir,
                              shard_size_bytes=1024 * 1024,
                              num_reader_threads=2,
                              reader_buffer_size=10))
        self.assertDatasetProduces(dataset, expected, assert_items_equal=True)

        # remove the original files and try to read the data back only from
        # snapshot.
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(tmpdir,
                              shard_size_bytes=1024 * 1024,
                              num_reader_threads=2,
                              reader_buffer_size=10))
        self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.times(
                combinations.combine(compression=[
                    snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP
                ]),
                combinations.combine(threads=2, size=[1, 2]) +
                combinations.combine(threads=8, size=[1, 4, 8]))))
    def testReadSnapshotBackAfterMultiThreadedWrite(self, compression, threads,
                                                    size):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.makeSnapshotDirectory()
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(
            snapshot.snapshot(tmpdir,
                              compression=compression,
                              num_writer_threads=threads,
                              writer_buffer_size=size))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from
        # snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(
            snapshot.snapshot(tmpdir, compression=compression))
        self.assertDatasetProduces(dataset2, expected, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testSameFingerprintWithDifferentInitializationOrder(self):
        tmpdir = self.makeSnapshotDirectory()

        dataset1 = dataset_ops.Dataset.range(0, 100)
        dataset2 = dataset_ops.Dataset.range(100, 200)
        dataset3 = dataset_ops.Dataset.range(200, 300)

        dataset = dataset1.concatenate(dataset2).concatenate(dataset3)
        dataset = dataset.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(300)))

        dataset4 = dataset_ops.Dataset.range(200, 300)
        dataset5 = dataset_ops.Dataset.range(100, 200)
        dataset6 = dataset_ops.Dataset.range(0, 100)

        dataset = dataset6.concatenate(dataset5).concatenate(dataset4)
        dataset = dataset.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset, list(range(300)))

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testExpiredSnapshotRewrite(self):
        tmpdir = self.makeSnapshotDirectory()

        dataset1 = dataset_ops.Dataset.range(1000)
        dataset1 = dataset1.apply(
            snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1))
        next1 = self.getNext(dataset1)

        # Don't finish reading dataset1, so it is never finalized
        for _ in range(500):
            self.evaluate(next1())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)

        time.sleep(2)

        # Creating dataset2 after we run through dataset1 due to eager mode, where
        # the snapshot state is determined immediately upon dataset creation. We
        # only want to determine the snapshot state for dataset2 after the first
        # snapshot has expired.
        dataset2 = dataset_ops.Dataset.range(1000)
        dataset2 = dataset2.apply(
            snapshot.snapshot(tmpdir, pending_snapshot_expiry_seconds=1))
        next2 = self.getNext(dataset2)

        for _ in range(500):
            self.evaluate(next2())
        self.assertSnapshotDirectoryContains(tmpdir, 1, 2, 1)

    @combinations.generate(test_base.default_test_combinations())
    def testSpecifyShardSize(self):
        tmpdir = self.makeSnapshotDirectory()

        dataset = dataset_ops.Dataset.from_tensor_slices([1.0])
        dataset = dataset.map(
            lambda x: gen_array_ops.broadcast_to(x, [1024, 1024]))
        dataset = dataset.repeat(10)
        dataset = dataset.apply(
            snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024))
        next_fn = self.getNext(dataset)

        for _ in range(10):
            self.evaluate(next_fn())

        self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 4)

    @combinations.generate(test_base.default_test_combinations())
    def testAdditionalOperationsAfterReadBack(self):
        self.setUpTFRecord()
        filenames = self.test_filenames

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        tmpdir = self.makeSnapshotDirectory()
        dataset = core_readers._TFRecordDataset(filenames)
        dataset = dataset.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset, expected)

        # remove the original files and try to read the data back only from snapshot
        self.removeTFRecords()

        dataset2 = core_readers._TFRecordDataset(filenames)
        dataset2 = dataset2.apply(snapshot.snapshot(tmpdir))
        self.assertDatasetProduces(dataset2, expected)

        expected_after = [
            b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]

        dataset3 = core_readers._TFRecordDataset(filenames)
        dataset3 = dataset3.apply(snapshot.snapshot(tmpdir))
        dataset3 = dataset3.map(lambda x: string_ops.substr_v2(x, 2, 1000))
        self.assertDatasetProduces(dataset3, expected_after)
Пример #26
0
class AutoShardTest(data_service_test_base.TestBase,
                    tf_record_test_base.TFRecordTestBase,
                    parameterized.TestCase):
    """Tests auto-sharding datasets with tf.data service."""
    def setUp(self):
        super(AutoShardTest, self).setUp()
        self._num_files = 10
        self._num_records = 10
        self._filenames = self._createFiles()

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sharding_policy=[
                ShardingPolicy.DATA, ShardingPolicy.FILE_OR_DATA
            ])))
    def testRangeDataset_AutoShard(self, sharding_policy):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=sharding_policy)
        self.assertDatasetProduces(dataset, [1, 6, 11, 16])

    @combinations.generate(test_base.default_test_combinations())
    def testRangeDataset_FileShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.FILE)
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Found an unshardable source dataset"):
            self.getDatasetOutput(dataset)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(worker_index=[distribute.SHARD_HINT, 0, 5])))
    def testRangeDataset_ShardHint(self, worker_index):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        # With HINT sharding, `num_shards` should be `SHARD_HINT`; `index` can be
        # any value.
        dataset = dataset.shard(num_shards=distribute.SHARD_HINT,
                                index=worker_index)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT)
        self.assertDatasetProduces(dataset, [1, 6, 11, 16])

    @combinations.generate(test_base.default_test_combinations())
    def testRangeDataset_InvalidWorkerIndexUsingShardHint(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        # With HINT sharding, `SHARD_HINT` should be passed to `num_shards`, not
        # `index`.
        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                r"Index must be between 0 and 4 \(currently index = -1\)."):
            dataset = dataset.shard(num_shards=5, index=distribute.SHARD_HINT)
            dataset = self.make_distributed_dataset(
                dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT)
            self.getDatasetOutput(dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testRangeDataset_NoShardHint(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        # No SHARD_HINT is provided. The given sharding arguments will be used.
        dataset = dataset.shard(num_shards=1, index=0)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT)
        self.assertDatasetProduces(dataset, list(range(20)))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sharding_policy=[
                ShardingPolicy.OFF, ShardingPolicy.FILE_OR_DATA
            ])))
    def testRangeDataset_ShardHintUsedInWrongShardingPolicy(
            self, sharding_policy):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=sharding_policy)
        with self.assertRaisesRegex(
                errors.FailedPreconditionError, "tf.data service with "
                "`tf.data.experimental.service.ShardingPolicy.HINT` processing mode."
        ):
            self.getDatasetOutput(dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testRangeDataset_NoShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.OFF,
            target_workers="LOCAL")
        self.assertDatasetProduces(dataset, list(range(20)))

    @combinations.generate(test_base.default_test_combinations())
    def testRangeDataset_OneWorker(self):
        """Makes sure shards from all workers form the complete dataset."""
        cluster = _make_service_cluster(num_workers=1, local_shard_index=0)
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)
        self.assertDatasetProduces(dataset, list(range(20)))

    @combinations.generate(test_base.default_test_combinations())
    def testRangeDataset_ReadFromAllWorkers(self):
        """Makes sure shards from all workers form the complete dataset."""
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA,
            target_workers="ANY")
        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "Static sharding requires reading from local workers"):
            self.getDatasetOutput(dataset)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sharding_policy=[
                ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE
            ])))
    def testTFRecordDataset_AutoShard(self, sharding_policy):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=sharding_policy,
            target_workers="LOCAL")

        expected = [
            b"Record %d of file %d" % (record, file) for file in (3, 8)
            for record in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sharding_policy=[
                ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE
            ])))
    def testTFRecordDataset_ShuffleFileList(self, sharding_policy):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=True)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=sharding_policy)

        expected = [
            b"Record %d of file %d" % (record, file) for file in (3, 8)
            for record in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected, assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordDataset_DataShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.DATA)

        expected = [
            b"Record %d of file %d" % (record, file) for file in range(0, 10)
            for record in (3, 8)
        ]
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordDataset_HintDataShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT)

        expected = [
            b"Record %d of file %d" % (record, file) for file in range(0, 10)
            for record in (3, 8)
        ]
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordDataset_HintFileShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT)

        expected = [
            b"Record %d of file %d" % (record, file) for file in (3, 8)
            for record in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordDataset_NoShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.OFF,
            target_workers="LOCAL")

        expected = [
            b"Record %d of file %d" % (record, file) for file in range(0, 10)
            for record in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordDataset_ReadFromAllWorkers(self):
        """Makes sure shards from all workers form the complete dataset."""
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA,
            target_workers="ANY")
        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "Static sharding requires reading from local workers"):
            self.getDatasetOutput(dataset)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sharding_policy=[
                ShardingPolicy.FILE_OR_DATA, ShardingPolicy.FILE
            ])))
    def testTFRecordDataset_FewerFilesThanWorkers(self, sharding_policy):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames[:4],
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=sharding_policy)

        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "not enough for the required 5 shards/workers."):
            self.getDatasetOutput(dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordDataset_FewerFilesThanWorkers_HintShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames[:4],
                                                 shuffle=False)
        dataset = dataset.shard(distribute.SHARD_HINT, distribute.SHARD_HINT)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.HINT)

        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "not enough for the required 5 shards/workers."):
            self.getDatasetOutput(dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testTFRecordDataset_FewerFilesThanWorkers_DataShard(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames[:4],
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=ShardingPolicy.DATA)

        expected = [
            b"Record %d of file %d" % (record, file) for file in range(0, 4)
            for record in (3, 8)
        ]
        self.assertDatasetProduces(dataset, expected, assert_items_equal=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sharding_policy=[
                ShardingPolicy.FILE_OR_DATA, ShardingPolicy.DATA
            ])))
    def testBatchDataset(self, sharding_policy):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=1)
        dataset = dataset_ops.Dataset.range(20)
        dataset = dataset.batch(batch_size=3, drop_remainder=False)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=sharding_policy)
        self.assertDatasetProduces(dataset, [[3, 4, 5], [18, 19]])

    @combinations.generate(test_base.default_test_combinations())
    def testInterleaveDataset(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.interleave(readers.TFRecordDataset,
                                     cycle_length=10,
                                     num_parallel_calls=dataset_ops.AUTOTUNE)
        dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)
        dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)

        expected = [
            b"Record %d of file %d" % (record, file)
            for record in range(0, 10) for file in (3, 8)
        ]
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testZipDataset(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset1 = dataset_ops.Dataset.list_files(self._filenames,
                                                  shuffle=False)
        dataset1 = dataset1.interleave(readers.TFRecordDataset,
                                       cycle_length=10,
                                       num_parallel_calls=dataset_ops.AUTOTUNE)
        dataset2 = dataset_ops.Dataset.list_files(self._filenames,
                                                  shuffle=False)
        dataset2 = dataset2.interleave(readers.TFRecordDataset,
                                       cycle_length=10,
                                       num_parallel_calls=dataset_ops.AUTOTUNE)
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)

        expected = [(b"Record %d of file %d" % (record, file),
                     b"Record %d of file %d" % (record, file))
                    for record in range(0, 10) for file in (3, 8)]
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testConcatenateDataset(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset1 = dataset_ops.Dataset.list_files(self._filenames,
                                                  shuffle=False)
        dataset1 = dataset1.interleave(readers.TFRecordDataset,
                                       cycle_length=10,
                                       num_parallel_calls=dataset_ops.AUTOTUNE)
        dataset2 = dataset_ops.Dataset.list_files(self._filenames,
                                                  shuffle=False)
        dataset2 = dataset2.interleave(readers.TFRecordDataset,
                                       cycle_length=10,
                                       num_parallel_calls=dataset_ops.AUTOTUNE)
        dataset = dataset1.concatenate(dataset2)
        dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)

        expected = [
            b"Record %d of file %d" % (record, file)
            for record in range(0, 10) for file in (3, 8)
        ]
        expected += expected
        self.assertDatasetProduces(dataset, expected)

    @combinations.generate(test_base.default_test_combinations())
    def testEmptyDataset(self):
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.range(0)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)
        self.assertDatasetProduces(dataset, [])

    @combinations.generate(test_base.default_test_combinations())
    def testAnonymousPorts(self):
        cluster = _make_service_cluster(
            num_workers=5,
            local_shard_index=3,
            worker_addresses=["localhost:%port%" for _ in range(5)])
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)
        self.assertDatasetProduces(dataset, [3, 8, 13, 18])

    @combinations.generate(test_base.default_test_combinations())
    def testNamedPorts(self):
        cluster = _make_service_cluster(
            num_workers=5,
            local_shard_index=3,
            worker_addresses=["localhost:%port_worker%" for _ in range(5)])
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)
        self.assertDatasetProduces(dataset, [3, 8, 13, 18])

    @combinations.generate(test_base.default_test_combinations())
    def testInvalidPorts(self):
        with self.assertRaisesRegex(RuntimeError,
                                    "The worker's address is not configured"):
            _ = _make_service_cluster(
                num_workers=5,
                local_shard_index=0,
                worker_addresses=["localhost:worker" for _ in range(5)])

    @combinations.generate(test_base.default_test_combinations())
    def testEmptyWorkerList(self):
        cluster = _make_service_cluster(num_workers=5,
                                        local_shard_index=1,
                                        worker_addresses=[])
        dataset = dataset_ops.Dataset.range(20)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)
        with self.assertRaisesRegex(errors.NotFoundError,
                                    "Worker .* is not in the workers list."):
            self.getDatasetOutput(dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testWorkerNotFound(self):
        worker_addresses = [f"fake_worker_{i}" for i in range(5)]
        with self.assertRaisesRegex(RuntimeError,
                                    "The worker's address is not configured"):
            _ = _make_service_cluster(num_workers=5,
                                      local_shard_index=0,
                                      worker_addresses=worker_addresses)

    @combinations.generate(test_base.default_test_combinations())
    def testMoreWorkersThanConfigured(self):
        worker_addresses = ["localhost:%port%"]
        with self.assertRaisesRegex(
                RuntimeError,
                "other workers are already running at the configured host"):
            _ = _make_service_cluster(num_workers=5,
                                      local_shard_index=1,
                                      worker_addresses=worker_addresses)

    @combinations.generate(test_base.default_test_combinations())
    def testNoLocalWorkers(self):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=0, num_remote_workers=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset,
            cluster=cluster,
            processing_mode=ShardingPolicy.FILE_OR_DATA)
        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "Local reads or static sharding require local tf.data workers"
        ):
            self.getDatasetOutput(dataset)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(sharding_policy=list(ShardingPolicy))))
    def testEnumerateShardingPolicies(self, sharding_policy):
        """Verifies tf.data service handles every sharding policy with no errors."""
        cluster = _make_service_cluster(num_workers=5, local_shard_index=3)
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = self.make_distributed_dataset(
            dataset, cluster=cluster, processing_mode=sharding_policy)
        self.getDatasetOutput(dataset)
Пример #27
0
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.graph_only_combinations())
    def testNoGradients(self):
        component = constant_op.constant([1.])
        side = constant_op.constant(0.)
        add = lambda x: x + side
        dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
        value = dataset_ops.make_one_shot_iterator(dataset).get_next()
        self.assertIsNone(gradients_impl.gradients(value, component)[0])
        self.assertIsNone(gradients_impl.gradients(value, side)[0])
        self.assertIsNone(
            gradients_impl.gradients(value, [component, side])[0])

    @combinations.generate(test_base.graph_only_combinations())
    def testCapturingStateInOneShotRaisesException(self):
        var = variables.Variable(37.0, name="myvar")
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [0.0, 1.0, 2.0]).map(lambda x: x + var))
        with self.assertRaisesRegex(
                ValueError,
                r"`Dataset.make_one_shot_iterator\(\)` does not support "
                "datasets that capture stateful objects.+myvar"):
            dataset_ops.make_one_shot_iterator(dataset)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIterator(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorCaptureByValue(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))
        tensor_components = tuple(
            [ops.convert_to_tensor(c) for c in components])

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(tensor_components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOneShotIteratorInsideContainer(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def within_container():
            def _map_fn(x, y, z):
                return math_ops.square(x), math_ops.square(y), math_ops.square(
                    z)

            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensor_slices(components).map(
                    _map_fn).repeat(14))
            return iterator.get_next()

        server = server_lib.Server.create_local_server()

        # Create two iterators within unique containers, and run them to
        # make sure that the resources aren't shared.
        #
        # The test below would fail if cname were the same across both
        # sessions.
        for j in range(2):
            with session.Session(server.target) as sess:
                cname = "iteration%d" % j
                with ops.container(cname):
                    get_next = within_container()

                for _ in range(14):
                    for i in range(7):
                        result = sess.run(get_next)
                        for component, result_component in zip(
                                components, result):
                            self.assertAllEqual(component[i]**2,
                                                result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorNonBlocking(self):
        dataset = dataset_ops.Dataset.from_tensors([1, 2,
                                                    3]).map(lambda x: x * x)
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        # Create a session with a single thread to ensure that the
        # one-shot iterator initializer does not deadlock.
        config = config_pb2.ConfigProto(inter_op_parallelism_threads=1,
                                        use_per_session_threads=True)
        with session.Session(config=config) as sess:
            self.assertAllEqual([1, 4, 9], sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

        # Test with multiple threads invoking the one-shot iterator concurrently.
        with session.Session(config=config) as sess:
            results = []

            def consumer_thread():
                try:
                    results.append(sess.run(next_element))
                except errors.OutOfRangeError:
                    results.append(None)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

            self.assertLen(results, num_threads)
            self.assertLen([None for r in results if r is None],
                           num_threads - 1)
            self.assertAllEqual([[1, 4, 9]],
                                [r for r in results if r is not None])

    @combinations.generate(test_base.graph_only_combinations())
    def testOneShotIteratorInitializerFails(self):
        # Define a dataset whose initialization will always fail.
        dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                sess.run(next_element)

            # Test that subsequent attempts to use the iterator also fail.
            with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                sess.run(next_element)

        with self.cached_session() as sess:

            def consumer_thread():
                with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
                    sess.run(next_element)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

    @combinations.generate(test_base.graph_only_combinations())
    def testSimpleSharedResource(self):
        components = (np.array(1, dtype=np.int64),
                      np.array([1, 2, 3],
                               dtype=np.int64), np.array(37.0,
                                                         dtype=np.float64))

        server = server_lib.Server.create_local_server()

        # Create two non-overlapping sessions that share the same iterator
        # resource on the same server, and verify that an action of the
        # first session (initializing the iterator) is visible in the
        # second session.
        with ops.Graph().as_default():
            iterator = dataset_ops.make_initializable_iterator(
                dataset_ops.Dataset.from_tensors(components).map(
                    lambda x, y, z: (x, y, z)),
                shared_name="shared_iterator")
            init_op = iterator.initializer
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                sess.run(init_op)
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

                # Re-initialize the iterator in the first session.
                sess.run(init_op)

        with ops.Graph().as_default():
            # Re-define the iterator manually, without defining any of the
            # functions in this graph, to ensure that we are not
            # accidentally redefining functions with the same names in the
            # new graph.
            iterator = iterator_ops.Iterator.from_structure(
                shared_name="shared_iterator",
                output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
                output_shapes=([], [3], []))
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                # Use the iterator without re-initializing in the second session.
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testNotInitializedError(self):
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
        iterator = dataset_ops.make_initializable_iterator(
            dataset_ops.Dataset.from_tensors(components))
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegex(errors.FailedPreconditionError,
                                        "iterator has not been initialized"):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testReinitializableIterator(self):
        dataset_3 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([1, 2, 3]))
        dataset_4 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([4, 5, 6, 7]))
        iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset_3), [None])

        dataset_3_init_op = iterator.make_initializer(dataset_3)
        dataset_4_init_op = iterator.make_initializer(dataset_4)
        get_next = iterator.get_next()

        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_3),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_4),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(
            [None],
            dataset_ops.get_legacy_output_shapes(iterator).as_list())

        with self.cached_session() as sess:
            # The iterator is initially uninitialized.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(get_next)

            # Initialize with one dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Initialize with a different dataset.
            sess.run(dataset_4_init_op)
            self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Reinitialize with the first dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.graph_only_combinations())
    def testReinitializableIteratorWithFunctions(self):
        def g():
            for i in range(10):
                yield i

        iterator = iterator_ops.Iterator.from_structure(dtypes.int64, [])
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            dataset_1 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_1))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

            dataset_2 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_2))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

    @combinations.generate(test_base.default_test_combinations())
    def testReinitializableIteratorStaticErrors(self):
        # Non-matching structure for types and shapes.
        with self.assertRaises(TypeError):
            iterator = iterator_ops.Iterator.from_structure(
                (dtypes.int64, dtypes.float64), [None])

        # Test validation of dataset argument.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64))

        # Incompatible structure.
        with self.assertRaises(ValueError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    ((constant_op.constant([1, 2, 3], dtype=dtypes.int64), ),
                     (constant_op.constant([4., 5., 6., 7.],
                                           dtype=dtypes.float64), ))))

        # Incompatible types.
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float32))))

        # Incompatible shapes.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64), ([None], []))
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float64))))

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandle(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleFuture(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleReuseTensorObject(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
        initializable_iterator = dataset_ops.make_initializable_iterator(
            dataset)
        structure_iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset))

        created_ops = len(ops.get_default_graph().get_operations())

        self.assertIs(one_shot_iterator.string_handle(),
                      one_shot_iterator.string_handle())
        self.assertIs(initializable_iterator.string_handle(),
                      initializable_iterator.string_handle())
        self.assertIs(structure_iterator.string_handle(),
                      structure_iterator.string_handle())

        # Assert that getting the (default) string handle creates no ops.
        self.assertEqual(created_ops,
                         len(ops.get_default_graph().get_operations()))

        # Specifying an explicit name will create a new op.
        handle_with_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo", handle_with_name.op.name)
        self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)

        handle_with_same_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo_1", handle_with_same_name.op.name)
        self.assertIsNot(handle_with_name, handle_with_same_name)

    @combinations.generate(test_base.graph_only_combinations())
    def testIteratorStringHandleError(self):
        dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices(
            [1, 2, 3]).repeat())
        dataset_float_vector = (dataset_ops.Dataset.from_tensors(
            [1.0, 2.0, 3.0]))

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])

        feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [])
        feedable_int_vector = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [None])
        feedable_int_any = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32)

        with self.cached_session() as sess:
            handle_int_scalar = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_int_scalar).string_handle())
            handle_float_vector = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_float_vector).string_handle())

            self.assertEqual(
                1,
                sess.run(feedable_int_scalar.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            self.assertEqual(
                2,
                sess.run(feedable_int_any.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(feedable_int_vector.get_next(),
                             feed_dict={handle_placeholder:
                                        handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(
                        feedable_int_vector.get_next(),
                        feed_dict={handle_placeholder: handle_float_vector}))

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 3

        with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        @function.Defun(dtypes.string)
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            remote_op = functional_ops.remote_call(args=[iterator_3_handle],
                                                   Tout=[dtypes.int32],
                                                   f=_remote_fn,
                                                   target=target_placeholder)

        with self.session(config=worker_config) as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [1])
            # Fails when target is cpu:2 where the resource is not located.
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:2"
                         })
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:1"
                         })

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
        s1 = server_lib.Server.create_local_server()
        s2 = server_lib.Server.create_local_server()
        s3 = server_lib.Server.create_local_server()

        cluster_def = cluster_pb2.ClusterDef()
        workers = cluster_def.job.add()
        workers.name = "worker"
        workers.tasks[0] = s1.target[len("grpc://"):]
        workers.tasks[1] = s2.target[len("grpc://"):]
        client = cluster_def.job.add()
        client.name = "client"
        client.tasks[0] = s3.target[len("grpc://"):]
        config = config_pb2.ConfigProto(cluster_def=cluster_def)

        worker_devices = [
            "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
        ]
        itr_handles = []
        for device in worker_devices:
            with ops.device(device):
                src = dataset_ops.Dataset.from_tensor_slices([device])
                itr = dataset_ops.make_one_shot_iterator(src)
                itr_handles.append(itr.string_handle())

        targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
        handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)

        @function.Defun(dtypes.string)
        def loading_func(h):
            remote_itr = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(itr),
                dataset_ops.get_legacy_output_shapes(itr))
            return remote_itr.get_next()

        def map_fn(target, handle):
            return functional_ops.remote_call(args=[handle],
                                              Tout=[dtypes.string],
                                              f=loading_func,
                                              target=target)

        with ops.device("/job:client"):
            client_dataset = dataset_ops.Dataset.zip(
                (targets, handles)).map(map_fn)
            itr = dataset_ops.make_initializable_iterator(client_dataset)
            n = itr.get_next()

        with session.Session(s3.target, config=config) as sess:
            sess.run(itr.initializer)
            expected_values = worker_devices
            for expected in expected_values:
                self.assertEqual((compat.as_bytes(expected), ), sess.run(n))

            with self.assertRaises(errors.OutOfRangeError):
                sess.run(n)

    @combinations.generate(test_base.graph_only_combinations())
    def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        def _encode_raw(byte_array):
            return bytes(bytearray(byte_array))

        @function.Defun(dtypes.uint8)
        def _remote_fn(h):
            handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            iterator_3_handle_uint8 = parsing_ops.decode_raw(
                input_bytes=iterator_3_handle, out_type=dtypes.uint8)
            remote_op = functional_ops.remote_call(
                args=[iterator_3_handle_uint8],
                Tout=[dtypes.int32],
                f=_remote_fn,
                target=target_placeholder)

        with self.cached_session() as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [1])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:0"
                         })

    @combinations.generate(test_base.graph_only_combinations())
    def testRepeatedGetNextWarning(self):
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.range(10))
        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            for _ in range(100):
                iterator.get_next()
        self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD,
                         len(w))
        for warning in w:
            self.assertIn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE,
                          str(warning.message))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=tensor_spec.TensorSpec(
                    [], dtypes.float32),
                expected_output_classes=ops.Tensor,
                expected_output_types=dtypes.float32,
                expected_output_shapes=[[]])))
    def testTensorIteratorStructure(self, expected_element_structure,
                                    expected_output_classes,
                                    expected_output_types,
                                    expected_output_shapes):
        tf_value_fn = lambda: constant_op.constant(37.0)
        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=sparse_tensor.SparseTensorSpec(
                    [1], dtypes.int32),
                expected_output_classes=sparse_tensor.SparseTensor,
                expected_output_types=dtypes.int32,
                expected_output_shapes=[[1]])))
    def testSparseTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return sparse_tensor.SparseTensor(indices=[[0]],
                                              values=constant_op.constant(
                                                  [0], dtype=dtypes.int32),
                                              dense_shape=[1])

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(expected_element_structure={
                "a":
                tensor_spec.TensorSpec([], dtypes.float32),
                "b": (tensor_spec.TensorSpec([1], dtypes.string),
                      tensor_spec.TensorSpec([], dtypes.string))
            },
                                 expected_output_classes={
                                     "a": ops.Tensor,
                                     "b": (ops.Tensor, ops.Tensor)
                                 },
                                 expected_output_types={
                                     "a": dtypes.float32,
                                     "b": (dtypes.string, dtypes.string)
                                 },
                                 expected_output_shapes={
                                     "a": [],
                                     "b": ([1], [])
                                 })))
    def testNestedTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(test_base.default_test_combinations())
    def testIteratorGetNextName(self):
        with ops.Graph().as_default():
            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensors(37.0))
            next_element = iterator.get_next(name="overridden_name")
            self.assertEqual("overridden_name", next_element.op.name)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2],
                             mode="eager",
                             execution_mode=[context.ASYNC, context.SYNC]))
    def testIteratorEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for foo in iterator:
                self.assertEqual(val, foo.numpy())
                val += 1

    @combinations.generate(test_base.eager_only_combinations())
    def testOwnedIteratorFunction(self):

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)

        @def_function.function
        def fn():
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for _ in range(10):
                queue.enqueue(next(iterator))

        fn()

        for i in range(10):
            self.assertEqual(queue.dequeue().numpy(), i)

    @combinations.generate(test_base.eager_only_combinations())
    def testOwnedIteratorFunctionError(self):
        # In this test we verify that a function that raises an error ends up
        # properly deallocating the iterator resource.

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
        queue.enqueue(0)

        def init_fn(n):
            return n

        def next_fn(_):
            ds = dataset_ops.Dataset.range(0)
            return next(iter(ds))

        def finalize_fn(n):
            queue.enqueue(0)
            return n

        @def_function.function
        def fn():
            output_signature = tensor_spec.TensorSpec((), dtypes.int64)
            dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn,
                                                    finalize_fn,
                                                    output_signature)
            iterator = iter(dataset)
            next(iterator)

        with self.assertRaises(errors.OutOfRangeError):
            fn()

        self.assertEqual(queue.size().numpy(), 2)

    @combinations.generate(test_base.eager_only_combinations())
    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(iterator):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in iterator:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(iter(dataset))), 10)
            self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
            self.assertEqual(trace_count[0], 1)

    @combinations.generate(test_base.eager_only_combinations())
    def testNestedFunctionsIteratorResource(self):
        @def_function.function
        def sum_dataset(ds):
            it = iter(ds)

            @def_function.function
            def next_element(it):
                return next(it)

            total = 0
            for _ in range(10):
                total += next_element(it)
            return total

        ds = dataset_ops.Dataset.range(10)
        self.assertEqual(sum_dataset(ds).numpy(), 45)
        self.assertEqual(sum_dataset(ds).numpy(), 45)

    @combinations.generate(test_base.default_test_combinations())
    def testNestedAutomaticControlDependencies(self):
        counter_var = variables.Variable(0)

        def map_fn(x):
            counter_var.assign_add(1)
            return x

        def dataset_fn():
            return dataset_ops.Dataset.range(10).map(map_fn)

        @def_function.function
        def fn():
            it = iter(dataset_fn())
            for _ in range(10):
                _ = next(it)
            return counter_var

        self.evaluate(counter_var.initializer)
        self.assertEqual(self.evaluate(fn()), 10)
class FlatMapDatasetCheckpointTest(checkpoint_test_base.CheckpointTestBase,
                                   parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testCore(self):
        # Complicated way of saying range(start, start+25).
        def build_ds(start):
            def map_fn(x):
                return dataset_ops.Dataset.range(x, x + 5)

            return dataset_ops.Dataset.range(start, start + 5 * 5,
                                             5).flat_map(map_fn)

        self.run_core_tests(lambda: build_ds(0), 25)

    @combinations.generate(test_base.default_test_combinations())
    def testMapThenFlatMap(self):
        def build_ds():
            def flat_map_fn(_):
                def map_fn(y):
                    return 10 * math_ops.cast(y, dtypes.int32)

                return dataset_ops.Dataset.range(100).map(map_fn)

            return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)

        self.run_core_tests(build_ds, 500)

    @combinations.generate(test_base.default_test_combinations())
    def testCaptureDefunInMapFn(self):
        def build_ds():
            def map_fn(x):
                @function.Defun(dtypes.int64)
                def defun_fn(x):
                    return constant_op.constant(1000) + math_ops.cast(
                        x, dtypes.int32)

                return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)])

            return dataset_ops.Dataset.range(100).flat_map(map_fn)

        self.run_core_tests(build_ds, 100)

    @combinations.generate(test_base.default_test_combinations())
    def testDisallowVariableCapture(self):
        def build_ds():
            test_var = variable_scope.get_variable(name="test_var",
                                                   shape=(),
                                                   use_resource=True)
            return dataset_ops.Dataset.range(5).flat_map(
                lambda _: dataset_ops.Dataset.from_tensor_slices([test_var]))

        self.verify_error_on_save(build_ds, 5, errors.FailedPreconditionError)

    @combinations.generate(test_base.default_test_combinations())
    def testDisallowCapturingStatefulOps(self):
        def build_ds():
            def flat_map_fn(_):
                def map_fn(x):
                    return random_ops.random_uniform(
                        (), 0, 10, dtype=dtypes.int32) * math_ops.cast(
                            x, dtypes.int32)

                return dataset_ops.Dataset.range(100).map(map_fn)

            return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)

        self.verify_error_on_save(build_ds, 500,
                                  errors.FailedPreconditionError)

    @combinations.generate(test_base.default_test_combinations())
    def testSparseCore(self):
        def _map_fn(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                                   values=(i * [1, -1]),
                                                   dense_shape=[2, 2])

        def _flat_map_fn(x):
            return dataset_ops.Dataset.from_tensor_slices(
                sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))

        def _build_ds():
            return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(
                _flat_map_fn)

        self.run_core_tests(_build_ds, 20)
Пример #29
0
class LocalWorkersTest(data_service_test_base.TestBase,
                       parameterized.TestCase):
    """Tests reading from local workers if `target_workers` is `local`."""
    @combinations.generate(test_base.default_test_combinations())
    def testOneLocalWorker(self):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=1, num_remote_workers=5)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="local")
        self.assertDatasetProduces(ds, list(range(num_elements)))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_local_workers=[1, 3],
                                 num_remote_workers=[0, 3])))
    def testLocalWorkers(self, num_local_workers, num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="LOCAL")
        self.assertDatasetProduces(ds,
                                   num_local_workers *
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_local_workers=[1, 3],
                                 num_remote_workers=[0, 3])))
    def testRepeatedDataset(self, num_local_workers, num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        num_elements = 10
        num_repetitions = 5
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="LOCAL")
        ds = ds.repeat(num_repetitions)
        self.assertDatasetProduces(ds,
                                   expected_output=num_local_workers *
                                   num_repetitions * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_local_workers=[1, 3],
                                 num_remote_workers=[0, 3])))
    def testPrefetchingDataset(self, num_local_workers, num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="LOCAL")
        ds = ds.prefetch(10)
        self.assertDatasetProduces(ds,
                                   expected_output=num_local_workers *
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_local_workers=[1, 3],
                                 num_remote_workers=[0, 3])))
    def testMultipleEpochs(self, num_local_workers, num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="LOCAL")
        for _ in range(10):
            self.assertDatasetProduces(ds,
                                       num_local_workers *
                                       list(range(num_elements)),
                                       assert_items_equal=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_local_workers=[1, 3],
                                 num_remote_workers=[0, 3])))
    def testDynamicSharding(self, num_local_workers, num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        num_elements = 100
        ds = self.make_distributed_range_dataset(
            num_elements,
            cluster,
            processing_mode=ShardingPolicy.DYNAMIC,
            target_workers="LOCAL")
        self.assertDatasetProduces(ds,
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_local_workers=[1, 3],
                                 num_remote_workers=[0, 3])))
    def testEmptyDataset(self, num_local_workers, num_remote_workers):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        num_elements = 0
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="LOCAL")
        self.assertDatasetProduces(ds, [])

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(num_local_workers=[1, 3],
                                 num_remote_workers=[0, 3])))
    def testNonLocalRead(self, num_local_workers, num_remote_workers):
        """This test ensures the remote workers are running and producing data."""

        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=num_local_workers,
            num_remote_workers=num_remote_workers)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="any")
        num_workers = num_local_workers + num_remote_workers
        self.assertDatasetProduces(ds,
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testNoLocalWorker(self):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=0, num_remote_workers=3)
        num_elements = 10
        ds = self.make_distributed_range_dataset(num_elements,
                                                 cluster,
                                                 target_workers="LOCAL")

        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "Local reads require local tf.data workers, but no local worker is "
                "found."):
            self.getDatasetOutput(ds)

    @combinations.generate(test_base.default_test_combinations())
    def testInconsistentTargetWorkers(self):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=3, num_remote_workers=3)
        ds = dataset_ops.Dataset.range(10)
        datasets = [
            self.make_distributed_dataset(ds,
                                          cluster,
                                          job_name="test_job",
                                          target_workers=target_workers)
            for target_workers in ["AUTO", "ANY", "LOCAL"]
        ]

        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "but there is already an existing job with that name using "
                "target_workers <AUTO>."):
            for dataset in datasets:
                self.getDatasetOutput(dataset)

    @combinations.generate(test_base.default_test_combinations())
    def testAnonymousJobWithDifferentTargetWorkers(self):
        num_local_workers, num_remote_workers = (3, 3)
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers, num_remote_workers)
        num_elements = 10
        ds = dataset_ops.Dataset.range(num_elements)
        datasets = {
            target_workers:
            self.make_distributed_dataset(ds,
                                          cluster,
                                          target_workers=target_workers)
            for target_workers in ["AUTO", "ANY", "LOCAL"]
        }

        num_workers = num_local_workers + num_remote_workers
        self.assertDatasetProduces(datasets["AUTO"],
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)
        self.assertDatasetProduces(datasets["ANY"],
                                   num_workers * list(range(num_elements)),
                                   assert_items_equal=True)
        self.assertDatasetProduces(datasets["LOCAL"],
                                   num_local_workers *
                                   list(range(num_elements)),
                                   assert_items_equal=True)

    @combinations.generate(test_base.default_test_combinations())
    def testCoordinatedRead(self):
        cluster = multi_process_cluster.MultiProcessCluster(
            num_local_workers=3, num_remote_workers=3)
        ds = dataset_ops.Dataset.range(10).repeat()
        ds = self.make_distributed_dataset(ds,
                                           cluster,
                                           job_name="test_job",
                                           consumer_index=0,
                                           num_consumers=3,
                                           target_workers="LOCAL")
        with self.assertRaisesRegex(
                errors.InvalidArgumentError,
                "Coordinated reads require non-local workers"):
            self.getDatasetOutput(ds)
Пример #30
0
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationStatefulFunction(self):
        dataset = dataset_ops.Dataset.range(10).map(
            lambda _: random_ops.random_uniform([])).batch(10)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        get_next = self.getNext(dataset)
        self.evaluate(get_next())

    # TODO(b/123902160)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensor(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
        dataset = dataset_ops.Dataset.from_tensors(input_t)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    # TODO(b/123902160)
    @combinations.generate(test_base.graph_only_combinations())
    def testOptimizationLargeInputFromTensorSlices(self):
        input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
        dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        iterator = dataset_ops.make_initializable_iterator(dataset)
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            sess.run(init_op,
                     {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
            self.evaluate(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDataset(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
            dataset = dataset.skip(0)  # Should be removed by noop elimination
            dataset = dataset.cache()
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)
        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.noop_elimination = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[0])

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationNestedDatasetWithModifiedRetval(self):
        def flat_map_fn(_):
            dataset = dataset_ops.Dataset.from_tensors(0)
            dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
            # Should be fused by map and batch fusion
            dataset = dataset.map(lambda x: x)
            dataset = dataset.batch(1)
            return dataset

        dataset = dataset_ops.Dataset.range(1)
        dataset = dataset.flat_map(flat_map_fn)

        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.map_and_batch_fusion = True
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset, expected_output=[[0]])

    @combinations.generate(
        combinations.times(test_base.default_test_combinations(),
                           _disable_intra_op_parallelism_test_combinations()))
    def testOptimizationDisableIntraOpParallelism(self, dataset_fn,
                                                  expected_output):
        os.environ[
            "TF_DATA_EXPERIMENT_OPT_IN"] = "disable_intra_op_parallelism"
        os.environ["TF_JOB_NAME"] = "test_job"

        dataset = dataset_fn()
        dataset = dataset.apply(testing.assert_next(["MaxIntraOpParallelism"]))

        self.assertDatasetProduces(dataset, expected_output=expected_output)

        del os.environ["TF_DATA_EXPERIMENT_OPT_IN"]
        del os.environ["TF_JOB_NAME"]

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationThreadPoolDataset(self):
        dataset = dataset_ops.Dataset.range(10).batch(10)

        dataset = threadpool.override_threadpool(
            dataset,
            threadpool.PrivateThreadPool(
                2, display_name="private_thread_pool_%d" % 2))

        options = dataset_ops.Options()
        options.experimental_optimization.apply_default_optimizations = False
        dataset = dataset.with_options(options)
        self.assertDatasetProduces(dataset,
                                   expected_output=[list(range(10))],
                                   requires_initialization=True)

    # Reference variables are not supported in eager mode.
    @combinations.generate(
        combinations.times(test_base.graph_only_combinations(),
                           _captured_refvar_test_combinations()))
    def testOptimizationWithCapturedRefVar(self, dataset_fn):
        """Tests that default optimizations are disabled with ref variables."""
        variable = variable_scope.get_variable("v",
                                               initializer=0,
                                               use_resource=False)
        assign_op = variable.assign_add(1)

        # Check that warning is logged.
        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            unoptimized_dataset = dataset_fn(variable)

            options = dataset_ops.Options()
            options.experimental_optimization.apply_default_optimizations = False
            options.experimental_optimization.noop_elimination = True
            options.experimental_optimization.map_and_batch_fusion = True
            optimized_dataset = unoptimized_dataset.with_options(options)
            optimized_it = dataset_ops.make_initializable_iterator(
                optimized_dataset)

        self.assertGreaterEqual(len(w), 1)
        graph_rewrites = options._graph_rewrites()
        expected = (
            "tf.data graph rewrites are not compatible with "
            "tf.Variable. The following rewrites will be disabled: %s."
            " To enable rewrites, use resource variables instead by "
            "calling `tf.enable_resource_variables()` at the start of the "
            "program." %
            (", ".join(graph_rewrites.enabled + graph_rewrites.default)))
        self.assertTrue(any(expected in str(warning) for warning in w))

        # Check that outputs are the same in the optimized and unoptimized cases,
        # when the variable value is changing.
        unoptimized_it = dataset_ops.make_initializable_iterator(
            unoptimized_dataset)
        with ops.control_dependencies([assign_op]):
            unoptimized_output = unoptimized_it.get_next()
            optimized_output = optimized_it.get_next()

        self.evaluate(variable.initializer)
        self.evaluate((unoptimized_it.initializer, optimized_it.initializer))
        while True:
            try:
                unoptimized, optimized = self.evaluate(
                    (unoptimized_output, optimized_output))
                self.assertEqual(unoptimized, optimized)
            except errors.OutOfRangeError:
                break

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDefault(self):
        """Tests the optimization settings by default."""
        options = dataset_ops.Options()
        expected_optimizations_enabled = []
        expected_optimizations_disabled = []
        expected_optimizations_default = [
            "map_and_batch_fusion",
            "noop_elimination",
            "shuffle_and_repeat_fusion",
        ]
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

        options.experimental_optimization.apply_default_optimizations = True
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

        options.experimental_optimization.apply_default_optimizations = False
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationEnabled(self):
        """Tests the optimization settings by enabling all."""
        options = dataset_ops.Options()
        options.experimental_optimization.filter_fusion = True
        options.experimental_optimization.filter_with_random_uniform_fusion = True
        options.experimental_optimization.hoist_random_uniform = True
        options.experimental_optimization.map_and_batch_fusion = True
        options.experimental_optimization.map_and_filter_fusion = True
        options.experimental_optimization.map_parallelization = True
        options.experimental_optimization.map_fusion = True
        options.experimental_optimization.noop_elimination = True
        options.experimental_optimization.parallel_batch = True
        options.experimental_optimization.shuffle_and_repeat_fusion = True
        options.experimental_optimization.map_vectorization.enabled = True
        options.experimental_optimization.autotune_buffers = True
        options.experimental_deterministic = False
        options.experimental_stats.latency_all_edges = True
        options.experimental_slack = True

        expected_optimizations_enabled = [
            "filter_fusion",
            "filter_with_random_uniform_fusion",
            "hoist_random_uniform",
            "map_and_batch_fusion",
            "map_and_filter_fusion",
            "map_parallelization",
            "map_fusion",
            "noop_elimination",
            "parallel_batch",
            "shuffle_and_repeat_fusion",
            "map_vectorization",
            "inject_prefetch",
            "make_sloppy",
            "latency_all_edges",
            "slack",
        ]
        expected_optimizations_disabled = []
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testOptimizationDisabled(self):
        """Tests the optimization settings by disabling all."""
        options = dataset_ops.Options()
        options.experimental_optimization.filter_fusion = False
        options.experimental_optimization.filter_with_random_uniform_fusion = False
        options.experimental_optimization.hoist_random_uniform = False
        options.experimental_optimization.map_and_batch_fusion = False
        options.experimental_optimization.map_and_filter_fusion = False
        options.experimental_optimization.map_parallelization = False
        options.experimental_optimization.map_fusion = False
        options.experimental_optimization.noop_elimination = False
        options.experimental_optimization.parallel_batch = False
        options.experimental_optimization.shuffle_and_repeat_fusion = False
        options.experimental_optimization.map_vectorization.enabled = False
        options.experimental_optimization.autotune = False
        options.experimental_deterministic = True
        options.experimental_stats.latency_all_edges = False
        options.experimental_slack = False

        expected_optimizations_enabled = []
        expected_optimizations_disabled = [
            "filter_fusion",
            "filter_with_random_uniform_fusion",
            "hoist_random_uniform",
            "map_and_batch_fusion",
            "map_and_filter_fusion",
            "map_parallelization",
            "map_fusion",
            "noop_elimination",
            "parallel_batch",
            "shuffle_and_repeat_fusion",
            "map_vectorization",
            "inject_prefetch",
            "make_sloppy",
            "latency_all_edges",
            "slack",
        ]
        expected_optimizations_default = []
        graph_rewrites = options._graph_rewrites()
        self.assertEqual(set(graph_rewrites.enabled),
                         set(expected_optimizations_enabled))
        self.assertEqual(set(graph_rewrites.disabled),
                         set(expected_optimizations_disabled))
        self.assertEqual(set(graph_rewrites.default),
                         set(expected_optimizations_default))

    @combinations.generate(test_base.default_test_combinations())
    def testAutotuningDefaults(self):
        options = dataset_ops.Options()

        # Check defaults
        autotune, algorithm, cpu_budget = options._autotune_settings()
        self.assertTrue(autotune)
        self.assertEqual(algorithm,
                         optimization_options._AutotuneAlgorithm.HILL_CLIMB)
        self.assertEqual(cpu_budget, 0)

    @combinations.generate(test_base.default_test_combinations())
    def testAutotuningBufferSizes(self):
        options = dataset_ops.Options()
        options.experimental_optimization.autotune_buffers = True
        self.assertIn("inject_prefetch", options._graph_rewrites().enabled)
        autotune, algorithm, cpu_budget = options._autotune_settings()
        self.assertTrue(autotune)
        self.assertEqual(
            algorithm,
            optimization_options._AutotuneAlgorithm.GRADIENT_DESCENT)
        self.assertEqual(cpu_budget, 0)