def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): return values.PerReplicaDataset( self._call_dataset_fn(input_fn, distribute_lib.InputContext()), [self._device])
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): return values.InputFunctionIterator( input_fn, [("/job:localhost", [self._device])], [distribute_lib.InputContext()])
def testPerReplicaBatchSize(self): input_context = distribute_lib.InputContext(num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6) self.assertEqual(2, input_context.get_per_replica_batch_size(12)) with self.assertRaises(ValueError): input_context.get_per_replica_batch_size(13)
def testProperties(self): input_context = distribute_lib.InputContext(num_input_pipelines=2, input_pipeline_id=1, num_replicas_in_sync=6) self.assertEqual(6, input_context.num_replicas_in_sync) self.assertEqual(1, input_context.input_pipeline_id) self.assertEqual(2, input_context.num_input_pipelines)
def _test_iterator(self, input_fn, worker_device_pairs, expected_values, sess=None): devices = nest.flatten([ds for _, ds in worker_device_pairs]) input_contexts = [ distribute_lib.InputContext() for _ in worker_device_pairs ] iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, input_contexts) evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() evaluate([values.select_device(d, next_element) for d in devices]) # After re-initializing the iterator, should be able to iterate again. evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, computed_value)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): if self._cluster_spec: input_fns = [] for i in range(len(self._worker_devices)): input_context = distribute_lib.InputContext( num_input_pipelines=len(self._worker_devices), input_pipeline_id=i, num_replicas_in_sync=self.num_replicas_in_sync) input_fns.append( partial(self._call_dataset_fn, input_fn, input_context)) return values.MultiWorkerDataset(input_fns, self._worker_devices, self._auto_shard_dataset) else: input_context = distribute_lib.InputContext( num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=self.num_replicas_in_sync) return values.PerReplicaDataset( self._call_dataset_fn(input_fn, input_context), self._devices)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec is None: input_pipeline_id = 0 else: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return values.PerReplicaDataset( self._call_dataset_fn(input_fn, input_context), self._devices, True)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): input_contexts = [] if self._cluster_spec: num_workers = len(self._worker_devices) worker_device_pairs = self._worker_devices else: num_workers = 1 worker_device_pairs = [("/job:localhost", self._devices)] for i in range(num_workers): input_contexts.append( distribute_lib.InputContext( num_input_pipelines=num_workers, input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) return values.InputFunctionIterator(input_fn, worker_device_pairs, input_contexts)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) num_input_pipelines = multi_worker_util.worker_count( self._cluster_spec, self._task_type) else: input_pipeline_id = 0 num_input_pipelines = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) worker_device_pairs = [(self._worker_device, self._compute_devices)] return values.InputFunctionIterator(input_fn, worker_device_pairs, [input_context])