def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
    ds_iterator_handle = ds_iterator.string_handle()

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, ds.output_types, ds.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn.get_concrete_function(),
          output_types=[dtypes.float32],
          target_device=target,
          string_arg=ds_iterator_handle,
          buffer_size=3,
          shared_name=buffer_name)

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.float32])
      reset_op = prefetching_ops.function_buffering_resource_reset(
          function_buffer_resource=buffer_resource_handle)
      destroy_op = resource_variable_ops.destroy_resource_op(
          buffer_resource_handle, ignore_lookup_error=True)

    return (prefetch_op, reset_op, destroy_op)
Exemplo n.º 2
0
    def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
        ds_iterator_handle = ds_iterator.string_handle()

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, ds.output_types, ds.output_shapes)
            return remote_iterator.get_next()

        target = constant_op.constant(device0)
        with ops.device(device1):
            buffer_resource_handle = prefetching_ops.function_buffering_resource(
                f=_remote_fn.get_concrete_function(),
                output_types=[dtypes.float32],
                target_device=target,
                string_arg=ds_iterator_handle,
                buffer_size=3,
                shared_name=buffer_name)

        with ops.device(device1):
            prefetch_op = prefetching_ops.function_buffering_resource_get_next(
                function_buffer_resource=buffer_resource_handle,
                output_types=[dtypes.float32])
            reset_op = prefetching_ops.function_buffering_resource_reset(
                function_buffer_resource=buffer_resource_handle)
            destroy_op = resource_variable_ops.destroy_resource_op(
                buffer_resource_handle, ignore_lookup_error=True)

        return (prefetch_op, reset_op, destroy_op)
Exemplo n.º 3
0
    def testStringsGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        device0 = "/job:localhost/replica:0/task:0/cpu:0"
        device1 = "/job:localhost/replica:0/task:0/gpu:0"

        ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
        ds_iterator = ds.make_one_shot_iterator()
        ds_iterator_handle = ds_iterator.string_handle()

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, ds.output_types, ds.output_shapes)
            return remote_iterator.get_next()

        target = constant_op.constant(device0)
        with ops.device(device1):
            buffer_resource_handle = prefetching_ops.function_buffering_resource(
                f=_remote_fn.get_concrete_function(),
                output_types=[dtypes.string],
                target_device=target,
                string_arg=ds_iterator_handle,
                buffer_size=3,
                shared_name="strings")

        with ops.device(device1):
            prefetch_op = prefetching_ops.function_buffering_resource_get_next(
                function_buffer_resource=buffer_resource_handle,
                output_types=[dtypes.string])
            destroy_op = resource_variable_ops.destroy_resource_op(
                buffer_resource_handle, ignore_lookup_error=True)

        with self.cached_session() as sess:
            self.assertEqual([b"a"], self.evaluate(prefetch_op))
            self.assertEqual([b"b"], self.evaluate(prefetch_op))
            self.assertEqual([b"c"], self.evaluate(prefetch_op))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(prefetch_op)

            self.evaluate(destroy_op)
  def testStringsGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    device0 = "/job:localhost/replica:0/task:0/cpu:0"
    device1 = "/job:localhost/replica:0/task:0/gpu:0"

    ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
    ds_iterator = ds.make_one_shot_iterator()
    ds_iterator_handle = ds_iterator.string_handle()

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, ds.output_types, ds.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn.get_concrete_function(),
          output_types=[dtypes.string],
          target_device=target,
          string_arg=ds_iterator_handle,
          buffer_size=3,
          shared_name="strings")

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.string])
      destroy_op = resource_variable_ops.destroy_resource_op(
          buffer_resource_handle, ignore_lookup_error=True)

    with self.cached_session() as sess:
      self.assertEqual([b"a"], self.evaluate(prefetch_op))
      self.assertEqual([b"b"], self.evaluate(prefetch_op))
      self.assertEqual([b"c"], self.evaluate(prefetch_op))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(prefetch_op)

      self.evaluate(destroy_op)
Exemplo n.º 5
0
def get_inputs_and_labels(function_buffering_resource, data_type):
    """Given a FunctionBufferingResource obtains images and labels from it."""
    return prefetching_ops.function_buffering_resource_get_next(
        function_buffer_resource=function_buffering_resource,
        output_types=[data_type, tf.int32])