def testInitializableIterator(self): with context.graph_mode(): devices = ["/device:CPU:0"] # Using random input since that is only allowed with initializable # iterator. dataset = dataset_ops.Dataset.from_tensor_slices( random_ops.random_uniform((10, ))) device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map) per_replica_dataset = input_lib.PerReplicaDataset( dataset, input_workers, 0) iterator = per_replica_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) next_element = iterator.get_next_as_list() for _ in range(10): self.evaluate(next_element) # Should fail after the input is finished. with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) # After re-initializing the iterator, should be able to iterate again. self.evaluate(iterator.initializer) for _ in range(10): self.evaluate(next_element)
def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0, prefetch_on_device=True)
def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. worker_index = 0 return input_lib.PerReplicaDataset(self._call_dataset_fn(dataset_fn), self._input_workers, worker_index, prefetch_on_device=True)
def _distribute_dataset(self, dataset_fn): if self._local_mode: return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0) else: return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers, auto_shard=self._auto_shard_dataset)
def _test_iterator(self, devices, dataset, expected_values): device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map) per_replica_dataset = input_lib.PerReplicaDataset( dataset, input_workers, 0) if context.executing_eagerly(): iterator = per_replica_dataset.make_one_shot_iterator() else: iterator = per_replica_dataset.make_initializable_iterator() self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next_as_list() computed_value = self.evaluate(next_element) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next_as_list() self.evaluate(next_element)
def _distribute_dataset(self, dataset_fn): return input_lib.PerReplicaDataset(self._call_dataset_fn(dataset_fn), self._input_workers, 0)