def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" return input_lib.DatasetIterator( dataset, self._input_workers, self._container_strategy(), split_batch_by=self._num_replicas_in_sync)
def _create_iterator(self, input_type, dataset_fn, worker_device_pairs, devices, split_batch_by, enable_get_next_as_optional): device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": input_contexts = [] for i in range(input_workers.num_workers): input_contexts.append( distribute_lib.InputContext( num_input_pipelines=input_workers.num_workers, input_pipeline_id=i, num_replicas_in_sync=len(devices))) iterator = input_lib.InputFunctionIterator( dataset_fn, input_workers, input_contexts, _enable_get_next_as_optional=enable_get_next_as_optional) else: iterator = input_lib.DatasetIterator( dataset_fn(distribute_lib.InputContext()), input_workers, split_batch_by, _enable_get_next_as_optional=enable_get_next_as_optional) return iterator
def _make_dataset_iterator(self, dataset): """Distributes the dataset to each local GPU.""" input_context = self._make_input_context() return input_lib.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync, input_context=input_context)
def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" input_workers = input_lib.InputWorkers( tuple(self._device_input_worker_devices.items())) return input_lib.DatasetIterator( dataset, input_workers, self._container_strategy(), split_batch_by=self._num_replicas_in_sync)
def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch. This implementation is different than the one in `tf.distribute.MirroredStrategy` for purposes of backward compatibility. We treat the incoming dataset's batch size as per replica batch size. Args: dataset: `tf.data.Dataset` for input. Returns: An `InputIterator` which returns inputs for each step of the computation. """ return input_lib.DatasetIterator(dataset, self._input_workers)
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, expected_values, sess=None, split_batch_by=None): devices = nest.flatten([ds for _, ds in worker_device_pairs]) device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": input_contexts = [ distribute_lib.InputContext() for _ in worker_device_pairs ] input_fn = lambda _: dataset_fn() iterator = input_lib.InputFunctionIterator(input_fn, input_workers, input_contexts) else: iterator = input_lib.DatasetIterator(dataset_fn(), input_workers, split_batch_by) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ values.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertAllEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([ values.select_replica(r, next_element) for r in range(len(devices)) ]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate([ values.select_replica(r, next_element) for r in range(len(devices)) ]) self.assertAllEqual(expected_value, computed_value)
def _wrap_iterator(self, input_type, dataset_fn, input_workers, devices, split_batch_by, enable_get_next_as_optional, strategy, input_context=None): # The `input_context` passed in is to shard dataset for # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where # multiple InputContexts are needed. if input_type == "input_fn": self.assertIsNone( input_context, msg= ("`The input_context` arg is only used to shard dataset in " "`MultiWorkerMirroredStrategy` when the input type is dataset." )) input_contexts = [] for i in range(input_workers.num_workers): input_contexts.append( distribute_lib.InputContext( # Note: `input_workers.num_workers` is always 1 in between-graph # case. num_input_pipelines=input_workers.num_workers, input_pipeline_id=i, num_replicas_in_sync=len(devices))) iterator = input_lib.InputFunctionIterator( dataset_fn, input_workers, input_contexts, strategy, _enable_get_next_as_optional=enable_get_next_as_optional) else: iterator = input_lib.DatasetIterator( dataset_fn(input_context), input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context, _enable_get_next_as_optional=enable_get_next_as_optional) return iterator
def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator( dataset, self._input_workers, self._container_strategy(), num_replicas_in_sync=self._num_replicas_in_sync)
def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync)
def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" # Note that split_batch_by argument is not passed because it is always 1 in # this strategy, and adding it adds unnecessary overhead to the dataset. return input_lib.DatasetIterator(dataset, self._input_workers, self._container_strategy())
def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator(dataset, self._input_workers)
def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" return input_lib.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync)
def _make_dataset_iterator(self, dataset): """Make iterator from dataset without splitting the batch.""" return input_lib.DatasetIterator(dataset, self._input_workers)
def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator(dataset, self._input_workers, self._container_strategy())
def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" return input_lib.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync, _enable_get_next_as_optional=True)