Ejemplo n.º 1
0
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, element_spec):
        self._element_spec = element_spec

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _init_func():
            return multi_device_iterator_string_handle

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _remote_init_func():
            return functional_ops.remote_call(
                target=source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _next_func(string_handle):
            # pylint: disable=protected-access
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=structure.get_flat_tensor_types(
                        self._element_spec),
                    output_shapes=structure.get_flat_tensor_shapes(
                        self._element_spec)))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=structure.get_flat_tensor_types(
                    self._element_spec),
                output_shapes=structure.get_flat_tensor_shapes(
                    self._element_spec))

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True},
            autograph=False)  # Pure graph code.
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=structure.get_flat_tensor_types(self._element_spec),
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        self._incarnation_id_index = -1
        for i, arg in enumerate(self._next_captured_args):
            if arg is incarnation_id:
                self._incarnation_id_index = i

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = (
            _remote_finalize_func._get_concrete_function_internal())  # pylint: disable=protected-access
        self._finalize_captured_args = self._finalize_func.captured_inputs

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **self._flat_structure)
        super(_PerDeviceGenerator, self).__init__(variant_tensor)
Ejemplo n.º 2
0
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, target_device, output_shapes,
                 output_types, output_classes):
        self._target_device = target_device
        self._output_types = output_types
        self._output_shapes = output_shapes
        self._output_classes = output_classes
        self._flat_output_shapes = nest.flatten(
            sparse.as_dense_shapes(self._output_shapes, self._output_classes))
        self._flat_output_types = nest.flatten(
            sparse.as_dense_types(self._output_types, self._output_classes))

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        @function.Defun()
        def _init_func():
            return multi_device_iterator_string_handle

        @function.Defun()
        def _remote_init_func():
            return functional_ops.remote_call(target=source_device,
                                              args=_init_func.captured_inputs,
                                              Tout=[dtypes.string],
                                              f=_init_func)

        self._init_func = _remote_init_func
        self._init_captured_args = _remote_init_func.captured_inputs

        @function.Defun(dtypes.string)
        def _next_func(string_handle):
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=self._flat_output_types,
                output_shapes=self._flat_output_shapes)

        @function.Defun(dtypes.string, experimental_ints_on_device=True)
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(target=source_device,
                                              args=[string_handle] +
                                              _next_func.captured_inputs,
                                              Tout=self._flat_output_types,
                                              f=_next_func)

        self._next_func = _remote_next_func
        self._next_captured_args = _remote_next_func.captured_inputs

        @function.Defun(dtypes.string)
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        @function.Defun(dtypes.string)
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(target=source_device,
                                              args=[string_handle] +
                                              _finalize_func.captured_inputs,
                                              Tout=[dtypes.int64],
                                              f=_finalize_func)

        self._finalize_func = _remote_finalize_func
        self._finalize_captured_args = _remote_finalize_func.captured_inputs
Ejemplo n.º 3
0
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, element_structure):
        self._structure = element_structure

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        @function.defun()
        def _init_func():
            return multi_device_iterator_string_handle

        init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun()
        def _remote_init_func():
            return functional_ops.remote_call(
                target=source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._init_captured_args = self._init_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _next_func(string_handle):
            # pylint: disable=protected-access
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=self._structure._flat_types,
                    output_shapes=self._structure._flat_shapes))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=self._structure._flat_types,
                output_shapes=self._structure._flat_shapes)

        next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True})
        def _remote_next_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=self._structure._flat_types,  # pylint: disable=protected-access
                f=next_func_concrete)

        self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
        self._next_captured_args = self._next_func.captured_inputs

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func._get_concrete_function_internal(
        )  # pylint: disable=protected-access

        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
        )
        self._finalize_captured_args = self._finalize_func.captured_inputs

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **dataset_ops.flat_structure(self))
        super(_PerDeviceGenerator, self).__init__(variant_tensor)
  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
               source_device, target_device, output_shapes, output_types,
               output_classes):
    self._target_device = target_device
    self._output_types = output_types
    self._output_shapes = output_shapes
    self._output_classes = output_classes
    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._output_shapes, self._output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._output_types, self._output_classes))

    multi_device_iterator_string_handle = (
        gen_dataset_ops.multi_device_iterator_to_string_handle(
            multi_device_iterator_resource))

    @function.Defun()
    def _init_func():
      return multi_device_iterator_string_handle

    @function.Defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=source_device,
          args=_init_func.captured_inputs,
          Tout=[dtypes.string],
          f=_init_func)

    self._init_func = _remote_init_func
    self._init_captured_args = _remote_init_func.captured_inputs

    @function.Defun(dtypes.string)
    def _next_func(string_handle):
      multi_device_iterator = (
          gen_dataset_ops.multi_device_iterator_from_string_handle(
              string_handle=string_handle,
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))
      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
          multi_device_iterator=multi_device_iterator,
          shard_num=shard_num,
          incarnation_id=incarnation_id,
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)

    @function.Defun(dtypes.string, experimental_ints_on_device=True)
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + _next_func.captured_inputs,
          Tout=self._flat_output_types,
          f=_next_func)

    self._next_func = _remote_next_func
    self._next_captured_args = _remote_next_func.captured_inputs

    @function.Defun(dtypes.string)
    def _finalize_func(unused_string_handle):
      return array_ops.constant(0, dtypes.int64)

    @function.Defun(dtypes.string)
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + _finalize_func.captured_inputs,
          Tout=[dtypes.int64],
          f=_finalize_func)

    self._finalize_func = _remote_finalize_func
    self._finalize_captured_args = _remote_finalize_func.captured_inputs
  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
               source_device, element_structure):
    self._structure = element_structure

    multi_device_iterator_string_handle = (
        gen_dataset_ops.multi_device_iterator_to_string_handle(
            multi_device_iterator_resource))

    @function.defun()
    def _init_func():
      return multi_device_iterator_string_handle

    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      # pylint: disable=protected-access
      multi_device_iterator = (
          gen_dataset_ops.multi_device_iterator_from_string_handle(
              string_handle=string_handle,
              output_types=self._structure._flat_types,
              output_shapes=self._structure._flat_shapes))
      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
          multi_device_iterator=multi_device_iterator,
          shard_num=shard_num,
          incarnation_id=incarnation_id,
          output_types=self._structure._flat_types,
          output_shapes=self._structure._flat_shapes)

    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun_with_attributes(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        attributes={"experimental_ints_on_device": True})
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + next_func_concrete.captured_inputs,
          Tout=self._structure._flat_types,  # pylint: disable=protected-access
          f=next_func_concrete)

    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(unused_string_handle):
      return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
    )
    self._finalize_captured_args = self._finalize_func.captured_inputs

    variant_tensor = gen_dataset_ops.generator_dataset(
        self._init_captured_args,
        self._next_captured_args,
        self._finalize_captured_args,
        init_func=self._init_func,
        next_func=self._next_func,
        finalize_func=self._finalize_func,
        **dataset_ops.flat_structure(self))
    super(_PerDeviceGenerator, self).__init__(variant_tensor)
    def __init__(self, shard_num, multi_device_iterator_resource,
                 incarnation_id, source_device, element_spec,
                 iterator_is_anonymous):
        self._element_spec = element_spec

        multi_device_iterator_string_handle = (
            gen_dataset_ops.multi_device_iterator_to_string_handle(
                multi_device_iterator_resource))

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _init_func():
            return multi_device_iterator_string_handle

        init_func_concrete = _init_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(autograph=False)  # Pure graph code.
        def _remote_init_func():
            return functional_ops.remote_call(
                target=source_device,
                args=init_func_concrete.captured_inputs,
                Tout=[dtypes.string],
                f=init_func_concrete)

        self._init_func = _remote_init_func.get_concrete_function()
        self._init_captured_args = self._init_func.captured_inputs

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _next_func(string_handle):
            # pylint: disable=protected-access
            multi_device_iterator = (
                gen_dataset_ops.multi_device_iterator_from_string_handle(
                    string_handle=string_handle,
                    output_types=structure.get_flat_tensor_types(
                        self._element_spec),
                    output_shapes=structure.get_flat_tensor_shapes(
                        self._element_spec)))
            return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
                multi_device_iterator=multi_device_iterator,
                shard_num=shard_num,
                incarnation_id=incarnation_id,
                output_types=structure.get_flat_tensor_types(
                    self._element_spec),
                output_shapes=structure.get_flat_tensor_shapes(
                    self._element_spec))

        next_func_concrete = _next_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun_with_attributes(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            attributes={"experimental_ints_on_device": True},
            autograph=False)  # Pure graph code.
        def _remote_next_func(string_handle):
            return_values = functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + next_func_concrete.captured_inputs,
                Tout=structure.get_flat_tensor_types(self._element_spec),
                f=next_func_concrete)
            # Add full type information to the graph so that the RemoteCall op
            # can determine for each of its outputs whether or not they are ragged
            # tensors (or other types that use variants) that contain strings
            # (or other host memory types). Then RemoteCall can
            # appropriately set AllocatorAttributes to control copies so
            # strings/host memory types stay on CPU.
            fulltype_list = type_utils.fulltypes_for_flat_tensors(
                self._element_spec)
            fulltype = type_utils.fulltype_list_to_product(fulltype_list)
            for return_value in return_values:
                return_value.op.experimental_set_type(fulltype)
            return return_values

        self._next_func = _remote_next_func.get_concrete_function()
        self._next_captured_args = self._next_func.captured_inputs

        if iterator_is_anonymous:
            self._next_captured_args = self._next_captured_args + [
                multi_device_iterator_resource
            ]

        self._incarnation_id_index = -1
        for i, arg in enumerate(self._next_captured_args):
            if arg is incarnation_id:
                self._incarnation_id_index = i

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _finalize_func(unused_string_handle):
            return array_ops.constant(0, dtypes.int64)

        finalize_func_concrete = _finalize_func.get_concrete_function()

        # TODO(b/124254153): Enable autograph once the overhead is low enough.
        @function.defun(
            input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
            autograph=False)  # Pure graph code.
        def _remote_finalize_func(string_handle):
            return functional_ops.remote_call(
                target=source_device,
                args=[string_handle] + finalize_func_concrete.captured_inputs,
                Tout=[dtypes.int64],
                f=finalize_func_concrete)

        self._finalize_func = _remote_finalize_func.get_concrete_function()
        self._finalize_captured_args = self._finalize_func.captured_inputs

        variant_tensor = gen_dataset_ops.generator_dataset(
            self._init_captured_args,
            self._next_captured_args,
            self._finalize_captured_args,
            init_func=self._init_func,
            next_func=self._next_func,
            finalize_func=self._finalize_func,
            **self._flat_structure)
        super(_PerDeviceGenerator, self).__init__(variant_tensor)
  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
               source_device, target_device, output_shapes, output_types,
               output_classes):
    self._target_device = target_device
    self._output_types = output_types
    self._output_shapes = output_shapes
    self._output_classes = output_classes
    self._flat_output_shapes = nest.flatten(
        sparse.as_dense_shapes(self._output_shapes, self._output_classes))
    self._flat_output_types = nest.flatten(
        sparse.as_dense_types(self._output_types, self._output_classes))

    multi_device_iterator_string_handle = (
        gen_dataset_ops.multi_device_iterator_to_string_handle(
            multi_device_iterator_resource))

    @function.defun()
    def _init_func():
      return multi_device_iterator_string_handle

    init_func_concrete = _init_func.get_concrete_function()
    @function.defun()
    def _remote_init_func():
      return functional_ops.remote_call(
          target=source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func.get_concrete_function()
    self._init_captured_args = self._init_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _next_func(string_handle):
      multi_device_iterator = (
          gen_dataset_ops.multi_device_iterator_from_string_handle(
              string_handle=string_handle,
              output_types=self._flat_output_types,
              output_shapes=self._flat_output_shapes))
      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
          multi_device_iterator=multi_device_iterator,
          shard_num=shard_num,
          incarnation_id=incarnation_id,
          output_types=self._flat_output_types,
          output_shapes=self._flat_output_shapes)

    next_func_concrete = _next_func.get_concrete_function()
    @function.defun_with_attributes(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        attributes={"experimental_ints_on_device": True})
    def _remote_next_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] +
          next_func_concrete.captured_inputs,
          Tout=self._flat_output_types,
          f=next_func_concrete)

    self._next_func = _remote_next_func.get_concrete_function()
    self._next_captured_args = self._next_func.captured_inputs

    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _finalize_func(unused_string_handle):
      return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func.get_concrete_function()
    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] +
          finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func.get_concrete_function()
    self._finalize_captured_args = self._finalize_func.captured_inputs