Beispiel #1
0
    def __init__(self, input_dataset, device, buffer_size):
        with ops.device("/device:CPU:0"):
            super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
            input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
                self._resource)

        self._device = device

        @function.Defun(dtypes.string)
        def _prefetch_fn(handle):
            """Prefetches one element from `input_iterator`."""
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, self.output_types, self.output_shapes,
                self.output_classes)
            ret = remote_iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))

        _prefetch_fn.add_to_graph(None)

        with ops.device(device):
            self._buffering_resource = function_buffering_resource(
                f=_prefetch_fn,
                target_device=gen_dataset_ops.iterator_get_device(
                    self._resource),
                string_arg=input_iterator_handle,
                buffer_size=buffer_size,
                shared_name=iterator_ops._generate_shared_name(
                    "function_buffer_resource"))
  def __init__(self, input_dataset, devices, buffer_size):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._devices = devices
    input_iterator = input_dataset.make_one_shot_iterator()
    input_iterator_handle = input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, input_iterator.output_types, input_iterator.output_shapes,
          input_iterator.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    target_device = gen_dataset_ops.iterator_get_device(
        input_iterator._iterator_resource)
    self._buffering_resources = []
    for device in nest.flatten(self._devices):
      with ops.device(device):
        buffer_resource_handle = prefetching_ops.function_buffering_resource(
            f=_prefetch_fn,
            target_device=target_device,
            string_arg=input_iterator_handle,
            buffer_size=buffer_size)
        self._buffering_resources.append(buffer_resource_handle)
Beispiel #3
0
  def __init__(self,
               input_dataset,
               device,
               buffer_size):
    with ops.device("/device:CPU:0"):
      super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
      input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
          self._resource)

    self._device = device

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self.output_types, self.output_shapes, self.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    _prefetch_fn.add_to_graph(None)

    with ops.device(device):
      self._buffering_resource = function_buffering_resource(
          f=_prefetch_fn,
          output_types=self._flat_output_types,
          target_device=gen_dataset_ops.iterator_get_device(self._resource),
          string_arg=input_iterator_handle,
          buffer_size=buffer_size,
          shared_name=iterator_ops._generate_shared_name(
              "function_buffer_resource"))
    def __init__(self,
                 input_dataset,
                 one_shot,
                 devices,
                 buffer_size,
                 shared_name=None):
        self._input_dataset = input_dataset
        self._get_next_call_count = 0
        self._one_shot = one_shot
        if shared_name is None:
            shared_name = ""
        self._devices = devices

        if self._one_shot:
            self._input_iterator = input_dataset.make_one_shot_iterator()
        else:
            self._input_iterator = iterator_ops.Iterator.from_structure(
                self._input_dataset.output_types,
                self._input_dataset.output_shapes, shared_name,
                self._input_dataset.output_classes)
        input_iterator_handle = self._input_iterator.string_handle()

        @function.Defun(dtypes.string)
        def _prefetch_fn(handle):
            """Prefetches one element from `input_iterator`."""
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, self._input_iterator.output_types,
                self._input_iterator.output_shapes,
                self._input_iterator.output_classes)
            ret = remote_iterator.get_next()
            return nest.flatten(sparse.serialize_sparse_tensors(ret))

        target_device = gen_dataset_ops.iterator_get_device(
            self._input_iterator._iterator_resource)
        self._buffering_resources = []
        for device in nest.flatten(self._devices):
            with ops.device(device):
                buffer_resource_handle = prefetching_ops.function_buffering_resource(
                    f=_prefetch_fn,
                    output_types=data_nest.flatten(
                        sparse.as_dense_types(
                            self._input_dataset.output_types,
                            self._input_dataset.output_classes)),
                    target_device=target_device,
                    string_arg=input_iterator_handle,
                    buffer_size=buffer_size,
                    shared_name=shared_name)
                self._buffering_resources.append(buffer_resource_handle)

        if not self._one_shot:
            reset_ops = []
            for buffer_resource in self._buffering_resources:
                reset_ops.append(
                    prefetching_ops.function_buffering_resource_reset(
                        buffer_resource))
            with ops.control_dependencies(reset_ops):
                self._initializer = self._input_iterator.make_initializer(
                    self._input_dataset)
  def __init__(self,
               input_dataset,
               one_shot,
               device,
               buffer_size,
               shared_name=None):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._one_shot = one_shot
    if shared_name is None:
      shared_name = ""

    if self._one_shot:
      self._input_iterator = input_dataset.make_one_shot_iterator()
    else:
      self._input_iterator = iterator_ops.Iterator.from_structure(
          self._input_dataset.output_types, self._input_dataset.output_shapes,
          shared_name, self._input_dataset.output_classes)
    input_iterator_handle = self._input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self._input_iterator.output_types,
          self._input_iterator.output_shapes,
          self._input_iterator.output_classes)
      ret = remote_iterator.get_next()

      # Convert any `SparseTensorValue`s to `SparseTensor`s.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor_lib.SparseTensor.from_value(t)
          if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
      ])

      # Serialize any sparse tensors and convert result to tensors.
      ret = nest.pack_sequence_as(ret, [
          ops.convert_to_tensor(t)
          for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
      ])
      return nest.flatten(ret)

    with ops.device(device):
      self._buffering_resource = function_buffering_resource(
          f=_prefetch_fn,
          target_device=gen_dataset_ops.iterator_get_device(
              self._input_iterator._iterator_resource),
          string_arg=input_iterator_handle,
          buffer_size=buffer_size,
          shared_name=shared_name)

    if not self._one_shot:
      reset_op = function_buffering_resource_reset(self._buffering_resource)
      with ops.control_dependencies([reset_op]):
        self._initializer = self._input_iterator.make_initializer(
            self._input_dataset)
  def __init__(self,
               input_dataset,
               one_shot,
               devices,
               buffer_size,
               shared_name=None):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    self._one_shot = one_shot
    if shared_name is None:
      shared_name = ""
    self._devices = devices

    if self._one_shot:
      self._input_iterator = input_dataset.make_one_shot_iterator()
    else:
      self._input_iterator = iterator_ops.Iterator.from_structure(
          self._input_dataset.output_types, self._input_dataset.output_shapes,
          shared_name, self._input_dataset.output_classes)
    input_iterator_handle = self._input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      """Prefetches one element from `input_iterator`."""
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, self._input_iterator.output_types,
          self._input_iterator.output_shapes,
          self._input_iterator.output_classes)
      ret = remote_iterator.get_next()
      return nest.flatten(sparse.serialize_sparse_tensors(ret))

    target_device = gen_dataset_ops.iterator_get_device(
        self._input_iterator._iterator_resource)
    self._buffering_resources = []
    for device in nest.flatten(self._devices):
      with ops.device(device):
        buffer_resource_handle = prefetching_ops.function_buffering_resource(
            f=_prefetch_fn,
            output_types=data_nest.flatten(
                sparse.as_dense_types(self._input_dataset.output_types,
                                      self._input_dataset.output_classes)),
            target_device=target_device,
            string_arg=input_iterator_handle,
            buffer_size=buffer_size,
            shared_name=shared_name)
        self._buffering_resources.append(buffer_resource_handle)

    if not self._one_shot:
      reset_ops = []
      for buffer_resource in self._buffering_resources:
        reset_ops.append(
            prefetching_ops.function_buffering_resource_reset(buffer_resource))
      with ops.control_dependencies(reset_ops):
        self._initializer = self._input_iterator.make_initializer(
            self._input_dataset)
Beispiel #7
0
    def __init__(self, input_dataset, device, buffer_size):
        self._input_dataset = input_dataset
        self._get_next_call_count = 0
        input_iterator = input_dataset.make_one_shot_iterator()
        input_iterator_handle = input_iterator.string_handle()

        @function.Defun(dtypes.string)
        def _prefetch_fn(handle):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, input_iterator.output_types,
                input_iterator.output_shapes, input_iterator.output_classes)
            return remote_iterator.get_next()

        with ops.device(device):
            self._buffering_resource = function_buffering_resource(
                f=_prefetch_fn,
                target_device=gen_dataset_ops.iterator_get_device(
                    input_iterator._iterator_resource),
                string_arg=input_iterator_handle,
                buffer_size=buffer_size,
                thread_pool_size=0)
  def __init__(self, input_dataset, device, buffer_size):
    self._input_dataset = input_dataset
    self._get_next_call_count = 0
    input_iterator = input_dataset.make_one_shot_iterator()
    input_iterator_handle = input_iterator.string_handle()

    @function.Defun(dtypes.string)
    def _prefetch_fn(handle):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, input_iterator.output_types, input_iterator.output_shapes,
          input_iterator.output_classes)
      return remote_iterator.get_next()

    with ops.device(device):
      self._buffering_resource = function_buffering_resource(
          f=_prefetch_fn,
          target_device=gen_dataset_ops.iterator_get_device(
              input_iterator._iterator_resource),
          string_arg=input_iterator_handle,
          buffer_size=buffer_size,
          thread_pool_size=0)