def __init__(self, per_device_dataset, incarnation_id): # pylint: disable=protected-access self._element_spec = per_device_dataset.element_spec self._init_func = per_device_dataset._init_func self._init_captured_args = self._init_func.captured_inputs self._next_func = per_device_dataset._next_func self._next_captured_args = per_device_dataset._next_captured_args # The captured arguments to the next_func are string_handle, incarnation_id. # We update the incarnation id to the new one. self._next_captured_args[ per_device_dataset._incarnation_id_index] = incarnation_id self._finalize_func = per_device_dataset._finalize_func self._finalize_captured_args = per_device_dataset._finalize_captured_args variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **self._flat_structure) super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, per_device_dataset, incarnation_id): # pylint: disable=protected-access self._structure = per_device_dataset._structure self._init_func = per_device_dataset._init_func self._init_captured_args = self._init_func.captured_inputs self._next_func = per_device_dataset._next_func self._next_captured_args = per_device_dataset._next_captured_args # The captured arguments to the next_func are string_handle, incarnation_id. # We update the incarnation id to the new one. self._next_captured_args[ per_device_dataset._incarnation_id_index] = incarnation_id self._finalize_func = per_device_dataset._finalize_func self._finalize_captured_args = per_device_dataset._finalize_captured_args variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self)) super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
def _as_variant_tensor(self): with ops.device(self._target_device): return gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self))
def _as_variant_tensor(self): with ops.device(self._target_device): return gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)
def _as_variant_tensor(self): with ops.device(self._target_device): return gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)
def __init__(self, input_dataset, target_device, source_device="/cpu:0"): """Constructs a _CopyToDeviceDataset. Args: input_dataset: `Dataset` to be copied target_device: The name of the device to which elements would be copied. source_device: Device where input_dataset would be placed. """ self._input_dataset = input_dataset._apply_debug_options() # pylint: disable=protected-access self._target_device = target_device spec = framework_device.DeviceSpec().from_string(self._target_device) self._is_gpu_target = (spec.device_type == "GPU") self._source_device_string = source_device self._source_device = ops.convert_to_tensor(source_device) wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( self._input_dataset._variant_tensor) # pylint: disable=protected-access @function.defun() def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) resource = gen_dataset_ops.anonymous_iterator( **self._input_dataset._flat_structure) # pylint: disable=protected-access with ops.control_dependencies( [gen_dataset_ops.make_iterator(ds_variant, resource)]): return gen_dataset_ops.iterator_to_string_handle(resource) init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=self._source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, dataset_ops.get_legacy_output_types(self), dataset_ops.get_legacy_output_shapes(self), dataset_ops.get_legacy_output_classes(self)) return structure.to_tensor_list(self.element_spec, iterator.get_next()) next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._input_dataset._flat_types, # pylint: disable=protected-access f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(string_handle): """Destroys the iterator resource created. Args: string_handle: An iterator string handle created by _init_func Returns: Tensor constant 0 """ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, **self._input_dataset._flat_structure) # pylint: disable=protected-access with ops.control_dependencies([ resource_variable_ops.destroy_resource_op( iterator_resource, ignore_lookup_error=True)]): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access ) self._finalize_captured_args = self._finalize_func.captured_inputs g = ops.get_default_graph() self._init_func.add_to_graph(g) self._next_func.add_to_graph(g) self._finalize_func.add_to_graph(g) # pylint: enable=protected-scope with ops.device(self._target_device): variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **self._input_dataset._flat_structure) # pylint: disable=protected-access super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_spec): self._element_spec = element_spec multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec))) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec)) next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}, autograph=False) # Pure graph code. def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=structure.get_flat_tensor_types(self._element_spec), f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs self._incarnation_id_index = -1 for i, arg in enumerate(self._next_captured_args): if arg is incarnation_id: self._incarnation_id_index = i # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal( ) # pylint: disable=protected-access # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = ( _remote_finalize_func._get_concrete_function_internal()) # pylint: disable=protected-access self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **self._flat_structure) super(_PerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_structure): self._structure = element_structure multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.defun() def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes) next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._structure._flat_types, # pylint: disable=protected-access f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal( ) # pylint: disable=protected-access @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access ) self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self)) super(_PerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_structure): self._structure = element_structure multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.defun() def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes) next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._structure._flat_types, # pylint: disable=protected-access f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access ) self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self)) super(_PerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, input_dataset, target_device, source_device="/cpu:0"): """Constructs a _CopyToDeviceDataset. Args: input_dataset: `Dataset` to be copied target_device: The name of the device to which elements would be copied. source_device: Device where input_dataset would be placed. """ self._input_dataset = input_dataset self._target_device = target_device spec = framework_device.DeviceSpec().from_string(self._target_device) self._is_gpu_target = (spec.device_type == "GPU") self._source_device_string = source_device self._source_device = ops.convert_to_tensor(source_device) wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( self._input_dataset._variant_tensor) # pylint: disable=protected-access @function.defun() def _init_func(): """Creates an iterator for the input dataset. Returns: A `string` tensor that encapsulates the iterator created. """ ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) resource = gen_dataset_ops.anonymous_iterator( **dataset_ops.flat_structure(self._input_dataset)) with ops.control_dependencies( [gen_dataset_ops.make_iterator(ds_variant, resource)]): return gen_dataset_ops.iterator_to_string_handle(resource) init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=self._source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, self.output_types, self.output_shapes, self.output_classes) return self._element_structure._to_tensor_list(iterator.get_next()) # pylint: disable=protected-access next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_next_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._input_dataset._element_structure._flat_types, # pylint: disable=protected-access f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(string_handle): """Destroys the iterator resource created. Args: string_handle: An iterator string handle created by _init_func Returns: Tensor constant 0 """ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, **dataset_ops.flat_structure(self._input_dataset)) with ops.control_dependencies([ resource_variable_ops.destroy_resource_op( iterator_resource, ignore_lookup_error=True)]): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=self._source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access ) self._finalize_captured_args = self._finalize_func.captured_inputs g = ops.get_default_graph() self._init_func.add_to_graph(g) self._next_func.add_to_graph(g) self._finalize_func.add_to_graph(g) # pylint: enable=protected-scope with ops.device(self._target_device): variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self._input_dataset)) super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_spec, iterator_is_anonymous): self._element_spec = element_spec multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func.get_concrete_function() # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func.get_concrete_function() self._init_captured_args = self._init_func.captured_inputs # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec))) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec)) next_func_concrete = _next_func.get_concrete_function() # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}, autograph=False) # Pure graph code. def _remote_next_func(string_handle): return_values = functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=structure.get_flat_tensor_types(self._element_spec), f=next_func_concrete) # Add full type information to the graph so that the RemoteCall op # can determine for each of its outputs whether or not they are ragged # tensors (or other types that use variants) that contain strings # (or other host memory types). Then RemoteCall can # appropriately set AllocatorAttributes to control copies so # strings/host memory types stay on CPU. fulltype_list = type_utils.fulltypes_for_flat_tensors( self._element_spec) fulltype = type_utils.fulltype_list_to_product(fulltype_list) for return_value in return_values: return_value.op.experimental_set_type(fulltype) return return_values self._next_func = _remote_next_func.get_concrete_function() self._next_captured_args = self._next_func.captured_inputs if iterator_is_anonymous: self._next_captured_args = self._next_captured_args + [ multi_device_iterator_resource ] self._incarnation_id_index = -1 for i, arg in enumerate(self._next_captured_args): if arg is incarnation_id: self._incarnation_id_index = i # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func.get_concrete_function() # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func.get_concrete_function() self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **self._flat_structure) super(_PerDeviceGenerator, self).__init__(variant_tensor)