def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1): ds_iterator_handle = ds_iterator.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( h, ds.output_types, ds.output_shapes) return remote_iterator.get_next() target = constant_op.constant(device0) with ops.device(device1): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, output_types=[dtypes.float32], target_device=target, string_arg=ds_iterator_handle, buffer_size=3, shared_name=buffer_name) with ops.device(device1): prefetch_op = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=buffer_resource_handle, output_types=[dtypes.float32]) reset_op = prefetching_ops.function_buffering_resource_reset( function_buffer_resource=buffer_resource_handle) destroy_op = resource_variable_ops.destroy_resource_op( buffer_resource_handle, ignore_lookup_error=True) return (prefetch_op, reset_op, destroy_op)
def __init__(self, input_dataset, one_shot, devices, buffer_size, shared_name=None): self._input_dataset = input_dataset self._get_next_call_count = 0 self._one_shot = one_shot if shared_name is None: shared_name = "" self._devices = devices if self._one_shot: self._input_iterator = input_dataset.make_one_shot_iterator() else: self._input_iterator = iterator_ops.Iterator.from_structure( self._input_dataset.output_types, self._input_dataset.output_shapes, shared_name, self._input_dataset.output_classes) input_iterator_handle = self._input_iterator.string_handle() @function.Defun(dtypes.string) def _prefetch_fn(handle): """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( handle, self._input_iterator.output_types, self._input_iterator.output_shapes, self._input_iterator.output_classes) ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) target_device = gen_dataset_ops.iterator_get_device( self._input_iterator._iterator_resource) self._buffering_resources = [] for device in nest.flatten(self._devices): with ops.device(device): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_prefetch_fn, output_types=data_nest.flatten( sparse.as_dense_types( self._input_dataset.output_types, self._input_dataset.output_classes)), target_device=target_device, string_arg=input_iterator_handle, buffer_size=buffer_size, shared_name=shared_name) self._buffering_resources.append(buffer_resource_handle) if not self._one_shot: reset_ops = [] for buffer_resource in self._buffering_resources: reset_ops.append( prefetching_ops.function_buffering_resource_reset( buffer_resource)) with ops.control_dependencies(reset_ops): self._initializer = self._input_iterator.make_initializer( self._input_dataset)
def __init__(self, input_dataset, one_shot, devices, buffer_size, shared_name=None): self._input_dataset = input_dataset self._get_next_call_count = 0 self._one_shot = one_shot if shared_name is None: shared_name = "" self._devices = devices if self._one_shot: self._input_iterator = input_dataset.make_one_shot_iterator() else: self._input_iterator = iterator_ops.Iterator.from_structure( self._input_dataset.output_types, self._input_dataset.output_shapes, shared_name, self._input_dataset.output_classes) input_iterator_handle = self._input_iterator.string_handle() @function.Defun(dtypes.string) def _prefetch_fn(handle): """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( handle, self._input_iterator.output_types, self._input_iterator.output_shapes, self._input_iterator.output_classes) ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) target_device = gen_dataset_ops.iterator_get_device( self._input_iterator._iterator_resource) self._buffering_resources = [] for device in nest.flatten(self._devices): with ops.device(device): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_prefetch_fn, output_types=data_nest.flatten( sparse.as_dense_types(self._input_dataset.output_types, self._input_dataset.output_classes)), target_device=target_device, string_arg=input_iterator_handle, buffer_size=buffer_size, shared_name=shared_name) self._buffering_resources.append(buffer_resource_handle) if not self._one_shot: reset_ops = [] for buffer_resource in self._buffering_resources: reset_ops.append( prefetching_ops.function_buffering_resource_reset(buffer_resource)) with ops.control_dependencies(reset_ops): self._initializer = self._input_iterator.make_initializer( self._input_dataset)