Пример #1
0
    def __init__(self, per_device_dataset, incarnation_id):
        # pylint: disable=protected-access
        self._element_spec = per_device_dataset.element_spec
        self._init_func = per_device_dataset._init_func
        self._init_captured_args = self._init_func.captured_inputs

        self._next_func = per_device_dataset._next_func
        self._next_captured_args = per_device_dataset._next_captured_args
        # The captured arguments to the next_func are string_handle, incarnation_id.
        # We update the incarnation id to the new one.
        self._next_captured_args[
            per_device_dataset._incarnation_id_index] = incarnation_id

        self._finalize_func = per_device_dataset._finalize_func
        self._finalize_captured_args = per_device_dataset._finalize_captured_args

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **self._flat_structure)
        super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
  def __init__(self, per_device_dataset, incarnation_id):
    # pylint: disable=protected-access
    self._structure = per_device_dataset._structure

    self._init_func = per_device_dataset._init_func
    self._init_captured_args = self._init_func.captured_inputs

    self._next_func = per_device_dataset._next_func
    self._next_captured_args = per_device_dataset._next_captured_args
    # The captured arguments to the next_func are string_handle, incarnation_id.
    # We update the incarnation id to the new one.
    self._next_captured_args[
        per_device_dataset._incarnation_id_index] = incarnation_id

    self._finalize_func = per_device_dataset._finalize_func
    self._finalize_captured_args = per_device_dataset._finalize_captured_args

    variant_tensor = gen_dataset_ops.generator_dataset(
        self._init_captured_args,
        self._next_captured_args,
        self._finalize_captured_args,
        init_func=self._init_func,
        next_func=self._next_func,
        finalize_func=self._finalize_func,
        **dataset_ops.flat_structure(self))
    super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
Пример #3
0
 def _as_variant_tensor(self):
     with ops.device(self._target_device):
         return gen_dataset_ops.generator_dataset(
             self._init_captured_args,
             self._next_captured_args,
             self._finalize_captured_args,
             init_func=self._init_func,
             next_func=self._next_func,
             finalize_func=self._finalize_func,
             **dataset_ops.flat_structure(self))
Пример #4
0
 def _as_variant_tensor(self):
     with ops.device(self._target_device):
         return gen_dataset_ops.generator_dataset(
             self._init_captured_args,
             self._next_captured_args,
             self._finalize_captured_args,
             init_func=self._init_func,
             next_func=self._next_func,
             finalize_func=self._finalize_func,
             output_types=self._flat_output_types,
             output_shapes=self._flat_output_shapes)
Пример #5
0
 def _as_variant_tensor(self):
   with ops.device(self._target_device):
     return gen_dataset_ops.generator_dataset(
         self._init_captured_args,
         self._next_captured_args,
         self._finalize_captured_args,
         init_func=self._init_func,
         next_func=self._next_func,
         finalize_func=self._finalize_func,
         output_types=self._flat_output_types,
         output_shapes=self._flat_output_shapes)
Пример #6
0
  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
    """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
    self._input_dataset = input_dataset._apply_debug_options()  # pylint: disable=protected-access
    self._target_device = target_device
    spec = framework_device.DeviceSpec().from_string(self._target_device)
    self._is_gpu_target = (spec.device_type == "GPU")
    self._source_device_string = source_device
    self._source_device = ops.convert_to_tensor(source_device)

    wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
        self._input_dataset._variant_tensor)  # pylint: disable=protected-access

    @function.defun()
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
      resource = gen_dataset_ops.anonymous_iterator(
          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)

    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=self._source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle,
            dataset_ops.get_legacy_output_types(self),
            dataset_ops.get_legacy_output_shapes(self),
            dataset_ops.get_legacy_output_classes(self))
      return structure.to_tensor_list(self.element_spec, iterator.get_next())

    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun_with_attributes(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        attributes={"experimental_ints_on_device": True})
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] + next_func_concrete.captured_inputs,
          Tout=self._input_dataset._flat_types,  # pylint: disable=protected-access
          f=next_func_concrete)

    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(string_handle):
      """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
          string_handle,
          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
      with ops.control_dependencies([
          resource_variable_ops.destroy_resource_op(
              iterator_resource, ignore_lookup_error=True)]):
        return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] + finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
    )
    self._finalize_captured_args = self._finalize_func.captured_inputs

    g = ops.get_default_graph()
    self._init_func.add_to_graph(g)
    self._next_func.add_to_graph(g)
    self._finalize_func.add_to_graph(g)
    # pylint: enable=protected-scope

    with ops.device(self._target_device):
      variant_tensor = gen_dataset_ops.generator_dataset(
          self._init_captured_args,
          self._next_captured_args,
          self._finalize_captured_args,
          init_func=self._init_func,
          next_func=self._next_func,
          finalize_func=self._finalize_func,
          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
    super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
Пример #7
0
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, element_spec):
        self._element_spec = element_spec

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _init_func():
            return multi_device_iterator_string_handle

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _remote_init_func():
            return functional_ops.remote_call(
                target=source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _next_func(string_handle):
            # pylint: disable=protected-access
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=structure.get_flat_tensor_types(
                        self._element_spec),
                    output_shapes=structure.get_flat_tensor_shapes(
                        self._element_spec)))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=structure.get_flat_tensor_types(
                    self._element_spec),
                output_shapes=structure.get_flat_tensor_shapes(
                    self._element_spec))

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True},
            autograph=False)  # Pure graph code.
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=structure.get_flat_tensor_types(self._element_spec),
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        self._incarnation_id_index = -1
        for i, arg in enumerate(self._next_captured_args):
            if arg is incarnation_id:
                self._incarnation_id_index = i

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = (
            _remote_finalize_func._get_concrete_function_internal())  # pylint: disable=protected-access
        self._finalize_captured_args = self._finalize_func.captured_inputs

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **self._flat_structure)
        super(_PerDeviceGenerator, self).__init__(variant_tensor)
Пример #8
0
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, element_structure):
        self._structure = element_structure

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        @function.defun()
        def _init_func():
            return multi_device_iterator_string_handle

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun()
        def _remote_init_func():
            return functional_ops.remote_call(
                target=source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _next_func(string_handle):
            # pylint: disable=protected-access
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=self._structure._flat_types,
                    output_shapes=self._structure._flat_shapes))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=self._structure._flat_types,
                output_shapes=self._structure._flat_shapes)

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True})
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=self._structure._flat_types,  # pylint: disable=protected-access
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
        )
        self._finalize_captured_args = self._finalize_func.captured_inputs

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **dataset_ops.flat_structure(self))
        super(_PerDeviceGenerator, self).__init__(variant_tensor)
  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
               source_device, element_structure):
    self._structure = element_structure

    multi_device_iterator_string_handle = (
        gen_dataset_ops.multi_device_iterator_to_string_handle(
            multi_device_iterator_resource))

    @function.defun()
    def _init_func():
      return multi_device_iterator_string_handle

    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      # pylint: disable=protected-access
      multi_device_iterator = (
          gen_dataset_ops.multi_device_iterator_from_string_handle(
              string_handle=string_handle,
              output_types=self._structure._flat_types,
              output_shapes=self._structure._flat_shapes))
      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
          multi_device_iterator=multi_device_iterator,
          shard_num=shard_num,
          incarnation_id=incarnation_id,
          output_types=self._structure._flat_types,
          output_shapes=self._structure._flat_shapes)

    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun_with_attributes(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        attributes={"experimental_ints_on_device": True})
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + next_func_concrete.captured_inputs,
          Tout=self._structure._flat_types,  # pylint: disable=protected-access
          f=next_func_concrete)

    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(unused_string_handle):
      return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
    )
    self._finalize_captured_args = self._finalize_func.captured_inputs

    variant_tensor = gen_dataset_ops.generator_dataset(
        self._init_captured_args,
        self._next_captured_args,
        self._finalize_captured_args,
        init_func=self._init_func,
        next_func=self._next_func,
        finalize_func=self._finalize_func,
        **dataset_ops.flat_structure(self))
    super(_PerDeviceGenerator, self).__init__(variant_tensor)
Пример #10
0
  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
    """Constructs a _CopyToDeviceDataset.

    Args:
      input_dataset: `Dataset` to be copied
      target_device: The name of the device to which elements would be copied.
      source_device: Device where input_dataset would be placed.
    """
    self._input_dataset = input_dataset
    self._target_device = target_device
    spec = framework_device.DeviceSpec().from_string(self._target_device)
    self._is_gpu_target = (spec.device_type == "GPU")
    self._source_device_string = source_device
    self._source_device = ops.convert_to_tensor(source_device)

    wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
        self._input_dataset._variant_tensor)  # pylint: disable=protected-access

    @function.defun()
    def _init_func():
      """Creates an iterator for the input dataset.

      Returns:
        A `string` tensor that encapsulates the iterator created.
      """
      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
      resource = gen_dataset_ops.anonymous_iterator(
          **dataset_ops.flat_structure(self._input_dataset))
      with ops.control_dependencies(
          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
        return gen_dataset_ops.iterator_to_string_handle(resource)

    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=self._source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle, self.output_types, self.output_shapes,
            self.output_classes)
      return self._element_structure._to_tensor_list(iterator.get_next())  # pylint: disable=protected-access

    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] +
          next_func_concrete.captured_inputs,
          Tout=self._input_dataset._element_structure._flat_types,  # pylint: disable=protected-access
          f=next_func_concrete)

    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(string_handle):
      """Destroys the iterator resource created.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        Tensor constant 0
      """
      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
          string_handle,
          **dataset_ops.flat_structure(self._input_dataset))
      with ops.control_dependencies([
          resource_variable_ops.destroy_resource_op(
              iterator_resource, ignore_lookup_error=True)]):
        return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=self._source_device,
          args=[string_handle] + finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
    )
    self._finalize_captured_args = self._finalize_func.captured_inputs

    g = ops.get_default_graph()
    self._init_func.add_to_graph(g)
    self._next_func.add_to_graph(g)
    self._finalize_func.add_to_graph(g)
    # pylint: enable=protected-scope

    with ops.device(self._target_device):
      variant_tensor = gen_dataset_ops.generator_dataset(
          self._init_captured_args,
          self._next_captured_args,
          self._finalize_captured_args,
          init_func=self._init_func,
          next_func=self._next_func,
          finalize_func=self._finalize_func,
          **dataset_ops.flat_structure(self._input_dataset))
    super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, element_spec,
                 iterator_is_anonymous):
        self._element_spec = element_spec

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _init_func():
            return multi_device_iterator_string_handle

        init_func_concrete = _init_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _remote_init_func():
            return functional_ops.remote_call(
                target=source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func.get_concrete_function()
        self._init_captured_args = self._init_func.captured_inputs

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _next_func(string_handle):
            # pylint: disable=protected-access
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=structure.get_flat_tensor_types(
                        self._element_spec),
                    output_shapes=structure.get_flat_tensor_shapes(
                        self._element_spec)))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=structure.get_flat_tensor_types(
                    self._element_spec),
                output_shapes=structure.get_flat_tensor_shapes(
                    self._element_spec))

        next_func_concrete = _next_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True},
            autograph=False)  # Pure graph code.
        def _remote_next_func(string_handle):
            return_values = functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=structure.get_flat_tensor_types(self._element_spec),
                f=next_func_concrete)
            # Add full type information to the graph so that the RemoteCall op
            # can determine for each of its outputs whether or not they are ragged
            # tensors (or other types that use variants) that contain strings
            # (or other host memory types). Then RemoteCall can
            # appropriately set AllocatorAttributes to control copies so
            # strings/host memory types stay on CPU.
            fulltype_list = type_utils.fulltypes_for_flat_tensors(
                self._element_spec)
            fulltype = type_utils.fulltype_list_to_product(fulltype_list)
            for return_value in return_values:
                return_value.op.experimental_set_type(fulltype)
            return return_values

        self._next_func = _remote_next_func.get_concrete_function()
        self._next_captured_args = self._next_func.captured_inputs

        if iterator_is_anonymous:
            self._next_captured_args = self._next_captured_args + [
                multi_device_iterator_resource
            ]

        self._incarnation_id_index = -1
        for i, arg in enumerate(self._next_captured_args):
            if arg is incarnation_id:
                self._incarnation_id_index = i

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func.get_concrete_function()
        self._finalize_captured_args = self._finalize_func.captured_inputs

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **self._flat_structure)
        super(_PerDeviceGenerator, self).__init__(variant_tensor)