Ejemplo n.º 1
0
  def __init__(self,
               dataset,
               input_workers,
               strategy,
               split_batch_by=None,
               input_context=None,
               **kwargs):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
          dataset, input_context.num_input_pipelines,
          input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = dataset
          if not context.executing_eagerly():
            cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
            cloned_dataset = cloned_dataset.with_options(dataset.options())
          # TODO(b/129506833): Figure out between graph cases
          cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    # TODO(anjalisridhar): Identify if we need to set this property on the
    # iterator.
    self._element_structure = dataset._element_structure  # pylint: disable=protected-access
    self._strategy = strategy
    self._kwargs = kwargs
Ejemplo n.º 2
0
 def testNotDivisibleError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).batch(
       32, drop_remainder=drop_remainder)
   with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                "not divisible by"):
     rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
Ejemplo n.º 3
0
 def testUnsupportedTransformError(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).batch(
       32, drop_remainder=drop_remainder).apply(
           scan_ops.scan([0], lambda _, a: ([0], a)))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
 def testBatchSizesDontMatch(self):
     dataset = dataset_ops.Dataset.from_tensors(
         (np.arange(10), np.arange(5)))
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "Cannot use rebatching fallback"):
         rebatched_dataset = distribute._RebatchDataset(dataset,
                                                        num_replicas=5)
         next_element = self.getNext(rebatched_dataset)
         self.evaluate(next_element())
Ejemplo n.º 5
0
 def testNotDivisibleError(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(1024).batch(
         32, drop_remainder=drop_remainder)
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                  "not divisible by"):
         rebatched_dataset = distribute._RebatchDataset(dataset,
                                                        num_workers=5)
         next_element = self.getNext(rebatched_dataset)
         self.evaluate(next_element())
Ejemplo n.º 6
0
    def __init__(self,
                 dataset,
                 input_workers,
                 split_batch_by=None,
                 input_context=None,
                 **kwargs):
        """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
        # We clone and shard the dataset on each worker. The current setup tries to
        # shard the dataset by files if possible so that each worker sees a
        # different subset of files. If that is not possible, will attempt to shard
        # the final input such that each worker will run the entire preprocessing
        # pipeline and only receive its own shard of the dataset.
        assert isinstance(input_workers, InputWorkers)
        if split_batch_by:
            dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

        self._cloned_datasets = []
        if input_context:
            # Between-graph where we rely on the input_context for sharding
            assert input_workers.num_workers == 1
            dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                dataset, input_context.num_input_pipelines,
                input_context.input_pipeline_id)
            self._cloned_datasets.append(dataset)
        else:
            for i, worker in enumerate(input_workers.worker_devices):
                with ops.device(worker):
                    cloned_dataset = dataset
                    if not context.executing_eagerly():
                        cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
                        cloned_dataset = cloned_dataset.with_options(
                            dataset.options())
                    # TODO(b/129506833): Figure out between graph cases
                    cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                        cloned_dataset, len(input_workers.worker_devices), i)
                    self._cloned_datasets.append(cloned_dataset)

        self._input_workers = input_workers
        # TODO(anjalisridhar): Identify if we need to set this property on the
        # iterator.
        self._element_structure = dataset._element_structure  # pylint: disable=protected-access
        self._kwargs = kwargs
Ejemplo n.º 7
0
 def testNestedDictionaryOutput(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).map(
       lambda x: {"a": x, "b": {"c": x}}).batch(
           32, drop_remainder=drop_remainder)
   rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
   expected_output = [{"a": [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                       "b": {"c": [k for k in range(i, i + 8)]}}
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 8
0
 def testUnsupportedTransformError(self, drop_remainder):
     dataset = dataset_ops.Dataset.range(1024).batch(
         32, drop_remainder=drop_remainder).apply(
             scan_ops.scan([0], lambda _, a: ([0], a)))
     with self.assertRaises(errors.InvalidArgumentError):
         rebatched_dataset = distribute._RebatchDataset(dataset,
                                                        num_workers=4)
         next_element = self.getNext(rebatched_dataset)
         self.evaluate(next_element())
  def testPartialBatchWithDropRemainder(self):
    dataset = dataset_ops.Dataset.range(5).batch(4, drop_remainder=False)
    rebatched_dataset = distribute._RebatchDataset(
        dataset, batch_sizes=[2, 2], drop_remainder=True)

    expected_shapes = [[2]]
    self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))
    expected_output = [[0, 1], [2, 3]]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testUnsupportedTransformError(self):
     dataset = dataset_ops.Dataset.range(1024).batch(32).apply(
         sleep.sleep(10))
     with self.assertRaises(errors.InvalidArgumentError):
         rebatched_dataset = distribute._RebatchDataset(dataset,
                                                        num_replicas=4,
                                                        use_fallback=False)
         next_element = self.getNext(rebatched_dataset)
         self.evaluate(next_element())
Ejemplo n.º 11
0
 def testMapAndBatch(self):
   dataset = dataset_ops.Dataset.range(1024).apply(
       batching.map_and_batch(math_ops.square, 32))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
   self.assertEqual([[None]],
                    [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [[k**2 for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 12
0
 def testUnsupportedTransformInFlatMapError(self):
   dataset = dataset_ops.Dataset.range(2).flat_map(
       lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
           32).apply(sleep.sleep(10)))
   with self.assertRaises(errors.InvalidArgumentError):
     rebatched_dataset = distribute._RebatchDataset(
         dataset, num_replicas=4, use_fallback=False)
     next_element = self.getNext(rebatched_dataset)
     self.evaluate(next_element())
Ejemplo n.º 13
0
 def testTupleOutput(self, drop_remainder):
   dataset = (
       dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
           32, drop_remainder=drop_remainder))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
   expected_output = [([k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
                       [k for k in range(i, i + 8)])
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 14
0
 def testZip(self):
   dataset1 = dataset_ops.Dataset.range(64).batch(8)
   dataset2 = dataset_ops.Dataset.range(32).batch(8)
   dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
   self.assertEqual([[None], [None]],
                    [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
    def testUseLegacyRebatchWithDataSharding(self, sharding_policy,
                                             with_prefetch):
        # This test simulates a distributed environment with 3 workers, each with
        # 1 replica.
        dataset = dataset_ops.Dataset.range(8)
        dataset = dataset.batch(4)
        options = options_lib.Options()
        options.experimental_distribute.auto_shard_policy = sharding_policy
        dataset = dataset.with_options(options)
        # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to
        # RebatchDataset(V1) for correctness reasons. This will modify the output
        # of the dataset.
        worker_a_dataset = distribute._RebatchDataset(dataset,
                                                      batch_sizes=[2, 1, 1])
        if with_prefetch:
            worker_a_dataset = worker_a_dataset.prefetch(1)
        worker_a_dataset = distribute._AutoShardDataset(worker_a_dataset,
                                                        3,
                                                        0,
                                                        num_replicas=3)
        expected = [[0, 1], [4, 5]]
        self.assertDatasetProduces(worker_a_dataset, expected)

        worker_b_dataset = distribute._RebatchDataset(dataset,
                                                      batch_sizes=[1, 1, 2])
        if with_prefetch:
            worker_b_dataset = worker_b_dataset.prefetch(1)
        worker_b_dataset = distribute._AutoShardDataset(worker_b_dataset,
                                                        3,
                                                        1,
                                                        num_replicas=3)
        expected = [[2, 3], [6, 7]]
        self.assertDatasetProduces(worker_b_dataset, expected)

        worker_c_dataset = distribute._RebatchDataset(dataset,
                                                      batch_sizes=[1, 2, 1])
        if with_prefetch:
            worker_c_dataset = worker_c_dataset.prefetch(1)
        worker_c_dataset = distribute._AutoShardDataset(worker_c_dataset,
                                                        3,
                                                        2,
                                                        num_replicas=3)
        expected = [[], []]
        self.assertDatasetProduces(worker_c_dataset, expected)
Ejemplo n.º 16
0
    def testWithUnknownBatchDim(self):
        dataset = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=False).apply(sleep.sleep(10))

        with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                     "Cannot use rebatching fallback"):
            rebatched_dataset = distribute._RebatchDataset(dataset,
                                                           num_workers=4)
            next_element = self.getNext(rebatched_dataset)
            self.evaluate(next_element())
 def testConcatenateDifferentShapes(self):
   dataset1 = dataset_ops.Dataset.range(64).batch(16)
   dataset2 = dataset_ops.Dataset.range(32).batch(8)
   dataset = dataset1.concatenate(dataset2)
   rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
   self.assertEqual([[None]],
                    [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
                      [[i, i + 1] for i in range(0, 32, 2)])
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 18
0
    def testScalarBatchSizeInput(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
        rebatched_dataset = distribute._RebatchDataset(
            dataset, batch_sizes=2, drop_remainder=drop_remainder)

        expected_shapes = [[2]]
        self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))

        expected_output = [[0, 1], [2, 3], [4, 5], [6, 7]]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testCanHandleUnknownRank(self):
   dataset = dataset_ops.Dataset.from_tensors("xxx")
   # decode_image results in a tensor of completely unknown shape (i.e. unknown
   # rank)
   dataset = dataset.map(image_ops.decode_image)
   self.assertEqual([tensor_shape.TensorShape(None)], _flat_shapes(dataset))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
   # Note that we are just testing the dataset shapes, not the actual output.
   self.assertEqual([tensor_shape.TensorShape(None)],
                    _flat_shapes(rebatched_dataset))
 def testCanHandleUnknownDims(self):
   dataset = dataset_ops.Dataset.range(1000)
   dataset = dataset.batch(10, drop_remainder=False)
   dataset = dataset.batch(10, drop_remainder=False)
   self.assertEqual([[None, None]],
                    [ts.as_list() for ts in _flat_shapes(dataset)])
   rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
   # Note that we are just testing the dataset shapes, not the actual output.
   self.assertEqual([[None, None]],
                    [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
Ejemplo n.º 21
0
 def testTupleOutput(self, drop_remainder):
     dataset = (dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
         32, drop_remainder=drop_remainder))
     rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
     expected_output = [
         (
             [k for k in range(i, i + 8)],  # pylint: disable=g-complex-comprehension
             [k for k in range(i, i + 8)]) for i in range(0, 1024, 8)
     ]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 22
0
  def testWithNoBatchDataset(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        [[k for k in range(i, i + 32)] for i in range(0, 1024, 32)])  # pylint: disable=g-complex-comprehension
    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
    self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
    self.assertEqual([[8]],
                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 23
0
  def testWithUnhandledTransformation(self):
    dataset = dataset_ops.Dataset.range(1024).batch(
        32, drop_remainder=True).apply(sleep.sleep(10))
    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
    self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
    self.assertEqual([[8]],
                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(rebatched_dataset, expected_output)
    def testScanAfterBatch(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(40).batch(10).apply(
            scan_ops.scan(np.int64(2), lambda state, value:
                          (state, value * state)))
        dataset = distribute._RebatchDataset(dataset, num_workers=2)

        self.assertEqual([[None]],
                         [ts.as_list() for ts in _flat_shapes(dataset)])
        expected_output = [[i * 2 for i in range(j * 5, (j + 1) * 5)]
                           for j in range(8)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(dataset, expected_output)
Ejemplo n.º 25
0
  def testBatchSizeIndivisibleByNumWorkers(self):
    # This doesn't work; reshape requires tensor shape to be exactly divisible
    # by the second dim.
    dataset = dataset_ops.Dataset.range(64).batch(
        32, drop_remainder=True).apply(sleep.sleep(10))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "Cannot use rebatching fallback"):
      rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
      next_element = self.getNext(rebatched_dataset)
      self.evaluate(next_element())
Ejemplo n.º 26
0
    def testEmptyLastSplits(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
        rebatched_dataset = distribute._RebatchDataset(
            dataset, batch_sizes=[1, 0], drop_remainder=drop_remainder)

        expected_shapes = [[None]]
        self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))

        expected_output = [[0], [], [1], [], [2], [], [3], [], [4], [], [5],
                           [], [6], [], [7], []]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 27
0
  def testFinalPartialBatchAfterRebatch(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(34).batch(
        32, drop_remainder=drop_remainder)
    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
    self.assertEqual([[8] if drop_remainder else [None]],
                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]  # pylint: disable=g-complex-comprehension
    if not drop_remainder:
      expected_output += [[32, 33]]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 28
0
    def testMultipleBatches(self):
        dataset = dataset_ops.Dataset.range(16).batch(
            2, drop_remainder=True).batch(4, drop_remainder=True)
        self.assertEqual([[4, 2]], _flat_shapes(dataset))

        rebatched_dataset = distribute._RebatchDataset(dataset, [2, 2])
        self.assertEqual([[2, 2]], _flat_shapes(rebatched_dataset))
        # Each element is a list of 2 elements where each element is a list of 2.
        expected_output = [[[0, 1], [2, 3]], [[4, 5], [6, 7]],
                           [[8, 9], [10, 11]], [[12, 13], [14, 15]]]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
    def testBasic(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
        self.assertEqual(
            [[8] if drop_remainder else [None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(rebatched_dataset, expected_output)
  def testNestedDictionaryOutput(self):
    dataset = dataset_ops.Dataset.range(8).map(
        lambda x: {"a": x, "b": {"c": x + 1}}).batch(4, drop_remainder=True)
    rebatched_dataset = distribute._RebatchDataset(dataset, [2, 2])
    self.assertEqual([[2], [2]], _flat_shapes(rebatched_dataset))

    expected_output = [{"a": [0, 1], "b": {"c": [1, 2]}},
                       {"a": [2, 3], "b": {"c": [3, 4]}},
                       {"a": [4, 5], "b": {"c": [5, 6]}},
                       {"a": [6, 7], "b": {"c": [7, 8]}}]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
    def testWithUnknownBatchDimInSecondComponent(self):
        dataset0 = dataset_ops.Dataset.range(1024).batch(32,
                                                         drop_remainder=True)
        dataset1 = dataset_ops.Dataset.range(1024).batch(
            32, drop_remainder=False).apply(sleep.sleep(10))
        dataset = dataset_ops.Dataset.zip((dataset0, dataset1))
        rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
        expected_output = [(x, x) for x in expected_output]
        self.assertDatasetProduces(rebatched_dataset, expected_output)
 def testMapAndBatchWithCapturedInput(self):
   captured_t = variables.Variable(42)
   dataset = dataset_ops.Dataset.range(1024).apply(
       batching.map_and_batch(lambda x: captured_t, 32))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
   self.assertEqual([[None]],
                    [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [[42 for _ in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                      for i in range(0, 1024, 8)]
   self.evaluate(variables.global_variables_initializer())
   self.assertDatasetProduces(
       rebatched_dataset, expected_output, requires_initialization=True)
 def testShardWithRebatch(self):
     # Tests that RebatchDatasetV2 is a passthrough op.
     dataset = dataset_ops.Dataset.list_files(self.test_filenames,
                                              shuffle=False)
     dataset = dataset.apply(
         testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
     dataset = dataset.flat_map(core_readers.TFRecordDataset)
     dataset = dataset.batch(5)
     dataset = distribute._RebatchDataset(dataset, batch_sizes=5)
     dataset = distribute._AutoShardDataset(dataset, 5, 3)
     nxt = self.getNext(dataset)
     self.evaluate(nxt())
 def testBatchSizeNotDivisibleByNumReplicas2(self):
   dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
   rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
   # This will rebatch into sub-batches of size 4, since
   # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get
   # data.
   expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)]
   expected_output.extend([[]])  # Last replica gets an empty batch
   expected_output.extend(
       [[k for k in range(i, i + 4)] for i in range(16, 32, 4)])
   expected_output.extend([[]])  # Last replica gets an empty batch
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 35
0
    def testFinalPartialBatchOriginal(self, drop_remainder):
        dataset = dataset_ops.Dataset.range(1032).batch(
            32, drop_remainder=drop_remainder)
        rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
        self.assertEqual([[32 if drop_remainder else None]],
                         [ts.as_list() for ts in _flat_shapes(dataset)])
        self.assertEqual(
            [[8 if drop_remainder else None]],
            [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

        expected_output = [[k for k in range(i, i + 8)]
                           for i in range(0, 1032, 8)]  # pylint: disable=g-complex-comprehension
        self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 36
0
  def testFinalPartialBatchOriginal(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(1032).batch(
        32, drop_remainder=drop_remainder)
    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])

    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 37
0
 def testMapAndBatch(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(1024).apply(
       batching.map_and_batch(
           math_ops.square, 32, drop_remainder=drop_remainder))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[32 if drop_remainder else None]],
       [ts.as_list() for ts in _flat_shapes(dataset)])
   self.assertEqual(
       [[8 if drop_remainder else None]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [[k**2 for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                      for i in range(0, 1024, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 38
0
 def testZip(self, drop_remainder):
   dataset1 = dataset_ops.Dataset.range(64).batch(
       8, drop_remainder=drop_remainder)
   dataset2 = dataset_ops.Dataset.range(32).batch(
       8, drop_remainder=drop_remainder)
   dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[8], [8]] if drop_remainder else [[None], [None]],
       [ts.as_list() for ts in _flat_shapes(dataset)])
   self.assertEqual(
       [[2], [2]] if drop_remainder else [[None], [None]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 39
0
 def testConcatenateDifferentShapes(self, drop_remainder):
   dataset1 = dataset_ops.Dataset.range(64).batch(
       16, drop_remainder=drop_remainder)
   dataset2 = dataset_ops.Dataset.range(32).batch(
       8, drop_remainder=drop_remainder)
   dataset = dataset1.concatenate(dataset2)
   rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
   self.assertEqual(
       [[None]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
                      [[i, i + 1] for i in range(0, 32, 2)])
   self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 40
0
 def testMapAndBatchWithCapturedInput(self, drop_remainder):
   captured_t = variables.Variable(42)
   dataset = dataset_ops.Dataset.range(1024).apply(
       batching.map_and_batch(
           lambda x: captured_t, 32, drop_remainder=drop_remainder))
   rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
   self.assertEqual([[32 if drop_remainder else None]],
                    [ts.as_list() for ts in _flat_shapes(dataset)])
   self.assertEqual([[8 if drop_remainder else None]],
                    [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   expected_output = [[42 for _ in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                      for i in range(0, 1024, 8)]
   self.evaluate(variables.global_variables_initializer())
   self.assertDatasetProduces(
       rebatched_dataset, expected_output, requires_initialization=True)
Ejemplo n.º 41
0
  def testGroupByWindowBatching(self, drop_remainder):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        [[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
    reduce_fn = lambda bucket_id, ds: ds.batch(
        batch_size=10, drop_remainder=drop_remainder)
    dataset = dataset.apply(
        grouping.group_by_window(
            key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10))
    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2)

    self.assertEqual([[5, 3] if drop_remainder else [None, 3]],
                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
    # pylint: disable=g-complex-comprehension
    expected_output = [[[j + i * 4 + k * 20] * 3
                        for i in range(5)]
                       for j in range(4)
                       for k in range(2)]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 42
0
  def testInterleaveBatching(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(
        2).interleave(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
            32, drop_remainder=drop_remainder), cycle_length=2)
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    # Two elements where each element is range(32)
    expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(dataset, expected_output)

    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
    # List of 4 elements where each element is a list of 8 numbering from 0 to
    # 31 repeated twice.
    expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                       for i in range(0, 32, 8)  # generates 4 elements
                       for _ in range(2)]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 43
0
  def testFlatMapBatching(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(
        2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
            32, drop_remainder=drop_remainder))
    self.assertEqual(
        [[32 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    # Two elements where each element is range(32)
    expected_output = [[k for k in range(32)] for _ in range(2)]  # pylint: disable=g-complex-comprehension
    self.assertDatasetProduces(dataset, expected_output)

    rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
    self.assertEqual(
        [[8 if drop_remainder else None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
    # Two elements where each element is a list of 4 elements where each element
    # is a list of 8.
    expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
                       for _ in range(2)
                       for i in range(0, 32, 8)]  # generates 4 elements
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 44
0
  def testMultipleBatches(self, drop_remainder):
    dataset = dataset_ops.Dataset.range(128).batch(
        4, drop_remainder=drop_remainder)
    dataset = dataset.batch(8, drop_remainder=drop_remainder)
    self.assertEqual(
        [[8, 4]] if drop_remainder else [[None, None]],
        [ts.as_list() for ts in _flat_shapes(dataset)])
    # Each element is a list of 8 elements where each element is a list of 4.
    expected_output = [[[j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                        for j in range(i, i + 32, 4)]  # generates 8 elements
                       for i in range(0, 128, 32)]
    self.assertDatasetProduces(dataset, expected_output)

    rebatched_dataset = distribute._RebatchDataset(dataset, 4)
    self.assertEqual(
        [[2, 4]] if drop_remainder else [[None, None]],
        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
    # Each element is a list of 2 elements where each element is a list of 4.
    expected_output = [[[j, j + 1, j + 2, j + 3]  # pylint: disable=g-complex-comprehension
                        for j in range(i, i + 8, 4)]  # generates 2 elements
                       for i in range(0, 128, 8)]
    self.assertDatasetProduces(rebatched_dataset, expected_output)
Ejemplo n.º 45
0
 def testPaddedBatch(self, drop_remainder):
   dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch(
       8, padded_shapes=[5], drop_remainder=drop_remainder)
   rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
   self.assertEqual(
       [[8, 5]] if drop_remainder else [[None, 5]],
       [ts.as_list() for ts in _flat_shapes(dataset)])
   # Each element is a list of 8 elements in which each element is a list of 5
   # elements, first four are numbers and the last one is a padded zero.
   expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
                       for j in range(i, i + 32, 4)]  # generates 8 elements
                      for i in range(0, 128, 32)]
   self.assertDatasetProduces(dataset, expected_output)
   self.assertEqual(
       [[2, 5]] if drop_remainder else [[None, 5]],
       [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
   # Each element is a list of 2 elements in which each element is a list of 5
   # elements, first four are numbers and the last one is a padded zero.
   expected_output = [[[j, j + 1, j + 2, j + 3, 0]  # pylint: disable=g-complex-comprehension
                       for j in range(i, i + 8, 4)]  # generates 2 elements
                      for i in range(0, 128, 8)]
   self.assertDatasetProduces(rebatched_dataset, expected_output)
 def build_dataset(num_elements, batch_size):
   return distribute._RebatchDataset(
       dataset_ops.Dataset.range(num_elements).batch(
           4 * batch_size, drop_remainder=True),
       num_workers=4)
Ejemplo n.º 47
0
 def testScalarInputError(self, _):
   dataset = dataset_ops.Dataset.range(1024)
   with self.assertRaisesRegexp(ValueError, "at least one dimension"):
     distribute._RebatchDataset(dataset, num_workers=4)