def testAllowStatefulOp(self): with compat.forward_compatibility_horizon(2019, 9, 12): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [], minval=1, maxval=10, dtype=dtypes.float32)) opt = dataset_ops.Options() opt.experimental_allow_stateful = True dataset0 = dataset0.with_options(opt) replicated_ds = distribute.replicate( dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) for _ in range(100): get_next0() get_next1() get_next2()
def testWhitelistStatefulOp(self): with compat.forward_compatibility_horizon(2019, 9, 12): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [], minval=1, maxval=10, dtype=dtypes.float32)) opt = dataset_ops.Options() opt.experimental_stateful_whitelist = ["RandomUniform"] dataset0 = dataset0.with_options(opt) replicated_ds = distribute.replicate( dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) with session.Session(self._target) as sess: for _ in range(100): sess.run(get_next0()) sess.run(get_next1()) sess.run(get_next2())
def testExternalStatePolicyIgnore(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [], minval=1, maxval=10, dtype=dtypes.float32)) opt = dataset_ops.Options() opt.experimental_external_state_policy = ( distribute_options.ExternalStatePolicy.IGNORE) dataset0 = dataset0.with_options(opt) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) for _ in range(100): self.evaluate(get_next0()) self.evaluate(get_next1()) self.evaluate(get_next2())
def testExternalStatePolicyFail(self): with compat.forward_compatibility_horizon(2019, 11, 30): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map( lambda _: random_ops.random_uniform( # pylint:disable=g-long-lambda [], minval=1, maxval=10, dtype=dtypes.float32)) opt = dataset_ops.Options() opt.experimental_external_state_policy = ( dataset_ops.ExternalStatePolicy.FAIL) dataset0 = dataset0.with_options(opt) with self.assertRaises(errors.FailedPreconditionError): replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) for _ in range(100): self.evaluate(get_next0()) self.evaluate(get_next1()) self.evaluate(get_next2())
def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable("counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] self.evaluate(counter_var.initializer) with ops.device(self._device1): self.assertDatasetProduces(dataset1, range(1, 101), requires_initialization=True) with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(1, 101), requires_initialization=True) # Iterate through the original device last so that replication happens # before counter_var is modified. The order only matters in graph mode. with ops.device(self._device0): self.assertDatasetProduces(dataset0, range(1, 101), requires_initialization=True)
def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) # We don't support stateful ops in functions as of now. with self.assertRaises(errors.FailedPreconditionError): replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) self.evaluate(replicated_ds[self._device1]._variant_tensor)
def testMap(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): self.assertDatasetProduces(dataset0, range(0, 200, 2)) with ops.device(self._device1): self.assertDatasetProduces(dataset1, range(0, 200, 2)) with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(0, 200, 2))
def testBasic(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): self.assertDatasetProduces(dataset0, range(100)) with ops.device(self._device1): self.assertDatasetProduces(dataset1, range(100)) with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(100))
def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) # We don't support stateful ops in functions as of now. with self.assertRaises(errors.InvalidArgumentError): replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] with ops.device(self._device1): self.assertDatasetProduces( dataset1, range(100), requires_initialization=True)
def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] with ops.device(self._device1): it1 = dataset_ops.make_initializable_iterator(dataset1) # We don't support stateful ops in functions as of now. with session.Session(self._target) as sess: with self.assertRaises(errors.FailedPreconditionError): sess.run(it1.initializer)
def testFromTensorSlicesWithDataset(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100) dataset0 = dataset_ops.Dataset.from_tensor_slices([dataset0]) dataset0 = dataset0.flat_map(lambda x: x) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): self.assertDatasetProduces(dataset0, range(100)) with ops.device(self._device1): self.assertDatasetProduces(dataset1, range(100)) with ops.device(self._device2): self.assertDatasetProduces(dataset2, range(100))
def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) # We don't support stateful ops across processes in functions as of now. with self.assertRaises(errors.InvalidArgumentError): replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) for _ in range(100): self.evaluate(get_next0()) self.evaluate(get_next1())
def testMap(self): with ops.device(self._device0): dataset0 = dataset_ops.Dataset.range(100).map(lambda x: x * 2) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) with session.Session(self._target) as sess: for i in range(100): self.assertEqual(i * 2, sess.run(get_next())) self.assertEqual(i * 2, sess.run(get_next1())) self.assertEqual(i * 2, sess.run(get_next2()))
def testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) replicated_ds = distribute.replicate(dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] self.evaluate(counter_var.initializer) with ops.device(self._device0): self.assertDatasetProduces( dataset0, range(1, 101), requires_initialization=True) with ops.device(self._device1): self.assertDatasetProduces( dataset1, range(101, 201), requires_initialization=True) with ops.device(self._device2): self.assertDatasetProduces( dataset2, range(201, 301), requires_initialization=True)
def _testVariableInput(self): with ops.device(self._device0): counter_var = variable_scope.get_variable("counter", (), dtypes.int32, use_resource=True) dataset0 = dataset_ops.Dataset.range(100).map( lambda _: counter_var.assign_add(1)) with self.assertRaises(errors.InvalidArgumentError): replicated_ds = distribute.replicate( dataset0, [self._device1, self._device2]) dataset1 = replicated_ds[self._device1] dataset2 = replicated_ds[self._device2] with ops.device(self._device0): get_next0 = self.getNext(dataset0) with ops.device(self._device1): get_next1 = self.getNext(dataset1) with ops.device(self._device2): get_next2 = self.getNext(dataset2) for _ in range(100): self.evaluate(get_next0()) self.evaluate(get_next1()) self.evaluate(get_next2())
def __init__(self, dataset, input_workers, strategy, split_batch_by=None, input_context=None): """Distribute the dataset on all workers. If `split_batch_by` is not None, we "split" each batch of the dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. split_batch_by: Optional integer. If present, we "split" each batch of the dataset by `split_batch_by` value. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and `num_input_pipelines` in the `InputContext`. """ super(DistributedDataset, self).__init__(input_workers=input_workers) # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. if split_batch_by: try: # pylint: disable=protected-access with ops.colocate_with(dataset._variant_tensor): dataset = distribute._RebatchDataset(dataset, split_batch_by) # Add a prefetch to pipeline rebatching for performance. # TODO(rachelim): Instead of inserting an extra prefetch stage here, # leverage static graph rewrites to insert _RebatchDataset before # the final `prefetch` if it exists. dataset = dataset.prefetch(split_batch_by) except errors.InvalidArgumentError as e: if "without encountering a batch" in str(e): six.reraise( ValueError, ValueError( "Call the `batch` method on the input Dataset in order to be " "able to split your input across {} replicas.\n Please " "the tf.distribute.Strategy guide. {}".format( split_batch_by, e)), sys.exc_info()[2]) else: raise # TODO(b/138745411): Remove once stateful transformations are supported. options = dataset_ops.Options() options.experimental_distribute._make_stateless = True # pylint: disable=protected-access dataset = dataset.with_options(options) self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 dataset = input_ops.auto_shard_dataset(dataset, input_context.num_input_pipelines, input_context.input_pipeline_id) self._cloned_datasets.append(dataset) else: replicated_ds = distribute.replicate(dataset, input_workers.worker_devices) for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): cloned_dataset = replicated_ds[worker] cloned_dataset = cloned_dataset.with_options(dataset.options()) cloned_dataset = input_ops.auto_shard_dataset( cloned_dataset, len(input_workers.worker_devices), i) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers self._strategy = strategy self._element_spec = _create_distributed_tensor_spec(self._strategy, dataset.element_spec) # pylint: disable=protected-access