コード例 #1
0
  def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
    ds_iterator_handle = ds_iterator.string_handle()

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, ds.output_types, ds.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn.get_concrete_function(),
          output_types=[dtypes.float32],
          target_device=target,
          string_arg=ds_iterator_handle,
          buffer_size=3,
          shared_name=buffer_name)

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.float32])
      reset_op = prefetching_ops.function_buffering_resource_reset(
          function_buffer_resource=buffer_resource_handle)
      destroy_op = resource_variable_ops.destroy_resource_op(
          buffer_resource_handle, ignore_lookup_error=True)

    return (prefetch_op, reset_op, destroy_op)
コード例 #2
0
    def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
        ds_iterator_handle = ds_iterator.string_handle()

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, ds.output_types, ds.output_shapes)
            return remote_iterator.get_next()

        target = constant_op.constant(device0)
        with ops.device(device1):
            buffer_resource_handle = prefetching_ops.function_buffering_resource(
                f=_remote_fn.get_concrete_function(),
                output_types=[dtypes.float32],
                target_device=target,
                string_arg=ds_iterator_handle,
                buffer_size=3,
                shared_name=buffer_name)

        with ops.device(device1):
            prefetch_op = prefetching_ops.function_buffering_resource_get_next(
                function_buffer_resource=buffer_resource_handle,
                output_types=[dtypes.float32])
            reset_op = prefetching_ops.function_buffering_resource_reset(
                function_buffer_resource=buffer_resource_handle)
            destroy_op = resource_variable_ops.destroy_resource_op(
                buffer_resource_handle, ignore_lookup_error=True)

        return (prefetch_op, reset_op, destroy_op)
コード例 #3
0
def build_prefetch_input_processing(batch_size, data_point_shape, num_splits,
                                    preprocess_fn, cpu_device, params,
                                    gpu_devices, data_type, dataset):
    """"Returns FunctionBufferingResources that do image pre(processing)."""
    with tf.device(cpu_device):
        if params.eval:
            subset = 'validation'
        else:
            subset = 'train'

        function_buffering_resources = []
        remote_fn, args = minibatch_fn(
            batch_size=batch_size,
            data_point_shape=data_point_shape,
            num_splits=num_splits,
            preprocess_fn=preprocess_fn,
            dataset=dataset,
            subset=subset,
            train=(not params.eval),
            cache_data=params.cache_data,
            num_threads=params.datasets_num_private_threads)
        for device_num in range(len(gpu_devices)):
            with tf.device(gpu_devices[device_num]):
                buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    f=remote_fn,
                    output_types=[data_type, tf.int32],
                    target_device=cpu_device,
                    string_arg=args[0],
                    buffer_size=params.datasets_prefetch_buffer_size,
                    shared_name=None)
                function_buffering_resources.append(buffer_resource_handle)
        return function_buffering_resources
コード例 #4
0
    def __init__(self,
                 input_dataset,
                 one_shot,
                 devices,
                 buffer_size,
                 shared_name=None):
        self._input_dataset = input_dataset
        self._get_next_call_count = 0
        self._one_shot = one_shot
        if shared_name is None:
            shared_name = ""
        self._devices = devices

        if self._one_shot:
            self._input_iterator = input_dataset.make_one_shot_iterator()
        else:
            self._input_iterator = iterator_ops.Iterator.from_structure(
                self._input_dataset.output_types,
                self._input_dataset.output_shapes, shared_name,
                self._input_dataset.output_classes)
        input_iterator_handle = self._input_iterator.string_handle()

        @function.Defun(dtypes.string)
        def _prefetch_fn(handle):
            """Prefetches one element from `input_iterator`."""
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, self._input_iterator.output_types,
                self._input_iterator.output_shapes,
                self._input_iterator.output_classes)
            ret = remote_iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))

        target_device = ged_ops.experimental_iterator_get_device(
            self._input_iterator._iterator_resource)
        self._buffering_resources = []
        for device in nest.flatten(self._devices):
            with ops.device(device):
                buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    f=_prefetch_fn,
                    output_types=data_nest.flatten(
                        sparse.as_dense_types(
                            self._input_dataset.output_types,
                            self._input_dataset.output_classes)),
                    target_device=target_device,
                    string_arg=input_iterator_handle,
                    buffer_size=buffer_size,
                    shared_name=shared_name)
                self._buffering_resources.append(buffer_resource_handle)

        if not self._one_shot:
            reset_ops = []
            for buffer_resource in self._buffering_resources:
                reset_ops.append(
                    ged_ops.experimental_function_buffering_resource_reset(
                        buffer_resource))
            with ops.control_dependencies(reset_ops):
                self._initializer = self._input_iterator.make_initializer(
                    self._input_dataset)
コード例 #5
0
    def testStringsGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        device0 = "/job:localhost/replica:0/task:0/cpu:0"
        device1 = "/job:localhost/replica:0/task:0/gpu:0"

        ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
        ds_iterator = ds.make_one_shot_iterator()
        ds_iterator_handle = ds_iterator.string_handle()

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, ds.output_types, ds.output_shapes)
            return remote_iterator.get_next()

        target = constant_op.constant(device0)
        with ops.device(device1):
            buffer_resource_handle = prefetching_ops.function_buffering_resource(
                f=_remote_fn.get_concrete_function(),
                output_types=[dtypes.string],
                target_device=target,
                string_arg=ds_iterator_handle,
                buffer_size=3,
                shared_name="strings")

        with ops.device(device1):
            prefetch_op = prefetching_ops.function_buffering_resource_get_next(
                function_buffer_resource=buffer_resource_handle,
                output_types=[dtypes.string])
            destroy_op = resource_variable_ops.destroy_resource_op(
                buffer_resource_handle, ignore_lookup_error=True)

        with self.cached_session() as sess:
            self.assertEqual([b"a"], self.evaluate(prefetch_op))
            self.assertEqual([b"b"], self.evaluate(prefetch_op))
            self.assertEqual([b"c"], self.evaluate(prefetch_op))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(prefetch_op)

            self.evaluate(destroy_op)
コード例 #6
0
  def testStringsGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    device0 = "/job:localhost/replica:0/task:0/cpu:0"
    device1 = "/job:localhost/replica:0/task:0/gpu:0"

    ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
    ds_iterator = ds.make_one_shot_iterator()
    ds_iterator_handle = ds_iterator.string_handle()

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, ds.output_types, ds.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn.get_concrete_function(),
          output_types=[dtypes.string],
          target_device=target,
          string_arg=ds_iterator_handle,
          buffer_size=3,
          shared_name="strings")

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.string])
      destroy_op = resource_variable_ops.destroy_resource_op(
          buffer_resource_handle, ignore_lookup_error=True)

    with self.cached_session() as sess:
      self.assertEqual([b"a"], self.evaluate(prefetch_op))
      self.assertEqual([b"b"], self.evaluate(prefetch_op))
      self.assertEqual([b"c"], self.evaluate(prefetch_op))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(prefetch_op)

      self.evaluate(destroy_op)
コード例 #7
0
  def build_prefetch_input_processing(self, batch_size, model_input_shapes,
                                      num_splits, cpu_device, params,
                                      gpu_devices, model_input_data_types,
                                      dataset):
    """"Returns FunctionBufferingResources that do input pre(processing)."""
    assert self.supports_datasets()
    with tf.device(cpu_device):
      if params.eval:
        subset = 'validation'
      else:
        subset = 'train'

      function_buffering_resources = []
      remote_fn, args = self.minibatch_fn(
          batch_size=batch_size,
          model_input_shapes=model_input_shapes,
          num_splits=num_splits,
          dataset=dataset,
          subset=subset,
          train=(not params.eval),
          datasets_repeat_cached_sample=params.datasets_repeat_cached_sample,
          num_threads=params.datasets_num_private_threads,
          datasets_use_caching=params.datasets_use_caching,
          datasets_parallel_interleave_cycle_length=(
              params.datasets_parallel_interleave_cycle_length),
          datasets_sloppy_parallel_interleave=(
              params.datasets_sloppy_parallel_interleave))
      for device_num in range(len(gpu_devices)):
        with tf.device(gpu_devices[device_num]):
          buffer_resource_handle = prefetching_ops.function_buffering_resource(
              f=remote_fn,
              output_types=model_input_data_types,
              target_device=cpu_device,
              string_arg=args[0],
              buffer_size=params.datasets_prefetch_buffer_size,
              shared_name=None)
          function_buffering_resources.append(buffer_resource_handle)
      return function_buffering_resources