def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): return values.InputFunctionIterator( input_fn, [("/job:localhost", [self._device])], [distribute_lib.InputContext()])
def _test_iterator(self, input_fn, worker_device_pairs, expected_values, sess=None): devices = nest.flatten([ds for _, ds in worker_device_pairs]) iterator = values.InputFunctionIterator(input_fn, worker_device_pairs) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([values.select_device(d, next_element) for d in devices]) # After re-initializing the iterator, should be able to iterate again. evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, computed_value)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): worker = device_util.canonicalize("/device:CPU:0") worker_device_pairs = [(worker, [self._device])] return values.InputFunctionIterator(input_fn, worker_device_pairs, [distribute_lib.InputContext()])
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): input_contexts = [] num_workers = self._input_workers.num_workers for i in range(num_workers): input_contexts.append(distribute_lib.InputContext( num_input_pipelines=num_workers, input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) return values.InputFunctionIterator( input_fn, self._input_workers, input_contexts)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec is None: input_pipeline_id = 0 else: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return values.InputFunctionIterator(input_fn, self._input_workers, [input_context])
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): input_contexts = [] if self._cluster_spec: num_workers = len(self._worker_devices) worker_device_pairs = self._worker_devices else: num_workers = 1 worker_device_pairs = [("/job:localhost", self._devices)] for i in range(num_workers): input_contexts.append( distribute_lib.InputContext( num_input_pipelines=num_workers, input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) return values.InputFunctionIterator(input_fn, worker_device_pairs, input_contexts)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) num_input_pipelines = multi_worker_util.worker_count( self._cluster_spec, self._task_type) else: input_pipeline_id = 0 num_input_pipelines = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) worker_device_pairs = [(self._worker_device, self._compute_devices)] return values.InputFunctionIterator(input_fn, worker_device_pairs, [input_context])
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): input_contexts = [] if self._local_mode: num_workers = 1 worker = device_util.canonicalize("/device:CPU:0") worker_device_pairs = [(worker, self._devices)] else: num_workers = len(self._worker_devices) worker_device_pairs = self._worker_devices for i in range(num_workers): input_contexts.append(distribute_lib.InputContext( num_input_pipelines=num_workers, input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) return values.InputFunctionIterator( input_fn, worker_device_pairs, input_contexts)
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, expected_values, sess=None, split_batch_by=None): devices = nest.flatten([ds for _, ds in worker_device_pairs]) device_map = values.ReplicaDeviceMap(devices) input_workers = values.InputWorkers(device_map, worker_device_pairs) if input_type == "input_fn": input_contexts = [ distribute_lib.InputContext() for _ in worker_device_pairs] input_fn = lambda _: dataset_fn() iterator = values.InputFunctionIterator( input_fn, input_workers, input_contexts) else: iterator = values.DatasetIterator( dataset_fn(), input_workers, split_batch_by) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertAllEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. evaluate(control_flow_ops.group(iterator.initialize())) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertAllEqual(expected_value, computed_value)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): return values.InputFunctionIterator(input_fn, self._input_workers, [distribute_lib.InputContext()])