Esempio n. 1
0
def build_multi_device_iterator(batch_size, num_splits, preprocess_fn,
                                cpu_device, params, gpu_devices, dataset):
    """Creates a MultiDeviceIterator."""
    assert num_splits == len(gpu_devices)
    with tf.name_scope('batch_processing'):
        if params.eval:
            subset = 'validation'
        else:
            subset = 'train'
        batch_size_per_split = batch_size // num_splits
        ds = create_dataset(batch_size,
                            num_splits,
                            batch_size_per_split,
                            preprocess_fn,
                            dataset,
                            subset,
                            train=(not params.eval),
                            cache_data=params.cache_data,
                            num_threads=params.datasets_num_private_threads)
        multi_device_iterator = prefetching_ops.MultiDeviceIterator(
            ds,
            gpu_devices,
            source_device=cpu_device,
            max_buffer_size=params.multi_device_iterator_max_buffer_size)
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                             multi_device_iterator.initializer)
        return multi_device_iterator
Esempio n. 2
0
    def testBasic(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = prefetching_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            sess.run(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                self.assertEqual(i, sess.run(elem_on_1))
                self.assertEqual(i + 1, sess.run(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(elem_on_1)
                sess.run(elem_on_2)
Esempio n. 3
0
    def testMultipleInitializations(self):
        with ops.device("/cpu:0"):
            epoch = array_ops.placeholder(dtypes.int64, shape=[])
            dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
            dataset2 = dataset_ops.Dataset.range(1000)
            dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        multi_device_iterator = prefetching_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()
        init_op = multi_device_iterator.initializer

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            for i in range(1000):
                sess.run(init_op, feed_dict={epoch: i})
                self.assertEqual([(i, 0), (i, 1)],
                                 sess.run([elem_on_1, elem_on_2]))
Esempio n. 4
0
    def testBasicGpu(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with compat.forward_compatibility_horizon(2018, 8, 4):
            dataset = dataset_ops.Dataset.range(10)
            multi_device_iterator = prefetching_ops.MultiDeviceIterator(
                dataset, ["/cpu:1", "/gpu:0"])
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()

            config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
            with self.test_session(config=config) as sess:
                sess.run(multi_device_iterator.initializer)
                for i in range(0, 10, 2):
                    self.assertEqual(i, sess.run(elem_on_1))
                    self.assertEqual(i + 1, sess.run(elem_on_2))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(elem_on_1)
                    sess.run(elem_on_2)
Esempio n. 5
0
    def testRepeatDevices(self):
        with ops.device("/cpu:0"):
            dataset = dataset_ops.Dataset.range(20)
        multi_device_iterator = prefetching_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
        elements = multi_device_iterator.get_next()
        elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            sess.run(multi_device_iterator.initializer)
            for i in range(0, 20, 4):
                self.assertEqual(i, sess.run(elem_on_1))
                self.assertEqual(i + 1, sess.run(elem_on_2))
                self.assertEqual(i + 2, sess.run(elem_on_3))
                self.assertEqual(i + 3, sess.run(elem_on_4))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(elem_on_1)
                sess.run(elem_on_2)
                sess.run(elem_on_3)
                sess.run(elem_on_4)