Ejemplo n.º 1
0
  def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
    ds_iterator_handle = ds_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, ds.output_types, ds.output_shapes)
      return remote_iterator.get_next()

    target = constant_op.constant(device0)
    with ops.device(device1):
      buffer_resource_handle = prefetching_ops.function_buffering_resource(
          f=_remote_fn,
          output_types=[dtypes.float32],
          target_device=target,
          string_arg=ds_iterator_handle,
          buffer_size=3,
          shared_name=buffer_name)

    with ops.device(device1):
      prefetch_op = prefetching_ops.function_buffering_resource_get_next(
          function_buffer_resource=buffer_resource_handle,
          output_types=[dtypes.float32])
      reset_op = prefetching_ops.function_buffering_resource_reset(
          function_buffer_resource=buffer_resource_handle)
      destroy_op = resource_variable_ops.destroy_resource_op(
          buffer_resource_handle, ignore_lookup_error=True)

    return (prefetch_op, reset_op, destroy_op)
    def __init__(self,
                 input_dataset,
                 one_shot,
                 devices,
                 buffer_size,
                 shared_name=None):
        self._input_dataset = input_dataset
        self._get_next_call_count = 0
        self._one_shot = one_shot
        if shared_name is None:
            shared_name = ""
        self._devices = devices

        if self._one_shot:
            self._input_iterator = input_dataset.make_one_shot_iterator()
        else:
            self._input_iterator = iterator_ops.Iterator.from_structure(
                self._input_dataset.output_types,
                self._input_dataset.output_shapes, shared_name,
                self._input_dataset.output_classes)
        input_iterator_handle = self._input_iterator.string_handle()

        @function.Defun(dtypes.string)
        def _prefetch_fn(handle):
            """Prefetches one element from `input_iterator`."""
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, self._input_iterator.output_types,
                self._input_iterator.output_shapes,
                self._input_iterator.output_classes)
            ret = remote_iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))

        target_device = gen_dataset_ops.iterator_get_device(
            self._input_iterator._iterator_resource)
        self._buffering_resources = []
        for device in nest.flatten(self._devices):
            with ops.device(device):
                buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    f=_prefetch_fn,
                    output_types=data_nest.flatten(
                        sparse.as_dense_types(
                            self._input_dataset.output_types,
                            self._input_dataset.output_classes)),
                    target_device=target_device,
                    string_arg=input_iterator_handle,
                    buffer_size=buffer_size,
                    shared_name=shared_name)
                self._buffering_resources.append(buffer_resource_handle)

        if not self._one_shot:
            reset_ops = []
            for buffer_resource in self._buffering_resources:
                reset_ops.append(
                    prefetching_ops.function_buffering_resource_reset(
                        buffer_resource))
            with ops.control_dependencies(reset_ops):
                self._initializer = self._input_iterator.make_initializer(
                    self._input_dataset)
Ejemplo n.º 3
0
  def __init__(self,
               input_dataset,
               one_shot,
               devices,
               buffer_size,
               shared_name=None):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._one_shot = one_shot
    if shared_name is None:
      shared_name = ""
    self._devices = devices

    if self._one_shot:
      self._input_iterator = input_dataset.make_one_shot_iterator()
    else:
      self._input_iterator = iterator_ops.Iterator.from_structure(
          self._input_dataset.output_types, self._input_dataset.output_shapes,
          shared_name, self._input_dataset.output_classes)
    input_iterator_handle = self._input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self._input_iterator.output_types,
          self._input_iterator.output_shapes,
          self._input_iterator.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    target_device = gen_dataset_ops.iterator_get_device(
        self._input_iterator._iterator_resource)
    self._buffering_resources = []
    for device in nest.flatten(self._devices):
      with ops.device(device):
        buffer_resource_handle = prefetching_ops.function_buffering_resource(
            f=_prefetch_fn,
            output_types=data_nest.flatten(
                sparse.as_dense_types(self._input_dataset.output_types,
                                      self._input_dataset.output_classes)),
            target_device=target_device,
            string_arg=input_iterator_handle,
            buffer_size=buffer_size,
            shared_name=shared_name)
        self._buffering_resources.append(buffer_resource_handle)

    if not self._one_shot:
      reset_ops = []
      for buffer_resource in self._buffering_resources:
        reset_ops.append(
            prefetching_ops.function_buffering_resource_reset(buffer_resource))
      with ops.control_dependencies(reset_ops):
        self._initializer = self._input_iterator.make_initializer(
            self._input_dataset)