def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" if self._prefetch_on_device: on_device_dataset = self._dataset.apply( prefetching_ops_v2.prefetch_to_devices(self._devices)) dataset_iterator = on_device_dataset.make_one_shot_iterator() elif context.executing_eagerly(): dataset_iterator = datasets.Iterator(self._dataset) else: dataset_iterator = self._dataset.make_one_shot_iterator() return PerDeviceDataIterator( dataset_iterator, self._devices, self._prefetch_on_device)
def testPrefetchToOneDevice(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops_v2.prefetch_to_devices("/gpu:0")) iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testPrefetchToOneDevice(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops_v2.prefetch_to_devices("/gpu:0")) iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.test_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def __init__(self, dataset, devices, prefetch_on_device=None): self._devices = devices # Default to using prefetching in graph mode, unless specified. # TODO(priyag): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") if self._prefetch_on_device: self._dataset = dataset.apply( prefetching_ops_v2.prefetch_to_devices(self._devices)) else: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. self._dataset = dataset.batch(len(devices), drop_remainder=True)
def __init__(self, dataset, devices, prefetch_on_device=None): self._devices = devices # Default to using prefetching in graph mode, unless specified. # TODO(priyag): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") if self._prefetch_on_device: self._dataset = dataset.apply( prefetching_ops_v2.prefetch_to_devices(self._devices)) else: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. self._dataset = dataset.batch(len(devices), drop_remainder=True)
def testPrefetchToTwoDevicesInAList(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() output = [] with self.test_session() as sess: for _ in range(5): result = sess.run(next_element) self.assertEqual(2, len(result)) output.extend(result) self.assertEquals(set(range(10)), set(output)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)
def testPrefetchToTwoDevicesWithReinit(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() # TODO(rohanj): Modify test to go till the end of the dataset when we # switch to MultiDeviceIterator. with self.cached_session() as sess: sess.run(iterator.initializer) for _ in range(4): sess.run(next_element) sess.run(iterator.initializer) for _ in range(4): sess.run(next_element)
def testPrefetchToTwoDevicesInAList(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() output = [] # TODO(rohanj): Modify test to go till the end of the dataset when we # switch to MultiDeviceIterator. with self.cached_session() as sess: for _ in range(4): result = sess.run(next_element) self.assertEqual(2, len(result)) output.extend(result) self.assertEquals(set(range(8)), set(output))
def testPrefetchToTwoDevicesWithReinit(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() with self.test_session() as sess: sess.run(iterator.initializer) for _ in range(5): sess.run(next_element) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) sess.run(iterator.initializer) for _ in range(5): sess.run(next_element)
def testPrefetchToTwoDevicesInAList(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() output = [] with self.test_session() as sess: for _ in range(5): result = sess.run(next_element) self.assertEqual(2, len(result)) output.extend(result) self.assertEquals(set(range(10)), set(output)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element)