def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None, **kwargs): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = distribute._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._strategy = strategy self._kwargs = kwargs
def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None, **kwargs): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = distribute._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._strategy = strategy self._kwargs = kwargs
def step(self, data_creator, epochs=1, verbose=1, callbacks=None, validation_data_creator=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1): """Runs a training epoch and updates the model parameters.""" train_dataset = data_creator(self.config) if validation_data_creator is not None: test_dataset = validation_data_creator(self.config) else: test_dataset = None if self.backend == "horovod": import horovod.tensorflow.keras as hvd from tensorflow.python.distribute.input_ops import auto_shard_dataset train_dataset = auto_shard_dataset(train_dataset, hvd.size(), hvd.rank()) if test_dataset is not None: test_dataset = auto_shard_dataset(test_dataset, hvd.size(), hvd.rank()) if self.backend == "horovod": import horovod.tensorflow.keras as hvd hvd_callbacks = [ hvd.callbacks.BroadcastGlobalVariablesCallback(0), hvd.callbacks.MetricAverageCallback() ] if hvd.rank() != 0: verbose = 0 if callbacks is not None: callbacks = hvd_callbacks + callbacks else: callbacks = hvd_callbacks history = self.model.fit(train_dataset, epochs=self.epoch + epochs, verbose=verbose, callbacks=callbacks, validation_data=test_dataset, class_weight=class_weight, initial_epoch=self.epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq) if history is None: stats = {} else: stats = {"train_" + k: v[-1] for k, v in history.history.items()} self.epoch += epochs return stats
def testTextLineReader(self): dataset = readers.TextLineDataset(self._createTextFiles()) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._text_line)
def testFixedLengthReaderWithFlatMap(self): dataset = readers.FixedLengthRecordDataset( self._createFixedLengthRecordFiles(), self._record_bytes) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
def testFixedLengthReader(self): dataset = readers.FixedLengthRecordDataset( self._createFixedLengthRecordFiles(), self._record_bytes) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
def DISABLED_testTextLineReaderWithFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles()) dataset = dataset.flat_map(readers.TextLineDataset) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._text_line)
def testTextLineReader(self): dataset = readers.TextLineDataset(self._createTextFiles()) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._text_line)
def testFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._record)
def testFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._record)
def testTextLineReaderWithFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createTextFiles()) dataset = dataset.flat_map(readers.TextLineDataset) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._text_line)
def validate(self, data_creator, verbose=1, sample_weight=None, steps=None, callbacks=None, data_config=None): """Evaluates the model on the validation data set.""" config = self.config.copy() if data_config is not None: config.update(data_config) if self.backend == "horovod": import horovod.tensorflow.keras as hvd assert "batch_size" in config, "batch_size must be set in config" config["batch_size"] = config["batch_size"] // hvd.size() dataset = data_creator(config) from tensorflow.python.distribute.input_ops import auto_shard_dataset dataset = auto_shard_dataset(dataset, hvd.size(), hvd.rank()) elif self.backend == "tf-distributed": with self.strategy.scope(): dataset = data_creator(config) else: dataset = data_creator(config) if self.backend == "horovod": import horovod.tensorflow.keras as hvd if hvd.rank() != 0: verbose = 0 elif self.backend == "tf-distributed": if self.strategy.cluster_resolver.task_id != 0: verbose = 0 params = dict( verbose=verbose, sample_weight=sample_weight, steps=steps, callbacks=callbacks, ) results = self.model.evaluate(dataset, **params) if results is None: # Using local Model since model.evaluate() returns None # for MultiWorkerMirroredStrategy logger.warning("Running a local model to get validation score.") self.local_model = self.model_creator(self.config) self.local_model.set_weights(self.model.get_weights()) results = self.local_model.evaluate(dataset, **params) if isinstance(results, list): stats = { "validation_" + k: v for k, v in zip(self.model.metrics_names, results) } else: stats = {"results": results} return stats
def testFixedLengthReaderWithFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createFixedLengthRecordFiles()) dataset = dataset.flat_map( lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes)) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
def DISABLED_testFixedLengthReaderWithFlatMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createFixedLengthRecordFiles()) dataset = dataset.flat_map( lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes)) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
def DISABLED_testZip(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) self._verifySimpleShardingOutput(dataset, record_fn)
def testZip(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) self._verifySimpleShardingOutput(dataset, record_fn)
def testInterleave(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.interleave( readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) # Since block_length == num records in each file, the output will still # contain records in order of files. self._verifySimpleShardingOutput(dataset, self._record)
def testInterleave(self): dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.interleave( readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) # Since block_length == num records in each file, the output will still # contain records in order of files. self._verifySimpleShardingOutput(dataset, self._record)
def testZip(self): src1 = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset1 = src1.flat_map(readers.TFRecordDataset) src2 = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles()) dataset2 = src2.flat_map(readers.TextLineDataset) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) self._verifySimpleShardingOutput(dataset, record_fn)
def __init__(self, dataset, input_workers, split_batch_by=None, **kwargs): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. To achieve this, we first unbatch the input dataset and then rebatch it with the per replica batch size that is calculated using `global_batch_size // split_batch_by`. The currently supported datasets are as follows: `dataset.batch()` is the last operation on the dataset OR `dataset.apply(map_and_batch)` is the last operation on the dataset OR `dataset.batch().prefetch()` are the last 2 operations on the dataset OR `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. We clone and shard the dataset on each worker. The current setup tries to shard the dataset by files if possible so that each worker sees a different subset of files. If that is not possible, will attempt to shard the final input such that each worker will run the entire preprocessing pipeline and only receive its own shard of the dataset. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. **kwargs: Additional experimental flags. Will be removed in future. """ assert isinstance(input_workers, InputWorkers) if split_batch_by: dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access iterators = [] for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options(dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, worker_devices) iterators.append(iterator) self._element_structure = dataset._element_structure # pylint: disable=protected-access super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
def testListfiles(self): filenames = self._createTFRecordFiles() file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt" dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) next_element_fn = self._getNext(dataset) actual, expected = [], [] for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): actual.append(self.evaluate(next_element_fn())) expected.append(self._record(r, f)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element_fn()) self.assertAllEqual(expected, actual)
def testListfiles(self): filenames = self._createTFRecordFiles() file_pattern = filenames[0].rsplit(os.sep, 1)[0] + "/tf_record.*.txt" dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) next_element_fn = self._getNext(dataset) actual, expected = [], [] for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): actual.append(self.evaluate(next_element_fn())) expected.append(self._record(r, f)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element_fn()) self.assertAllEqual(expected, actual)
def testConcat(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset1.concatenate(dataset2) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(self._record(r, f), sess.run(next_element)) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(self._text_line(r, f), sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testListfiles(self): filenames = self._createTFRecordFiles() file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt" dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: actual, expected = [], [] for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): actual.append(sess.run(next_element)) expected.append(self._record(r, f)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) self.assertAllEqual(expected, actual)
def testConcat(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset1.concatenate(dataset2) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) next_element_fn = self._getNext(dataset) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual( self._record(r, f), self.evaluate(next_element_fn())) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual( self._text_line(r, f), self.evaluate(next_element_fn())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element_fn())
def testConcat(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset1.concatenate(dataset2) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) next_element_fn = self._getNext(dataset) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual( self._record(r, f), self.evaluate(next_element_fn())) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual( self._text_line(r, f), self.evaluate(next_element_fn())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element_fn())
def testListfiles(self): filenames = self._createTFRecordFiles() file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt" dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: actual, expected = [], [] for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): actual.append(sess.run(next_element)) expected.append(self._record(r, f)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) self.assertAllEqual(expected, actual)
def testConcat(self): dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) dataset2 = readers.TextLineDataset(self._createTextFiles()) dataset = dataset1.concatenate(dataset2) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(self._record(r, f), sess.run(next_element)) for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(self._text_line(r, f), sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testComplexPipeline(self): # Setup a complex input pipeline. batch_size = 2 num_epochs = 5 dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.shuffle(buffer_size=self._num_files) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.shuffle(2 * self._num_files * self._num_records) dataset = dataset.repeat(num_epochs) dataset = dataset.map(lambda x: x) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size=None) # Auto shard. dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) # Verify output. iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: actual = [] num_iterations = (self._num_files * self._num_records * num_epochs) // (self._num_shards * batch_size) for _ in range(num_iterations): actual.extend(sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) expected = [] for f in range(0, self._num_files, self._num_shards): for r in range(self._num_records): expected.append(self._record(r, f)) expected *= num_epochs self.assertAllEqual(sorted(expected), sorted(actual))
def DISABLED_testComplexPipeline(self): # Setup a complex input pipeline. batch_size = 2 num_epochs = 5 dataset = dataset_ops.Dataset.from_tensor_slices( self._createTFRecordFiles()) dataset = dataset.shuffle(buffer_size=self._num_files) dataset = dataset.flat_map(readers.TFRecordDataset) dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.shuffle(2 * self._num_files * self._num_records) dataset = dataset.repeat(num_epochs) dataset = dataset.map(lambda x: x) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size=None) # Auto shard. dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) # Verify output. iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session(): actual = [] num_iterations = (self._num_files * self._num_records * num_epochs) // ( self._num_shards * batch_size) for _ in range(num_iterations): actual.extend(self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) expected = [] for f in range(0, self._num_files, self._num_shards): for r in range(self._num_records): expected.append(self._record(r, f)) expected *= num_epochs self.assertAllEqual(sorted(expected), sorted(actual))
def _handle_sharding(self, dataset): from tensorflow.python.distribute.input_ops import auto_shard_dataset dataset = auto_shard_dataset(dataset, self.size, self.rank) return dataset
def step(self, data_creator, epochs=1, verbose=1, callbacks=None, validation_data_creator=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1): """Runs a training epoch and updates the model parameters.""" # process datasets if self.backend == "horovod": import horovod.tensorflow.keras as hvd config = self.config.copy() assert "batch_size" in config, "batch_size must be set in config" config["batch_size"] = config["batch_size"] // hvd.size() train_dataset = data_creator(config) if validation_data_creator is not None: test_dataset = validation_data_creator(config) else: test_dataset = None from tensorflow.python.distribute.input_ops import auto_shard_dataset train_dataset = auto_shard_dataset(train_dataset, hvd.size(), hvd.rank()) if test_dataset is not None: test_dataset = auto_shard_dataset(test_dataset, hvd.size(), hvd.rank()) elif self.backend == "tf-distributed": with self.strategy.scope(): train_dataset = data_creator(self.config) if validation_data_creator is not None: test_dataset = validation_data_creator(self.config) else: test_dataset = None else: train_dataset = data_creator(self.config) if validation_data_creator is not None: test_dataset = validation_data_creator(self.config) else: test_dataset = None # process other arguments if self.backend == "horovod": import horovod.tensorflow.keras as hvd hvd_callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0), hvd.callbacks.MetricAverageCallback()] if hvd.rank() != 0: verbose = 0 if callbacks is not None: callbacks = hvd_callbacks + callbacks else: callbacks = hvd_callbacks elif self.backend == "tf-distributed": if self.strategy.cluster_resolver.task_id != 0: verbose = 0 history = self.model.fit(train_dataset, epochs=self.epoch + epochs, verbose=verbose, callbacks=callbacks, validation_data=test_dataset, class_weight=class_weight, initial_epoch=self.epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq) if history is None: stats = {} else: stats = {"train_" + k: v[-1] for k, v in history.history.items()} self.epoch += epochs return stats
def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. """ super(DistributedDataset, self).__init__(input_workers=input_workers) # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. if split_batch_by: try: # pylint: disable=protected-access with ops.colocate_with(dataset._variant_tensor): dataset = distribute._RebatchDataset(dataset, split_batch_by) # Add a prefetch to pipeline rebatching for performance. # TODO(rachelim): Instead of inserting an extra prefetch stage here, # leverage static graph rewrites to insert _RebatchDataset before # the final `prefetch` if it exists. dataset = dataset.prefetch(split_batch_by) except errors.InvalidArgumentError as e: if "without encountering a batch" in str(e): six.reraise( ValueError, ValueError( "Call the `batch` method on the input Dataset in order to be " "able to split your input across {} replicas.\n Please " "the tf.distribute.Strategy guide. {}".format( split_batch_by, e)), sys.exc_info()[2]) else: raise # TODO(b/138745411): Remove once stateful transformations are supported. options = dataset_ops.Options() options.experimental_distribute._make_stateless = True # pylint: disable=protected-access dataset = dataset.with_options(options) self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset(dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: replicated_ds = distribute.replicate(dataset, input_workers.worker_devices) for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = replicated_ds[worker] cloned_dataset = cloned_dataset.with_options(dataset.options()) cloned_dataset = input_ops.auto_shard_dataset( cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers self._strategy = strategy self._element_spec = _create_distributed_tensor_spec(self._strategy, dataset.element_spec) # pylint: disable=protected-access
def testTFRecordDataset(self): dataset = readers.TFRecordDataset(self._createTFRecordFiles()) dataset = input_ops.auto_shard_dataset(dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._record)
def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. """ super(DistributedDataset, self).__init__(input_workers=input_workers) # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. if split_batch_by: try: dataset = distribute._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access except errors.InvalidArgumentError as e: if "without encountering a batch" in str(e): six.reraise( ValueError, ValueError( "Call the `batch` method on the input Dataset in order to be " "able to split your input across {} replicas.\n Please " "the tf.distribute.Strategy guide. {}".format( split_batch_by, e)), sys.exc_info()[2]) else: raise self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = dataset if not context.executing_eagerly(): cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access cloned_dataset = cloned_dataset.with_options( dataset.options()) # TODO(b/129506833): Figure out between graph cases cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers # TODO(anjalisridhar): Identify if we need to set this property on the # iterator. self._element_structure = dataset._element_structure # pylint: disable=protected-access self._strategy = strategy
def testTFRecordDataset(self): dataset = readers.TFRecordDataset(self._createTFRecordFiles()) dataset = input_ops.auto_shard_dataset( dataset, self._num_shards, self._shard_index) self._verifySimpleShardingOutput(dataset, self._record)