def _component_specs(self): specs = [ tensor_spec.TensorSpec([], dtypes.resource), ] for _ in range(len(self._devices)): specs.append(iterator_ops.IteratorSpec(self._element_spec)) return specs
def __iter__(self): # We would like users to create iterators outside `tf.function`s so that we # can track them. if (not context.executing_eagerly() or ops.get_default_graph().building_function): raise RuntimeError( "__iter__() is not supported inside of tf.function or in graph mode.") def _create_per_worker_iterator(): dataset = self._dataset_fn() return iter(dataset) # If _PerWorkerDistributedDataset.__iter__ is called multiple # times, for the same object it should only create and register resource # once. Using object id to distinguish different iterator resources. per_worker_iterator = self._client._create_per_worker_resources( _create_per_worker_iterator) # Setting type_spec of each RemoteValue so that functions taking these # RemoteValues as inputs can be traced. for iterator_remote_value in per_worker_iterator._values: iterator_remote_value._set_type_spec( iterator_ops.IteratorSpec( self._dataset_fn.structured_outputs.element_spec)) return _PerWorkerDistributedIterator(per_worker_iterator._values)
def _component_specs(self): if use_anonymous_multi_device_iterator_v3(): specs = [ tensor_spec.TensorSpec([], dtypes.resource), ] else: specs = [ tensor_spec.TensorSpec([], dtypes.resource), tensor_spec.TensorSpec([], dtypes.variant) ] for _ in range(len(self._devices)): specs.append(iterator_ops.IteratorSpec(self._element_spec)) return specs
def _type_spec(self): return iterator_ops.IteratorSpec(self._element_spec)