Esempio n. 1
0
 def _component_specs(self):
   specs = [
       tensor_spec.TensorSpec([], dtypes.resource),
   ]
   for _ in range(len(self._devices)):
     specs.append(iterator_ops.IteratorSpec(self._element_spec))
   return specs
Esempio n. 2
0
  def __iter__(self):
    # We would like users to create iterators outside `tf.function`s so that we
    # can track them.
    if (not context.executing_eagerly() or
        ops.get_default_graph().building_function):
      raise RuntimeError(
          "__iter__() is not supported inside of tf.function or in graph mode.")

    def _create_per_worker_iterator():
      dataset = self._dataset_fn()
      return iter(dataset)

    # If _PerWorkerDistributedDataset.__iter__ is called multiple
    # times, for the same object it should only create and register resource
    # once. Using object id to distinguish different iterator resources.
    per_worker_iterator = self._client._create_per_worker_resources(
        _create_per_worker_iterator)

    # Setting type_spec of each RemoteValue so that functions taking these
    # RemoteValues as inputs can be traced.
    for iterator_remote_value in per_worker_iterator._values:
      iterator_remote_value._set_type_spec(
          iterator_ops.IteratorSpec(
              self._dataset_fn.structured_outputs.element_spec))
    return _PerWorkerDistributedIterator(per_worker_iterator._values)
Esempio n. 3
0
 def _component_specs(self):
     if use_anonymous_multi_device_iterator_v3():
         specs = [
             tensor_spec.TensorSpec([], dtypes.resource),
         ]
     else:
         specs = [
             tensor_spec.TensorSpec([], dtypes.resource),
             tensor_spec.TensorSpec([], dtypes.variant)
         ]
     for _ in range(len(self._devices)):
         specs.append(iterator_ops.IteratorSpec(self._element_spec))
     return specs
Esempio n. 4
0
 def _type_spec(self):
     return iterator_ops.IteratorSpec(self._element_spec)