def testMultipleBatches(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(128).batch(
            4, drop_remainder=drop_remainder)
        dataset = dataset.batch(8, drop_remainder=drop_remainder)
        self.assertEqual(
            [[8, 4]] if drop_remainder else [[None, None]],
            [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
        # Each element is a list of 8 elements where each element is a list of 4.
        expected_output = [
            [
                [j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                for j in range(i, i + 32, 4)
            ]  # generates 8 elements
            for i in range(0, 128, 32)
        ]
        self.assertDatasetProduces(dataset, expected_output)

        rebatched_dataset = batching._RebatchDataset(dataset, 4)
        self.assertEqual([[2, 4]] if drop_remainder else [[None, None]], [
            ts.as_list()
            for ts in nest.flatten(rebatched_dataset.output_shapes)
        ])
        # Each element is a list of 2 elements where each element is a list of 4.
        expected_output = [
            [
                [j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                for j in range(i, i + 8, 4)
            ]  # generates 2 elements
            for i in range(0, 128, 8)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testPaddedBatch(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch(
         8, padded_shapes=[5], drop_remainder=drop_remainder)
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
     self.assertEqual(
         [[8, 5]] if drop_remainder else [[None, 5]],
         [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
     # Each element is a list of 8 elements in which each element is a list of 5
     # elements, first four are numbers and the last one is a padded zero.
     expected_output = [
         [
             [j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
             for j in range(i, i + 32, 4)
         ]  # generates 8 elements
         for i in range(0, 128, 32)
     ]
     self.assertDatasetProduces(dataset, expected_output)
     self.assertEqual([[2, 5]] if drop_remainder else [[None, 5]], [
         ts.as_list()
         for ts in nest.flatten(rebatched_dataset.output_shapes)
     ])
     # Each element is a list of 2 elements in which each element is a list of 5
     # elements, first four are numbers and the last one is a padded zero.
     expected_output = [
         [
             [j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
             for j in range(i, i + 8, 4)
         ]  # generates 2 elements
         for i in range(0, 128, 8)
     ]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
    def testParallelInterleaveBatching(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(2).interleave(
            lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
                32,
                drop_remainder=drop_remainder),
            cycle_length=2,
            num_parallel_calls=2)
        self.assertEqual(
            [[32 if drop_remainder else None]],
            [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
        # Two elements where each element is range(32)
        expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(dataset, expected_output)

        rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
        self.assertEqual([[8 if drop_remainder else None]], [
            ts.as_list()
            for ts in nest.flatten(rebatched_dataset.output_shapes)
        ])
        # List of 4 elements where each element is a list of 8 numbering from 0 to
        # 31 repeated twice in collated fashion i.e [0...8], [0...8] etc.
        expected_output = [
            [k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
            for i in range(0, 32, 8)  # generates 4 elements
            for _ in range(2)
        ]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testUnsupportedTransformError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).batch(
       32, drop_remainder=drop_remainder).apply(
           scan_ops.scan([0], lambda _, a: ([0], a)))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
示例#5
0
 def testUnsupportedTransformError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).batch(
       32, drop_remainder=drop_remainder).apply(
           scan_ops.scan([0], lambda _, a: ([0], a)))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
 def testNotDivisibleError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).batch(
       32, drop_remainder=drop_remainder)
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                "not divisible by"):
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=5)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
 def testNotDivisibleError(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(1024).batch(
         32, drop_remainder=drop_remainder)
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "not divisible by"):
         rebatched_dataset = batching._RebatchDataset(dataset,
                                                      num_workers=5)
         next_element = self.getNext(rebatched_dataset)
         self.evaluate(next_element())
 def testTupleOutput(self, drop_remainder):
   dataset = (
       dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
           32, drop_remainder=drop_remainder))
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   expected_output = [([k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                       [k for k in range(i, i + 8)])
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testNestedDictionaryOutput(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).map(
       lambda x: {"a": x, "b": {"c": x}}).batch(
           32, drop_remainder=drop_remainder)
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   expected_output = [{"a": [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                       "b": {"c": [k for k in range(i, i + 8)]}}
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#10
0
 def testTupleOutput(self, drop_remainder):
   dataset = (
       dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
           32, drop_remainder=drop_remainder))
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   expected_output = [([k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                       [k for k in range(i, i + 8)])
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#11
0
 def testNestedDictionaryOutput(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).map(
       lambda x: {"a": x, "b": {"c": x}}).batch(
           32, drop_remainder=drop_remainder)
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   expected_output = [{"a": [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                       "b": {"c": [k for k in range(i, i + 8)]}}
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#12
0
    def __init__(self,
                 dataset,
                 input_workers,
                 split_batch_by=None,
                 input_context=None,
                 **kwargs):
        """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
        # We clone and shard the dataset on each worker. The current setup tries to
        # shard the dataset by files if possible so that each worker sees a
        # different subset of files. If that is not possible, will attempt to shard
        # the final input such that each worker will run the entire preprocessing
        # pipeline and only receive its own shard of the dataset.
        assert isinstance(input_workers, InputWorkers)
        if split_batch_by:
            dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

        self._cloned_datasets = []
        if input_context:
            # Between-graph where we rely on the input_context for sharding
            assert input_workers.num_workers == 1
            dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                dataset, input_context.num_input_pipelines,
                input_context.input_pipeline_id)
            self._cloned_datasets.append(dataset)
        else:
            for i, worker in enumerate(input_workers.worker_devices):
                with ops.device(worker):
                    cloned_dataset = dataset
                    if not context.executing_eagerly():
                        cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
                        cloned_dataset = cloned_dataset.with_options(
                            dataset.options())
                    # TODO(b/129506833): Figure out between graph cases
                    cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                        cloned_dataset, len(input_workers.worker_devices), i)
                    self._cloned_datasets.append(cloned_dataset)

        self._input_workers = input_workers
        # TODO(anjalisridhar): Identify if we need to set this property on the
        # iterator.
        self._element_structure = dataset._element_structure  # pylint: disable=protected-access
        self._kwargs = kwargs
示例#13
0
  def testBasic(self):
    dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
    rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[32]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
    self.assertEqual(
        [[8]],
        [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#14
0
  def __init__(self, dataset, input_workers, split_batch_by=None,
               input_context=None, **kwargs):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
          dataset, input_context.num_input_pipelines,
          input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = dataset
          if not context.executing_eagerly():
            cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
            cloned_dataset = cloned_dataset.with_options(dataset.options())
          # TODO(b/129506833): Figure out between graph cases
          cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    # TODO(anjalisridhar): Identify if we need to set this property on the
    # iterator.
    self._element_structure = dataset._element_structure  # pylint: disable=protected-access
    self._kwargs = kwargs
示例#15
0
  def testFinalPartialBatchOriginal(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(1032).batch(
        32, drop_remainder=drop_remainder)
    rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(rebatched_dataset, expected_output)
  def testFinalPartialBatchOriginal(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(1032).batch(
        32, drop_remainder=drop_remainder)
    rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testZip(self, drop_remainder):
     dataset1 = dataset_ops.Dataset.range(64).batch(
         8, drop_remainder=drop_remainder)
     dataset2 = dataset_ops.Dataset.range(32).batch(
         8, drop_remainder=drop_remainder)
     dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
     self.assertEqual([[8], [8]] if drop_remainder else [[None], [None]],
                      [ts.as_list() for ts in _flat_shapes(dataset)])
     self.assertEqual(
         [[2], [2]] if drop_remainder else [[None], [None]],
         [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testMapAndBatch(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).apply(
       batching.map_and_batch(
           math_ops.square, 32, drop_remainder=drop_remainder))
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[32 if drop_remainder else None]],
       [ts.as_list() for ts in _flat_shapes(dataset)])
   self.assertEqual(
       [[8 if drop_remainder else None]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [[k**2 for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#19
0
 def testMapAndBatch(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).apply(
       batching.map_and_batch(
           math_ops.square, 32, drop_remainder=drop_remainder))
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[32 if drop_remainder else None]],
       [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
   self.assertEqual(
       [[8 if drop_remainder else None]],
       [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])
   expected_output = [[k**2 for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testZip(self, drop_remainder):
   dataset1 = dataset_ops.Dataset.range(64).batch(
       8, drop_remainder=drop_remainder)
   dataset2 = dataset_ops.Dataset.range(32).batch(
       8, drop_remainder=drop_remainder)
   dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[8], [8]] if drop_remainder else [[None], [None]],
       [ts.as_list() for ts in _flat_shapes(dataset)])
   self.assertEqual(
       [[2], [2]] if drop_remainder else [[None], [None]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
    def testFinalPartialBatchAfterRebatch(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(34).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
        self.assertEqual([[32 if drop_remainder else None]],
                         [ts.as_list() for ts in _flat_shapes(dataset)])
        self.assertEqual(
            [[8 if drop_remainder else None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 32, 8)]  # pylint: disable=g-complex-comprehension
        if not drop_remainder:
            expected_output += [[32, 33]]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testConcatenate(self, drop_remainder):
     dataset1 = dataset_ops.Dataset.range(64).batch(
         8, drop_remainder=drop_remainder)
     dataset2 = dataset_ops.Dataset.range(32).batch(
         8, drop_remainder=drop_remainder)
     dataset = dataset1.concatenate(dataset2)
     rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
     self.assertEqual([[8 if drop_remainder else None]],
                      [ts.as_list() for ts in _flat_shapes(dataset)])
     self.assertEqual(
         [[2 if drop_remainder else None]],
         [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = ([[i, i + 1] for i in range(0, 64, 2)] +
                        [[i, i + 1] for i in range(0, 32, 2)])
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testConcatenateDifferentShapes(self, drop_remainder):
   dataset1 = dataset_ops.Dataset.range(64).batch(
       16, drop_remainder=drop_remainder)
   dataset2 = dataset_ops.Dataset.range(32).batch(
       8, drop_remainder=drop_remainder)
   dataset = dataset1.concatenate(dataset2)
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
   self.assertEqual(
       [[None]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
                      [[i, i + 1] for i in range(0, 32, 2)])
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#24
0
 def testConcatenateDifferentShapes(self, drop_remainder):
   dataset1 = dataset_ops.Dataset.range(64).batch(
       16, drop_remainder=drop_remainder)
   dataset2 = dataset_ops.Dataset.range(32).batch(
       8, drop_remainder=drop_remainder)
   dataset = dataset1.concatenate(dataset2)
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[None]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
   self.assertEqual(
       [[None]],
       [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])
   expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
                      [[i, i + 1] for i in range(0, 32, 2)])
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#25
0
    def testBasic(self):
        dataset = dataset_ops.Dataset.range(1024).batch(32,
                                                        drop_remainder=True)
        rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
        self.assertEqual(
            [[32]],
            [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
        self.assertEqual([[8]], [
            ts.as_list()
            for ts in nest.flatten(rebatched_dataset.output_shapes)
        ])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#26
0
    def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs):
        """Make an iterator for the dataset on given devices.

    If `split_batch_by` is not None, we "split" each batch of the
    dataset by `split_batch_by` value. To achieve this, we first unbatch the
    input dataset and then rebatch it with the per replica batch size that is
    calculated using `global_batch_size // split_batch_by`.
    The currently supported datasets are as follows:
    `dataset.batch()` is the last operation on the dataset OR
    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    TODO(priyag): Support multi worker / host cases properly by cloning
    and sharding the dataset on each worker. Current setup will only work in
    some cases, such as in-graph multi worker GPU case. If the input pipeline
    has random shuffling (with a different seed on each worker), each worker
    will see random input from the same overall dataset in each step. Otherwise,
    each worker will see the same input in each step.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
        assert isinstance(input_workers, InputWorkers)
        if split_batch_by:
            dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

        iterators = []
        for i, worker in enumerate(input_workers.worker_devices):
            with ops.device(worker):
                worker_devices = input_workers.compute_devices_for_worker(i)
                cloned_dataset = dataset
                if not context.executing_eagerly():
                    cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
                    cloned_dataset = cloned_dataset.with_options(
                        dataset.options())
                iterator = _SingleWorkerDatasetIterator(
                    cloned_dataset, worker, worker_devices)
                iterators.append(iterator)

        self._element_structure = dataset._element_structure  # pylint: disable=protected-access

        super(DatasetIterator, self).__init__(input_workers, iterators,
                                              **kwargs)
示例#27
0
  def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs):
    """Make an iterator for the dataset on given devices.

    If `split_batch_by` is not None, we "split" each batch of the
    dataset by `split_batch_by` value. To achieve this, we first unbatch the
    input dataset and then rebatch it with the per replica batch size that is
    calculated using `global_batch_size // split_batch_by`.
    The currently supported datasets are as follows:
    `dataset.batch()` is the last operation on the dataset OR
    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    We clone and shard the dataset on each worker. The current setup tries to
    shard the dataset by files if possible so that each worker sees a different
    subset of files. If that is not possible, will attempt to shard the final
    input such that each worker will run the entire preprocessing pipeline and
    only receive its own shard of the dataset.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    iterators = []
    for i, worker in enumerate(input_workers.worker_devices):
      with ops.device(worker):
        worker_devices = input_workers.compute_devices_for_worker(i)
        cloned_dataset = dataset
        if not context.executing_eagerly():
          cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
          cloned_dataset = cloned_dataset.with_options(dataset.options())
        # TODO(b/129506833): Figure out between graph cases
        cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
            cloned_dataset, len(input_workers.worker_devices), i)
        iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
                                                worker_devices)
        iterators.append(iterator)

    self._element_structure = dataset._element_structure  # pylint: disable=protected-access

    super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
示例#28
0
 def testZipDifferentShapes(self, drop_remainder):
   dataset1 = dataset_ops.Dataset.range(64).batch(
       16, drop_remainder=drop_remainder)
   dataset2 = dataset_ops.Dataset.range(32).batch(
       8, drop_remainder=drop_remainder)
   dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[16], [8]] if drop_remainder else [[None], [None]],
       [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
   self.assertEqual(
       [[4], [2]] if drop_remainder else [[None], [None]],
       [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])
   expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1])
                      for i in range(0, 32, 2)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#29
0
  def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs):
    """Make an iterator for the dataset on given devices.

    If `split_batch_by` is not None, we "split" each batch of the
    dataset by `split_batch_by` value. To achieve this, we first unbatch the
    input dataset and then rebatch it with the per replica batch size that is
    calculated using `global_batch_size // split_batch_by`.
    The currently supported datasets are as follows:
    `dataset.batch()` is the last operation on the dataset OR
    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    TODO(priyag): Support multi worker / host cases properly by cloning
    and sharding the dataset on each worker. Current setup will only work in
    some cases, such as in-graph multi worker GPU case. If the input pipeline
    has random shuffling (with a different seed on each worker), each worker
    will see random input from the same overall dataset in each step. Otherwise,
    each worker will see the same input in each step.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    iterators = []
    for i, worker in enumerate(input_workers.worker_devices):
      with ops.device(worker):
        worker_devices = input_workers.compute_devices_for_worker(i)
        cloned_dataset = dataset
        if not context.executing_eagerly():
          cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
          cloned_dataset = cloned_dataset.with_options(dataset.options())
        iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
                                                worker_devices)
        iterators.append(iterator)

    self._element_structure = dataset._element_structure  # pylint: disable=protected-access

    super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
示例#30
0
  def testFlatMapBatching(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(
        2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
            32, drop_remainder=drop_remainder))
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
    # Two elements where each element is range(32)
    expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(dataset, expected_output)

    rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])
    # Two elements where each element is a list of 4 elements where each element
    # is a list of 8.
    expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                       for _ in range(2)
                       for i in range(0, 32, 8)]  # generates 4 elements
    self.assertDatasetProduces(rebatched_dataset, expected_output)
  def testFlatMapBatching(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(
        2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
            32, drop_remainder=drop_remainder))
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    # Two elements where each element is range(32)
    expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(dataset, expected_output)

    rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
    # Two elements where each element is a list of 4 elements where each element
    # is a list of 8.
    expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                       for _ in range(2)
                       for i in range(0, 32, 8)]  # generates 4 elements
    self.assertDatasetProduces(rebatched_dataset, expected_output)
  def testInterleaveBatching(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(
        2).interleave(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
            32, drop_remainder=drop_remainder), cycle_length=2)
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    # Two elements where each element is range(32)
    expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(dataset, expected_output)

    rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
    # List of 4 elements where each element is a list of 8 numbering from 0 to
    # 31 repeated twice.
    expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                       for i in range(0, 32, 8)  # generates 4 elements
                       for _ in range(2)]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
  def testMultipleBatches(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(128).batch(
        4, drop_remainder=drop_remainder)
    dataset = dataset.batch(8, drop_remainder=drop_remainder)
    self.assertEqual(
        [[8, 4]] if drop_remainder else [[None, None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    # Each element is a list of 8 elements where each element is a list of 4.
    expected_output = [[[j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                        for j in range(i, i + 32, 4)]  # generates 8 elements
                       for i in range(0, 128, 32)]
    self.assertDatasetProduces(dataset, expected_output)

    rebatched_dataset = batching._RebatchDataset(dataset, 4)
    self.assertEqual(
        [[2, 4]] if drop_remainder else [[None, None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
    # Each element is a list of 2 elements where each element is a list of 4.
    expected_output = [[[j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                        for j in range(i, i + 8, 4)]  # generates 2 elements
                       for i in range(0, 128, 8)]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testPaddedBatch(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch(
       8, padded_shapes=[5], drop_remainder=drop_remainder)
   rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[8, 5]] if drop_remainder else [[None, 5]],
       [ts.as_list() for ts in _flat_shapes(dataset)])
   # Each element is a list of 8 elements in which each element is a list of 5
   # elements, first four are numbers and the last one is a padded zero.
   expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
                       for j in range(i, i + 32, 4)]  # generates 8 elements
                      for i in range(0, 128, 32)]
   self.assertDatasetProduces(dataset, expected_output)
   self.assertEqual(
       [[2, 5]] if drop_remainder else [[None, 5]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   # Each element is a list of 2 elements in which each element is a list of 5
   # elements, first four are numbers and the last one is a padded zero.
   expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
                       for j in range(i, i + 8, 4)]  # generates 2 elements
                      for i in range(0, 128, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
示例#35
0
 def build_dataset(num_elements, batch_size):
   return batching._RebatchDataset(
       dataset_ops.Dataset.range(num_elements).batch(
           4 * batch_size, drop_remainder=True),
       num_workers=4)
示例#36
0
 def testUnknownBatchSizeError(self):
   dataset = dataset_ops.Dataset.range(1024).batch(32)
   with self.assertRaisesRegexp(ValueError, "unknown batch size datasets"):
     batching._RebatchDataset(dataset, num_workers=4)
示例#37
0
 def testNotDivisibleError(self):
   dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
   with self.assertRaisesRegexp(ValueError, "not divisible by"):
     batching._RebatchDataset(dataset, num_workers=5)
示例#38
0
 def testNotDivisibleError(self, drop_remainder):
   # TODO(rohanj): This should fail even with drop_remainder=False, by adding
   # an assertion in the mutated graph.
   dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
   with self.assertRaisesRegexp(ValueError, "not divisible by"):
     batching._RebatchDataset(dataset, num_workers=5)
示例#39
0
 def testScalarInputError(self, _):
   dataset = dataset_ops.Dataset.range(1024)
   with self.assertRaisesRegexp(ValueError, "at least one dimension"):
     batching._RebatchDataset(dataset, num_workers=4)
示例#40
0
 def testNotDivisibleError(self):
     dataset = dataset_ops.Dataset.range(1024).batch(32,
                                                     drop_remainder=True)
     with self.assertRaisesRegexp(ValueError, "not divisible by"):
         batching._RebatchDataset(dataset, num_workers=5)
示例#41
0
 def testUnknownBatchSizeError(self):
     dataset = dataset_ops.Dataset.range(1024).batch(32)
     with self.assertRaisesRegexp(ValueError,
                                  "unknown batch size datasets"):
         batching._RebatchDataset(dataset, num_workers=4)
 def testScalarInputError(self, _):
   dataset = dataset_ops.Dataset.range(1024)
   with self.assertRaisesRegexp(ValueError, "at least one dimension"):
     batching._RebatchDataset(dataset, num_workers=4)