def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_spec): self._element_spec = element_spec multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec))) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec)) next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}, autograph=False) # Pure graph code. def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=structure.get_flat_tensor_types(self._element_spec), f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs self._incarnation_id_index = -1 for i, arg in enumerate(self._next_captured_args): if arg is incarnation_id: self._incarnation_id_index = i # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal( ) # pylint: disable=protected-access # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = ( _remote_finalize_func._get_concrete_function_internal()) # pylint: disable=protected-access self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **self._flat_structure) super(_PerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, target_device, output_shapes, output_types, output_classes): self._target_device = target_device self._output_types = output_types self._output_shapes = output_shapes self._output_classes = output_classes self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) self._flat_output_types = nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)) multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.Defun() def _init_func(): return multi_device_iterator_string_handle @function.Defun() def _remote_init_func(): return functional_ops.remote_call(target=source_device, args=_init_func.captured_inputs, Tout=[dtypes.string], f=_init_func) self._init_func = _remote_init_func self._init_captured_args = _remote_init_func.captured_inputs @function.Defun(dtypes.string) def _next_func(string_handle): multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) @function.Defun(dtypes.string, experimental_ints_on_device=True) def _remote_next_func(string_handle): return functional_ops.remote_call(target=source_device, args=[string_handle] + _next_func.captured_inputs, Tout=self._flat_output_types, f=_next_func) self._next_func = _remote_next_func self._next_captured_args = _remote_next_func.captured_inputs @function.Defun(dtypes.string) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) @function.Defun(dtypes.string) def _remote_finalize_func(string_handle): return functional_ops.remote_call(target=source_device, args=[string_handle] + _finalize_func.captured_inputs, Tout=[dtypes.int64], f=_finalize_func) self._finalize_func = _remote_finalize_func self._finalize_captured_args = _remote_finalize_func.captured_inputs
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_structure): self._structure = element_structure multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.defun() def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes) next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._structure._flat_types, # pylint: disable=protected-access f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal( ) # pylint: disable=protected-access @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access ) self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self)) super(_PerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, target_device, output_shapes, output_types, output_classes): self._target_device = target_device self._output_types = output_types self._output_shapes = output_shapes self._output_classes = output_classes self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) self._flat_output_types = nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)) multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.Defun() def _init_func(): return multi_device_iterator_string_handle @function.Defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=_init_func.captured_inputs, Tout=[dtypes.string], f=_init_func) self._init_func = _remote_init_func self._init_captured_args = _remote_init_func.captured_inputs @function.Defun(dtypes.string) def _next_func(string_handle): multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) @function.Defun(dtypes.string, experimental_ints_on_device=True) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + _next_func.captured_inputs, Tout=self._flat_output_types, f=_next_func) self._next_func = _remote_next_func self._next_captured_args = _remote_next_func.captured_inputs @function.Defun(dtypes.string) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) @function.Defun(dtypes.string) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + _finalize_func.captured_inputs, Tout=[dtypes.int64], f=_finalize_func) self._finalize_func = _remote_finalize_func self._finalize_captured_args = _remote_finalize_func.captured_inputs
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_structure): self._structure = element_structure multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.defun() def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access self._init_captured_args = self._init_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._structure._flat_types, output_shapes=self._structure._flat_shapes) next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._structure._flat_types, # pylint: disable=protected-access f=next_func_concrete) self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access self._next_captured_args = self._next_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access ) self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **dataset_ops.flat_structure(self)) super(_PerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, element_spec, iterator_is_anonymous): self._element_spec = element_spec multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func.get_concrete_function() # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun(autograph=False) # Pure graph code. def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func.get_concrete_function() self._init_captured_args = self._init_func.captured_inputs # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _next_func(string_handle): # pylint: disable=protected-access multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec))) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=structure.get_flat_tensor_types( self._element_spec), output_shapes=structure.get_flat_tensor_shapes( self._element_spec)) next_func_concrete = _next_func.get_concrete_function() # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}, autograph=False) # Pure graph code. def _remote_next_func(string_handle): return_values = functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=structure.get_flat_tensor_types(self._element_spec), f=next_func_concrete) # Add full type information to the graph so that the RemoteCall op # can determine for each of its outputs whether or not they are ragged # tensors (or other types that use variants) that contain strings # (or other host memory types). Then RemoteCall can # appropriately set AllocatorAttributes to control copies so # strings/host memory types stay on CPU. fulltype_list = type_utils.fulltypes_for_flat_tensors( self._element_spec) fulltype = type_utils.fulltype_list_to_product(fulltype_list) for return_value in return_values: return_value.op.experimental_set_type(fulltype) return return_values self._next_func = _remote_next_func.get_concrete_function() self._next_captured_args = self._next_func.captured_inputs if iterator_is_anonymous: self._next_captured_args = self._next_captured_args + [ multi_device_iterator_resource ] self._incarnation_id_index = -1 for i, arg in enumerate(self._next_captured_args): if arg is incarnation_id: self._incarnation_id_index = i # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func.get_concrete_function() # TODO(b/124254153): Enable autograph once the overhead is low enough. @function.defun( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], autograph=False) # Pure graph code. def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func.get_concrete_function() self._finalize_captured_args = self._finalize_func.captured_inputs variant_tensor = gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, init_func=self._init_func, next_func=self._next_func, finalize_func=self._finalize_func, **self._flat_structure) super(_PerDeviceGenerator, self).__init__(variant_tensor)
def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, source_device, target_device, output_shapes, output_types, output_classes): self._target_device = target_device self._output_types = output_types self._output_shapes = output_shapes self._output_classes = output_classes self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) self._flat_output_types = nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)) multi_device_iterator_string_handle = ( gen_dataset_ops.multi_device_iterator_to_string_handle( multi_device_iterator_resource)) @function.defun() def _init_func(): return multi_device_iterator_string_handle init_func_concrete = _init_func.get_concrete_function() @function.defun() def _remote_init_func(): return functional_ops.remote_call( target=source_device, args=init_func_concrete.captured_inputs, Tout=[dtypes.string], f=init_func_concrete) self._init_func = _remote_init_func.get_concrete_function() self._init_captured_args = self._init_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _next_func(string_handle): multi_device_iterator = ( gen_dataset_ops.multi_device_iterator_from_string_handle( string_handle=string_handle, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes)) return gen_dataset_ops.multi_device_iterator_get_next_from_shard( multi_device_iterator=multi_device_iterator, shard_num=shard_num, incarnation_id=incarnation_id, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) next_func_concrete = _next_func.get_concrete_function() @function.defun_with_attributes( input_signature=[tensor_spec.TensorSpec([], dtypes.string)], attributes={"experimental_ints_on_device": True}) def _remote_next_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + next_func_concrete.captured_inputs, Tout=self._flat_output_types, f=next_func_concrete) self._next_func = _remote_next_func.get_concrete_function() self._next_captured_args = self._next_func.captured_inputs @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _finalize_func(unused_string_handle): return array_ops.constant(0, dtypes.int64) finalize_func_concrete = _finalize_func.get_concrete_function() @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) def _remote_finalize_func(string_handle): return functional_ops.remote_call( target=source_device, args=[string_handle] + finalize_func_concrete.captured_inputs, Tout=[dtypes.int64], f=finalize_func_concrete) self._finalize_func = _remote_finalize_func.get_concrete_function() self._finalize_captured_args = self._finalize_func.captured_inputs