def testInitializableIterator(self): with context.graph_mode(): devices = ["/device:CPU:0"] # Using random input since that is only allowed with initializable # iterator. dataset = dataset_ops.Dataset.from_tensor_slices( random_ops.random_uniform((10, ))) per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=False) iterator = per_device_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) next_element = iterator.get_next() for _ in range(10): self.evaluate(next_element) # Should fail after the input is finished. with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element) # After re-initializing the iterator, should be able to iterate again. self.evaluate(iterator.initializer) for _ in range(10): self.evaluate(next_element)
def _test_iterator_with_prefetch(self, devices, dataset, expected_values): if not context.executing_eagerly(): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=True) iterator = per_device_dataset.make_one_shot_iterator() # With prefetching, we cannot guarantee which input ends up on which # device, so we verify that the complete set seen on all devices is # correct, and equal numbers are distributed to each device. combined_actual = [] combined_expected = [] for expected_value in expected_values: next_element = iterator.get_next() combined_actual.extend( self.evaluate([ values.select_device(d, next_element) for d in devices ])) combined_expected.extend(expected_value) self.assertEqual(set(combined_expected), set(combined_actual)) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() self.evaluate( [values.select_device(d, next_element) for d in devices])
def distribute_dataset(self, dataset_fn): if self._cluster_spec: return values.MultiWorkerDataset( partial(self._call_dataset_fn, dataset_fn), self._worker_devices, self._prefetch_on_device, self._auto_shard_dataset) else: return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), self._devices, self._prefetch_on_device)
def distribute_dataset(self, dataset_fn): if self._cluster_spec: return values.MultiWorkerDataset( partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, self._prefetch_on_device) else: return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), self._devices, self._prefetch_on_device, source_device=device_util.resolve("/device:CPU:0"))
def _test_iterator_no_prefetch(self, devices, dataset, expected_values): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=False) iterator = per_device_dataset.make_one_shot_iterator() for expected_value in expected_values: next_element = iterator.get_next() actual = self.evaluate([ values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, actual) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() self.evaluate([ values.select_device(d, next_element) for d in devices])
def _test_iterator_with_prefetch(self, devices, dataset, expected_values): if not context.executing_eagerly(): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=True) iterator = per_device_dataset.make_initializable_iterator() self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next() computed_value = self.evaluate( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() self.evaluate( [values.select_device(d, next_element) for d in devices])
def distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" return values.PerDeviceDataset(self._call_dataset_fn(dataset_fn), self._compute_devices, True)
def distribute_dataset(self, dataset_fn): return values.PerDeviceDataset(self._call_dataset_fn(dataset_fn), self._devices, self._prefetch_on_device)
def distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. return values.PerDeviceDataset(self._call_dataset_fn(dataset_fn), self._devices, True)
def distribute_dataset(self, dataset): per_device_dataset = values.PerDeviceDataset(dataset, self._devices, self._prefetch_on_device) return per_device_dataset.make_one_shot_iterator()
def _replicated_input_fn(): result = input_fn() if not isinstance(result, dataset_ops.Dataset): raise ValueError("input_fn should return an object of" "type Dataset for it to be replicated") return values.PerDeviceDataset(result, devices, True)