Esempio n. 1
0
  def __init__(self,
               dataset,
               input_workers,
               strategy,
               split_batch_by=None,
               input_context=None,
               **kwargs):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
          dataset, input_context.num_input_pipelines,
          input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = dataset
          if not context.executing_eagerly():
            cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
            cloned_dataset = cloned_dataset.with_options(dataset.options())
          # TODO(b/129506833): Figure out between graph cases
          cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    # TODO(anjalisridhar): Identify if we need to set this property on the
    # iterator.
    self._element_structure = dataset._element_structure  # pylint: disable=protected-access
    self._strategy = strategy
    self._kwargs = kwargs
Esempio n. 2
0
  def __init__(self,
               dataset,
               input_workers,
               strategy,
               split_batch_by=None,
               input_context=None,
               **kwargs):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
          dataset, input_context.num_input_pipelines,
          input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = dataset
          if not context.executing_eagerly():
            cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
            cloned_dataset = cloned_dataset.with_options(dataset.options())
          # TODO(b/129506833): Figure out between graph cases
          cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    # TODO(anjalisridhar): Identify if we need to set this property on the
    # iterator.
    self._element_structure = dataset._element_structure  # pylint: disable=protected-access
    self._strategy = strategy
    self._kwargs = kwargs
Esempio n. 3
0
    def step(self,
             data_creator,
             epochs=1,
             verbose=1,
             callbacks=None,
             validation_data_creator=None,
             class_weight=None,
             steps_per_epoch=None,
             validation_steps=None,
             validation_freq=1):
        """Runs a training epoch and updates the model parameters."""

        train_dataset = data_creator(self.config)
        if validation_data_creator is not None:
            test_dataset = validation_data_creator(self.config)
        else:
            test_dataset = None

        if self.backend == "horovod":
            import horovod.tensorflow.keras as hvd
            from tensorflow.python.distribute.input_ops import auto_shard_dataset
            train_dataset = auto_shard_dataset(train_dataset, hvd.size(),
                                               hvd.rank())
            if test_dataset is not None:
                test_dataset = auto_shard_dataset(test_dataset, hvd.size(),
                                                  hvd.rank())

        if self.backend == "horovod":
            import horovod.tensorflow.keras as hvd
            hvd_callbacks = [
                hvd.callbacks.BroadcastGlobalVariablesCallback(0),
                hvd.callbacks.MetricAverageCallback()
            ]
            if hvd.rank() != 0:
                verbose = 0

            if callbacks is not None:
                callbacks = hvd_callbacks + callbacks
            else:
                callbacks = hvd_callbacks

        history = self.model.fit(train_dataset,
                                 epochs=self.epoch + epochs,
                                 verbose=verbose,
                                 callbacks=callbacks,
                                 validation_data=test_dataset,
                                 class_weight=class_weight,
                                 initial_epoch=self.epoch,
                                 steps_per_epoch=steps_per_epoch,
                                 validation_steps=validation_steps,
                                 validation_freq=validation_freq)
        if history is None:
            stats = {}
        else:
            stats = {"train_" + k: v[-1] for k, v in history.history.items()}

        self.epoch += epochs
        return stats
Esempio n. 4
0
  def testTextLineReader(self):
    dataset = readers.TextLineDataset(self._createTextFiles())

    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._text_line)
Esempio n. 5
0
  def testFixedLengthReaderWithFlatMap(self):
    dataset = readers.FixedLengthRecordDataset(
        self._createFixedLengthRecordFiles(), self._record_bytes)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
Esempio n. 6
0
    def testFixedLengthReader(self):
        dataset = readers.FixedLengthRecordDataset(
            self._createFixedLengthRecordFiles(), self._record_bytes)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
Esempio n. 7
0
  def DISABLED_testTextLineReaderWithFlatMap(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
    dataset = dataset.flat_map(readers.TextLineDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._text_line)
Esempio n. 8
0
  def testTextLineReader(self):
    dataset = readers.TextLineDataset(self._createTextFiles())

    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._text_line)
Esempio n. 9
0
  def testFlatMap(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createTFRecordFiles())
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._record)
Esempio n. 10
0
    def testFlatMap(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createTFRecordFiles())
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._record)
Esempio n. 11
0
    def testTextLineReaderWithFlatMap(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createTextFiles())
        dataset = dataset.flat_map(readers.TextLineDataset)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._text_line)
Esempio n. 12
0
    def validate(self,
                 data_creator,
                 verbose=1,
                 sample_weight=None,
                 steps=None,
                 callbacks=None,
                 data_config=None):
        """Evaluates the model on the validation data set."""
        config = self.config.copy()
        if data_config is not None:
            config.update(data_config)
        if self.backend == "horovod":
            import horovod.tensorflow.keras as hvd

            assert "batch_size" in config, "batch_size must be set in config"
            config["batch_size"] = config["batch_size"] // hvd.size()
            dataset = data_creator(config)
            from tensorflow.python.distribute.input_ops import auto_shard_dataset
            dataset = auto_shard_dataset(dataset, hvd.size(), hvd.rank())

        elif self.backend == "tf-distributed":
            with self.strategy.scope():
                dataset = data_creator(config)
        else:
            dataset = data_creator(config)

        if self.backend == "horovod":
            import horovod.tensorflow.keras as hvd
            if hvd.rank() != 0:
                verbose = 0
        elif self.backend == "tf-distributed":
            if self.strategy.cluster_resolver.task_id != 0:
                verbose = 0

        params = dict(
            verbose=verbose,
            sample_weight=sample_weight,
            steps=steps,
            callbacks=callbacks,
        )
        results = self.model.evaluate(dataset, **params)
        if results is None:
            # Using local Model since model.evaluate() returns None
            # for MultiWorkerMirroredStrategy
            logger.warning("Running a local model to get validation score.")
            self.local_model = self.model_creator(self.config)
            self.local_model.set_weights(self.model.get_weights())
            results = self.local_model.evaluate(dataset, **params)

        if isinstance(results, list):
            stats = {
                "validation_" + k: v
                for k, v in zip(self.model.metrics_names, results)
            }
        else:
            stats = {"results": results}

        return stats
Esempio n. 13
0
    def testFixedLengthReaderWithFlatMap(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createFixedLengthRecordFiles())
        dataset = dataset.flat_map(
            lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
Esempio n. 14
0
  def DISABLED_testFixedLengthReaderWithFlatMap(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createFixedLengthRecordFiles())
    dataset = dataset.flat_map(
        lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes))
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
Esempio n. 15
0
  def DISABLED_testZip(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())
    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
    self._verifySimpleShardingOutput(dataset, record_fn)
Esempio n. 16
0
    def testZip(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
        self._verifySimpleShardingOutput(dataset, record_fn)
Esempio n. 17
0
  def testInterleave(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createTFRecordFiles())
    dataset = dataset.interleave(
        readers.TFRecordDataset, cycle_length=4, block_length=self._num_records)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    # Since block_length == num records in each file, the output will still
    # contain records in order of files.
    self._verifySimpleShardingOutput(dataset, self._record)
Esempio n. 18
0
  def testInterleave(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createTFRecordFiles())
    dataset = dataset.interleave(
        readers.TFRecordDataset, cycle_length=4, block_length=self._num_records)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    # Since block_length == num records in each file, the output will still
    # contain records in order of files.
    self._verifySimpleShardingOutput(dataset, self._record)
Esempio n. 19
0
    def testZip(self):
        src1 = dataset_ops.Dataset.from_tensor_slices(
            self._createTFRecordFiles())
        dataset1 = src1.flat_map(readers.TFRecordDataset)
        src2 = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
        dataset2 = src2.flat_map(readers.TextLineDataset)

        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
        self._verifySimpleShardingOutput(dataset, record_fn)
Esempio n. 20
0
  def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs):
    """Make an iterator for the dataset on given devices.

    If `split_batch_by` is not None, we "split" each batch of the
    dataset by `split_batch_by` value. To achieve this, we first unbatch the
    input dataset and then rebatch it with the per replica batch size that is
    calculated using `global_batch_size // split_batch_by`.
    The currently supported datasets are as follows:
    `dataset.batch()` is the last operation on the dataset OR
    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.

    We clone and shard the dataset on each worker. The current setup tries to
    shard the dataset by files if possible so that each worker sees a different
    subset of files. If that is not possible, will attempt to shard the final
    input such that each worker will run the entire preprocessing pipeline and
    only receive its own shard of the dataset.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      **kwargs: Additional experimental flags. Will be removed in future.
    """
    assert isinstance(input_workers, InputWorkers)
    if split_batch_by:
      dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access

    iterators = []
    for i, worker in enumerate(input_workers.worker_devices):
      with ops.device(worker):
        worker_devices = input_workers.compute_devices_for_worker(i)
        cloned_dataset = dataset
        if not context.executing_eagerly():
          cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
          cloned_dataset = cloned_dataset.with_options(dataset.options())
        # TODO(b/129506833): Figure out between graph cases
        cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
            cloned_dataset, len(input_workers.worker_devices), i)
        iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
                                                worker_devices)
        iterators.append(iterator)

    self._element_structure = dataset._element_structure  # pylint: disable=protected-access

    super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
Esempio n. 21
0
  def testListfiles(self):
    filenames = self._createTFRecordFiles()
    file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
    dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    next_element_fn = self._getNext(dataset)
    actual, expected = [], []
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        actual.append(self.evaluate(next_element_fn()))
        expected.append(self._record(r, f))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element_fn())
    self.assertAllEqual(expected, actual)
Esempio n. 22
0
  def testListfiles(self):
    filenames = self._createTFRecordFiles()
    file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt"
    dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    next_element_fn = self._getNext(dataset)
    actual, expected = [], []
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        actual.append(self.evaluate(next_element_fn()))
        expected.append(self._record(r, f))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element_fn())
    self.assertAllEqual(expected, actual)
Esempio n. 23
0
  def testConcat(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())
    dataset = dataset1.concatenate(dataset2)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with self.cached_session() as sess:
      for f in range(self._shard_index, self._num_files, self._num_shards):
        for r in range(self._num_records):
          self.assertAllEqual(self._record(r, f), sess.run(next_element))
      for f in range(self._shard_index, self._num_files, self._num_shards):
        for r in range(self._num_records):
          self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)
Esempio n. 24
0
  def testListfiles(self):
    filenames = self._createTFRecordFiles()
    file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
    dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with self.cached_session() as sess:
      actual, expected = [], []
      for f in range(self._shard_index, self._num_files, self._num_shards):
        for r in range(self._num_records):
          actual.append(sess.run(next_element))
          expected.append(self._record(r, f))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)
      self.assertAllEqual(expected, actual)
Esempio n. 25
0
  def testConcat(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())

    dataset = dataset1.concatenate(dataset2)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    next_element_fn = self._getNext(dataset)
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._record(r, f), self.evaluate(next_element_fn()))
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._text_line(r, f), self.evaluate(next_element_fn()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element_fn())
Esempio n. 26
0
  def testConcat(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())

    dataset = dataset1.concatenate(dataset2)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    next_element_fn = self._getNext(dataset)
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._record(r, f), self.evaluate(next_element_fn()))
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._text_line(r, f), self.evaluate(next_element_fn()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element_fn())
Esempio n. 27
0
    def testListfiles(self):
        filenames = self._createTFRecordFiles()
        file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
        dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            actual, expected = [], []
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    actual.append(sess.run(next_element))
                    expected.append(self._record(r, f))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
            self.assertAllEqual(expected, actual)
Esempio n. 28
0
    def testConcat(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset1.concatenate(dataset2)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._record(r, f),
                                        sess.run(next_element))
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._text_line(r, f),
                                        sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
Esempio n. 29
0
    def testComplexPipeline(self):
        # Setup a complex input pipeline.
        batch_size = 2
        num_epochs = 5
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createTFRecordFiles())
        dataset = dataset.shuffle(buffer_size=self._num_files)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = dataset.prefetch(buffer_size=batch_size)
        dataset = dataset.shuffle(2 * self._num_files * self._num_records)
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.map(lambda x: x)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(buffer_size=None)

        # Auto shard.
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        # Verify output.
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            actual = []
            num_iterations = (self._num_files * self._num_records *
                              num_epochs) // (self._num_shards * batch_size)
            for _ in range(num_iterations):
                actual.extend(sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

            expected = []
            for f in range(0, self._num_files, self._num_shards):
                for r in range(self._num_records):
                    expected.append(self._record(r, f))
            expected *= num_epochs

            self.assertAllEqual(sorted(expected), sorted(actual))
Esempio n. 30
0
  def DISABLED_testComplexPipeline(self):
    # Setup a complex input pipeline.
    batch_size = 2
    num_epochs = 5
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createTFRecordFiles())
    dataset = dataset.shuffle(buffer_size=self._num_files)
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = dataset.prefetch(buffer_size=batch_size)
    dataset = dataset.shuffle(2 * self._num_files * self._num_records)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.map(lambda x: x)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=None)

    # Auto shard.
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    # Verify output.
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with self.cached_session():
      actual = []
      num_iterations = (self._num_files * self._num_records * num_epochs) // (
          self._num_shards * batch_size)
      for _ in range(num_iterations):
        actual.extend(self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)

      expected = []
      for f in range(0, self._num_files, self._num_shards):
        for r in range(self._num_records):
          expected.append(self._record(r, f))
      expected *= num_epochs

      self.assertAllEqual(sorted(expected), sorted(actual))
Esempio n. 31
0
 def _handle_sharding(self, dataset):
     from tensorflow.python.distribute.input_ops import auto_shard_dataset
     dataset = auto_shard_dataset(dataset, self.size, self.rank)
     return dataset
Esempio n. 32
0
    def step(self, data_creator, epochs=1, verbose=1,
             callbacks=None, validation_data_creator=None, class_weight=None,
             steps_per_epoch=None, validation_steps=None, validation_freq=1):
        """Runs a training epoch and updates the model parameters."""

        # process datasets
        if self.backend == "horovod":
            import horovod.tensorflow.keras as hvd
            config = self.config.copy()
            assert "batch_size" in config, "batch_size must be set in config"
            config["batch_size"] = config["batch_size"] // hvd.size()
            train_dataset = data_creator(config)
            if validation_data_creator is not None:
                test_dataset = validation_data_creator(config)
            else:
                test_dataset = None
            from tensorflow.python.distribute.input_ops import auto_shard_dataset
            train_dataset = auto_shard_dataset(train_dataset, hvd.size(), hvd.rank())
            if test_dataset is not None:
                test_dataset = auto_shard_dataset(test_dataset, hvd.size(), hvd.rank())
        elif self.backend == "tf-distributed":
            with self.strategy.scope():
                train_dataset = data_creator(self.config)
                if validation_data_creator is not None:
                    test_dataset = validation_data_creator(self.config)
                else:
                    test_dataset = None
        else:
            train_dataset = data_creator(self.config)
            if validation_data_creator is not None:
                test_dataset = validation_data_creator(self.config)
            else:
                test_dataset = None

        # process other arguments
        if self.backend == "horovod":
            import horovod.tensorflow.keras as hvd
            hvd_callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0),
                             hvd.callbacks.MetricAverageCallback()]
            if hvd.rank() != 0:
                verbose = 0

            if callbacks is not None:
                callbacks = hvd_callbacks + callbacks
            else:
                callbacks = hvd_callbacks
        elif self.backend == "tf-distributed":
            if self.strategy.cluster_resolver.task_id != 0:
                verbose = 0

        history = self.model.fit(train_dataset,
                                 epochs=self.epoch + epochs,
                                 verbose=verbose,
                                 callbacks=callbacks,
                                 validation_data=test_dataset,
                                 class_weight=class_weight,
                                 initial_epoch=self.epoch,
                                 steps_per_epoch=steps_per_epoch,
                                 validation_steps=validation_steps,
                                 validation_freq=validation_freq)
        if history is None:
            stats = {}
        else:
            stats = {"train_" + k: v[-1] for k, v in history.history.items()}

        self.epoch += epochs
        return stats
Esempio n. 33
0
  def __init__(self,
               dataset,
               input_workers,
               strategy,
               split_batch_by=None,
               input_context=None):
    """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
    """
    super(DistributedDataset, self).__init__(input_workers=input_workers)

    # We clone and shard the dataset on each worker. The current setup tries to
    # shard the dataset by files if possible so that each worker sees a
    # different subset of files. If that is not possible, will attempt to shard
    # the final input such that each worker will run the entire preprocessing
    # pipeline and only receive its own shard of the dataset.
    if split_batch_by:
      try:
        # pylint: disable=protected-access
        with ops.colocate_with(dataset._variant_tensor):
          dataset = distribute._RebatchDataset(dataset, split_batch_by)
          # Add a prefetch to pipeline rebatching for performance.
          # TODO(rachelim): Instead of inserting an extra prefetch stage here,
          # leverage static graph rewrites to insert _RebatchDataset before
          # the final `prefetch` if it exists.
          dataset = dataset.prefetch(split_batch_by)
      except errors.InvalidArgumentError as e:
        if "without encountering a batch" in str(e):
          six.reraise(
              ValueError,
              ValueError(
                  "Call the `batch` method on the input Dataset in order to be "
                  "able to split your input across {} replicas.\n Please "
                  "the tf.distribute.Strategy guide. {}".format(
                      split_batch_by, e)),
              sys.exc_info()[2])
        else:
          raise

    # TODO(b/138745411): Remove once stateful transformations are supported.
    options = dataset_ops.Options()
    options.experimental_distribute._make_stateless = True  # pylint: disable=protected-access
    dataset = dataset.with_options(options)

    self._cloned_datasets = []
    if input_context:
      # Between-graph where we rely on the input_context for sharding
      assert input_workers.num_workers == 1
      dataset = input_ops.auto_shard_dataset(dataset,
                                             input_context.num_input_pipelines,
                                             input_context.input_pipeline_id)
      self._cloned_datasets.append(dataset)
    else:
      replicated_ds = distribute.replicate(dataset,
                                           input_workers.worker_devices)
      for i, worker in enumerate(input_workers.worker_devices):
        with ops.device(worker):
          cloned_dataset = replicated_ds[worker]
          cloned_dataset = cloned_dataset.with_options(dataset.options())
          cloned_dataset = input_ops.auto_shard_dataset(
              cloned_dataset, len(input_workers.worker_devices), i)
          self._cloned_datasets.append(cloned_dataset)

    self._input_workers = input_workers
    self._strategy = strategy
    self._element_spec = _create_distributed_tensor_spec(self._strategy,
                                                         dataset.element_spec)  # pylint: disable=protected-access
Esempio n. 34
0
    def testTFRecordDataset(self):
        dataset = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._record)
Esempio n. 35
0
    def __init__(self,
                 dataset,
                 input_workers,
                 strategy,
                 split_batch_by=None,
                 input_context=None):
        """Distribute the dataset on all workers.

    If `split_batch_by` is not None, we "split" each batch of the dataset by
    `split_batch_by` value.

    Args:
      dataset: `tf.data.Dataset` that will be used as the input source.
      input_workers: an `InputWorkers` object.
      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
        handle last partial batch.
      split_batch_by: Optional integer. If present, we "split" each batch of the
        dataset by `split_batch_by` value.
      input_context: `InputContext` for sharding. Only pass this in for between
        graph multi-worker cases where there is only one `input_worker`. In
        these cases, we will shard based on the `input_pipeline_id` and
        `num_input_pipelines` in the `InputContext`.
    """
        super(DistributedDataset, self).__init__(input_workers=input_workers)

        # We clone and shard the dataset on each worker. The current setup tries to
        # shard the dataset by files if possible so that each worker sees a
        # different subset of files. If that is not possible, will attempt to shard
        # the final input such that each worker will run the entire preprocessing
        # pipeline and only receive its own shard of the dataset.
        if split_batch_by:
            try:
                dataset = distribute._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access
            except errors.InvalidArgumentError as e:
                if "without encountering a batch" in str(e):
                    six.reraise(
                        ValueError,
                        ValueError(
                            "Call the `batch` method on the input Dataset in order to be "
                            "able to split your input across {} replicas.\n Please "
                            "the tf.distribute.Strategy guide. {}".format(
                                split_batch_by, e)),
                        sys.exc_info()[2])
                else:
                    raise

        self._cloned_datasets = []
        if input_context:
            # Between-graph where we rely on the input_context for sharding
            assert input_workers.num_workers == 1
            dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                dataset, input_context.num_input_pipelines,
                input_context.input_pipeline_id)
            self._cloned_datasets.append(dataset)
        else:
            for i, worker in enumerate(input_workers.worker_devices):
                with ops.device(worker):
                    cloned_dataset = dataset
                    if not context.executing_eagerly():
                        cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
                        cloned_dataset = cloned_dataset.with_options(
                            dataset.options())
                    # TODO(b/129506833): Figure out between graph cases
                    cloned_dataset = input_ops.auto_shard_dataset(  # pylint: disable=protected-access
                        cloned_dataset, len(input_workers.worker_devices), i)
                    self._cloned_datasets.append(cloned_dataset)

        self._input_workers = input_workers
        # TODO(anjalisridhar): Identify if we need to set this property on the
        # iterator.
        self._element_structure = dataset._element_structure  # pylint: disable=protected-access
        self._strategy = strategy
Esempio n. 36
0
  def testTFRecordDataset(self):
    dataset = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._record)