def testReplicateAndShardProduceDisjointData(self, shuffle,
                                                 sharding_policy):
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=shuffle)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)

        graph_def = dataset._as_serialized_graph(
            strip_device_assignment=True,
            external_state_policy=options_lib.ExternalStatePolicy.WARN)

        options = options_lib.Options()
        options.experimental_distribute.auto_shard_policy = sharding_policy

        ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0",
                                        dataset.element_spec)
        ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0",
                                        dataset.element_spec)

        ds1 = ds1.with_options(options)
        ds2 = ds2.with_options(options)

        ds1 = distribute._AutoShardDataset(ds1, 2, 0)
        ds2 = distribute._AutoShardDataset(ds2, 2, 1)

        elems1 = set(self.getAllDatasetElements(ds1))
        elems2 = set(self.getAllDatasetElements(ds2))

        self.assertEmpty(elems1.intersection(elems2))
 def testWorkersGreaterThanNumFiles(self):
   dataset = dataset_ops.Dataset.list_files(self.test_filenames)
   dataset = dataset.apply(
       interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
   dataset = dataset.batch(5)
   dataset = distribute._AutoShardDataset(dataset, 500, 499)
   self.assertDatasetProduces(dataset, [])
    def testDatasetOfReaderDatasetsPipeline(self, batch_size):
        # This tests a scenario where a list_files main return multiple files
        # due to the glob containing wildcards.
        def batch(iterator, n):
            l = len(iterator)
            for i in range(0, l, n):
                yield iterator[i:min(i + n, l)]

        datasets = []
        for files in batch(self._filenames, batch_size):
            datasets.append(
                dataset_ops.Dataset.list_files(files, shuffle=False).map(
                    core_readers.TFRecordDataset))
        dataset = dataset_ops.Dataset.from_tensor_slices(datasets)
        dataset = dataset.flat_map(lambda x: x)

        # Simulate additional ops in between flat_map and interleave. This should be
        # a no-op since if ShardDataset is placed right after flat_map, we will only
        # have two datasets left at this point.
        dataset = dataset.prefetch(1)
        dataset = dataset.prefetch(1)

        dataset = dataset.interleave(lambda x: x,
                                     cycle_length=1,
                                     num_parallel_calls=1)

        dataset = distribute._AutoShardDataset(dataset, 5, 0)
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in (0, 5) for r in range(0, 10)
        ]

        self.assertDatasetProduces(dataset, expected)
Example #4
0
def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None):
    """Shard the input pipeline by sharding the underlying list of files.

  Args:
    dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
      dataset transformations.
    num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
        shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
    index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
      Same usage as in `tf.data.Dataset.shard`.
    num_replicas_in_sync: An integer representing the total number of replicas
      across all workers. This is used in the rewrite when sharding by data.

  Returns:
    A modified `Dataset` obtained by updating the pipeline sharded by the
    files. The input dataset will be returned if we cannot automatically
    determine a good way to shard the input dataset.
  """
    if (dataset.options().experimental_distribute.auto_shard_policy !=
            AutoShardPolicy.OFF):
        if num_replicas_in_sync is None:
            num_replicas_in_sync = 1
        if isinstance(dataset, dataset_ops.DatasetV1):
            return distribute._AutoShardDatasetV1(dataset, num_shards, index,
                                                  num_replicas_in_sync)
        else:
            return distribute._AutoShardDataset(dataset, num_shards, index,
                                                num_replicas_in_sync)
    else:
        return dataset
Example #5
0
 def testWorkersGreaterThanNumFiles(self):
   dataset = dataset_ops.Dataset.list_files(self.test_filenames)
   dataset = dataset.apply(
       interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
   dataset = dataset.batch(5)
   dataset = distribute._AutoShardDataset(dataset, 500, 499)
   self.assertDatasetProduces(dataset, [])
    def testInvalidWorkerIndex(self):
        dataset = dataset_ops.Dataset.list_files(self._filenames)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.batch(5)

        with self.assertRaises(errors.InvalidArgumentError):
            dataset = distribute._AutoShardDataset(dataset, 2, 2)
            self.evaluate(self.getNext(dataset)())
 def build_dataset():
     dataset = dataset_ops.Dataset.list_files(self._filenames,
                                              shuffle=False)
     dataset = dataset.apply(
         interleave_ops.parallel_interleave(
             core_readers.TFRecordDataset, 10))
     dataset = distribute._AutoShardDataset(dataset, 5, 3)
     return dataset
  def testInvalidWorkerIndex(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)

    with self.assertRaises(errors.InvalidArgumentError):
      dataset = distribute._AutoShardDataset(dataset, 2, 2)
      self.evaluate(self.getNext(dataset)())
  def testUnsupportedOpInPipeline(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = dataset.apply(unique.unique())

    with self.assertRaises(errors.NotFoundError):
      dataset = distribute._AutoShardDataset(dataset, 2, 0)
      self.evaluate(self.getNext(dataset)())
Example #10
0
    def testUnsupportedOpInPipeline(self):
        dataset = dataset_ops.Dataset.list_files(self.test_filenames)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.batch(5)
        dataset = dataset.apply(unique.unique())

        with self.assertRaises(errors.NotFoundError):
            dataset = distribute._AutoShardDataset(dataset, 2, 0)
            self.evaluate(self.getNext(dataset)())
Example #11
0
    def testDirectFilenameTextLineReaderPipeline(self):
        dataset = core_readers.TextLineDataset(self.test_filenames)
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"%d: %d" % (f, r)  # pylint:disable=g-complex-comprehension
            for f in (0, 5) for r in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)
 def testShardByDataBeforePrefetch(self, sharding_policy):
     dataset = dataset_ops.Dataset.range(4)
     dataset = dataset.apply(testing.assert_next(["Shard", "Prefetch"]))
     dataset = dataset.prefetch(1)
     options = options_lib.Options()
     options.experimental_distribute.auto_shard_policy = sharding_policy
     dataset = dataset.with_options(options)
     dataset = distribute._AutoShardDataset(dataset, 2, 0)
     self.assertDatasetProduces(dataset, [0, 2])
    def testDirectFilenameTFRecordReaderPipeline(self):
        dataset = core_readers.TFRecordDataset(self._filenames)
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in (0, 5) for r in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)
    def testUseLegacyRebatchWithDataSharding(self, sharding_policy,
                                             with_prefetch):
        # This test simulates a distributed environment with 3 workers, each with
        # 1 replica.
        dataset = dataset_ops.Dataset.range(8)
        dataset = dataset.batch(4)
        options = options_lib.Options()
        options.experimental_distribute.auto_shard_policy = sharding_policy
        dataset = dataset.with_options(options)
        # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to
        # RebatchDataset(V1) for correctness reasons. This will modify the output
        # of the dataset.
        worker_a_dataset = distribute._RebatchDataset(dataset,
                                                      batch_sizes=[2, 1, 1])
        if with_prefetch:
            worker_a_dataset = worker_a_dataset.prefetch(1)
        worker_a_dataset = distribute._AutoShardDataset(worker_a_dataset,
                                                        3,
                                                        0,
                                                        num_replicas=3)
        expected = [[0, 1], [4, 5]]
        self.assertDatasetProduces(worker_a_dataset, expected)

        worker_b_dataset = distribute._RebatchDataset(dataset,
                                                      batch_sizes=[1, 1, 2])
        if with_prefetch:
            worker_b_dataset = worker_b_dataset.prefetch(1)
        worker_b_dataset = distribute._AutoShardDataset(worker_b_dataset,
                                                        3,
                                                        1,
                                                        num_replicas=3)
        expected = [[2, 3], [6, 7]]
        self.assertDatasetProduces(worker_b_dataset, expected)

        worker_c_dataset = distribute._RebatchDataset(dataset,
                                                      batch_sizes=[1, 2, 1])
        if with_prefetch:
            worker_c_dataset = worker_c_dataset.prefetch(1)
        worker_c_dataset = distribute._AutoShardDataset(worker_c_dataset,
                                                        3,
                                                        2,
                                                        num_replicas=3)
        expected = [[], []]
        self.assertDatasetProduces(worker_c_dataset, expected)
Example #15
0
 def testEnumerateAutoShardPolicies(self, auto_shard_policy):
   """Verifies tf.data handles every auto-shard policy with no errors."""
   dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
   dataset = dataset.flat_map(core_readers.TFRecordDataset)
   dataset = dataset.batch(5)
   options = options_lib.Options()
   options.experimental_distribute.auto_shard_policy = auto_shard_policy
   dataset = dataset.with_options(options)
   dataset = distribute._AutoShardDataset(dataset, 5, 3)
   self.getDatasetOutput(dataset, requires_initialization=True)
    def testHintShardingInvalidPattern(self):
        options = options_lib.Options()
        options.experimental_distribute.auto_shard_policy = (
            options_lib.AutoShardPolicy.HINT)

        dataset = dataset_ops.Dataset.range(100).shard(1, 0)
        dataset = dataset.with_options(options)
        dataset = distribute._AutoShardDataset(dataset, 10, 0)

        self.assertDatasetProduces(dataset, list(range(100)))
  def testDirectFilenameTextLineReaderPipeline(self):
    dataset = core_readers.TextLineDataset(self.test_filenames)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"%d: %d" % (f, r)  # pylint:disable=g-complex-comprehension
        for f in (0, 5)
        for r in range(0, 10)
    ]
    self.assertDatasetProduces(dataset, expected)
    def testTFRecordReaderWithDirectFileNames(self):
        # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
        # a flat_map automatically.
        dataset = core_readers._TFRecordDataset(self._filenames)
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in (0, 5)
        ]
        self.assertDatasetProduces(dataset, expected)
  def testTFRecordReaderWithDirectFileNames(self):
    # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
    # a flat_map automatically.
    dataset = core_readers._TFRecordDataset(self.test_filenames)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in (0, 5)
    ]
    self.assertDatasetProduces(dataset, expected)
 def testShardWithRebatch(self):
     # Tests that RebatchDatasetV2 is a passthrough op.
     dataset = dataset_ops.Dataset.list_files(self.test_filenames,
                                              shuffle=False)
     dataset = dataset.apply(
         testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
     dataset = dataset.flat_map(core_readers.TFRecordDataset)
     dataset = dataset.batch(5)
     dataset = distribute._RebatchDataset(dataset, batch_sizes=5)
     dataset = distribute._AutoShardDataset(dataset, 5, 3)
     nxt = self.getNext(dataset)
     self.evaluate(nxt())
    def testFlatMapReaderPipeline(self, shuffle):
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=shuffle)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.batch(5)
        dataset = distribute._AutoShardDataset(dataset, 5, 3)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in (3, 8) for r in range(0, 10)
        ]
        self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
Example #22
0
 def testFileShardingWithLegacyRebatch(self):
   # Tests that RebatchDatasetV1 is a passthrough op.
   self._setUpFiles(num_files=5, num_records_per_file=10)
   dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
   dataset = dataset.apply(
       testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
   dataset = dataset.flat_map(core_readers.TFRecordDataset)
   dataset = dataset.batch(5)
   dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5)
   dataset = distribute._AutoShardDataset(dataset, 5, 3)
   expected = [[self._record(3, i)] for i in range(10)]
   self.assertDatasetProduces(dataset, expected)
Example #23
0
 def testFileShardingWithRebatch(self):
   # Tests that RebatchDatasetV2 is a passthrough op.
   self._setUpFiles(num_files=3, num_records_per_file=5)
   dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
   dataset = dataset.apply(
       testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
   dataset = dataset.flat_map(core_readers.TFRecordDataset)
   dataset = dataset.batch(5)
   dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 2])
   dataset = distribute._AutoShardDataset(dataset, 3, 1)
   expected = [[self._record(1, 0), self._record(1, 1)], [self._record(1, 2)],
               [self._record(1, 3), self._record(1, 4)]]
   self.assertDatasetProduces(dataset, expected)
  def testFlatMapReaderPipeline(self, shuffle):
    dataset = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=shuffle)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (3, 8)
        for r in range(0, 10)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
    def testUnknownOpInPipelineStillShardsAtTheEnd(self):
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.apply(unique.unique())

        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in (0, 5)
        ]
        self.assertDatasetProduces(dataset, expected)
Example #26
0
  def testPrivateThreadpool(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = dataset_ops._PrivateThreadPoolDataset(dataset, 1)
    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in (0, 5)
        for r in range(0, 10)
    ]
    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
    def testFileShardingWithoutReaderDatasetOp(self):
        options = options_lib.Options()
        options.experimental_distribute.auto_shard_policy = (
            options_lib.AutoShardPolicy.FILE)

        dataset = dataset_ops.Dataset.range(1024)
        dataset = dataset.with_options(options)

        # We are specifying that we want a file sharding policy, and this pipeline
        # doesn't start with file reading, so we should error out.
        with self.assertRaises(errors.NotFoundError):
            dataset = distribute._AutoShardDataset(dataset, 10, 0)
            self.evaluate(self.getNext(dataset)())
    def testAssertCardinality(self):
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.batch(5)
        dataset = dataset.apply(cardinality.assert_cardinality(42))
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in (0, 5) for r in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
  def testUnknownOpInPipelineStillShardsAtTheEnd(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.apply(unique.unique())

    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in (0, 5)
    ]
    self.assertDatasetProduces(dataset, expected)
    def testShardInputToInterleave(self):
        file1 = self._writeFile("f0", [1, 2, 3])
        file2 = self._writeFile("f1", [4, 5, 6])
        file3 = self._writeFile("f2", [7, 8, 9])
        dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3])
        dataset = dataset.interleave(core_readers.TFRecordDataset,
                                     cycle_length=3)
        dataset = distribute._AutoShardDataset(dataset, 2, 0)

        # Sharding by file will interleave files 0 and 2
        expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]]
        actual = self.getDatasetOutput(dataset)
        self.assertEqual(actual, expected)
    def testShardInputToInterleaveWithIdentityFunction(self):
        self.skipTest("Currently fails due to b/238645949")
        file1 = self._writeFile("f0", [1, 2, 3])
        file2 = self._writeFile("f1", [4, 5, 6])
        file3 = self._writeFile("f2", [7, 8, 9])
        dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3])
        dataset = dataset.map(core_readers.TFRecordDataset)
        dataset = dataset.interleave(lambda x: x, cycle_length=3)
        dataset = distribute._AutoShardDataset(dataset, 2, 0)

        # Sharding by file will interleave files 0 and 2
        expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]]
        actual = self.getDatasetOutput(dataset)
        self.assertEqual(actual, expected)
    def testTFRecordReaderWithDirectFileNamesAndShapes(self):
        # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
        # a flat_map automatically.
        dataset = core_readers._TFRecordDataset(self._filenames)

        # BatchDataset contains `output_types` and `output_shapes`
        dataset = dataset.batch(5)
        dataset = distribute._AutoShardDataset(dataset, 2, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 5)
        ]
        self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
  def testSampleResNetPipeline(self, shuffle):
    dataset = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=shuffle)
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
    def testSampleResNetPipeline(self, shuffle):
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=shuffle)
        dataset = dataset.apply(
            interleave_ops.parallel_interleave(core_readers.TFRecordDataset,
                                               10))
        dataset = dataset.batch(5)
        dataset = distribute._AutoShardDataset(dataset, 5, 3)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for r in range(0, 10) for f in (3, 8)
        ]
        self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
    def testAutoshardPolicyOff(self):
        options = options_lib.Options()
        options.experimental_distribute.auto_shard_policy = (
            options_lib.AutoShardPolicy.OFF)

        dataset = core_readers._TFRecordDataset(self._filenames)
        dataset = dataset.with_options(options)
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        # Should return every record in every file since autosharding is turned off.
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)
  def testTFRecordReaderWithDirectFileNamesAndShapes(self):
    # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
    # a flat_map automatically.
    dataset = core_readers._TFRecordDataset(self.test_filenames)

    # BatchDataset contains `output_types` and `output_shapes`
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 2, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in range(0, 5)
    ]
    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
    def testPipelineWithMap(self):
        dataset = dataset_ops.Dataset.list_files(self.test_filenames,
                                                 shuffle=True)
        dataset = dataset.apply(
            interleave_ops.parallel_interleave(core_readers.TFRecordDataset,
                                               10))
        dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
        dataset = dataset.batch(5)
        dataset = distribute._AutoShardDataset(dataset, 5, 3)

        expected = [
            b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for r in range(0, 10) for f in (3, 8)
        ]
        self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
    def testWorkersGreaterThanNumFilesWithDataSharding(self):
        options = options_lib.Options()
        options.experimental_distribute.auto_shard_policy = (
            options_lib.AutoShardPolicy.DATA)

        dataset = core_readers._TFRecordDataset(self._filenames)
        dataset = dataset.with_options(options)
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by
        # individual elements, we should be able to get some data from all files.
        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in (0, 5)
        ]
        self.assertDatasetProduces(dataset, expected)
Example #39
0
  def testValidPipelineWithRangeDataset(self, shuffle):
    dataset = dataset_ops.Dataset.range(self._num_files)
    dataset = dataset.map(lambda n: string_ops.string_join(  # pylint:disable=g-long-lambda
        [self.get_temp_dir(),
         string_ops.string_format("/tf_record.{}.txt", [n])]))
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
  def testValidPipelineWithRangeDataset(self, shuffle):
    dataset = dataset_ops.Dataset.range(self._num_files)
    dataset = dataset.map(lambda n: string_ops.string_join(  # pylint:disable=g-long-lambda
        [self.get_temp_dir(),
         string_ops.string_format("/tf_record.{}.txt", [n])]))
    dataset = dataset.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
    dataset = dataset.batch(5)
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]
    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
  def testZipReaderPipeline(self):
    dataset1 = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=False)
    dataset1 = dataset1.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
    dataset2 = dataset_ops.Dataset.list_files(
        self.test_filenames, shuffle=False)
    dataset2 = dataset2.apply(
        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))

    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
    dataset = distribute._AutoShardDataset(dataset, 5, 3)

    expected = [
        (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f))  # pylint:disable=g-complex-comprehension
        for r in range(0, 10)
        for f in (3, 8)
    ]

    self.assertDatasetProduces(dataset, expected)
Example #42
0
def auto_shard_dataset(dataset, num_shards, index):
  """Shard the input pipeline by sharding the underlying list of files.

  Args:
    dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
      dataset transformations.
    num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
        shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
    index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
      Same usage as in `tf.data.Dataset.shard`.

  Returns:
    A modified `Dataset` obtained by updating the pipeline sharded by the
    files. The input dataset will be returned if we cannot automatically
    determine a good way to shard the input dataset.
  """
  if isinstance(dataset, dataset_ops.DatasetV1):
    return distribute._AutoShardDatasetV1(dataset, num_shards, index)
  else:
    return distribute._AutoShardDataset(dataset, num_shards, index)
 def testStandardReaderPipeline(self, num_epochs, index, batch_size,
                                parallel_reads):
   dataset = readers.make_tf_record_dataset(
       file_pattern=self.test_filenames,
       num_epochs=num_epochs,
       batch_size=batch_size,
       parser_fn=None,
       num_parallel_reads=parallel_reads,
       drop_final_batch=True,
       shuffle=False)
   dataset = distribute._AutoShardDataset(dataset, 2, index)
   outputs = self.getNext(dataset)
   self._verify_records(
       outputs,
       batch_size=batch_size,
       file_index=[i for i in range(index, self._num_records, 2)],
       num_epochs=num_epochs,
       interleave_cycle_length=parallel_reads,
       drop_final_batch=True,
       use_parser_fn=None)
   with self.assertRaises(errors.OutOfRangeError):
     self.evaluate(outputs())
 def testNoReaderPipelines(self):
   dataset = dataset_ops.Dataset.range(1024)
   dataset = distribute._AutoShardDataset(dataset, 2, 0)
   self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0])
 def testNoReaderPipelines(self):
   dataset = dataset_ops.Dataset.range(1024)
   with self.assertRaises(errors.NotFoundError):
     dataset = distribute._AutoShardDataset(dataset, 2, 0)
     self.evaluate(self.getNext(dataset)())
 def testShardOutOfRangeEmptyDataset(self):
   dataset = dataset_ops.Dataset.range(0)
   with self.assertRaises(errors.OutOfRangeError):
     dataset = distribute._AutoShardDataset(dataset, 10, 0)
     self.evaluate(self.getNext(dataset)())
 def testShardOutOfRange(self):
   dataset = dataset_ops.Dataset.range(5)
   with self.assertRaises(errors.InvalidArgumentError):
     dataset = distribute._AutoShardDataset(dataset, 10, 0)
     self.evaluate(self.getNext(dataset)())