def build_multi_device_iterator(batch_size, num_splits, preprocess_fn, cpu_device, params, gpu_devices, dataset): """Creates a MultiDeviceIterator.""" assert num_splits == len(gpu_devices) with tf.name_scope('batch_processing'): if params.eval: subset = 'validation' else: subset = 'train' batch_size_per_split = batch_size // num_splits ds = create_dataset(batch_size, num_splits, batch_size_per_split, preprocess_fn, dataset, subset, train=(not params.eval), cache_data=params.cache_data, num_threads=params.datasets_num_private_threads) multi_device_iterator = prefetching_ops.MultiDeviceIterator( ds, gpu_devices, source_device=cpu_device, max_buffer_size=params.multi_device_iterator_max_buffer_size) tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, multi_device_iterator.initializer) return multi_device_iterator
def testBasic(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = prefetching_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"]) elem_on_1, elem_on_2 = multi_device_iterator.get_next() config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config) as sess: sess.run(multi_device_iterator.initializer) for i in range(0, 10, 2): self.assertEqual(i, sess.run(elem_on_1)) self.assertEqual(i + 1, sess.run(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): sess.run(elem_on_1) sess.run(elem_on_2)
def testMultipleInitializations(self): with ops.device("/cpu:0"): epoch = array_ops.placeholder(dtypes.int64, shape=[]) dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000) dataset2 = dataset_ops.Dataset.range(1000) dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) multi_device_iterator = prefetching_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() init_op = multi_device_iterator.initializer config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config) as sess: for i in range(1000): sess.run(init_op, feed_dict={epoch: i}) self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
def testBasicGpu(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with compat.forward_compatibility_horizon(2018, 8, 4): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = prefetching_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/gpu:0"]) elem_on_1, elem_on_2 = multi_device_iterator.get_next() config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) with self.test_session(config=config) as sess: sess.run(multi_device_iterator.initializer) for i in range(0, 10, 2): self.assertEqual(i, sess.run(elem_on_1)) self.assertEqual(i + 1, sess.run(elem_on_2)) with self.assertRaises(errors.OutOfRangeError): sess.run(elem_on_1) sess.run(elem_on_2)
def testRepeatDevices(self): with ops.device("/cpu:0"): dataset = dataset_ops.Dataset.range(20) multi_device_iterator = prefetching_ops.MultiDeviceIterator( dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"]) elements = multi_device_iterator.get_next() elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements config = config_pb2.ConfigProto(device_count={"CPU": 3}) with self.test_session(config=config) as sess: sess.run(multi_device_iterator.initializer) for i in range(0, 20, 4): self.assertEqual(i, sess.run(elem_on_1)) self.assertEqual(i + 1, sess.run(elem_on_2)) self.assertEqual(i + 2, sess.run(elem_on_3)) self.assertEqual(i + 3, sess.run(elem_on_4)) with self.assertRaises(errors.OutOfRangeError): sess.run(elem_on_1) sess.run(elem_on_2) sess.run(elem_on_3) sess.run(elem_on_4)