예제 #1
0
def build_prefetch_image_processing(height, width, batch_size, num_splits,
                                    preprocess_fn, cpu_device, params,
                                    gpu_devices, dataset):
    """"Returns FunctionBufferingResources that do image pre(processing)."""
    with tf.device(cpu_device):
        if params.eval:
            subset = 'validation'
        else:
            subset = 'train'

        function_buffering_resources = []
        remote_fn, args = minibatch_fn(
            height=height,
            width=width,
            batch_size=batch_size,
            num_splits=num_splits,
            preprocess_fn=preprocess_fn,
            dataset=dataset,
            subset=subset,
            train=(not params.eval),
            cache_data=params.cache_data,
            num_threads=params.datasets_num_private_threads)
        for device_num in range(len(gpu_devices)):
            with tf.device(gpu_devices[device_num]):
                buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    f=remote_fn,
                    target_device=cpu_device,
                    string_arg=args[0],
                    buffer_size=params.datasets_prefetch_buffer_size,
                    thread_pool_size=params.datasets_prefetch_threadpool_size,
                    shared_name=None)
                function_buffering_resources.append(buffer_resource_handle)
        return function_buffering_resources
예제 #2
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      TypeError: If `dataset` is an unsupported type.
      RuntimeError: When invoked without eager execution enabled.
    """
        if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset):  # pylint: disable=protected-access
            raise TypeError(
                "`tf.contrib.data.prefetch_to_device()` is not compatible with "
                "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate "
                "over the dataset instead.")

        super(Iterator, self).__init__(dataset)
        if not context.context().device_spec.device_type:
            is_remote_device = False
        else:
            is_remote_device = context.context(
            ).device_spec.device_type != "CPU"
        self._buffer_resource_handle = None
        if is_remote_device:
            with ops.device("/device:CPU:0"):
                iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
                    self._resource)

                @function.Defun(dtypes.string)
                def remote_fn(h):
                    remote_iterator = iterator_ops.Iterator.from_string_handle(
                        h, self.output_types, self.output_shapes,
                        self.output_classes)
                    return remote_iterator.get_next()

                remote_fn.add_to_graph(None)
                target = constant_op.constant("/device:CPU:0")
            with ops.device(self._device):
                self._buffer_resource_handle = prefetching_ops.function_buffering_resource(  # pylint: disable=line-too-long
                    string_arg=iter_string_handle,
                    output_types=self._flat_output_types,
                    f=remote_fn,
                    target_device=target,
                    buffer_size=10,
                    container="",
                    shared_name=_generate_shared_name(
                        "contrib_eager_iterator_function_buffer_resource"))
                self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(  # pylint: disable=line-too-long
                    handle=self._buffer_resource_handle,
                    handle_device=self._device)
예제 #3
0
  def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
    ds_iterator_handle = ds_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, ds.output_types, ds.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn,
          output_types=[dtypes.float32],
          target_device=target,
          string_arg=ds_iterator_handle,
          buffer_size=3,
          shared_name=buffer_name)

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.float32])
      reset_op = prefetching_ops.function_buffering_resource_reset(
          function_buffer_resource=buffer_resource_handle)
      destroy_op = resource_variable_ops.destroy_resource_op(
          buffer_resource_handle, ignore_lookup_error=True)

    return (prefetch_op, reset_op, destroy_op)
  def __init__(self, input_dataset, devices, buffer_size):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._devices = devices
    input_iterator = input_dataset.make_one_shot_iterator()
    input_iterator_handle = input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, input_iterator.output_types, input_iterator.output_shapes,
          input_iterator.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    target_device = gen_dataset_ops.iterator_get_device(
        input_iterator._iterator_resource)
    self._buffering_resources = []
    for device in nest.flatten(self._devices):
      with ops.device(device):
        buffer_resource_handle = prefetching_ops.function_buffering_resource(
            f=_prefetch_fn,
            target_device=target_device,
            string_arg=input_iterator_handle,
            buffer_size=buffer_size)
        self._buffering_resources.append(buffer_resource_handle)
    def __init__(self,
                 input_dataset,
                 one_shot,
                 devices,
                 buffer_size,
                 shared_name=None):
        self._input_dataset = input_dataset
        self._get_next_call_count = 0
        self._one_shot = one_shot
        if shared_name is None:
            shared_name = ""
        self._devices = devices

        if self._one_shot:
            self._input_iterator = input_dataset.make_one_shot_iterator()
        else:
            self._input_iterator = iterator_ops.Iterator.from_structure(
                self._input_dataset.output_types,
                self._input_dataset.output_shapes, shared_name,
                self._input_dataset.output_classes)
        input_iterator_handle = self._input_iterator.string_handle()

        @function.Defun(dtypes.string)
        def _prefetch_fn(handle):
            """Prefetches one element from `input_iterator`."""
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, self._input_iterator.output_types,
                self._input_iterator.output_shapes,
                self._input_iterator.output_classes)
            ret = remote_iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))

        target_device = gen_dataset_ops.iterator_get_device(
            self._input_iterator._iterator_resource)
        self._buffering_resources = []
        for device in nest.flatten(self._devices):
            with ops.device(device):
                buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    f=_prefetch_fn,
                    output_types=data_nest.flatten(
                        sparse.as_dense_types(
                            self._input_dataset.output_types,
                            self._input_dataset.output_classes)),
                    target_device=target_device,
                    string_arg=input_iterator_handle,
                    buffer_size=buffer_size,
                    shared_name=shared_name)
                self._buffering_resources.append(buffer_resource_handle)

        if not self._one_shot:
            reset_ops = []
            for buffer_resource in self._buffering_resources:
                reset_ops.append(
                    prefetching_ops.function_buffering_resource_reset(
                        buffer_resource))
            with ops.control_dependencies(reset_ops):
                self._initializer = self._input_iterator.make_initializer(
                    self._input_dataset)
예제 #6
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      TypeError: If `dataset` is an unsupported type.
      RuntimeError: When invoked without eager execution enabled.
    """
    if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset):  # pylint: disable=protected-access
      raise TypeError(
          "`tf.contrib.data.prefetch_to_device()` is not compatible with "
          "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate "
          "over the dataset instead.")

    super(Iterator, self).__init__(dataset)
    if not context.context().device_spec.device_type:
      is_remote_device = False
    else:
      is_remote_device = context.context().device_spec.device_type != "CPU"
    self._buffer_resource_handle = None
    if is_remote_device:
      with ops.device("/device:CPU:0"):
        iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
            self._resource)

        @function.Defun(dtypes.string)
        def remote_fn(h):
          remote_iterator = iterator_ops.Iterator.from_string_handle(
              h, self.output_types, self.output_shapes, self.output_classes)
          return remote_iterator.get_next()

        remote_fn.add_to_graph(None)
        target = constant_op.constant("/device:CPU:0")
      with ops.device(self._device):
        self._buffer_resource_handle = prefetching_ops.function_buffering_resource(  # pylint: disable=line-too-long
            string_arg=iter_string_handle,
            output_types=self._flat_output_types,
            f=remote_fn,
            target_device=target,
            buffer_size=10,
            container="",
            shared_name=_generate_shared_name(
                "contrib_eager_iterator_function_buffer_resource"))
        self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(  # pylint: disable=line-too-long
            handle=self._buffer_resource_handle,
            handle_device=self._device)
예제 #7
0
  def __init__(self,
               input_dataset,
               one_shot,
               devices,
               buffer_size,
               shared_name=None):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._one_shot = one_shot
    if shared_name is None:
      shared_name = ""
    self._devices = devices

    if self._one_shot:
      self._input_iterator = input_dataset.make_one_shot_iterator()
    else:
      self._input_iterator = iterator_ops.Iterator.from_structure(
          self._input_dataset.output_types, self._input_dataset.output_shapes,
          shared_name, self._input_dataset.output_classes)
    input_iterator_handle = self._input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self._input_iterator.output_types,
          self._input_iterator.output_shapes,
          self._input_iterator.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    target_device = ged_ops.experimental_iterator_get_device(
        self._input_iterator._iterator_resource)
    self._buffering_resources = []
    for device in nest.flatten(self._devices):
      with ops.device(device):
        buffer_resource_handle = prefetching_ops.function_buffering_resource(
            f=_prefetch_fn,
            output_types=data_nest.flatten(
                sparse.as_dense_types(self._input_dataset.output_types,
                                      self._input_dataset.output_classes)),
            target_device=target_device,
            string_arg=input_iterator_handle,
            buffer_size=buffer_size,
            shared_name=shared_name)
        self._buffering_resources.append(buffer_resource_handle)

    if not self._one_shot:
      reset_ops = []
      for buffer_resource in self._buffering_resources:
        reset_ops.append(
            ged_ops.experimental_function_buffering_resource_reset(
                buffer_resource))
      with ops.control_dependencies(reset_ops):
        self._initializer = self._input_iterator.make_initializer(
            self._input_dataset)
예제 #8
0
    def _prefetch_fn_helper(self, buffer_name, device0, device1):
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 2

        def gen():
            for i in itertools.count(start=1, step=1):
                yield [i + 0.0]
                if i == 6:
                    self._event.set()

        with ops.device(device0):
            dataset_3 = dataset_ops.Dataset.from_generator(
                gen, (dtypes.float32))
            iterator_3 = dataset_3.make_one_shot_iterator()
            iterator_3_handle = iterator_3.string_handle()

        @function.Defun(dtypes.string)
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, dataset_3.output_types, dataset_3.output_shapes)
            return remote_iterator.get_next()

        target = constant_op.constant(device0)
        with ops.device(device1):
            buffer_resource_handle = prefetching_ops.function_buffering_resource(
                f=_remote_fn,
                target_device=target,
                string_arg=iterator_3_handle,
                buffer_size=3,
                thread_pool_size=2,
                shared_name=buffer_name)

        with ops.device(device1):
            prefetch_op = prefetching_ops.function_buffering_resource_get_next(
                function_buffer_resource=buffer_resource_handle,
                output_types=[dtypes.float32])

        with self.test_session(config=worker_config) as sess:
            elem = sess.run(prefetch_op)
            self.assertEqual(elem, [1.0])
            elem = sess.run(prefetch_op)
            self.assertEqual(elem, [2.0])
            elem = sess.run(prefetch_op)
            self.assertEqual(elem, [3.0])
            elem = sess.run(prefetch_op)
            self.assertEqual(elem, [4.0])
            self._event.wait()
            elem = sess.run(prefetch_op)
            self.assertEqual(elem, [5.0])
            sess.run(
                resource_variable_ops.destroy_resource_op(
                    buffer_resource_handle, ignore_lookup_error=True))
  def _prefetch_fn_helper(self, buffer_name, device0, device1):
    worker_config = config_pb2.ConfigProto()
    worker_config.device_count["CPU"] = 2

    def gen():
      for i in itertools.count(start=1, step=1):
        yield [i + 0.0]
        if i == 6:
          self._event.set()

    with ops.device(device0):
      dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
      iterator_3 = dataset_3.make_one_shot_iterator()
      iterator_3_handle = iterator_3.string_handle()

    @function.Defun(dtypes.string)
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, dataset_3.output_types, dataset_3.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn,
          target_device=target,
          string_arg=iterator_3_handle,
          buffer_size=3,
          thread_pool_size=2,
          shared_name=buffer_name)

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.float32])

    with self.test_session(config=worker_config) as sess:
      elem = sess.run(prefetch_op)
      self.assertEqual(elem, [1.0])
      elem = sess.run(prefetch_op)
      self.assertEqual(elem, [2.0])
      elem = sess.run(prefetch_op)
      self.assertEqual(elem, [3.0])
      elem = sess.run(prefetch_op)
      self.assertEqual(elem, [4.0])
      self._event.wait()
      elem = sess.run(prefetch_op)
      self.assertEqual(elem, [5.0])
      sess.run(
          resource_variable_ops.destroy_resource_op(
              buffer_resource_handle, ignore_lookup_error=True))
예제 #10
0
  def testStringsGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    device0 = "/job:localhost/replica:0/task:0/cpu:0"
    device1 = "/job:localhost/replica:0/task:0/gpu:0"

    ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
    ds_iterator = ds.make_one_shot_iterator()
    ds_iterator_handle = ds_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, ds.output_types, ds.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn,
          output_types=[dtypes.string],
          target_device=target,
          string_arg=ds_iterator_handle,
          buffer_size=3,
          shared_name="strings")

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.string])
      destroy_op = resource_variable_ops.destroy_resource_op(
          buffer_resource_handle, ignore_lookup_error=True)

    with self.cached_session() as sess:
      self.assertEqual([b"a"], sess.run(prefetch_op))
      self.assertEqual([b"b"], sess.run(prefetch_op))
      self.assertEqual([b"c"], sess.run(prefetch_op))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(prefetch_op)

      sess.run(destroy_op)
예제 #11
0
    def __init__(self, dataset):
        """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

        if not context.in_eager_mode():
            raise RuntimeError(
                "{} objects can only be used when eager execution is enabled, use "
                "tf.data.Dataset.make_iterator or "
                "tf.data.Dataset.make_one_shot_iterator for graph construction"
                .format(type(self)))
        with ops.device("/device:CPU:0"):
            ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
            self._output_types = dataset.output_types
            self._output_shapes = dataset.output_shapes
            self._flat_output_types = nest.flatten(dataset.output_types)
            self._flat_output_shapes = nest.flatten(dataset.output_shapes)
            self._resource = gen_dataset_ops.iterator(
                container="",
                shared_name=_generate_shared_name("eager_iterator"),
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)
            gen_dataset_ops.make_iterator(ds_variant, self._resource)
            # Delete the resource when this object is deleted
            self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._resource, handle_device="/device:CPU:0")
        self._device = context.context().device_name
        self._buffer_resource_handle = None
        if not context.context().device_spec.device_type:
            is_remote_device = False
        else:
            is_remote_device = context.context(
            ).device_spec.device_type != "CPU"
        if is_remote_device:
            with ops.device("/device:CPU:0"):
                iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
                    self._resource)

                @function.Defun(dtypes.string)
                def remote_fn(h):
                    remote_iterator = iterator_ops.Iterator.from_string_handle(
                        h, self._output_types, self._output_shapes)
                    return remote_iterator.get_next()

                remote_fn.add_to_graph(None)
                target = constant_op.constant("/device:CPU:0")
            with ops.device(self._device):
                self._buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    string_arg=iter_string_handle,
                    f=remote_fn,
                    target_device=target,
                    buffer_size=10,
                    thread_pool_size=1,
                    container="",
                    shared_name=_generate_shared_name(
                        "function_buffer_resource"))
                self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._buffer_resource_handle,
                    handle_device=self._device)
예제 #12
0
  def __init__(self, dataset):
    """Creates a new iterator over the given dataset.

    For example:
    ```python
    dataset = tf.data.Dataset.range(4)
    for x in Iterator(dataset):
      print(x)
    ```

    Tensors produced will be placed on the device on which this iterator object
    was created.

    Args:
      dataset: A `tf.data.Dataset` object.

    Raises:
      RuntimeError: When invoked without eager execution enabled.
    """

    if not context.in_eager_mode():
      raise RuntimeError(
          "{} objects can only be used when eager execution is enabled, use "
          "tf.data.Dataset.make_iterator or "
          "tf.data.Dataset.make_one_shot_iterator for graph construction".
          format(type(self)))
    with ops.device("/device:CPU:0"):
      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes
      self._flat_output_types = nest.flatten(dataset.output_types)
      self._flat_output_shapes = nest.flatten(dataset.output_shapes)
      self._resource = gen_dataset_ops.iterator(
          container="",
          shared_name=_generate_shared_name("eager_iterator"),
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)
      gen_dataset_ops.make_iterator(ds_variant, self._resource)
      # Delete the resource when this object is deleted
      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._resource, handle_device="/device:CPU:0")
    self._device = context.context().device_name
    self._buffer_resource_handle = None
    if not context.context().device_spec.device_type:
      is_remote_device = False
    else:
      is_remote_device = context.context().device_spec.device_type != "CPU"
    if is_remote_device:
      with ops.device("/device:CPU:0"):
        iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
            self._resource)

        @function.Defun(dtypes.string)
        def remote_fn(h):
          remote_iterator = iterator_ops.Iterator.from_string_handle(
              h, self._output_types, self._output_shapes)
          return remote_iterator.get_next()

        remote_fn.add_to_graph(None)
        target = constant_op.constant("/device:CPU:0")
      with ops.device(self._device):
        self._buffer_resource_handle = prefetching_ops.function_buffering_resource(
            string_arg=iter_string_handle,
            f=remote_fn,
            target_device=target,
            buffer_size=10,
            thread_pool_size=1,
            container="",
            shared_name=_generate_shared_name("function_buffer_resource"))
        self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(
            handle=self._buffer_resource_handle, handle_device=self._device)