def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     return input_lib.InputFunctionIterator(input_fn, self._input_workers,
                                            [distribute_lib.InputContext()],
                                            self._container_strategy())
Example #2
0
    def _create_iterator(self, input_type, dataset_fn, worker_device_pairs,
                         devices, split_batch_by, enable_get_next_as_optional):
        device_map = values.ReplicaDeviceMap(devices)
        input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)

        if input_type == "input_fn":
            input_contexts = []
            for i in range(input_workers.num_workers):
                input_contexts.append(
                    distribute_lib.InputContext(
                        num_input_pipelines=input_workers.num_workers,
                        input_pipeline_id=i,
                        num_replicas_in_sync=len(devices)))

            iterator = input_lib.InputFunctionIterator(
                dataset_fn,
                input_workers,
                input_contexts,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        else:
            iterator = input_lib.DatasetIterator(
                dataset_fn(distribute_lib.InputContext()),
                input_workers,
                split_batch_by,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        return iterator
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     """Distributes the input function to each local GPU."""
     input_context = self._make_input_context()
     return input_lib.InputFunctionIterator(input_fn, self._input_workers,
                                            [input_context])
Example #4
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   input_contexts = []
   num_workers = self._input_workers.num_workers
   for i in range(num_workers):
     input_contexts.append(distribute_lib.InputContext(
         num_input_pipelines=num_workers,
         input_pipeline_id=i,
         num_replicas_in_sync=self._num_replicas_in_sync))
   return input_lib.InputFunctionIterator(
       input_fn, self._input_workers, input_contexts)
    def _test_iterator(self,
                       input_type,
                       dataset_fn,
                       worker_device_pairs,
                       expected_values,
                       sess=None,
                       split_batch_by=None):
        devices = nest.flatten([ds for _, ds in worker_device_pairs])
        device_map = values.ReplicaDeviceMap(devices)
        input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)

        if input_type == "input_fn":
            input_contexts = [
                distribute_lib.InputContext() for _ in worker_device_pairs
            ]
            input_fn = lambda _: dataset_fn()
            iterator = input_lib.InputFunctionIterator(input_fn, input_workers,
                                                       input_contexts)
        else:
            iterator = input_lib.DatasetIterator(dataset_fn(), input_workers,
                                                 split_batch_by)

        evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)

        evaluate(control_flow_ops.group(iterator.initialize()))

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate([
                values.select_replica(r, next_element)
                for r in range(len(devices))
            ])
            self.assertAllEqual(expected_value, computed_value)

        with self.assertRaises(errors.OutOfRangeError):
            next_element = iterator.get_next()
            evaluate([
                values.select_replica(r, next_element)
                for r in range(len(devices))
            ])

        # After re-initializing the iterator, should be able to iterate again.
        evaluate(control_flow_ops.group(iterator.initialize()))

        for expected_value in expected_values:
            next_element = iterator.get_next()
            computed_value = evaluate([
                values.select_replica(r, next_element)
                for r in range(len(devices))
            ])
            self.assertAllEqual(expected_value, computed_value)
Example #6
0
    def _make_input_fn_iterator(
            self,
            input_fn,
            replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
        """Distributes the dataset to each local GPU."""
        if self._cluster_spec is None:
            input_pipeline_id = 0
        else:
            input_pipeline_id = multi_worker_util.id_in_cluster(
                self._cluster_spec, self._task_type, self._task_id)
        input_context = distribute_lib.InputContext(
            num_input_pipelines=self._num_workers,
            input_pipeline_id=input_pipeline_id,
            num_replicas_in_sync=self._num_replicas_in_sync)

        return input_lib.InputFunctionIterator(input_fn, self._input_workers,
                                               [input_context])
Example #7
0
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     input_contexts = []
     input_workers = input_lib.InputWorkers(
         tuple(self._device_input_worker_devices.items()))
     num_workers = input_workers.num_workers
     for i in range(num_workers):
         input_contexts.append(
             distribute_lib.InputContext(
                 num_input_pipelines=num_workers,
                 input_pipeline_id=i,
                 num_replicas_in_sync=self._num_replicas_in_sync))
     return input_lib.InputFunctionIterator(input_fn, input_workers,
                                            input_contexts,
                                            self._container_strategy())
Example #8
0
    def _wrap_iterator(self,
                       input_type,
                       dataset_fn,
                       input_workers,
                       devices,
                       split_batch_by,
                       enable_get_next_as_optional,
                       strategy,
                       input_context=None):
        # The `input_context` passed in is to shard dataset for
        # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where
        # multiple InputContexts are needed.
        if input_type == "input_fn":
            self.assertIsNone(
                input_context,
                msg=
                ("`The input_context` arg is only used to shard dataset in "
                 "`MultiWorkerMirroredStrategy` when the input type is dataset."
                 ))

            input_contexts = []
            for i in range(input_workers.num_workers):
                input_contexts.append(
                    distribute_lib.InputContext(
                        # Note: `input_workers.num_workers` is always 1 in between-graph
                        # case.
                        num_input_pipelines=input_workers.num_workers,
                        input_pipeline_id=i,
                        num_replicas_in_sync=len(devices)))

            iterator = input_lib.InputFunctionIterator(
                dataset_fn,
                input_workers,
                input_contexts,
                strategy,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        else:
            iterator = input_lib.DatasetIterator(
                dataset_fn(input_context),
                input_workers,
                strategy,
                split_batch_by=split_batch_by,
                input_context=input_context,
                _enable_get_next_as_optional=enable_get_next_as_optional)
        return iterator