Example #1
0
def build_multi_device_iterator(batch_size, num_splits, preprocess_fn,
                                cpu_device, params, gpu_devices, dataset):
    """Creates a MultiDeviceIterator."""
    assert num_splits == len(gpu_devices)
    with tf.name_scope('batch_processing'):
        if params.eval:
            subset = 'validation'
        else:
            subset = 'train'
        batch_size_per_split = batch_size // num_splits
        ds = create_dataset(batch_size,
                            num_splits,
                            batch_size_per_split,
                            preprocess_fn,
                            dataset,
                            subset,
                            train=(not params.eval),
                            cache_data=params.cache_data,
                            num_threads=params.datasets_num_private_threads)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            ds,
            gpu_devices,
            source_device=cpu_device,
            max_buffer_size=params.multi_device_iterator_max_buffer_size)
        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                             multi_device_iterator.initializer)
        return multi_device_iterator
Example #2
0
 def dataset_producer(t):
     ragged_ds = dataset_ops.Dataset.from_tensor_slices(t).batch(
         2, drop_remainder)
     it = multi_device_iterator_ops.MultiDeviceIterator(
         ragged_ds, ['GPU:0'])
     with ops.device_v2('GPU:0'):
         return it.get_next_as_optional()
    def testGetNextAsOptional(self):
        if context.executing_eagerly():
            return

        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])
        elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
        elem_on_1_has_value_t = elem_on_1.has_value()
        elem_on_1_t = elem_on_1.get_value()
        elem_on_2_has_value_t = elem_on_2.has_value()
        elem_on_2_t = elem_on_2.get_value()

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 8, 2):
                elem_on_1_has_value, elem_on_1_value = sess.run(
                    [elem_on_1_has_value_t, elem_on_1_t])
                self.assertTrue(elem_on_1_has_value)
                self.assertEqual(i, elem_on_1_value)
                elem_on_2_has_value, elem_on_2_value = sess.run(
                    [elem_on_2_has_value_t, elem_on_2_t])
                self.assertTrue(elem_on_2_has_value)
                self.assertEqual(i + 1, elem_on_2_value)
            elem_on_1_has_value, elem_on_1_value = sess.run(
                [elem_on_1_has_value_t, elem_on_1_t])
            self.assertTrue(elem_on_1_has_value)
            self.assertEqual(8, elem_on_1_value)
            self.assertFalse(self.evaluate(elem_on_1_has_value_t))
            self.assertFalse(self.evaluate(elem_on_2_has_value_t))
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(elem_on_1_t)
            with self.assertRaises(errors.InvalidArgumentError):
                self.evaluate(elem_on_2_t)
Example #4
0
  def testPrefetchWithSlackOption(self):
    """Determines slack_period based on num devices attached to iterator."""
    dataset = dataset_ops.Dataset.range(10)
    dataset = dataset.prefetch(1)
    options = dataset_ops.Options()
    options.experimental_slack = True
    dataset = dataset.with_options(options)
    multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
        dataset, ["/cpu:1", "/cpu:2"])
    dataset = multi_device_iterator._dataset  # pylint: disable=protected-access
    self.assertIn("slack", dataset.options()._graph_rewrites().enabled)
    self.assertIn("slack:slack_period:2",
                  dataset.options()._graph_rewrite_configs())

    config = config_pb2.ConfigProto(device_count={"CPU": 3})
    with self.test_session(config=config):
      self.evaluate(multi_device_iterator.initializer)
      for i in range(0, 10, 2):
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()
        self.assertEqual(i, self.evaluate(elem_on_1))
        self.assertEqual(i + 1, self.evaluate(elem_on_2))
      with self.assertRaises(errors.OutOfRangeError):
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()
        self.evaluate(elem_on_1)
        self.evaluate(elem_on_2)
    def testGetNextAsOptional(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
            has_elem_1, get_elem_1 = self.evaluate(
                [elem_on_1.has_value(),
                 elem_on_1.get_value()])
            has_elem_2, get_elem_2 = self.evaluate(
                [elem_on_2.has_value(),
                 elem_on_2.get_value()])
            self.assertTrue(has_elem_1)
            self.assertEqual(i, get_elem_1)
            self.assertTrue(has_elem_2)
            self.assertEqual(i + 1, get_elem_2)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
        has_elem_1 = elem_on_1.has_value()
        has_elem_2 = elem_on_2.has_value()
        self.assertFalse(self.evaluate(has_elem_1))
        self.assertFalse(self.evaluate(has_elem_2))
        with self.assertRaises(errors.InvalidArgumentError):
            elem_1 = elem_on_1.get_value()
            self.evaluate(elem_1)
        with self.assertRaises(errors.InvalidArgumentError):
            elem_2 = elem_on_2.get_value()
            self.evaluate(elem_2)
Example #6
0
 def _make_iterator(self):
     """Make appropriate iterator on the dataset."""
     with ops.device(self._worker):
         if self._options is not None:
             self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
                 self._dataset,
                 self._devices,
                 max_buffer_size=self._options.
                 experimental_per_replica_buffer_size,
                 prefetch_buffer_size=self._options.
                 experimental_per_replica_buffer_size)
         else:
             self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
                 self._dataset,
                 self._devices,
             )
Example #7
0
    def testGetNextAsOptionalGpu(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/gpu:0"])
        elem_on_1, elem_on_2 = multi_device_iterator.get_next_as_optional()
        elem_on_1_has_value_t = elem_on_1.has_value()
        elem_on_1_t = elem_on_1.get_value()
        elem_on_2_has_value_t = elem_on_2.has_value()
        elem_on_2_t = elem_on_2.get_value()

        config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
        with self.test_session(config=config) as sess:
            sess.run(multi_device_iterator.initializer)
            for i in range(0, 8, 2):
                elem_on_1_has_value, elem_on_1_value = sess.run(
                    [elem_on_1_has_value_t, elem_on_1_t])
                self.assertTrue(elem_on_1_has_value)
                self.assertEqual(i, elem_on_1_value)
                elem_on_2_has_value, elem_on_2_value = sess.run(
                    [elem_on_2_has_value_t, elem_on_2_t])
                self.assertTrue(elem_on_2_has_value)
                self.assertEqual(i + 1, elem_on_2_value)
            elem_on_1_has_value, elem_on_1_value = sess.run(
                [elem_on_1_has_value_t, elem_on_1_t])
            self.assertTrue(elem_on_1_has_value)
            self.assertEqual(8, elem_on_1_value)
            self.assertFalse(sess.run(elem_on_1_has_value_t))
            self.assertFalse(sess.run(elem_on_2_has_value_t))
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(elem_on_1_t)
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(elem_on_2_t)
    def testInitOnly(self, num_inits):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        for _ in range(num_inits):
            self.evaluate(multi_device_iterator.initializer)
 def f():
     dataset = dataset_ops.Dataset.range(10)
     multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
         dataset, ["/cpu:1", "/cpu:2"])
     self.evaluate(multi_device_iterator.get_next())
     del multi_device_iterator
     del dataset
Example #10
0
 def build_multi_device_iterator(self, batch_size, num_splits, cpu_device,
                                 params, gpu_devices, dataset):
   """Creates a MultiDeviceIterator."""
   assert self.supports_datasets()
   assert num_splits == len(gpu_devices)
   with tf.name_scope('batch_processing'):
     if params.eval:
       subset = 'validation'
     else:
       subset = 'train'
     batch_size_per_split = batch_size // num_splits
     ds = self.create_dataset(
         batch_size,
         num_splits,
         batch_size_per_split,
         dataset,
         subset,
         train=(not params.eval),
         datasets_repeat_cached_sample=params.datasets_repeat_cached_sample,
         num_threads=params.datasets_num_private_threads,
         datasets_use_caching=params.datasets_use_caching,
         datasets_parallel_interleave_cycle_length=(
             params.datasets_parallel_interleave_cycle_length),
         datasets_sloppy_parallel_interleave=(
             params.datasets_sloppy_parallel_interleave))
     multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
         ds,
         gpu_devices,
         source_device=cpu_device,
         max_buffer_size=params.multi_device_iterator_max_buffer_size)
     tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS,
                          multi_device_iterator.initializer)
     return multi_device_iterator
Example #11
0
    def testNoGetNext(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            sess.run(multi_device_iterator.initializer)
    def testInitOnly(self, num_inits):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            for _ in range(num_inits):
                self.evaluate(multi_device_iterator.initializer)
  def testMultipleInitializationsEager(self):
    dataset1 = dataset_ops.Dataset.range(1000)
    dataset2 = dataset_ops.Dataset.range(1000)
    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))

    for _ in range(5):
      multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
          dataset, [self._devices[1], self._devices[2]], prefetch_buffer_size=4)
      self.evaluate(multi_device_iterator.initializer)
      elem_on_1, elem_on_2 = multi_device_iterator.get_next()
      self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2]))
    def testEagerMemoryUsageWithReset(self):
        if memory_profiler is None:
            self.skipTest("memory_profiler required to run this test")

        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        def f():
            self.evaluate(multi_device_iterator.get_next())
            multi_device_iterator._eager_reset()

        self.assertMemoryNotIncreasing(f, num_iters=50, max_increase_mb=250)
    def testEagerMemoryUsageWithReset(self):
        if memory_profiler is None:
            self.skipTest("memory_profiler required to run this test")

        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        def f():
            self.evaluate(multi_device_iterator.get_next())
            multi_device_iterator._eager_reset()

        self.assertNotIncreasingMemory(f,
                                       num_iters=100,
                                       increase_threshold_absolute_mb=350)
    def testMultipleInitializationsEager(self):
        if not context.executing_eagerly():
            return

        with ops.device("/cpu:0"):
            dataset1 = dataset_ops.Dataset.range(1000)
            dataset2 = dataset_ops.Dataset.range(1000)
            dataset = dataset_ops.Dataset.zip((dataset1, dataset2))

        for _ in range(5):
            multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
                dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual([(0, 0), (1, 1)],
                             self.evaluate([elem_on_1, elem_on_2]))
Example #17
0
 def make_initializable_iterator(self):
   """Get an initializable iterator for the distributed PerReplicaDataset."""
   # Eager mode generates already initialized iterators. Hence we cannot create
   # an initializable iterator.
   if context.executing_eagerly():
     raise ValueError("Cannot create initializable iterator in Eager mode. "
                      "Please use `make_one_shot_iterator` instead.")
   if self._prefetch_on_device:
     dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
         self._dataset, self._replica_devices)
   else:
     dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
   return PerReplicaDataIterator(
       dataset_iterator, self._input_workers, self._worker_index,
       prefetch_on_device=self._prefetch_on_device)
Example #18
0
    def testBasic(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            sess.run(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                self.assertEqual(i, sess.run(elem_on_1))
                self.assertEqual(i + 1, sess.run(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(elem_on_1)
                sess.run(elem_on_2)
    def testRepeatDevices(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[1]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elements = multi_device_iterator.get_next()
            elem_on_1, elem_on_2 = elements
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elements = multi_device_iterator.get_next()
            elem_on_1, elem_on_2 = elements
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)
    def testOneOnSameDevice(self):
        dataset = dataset_ops.Dataset.range(12)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[0], self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 12, 3):
            elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_0))
            self.assertEqual(i + 1, self.evaluate(elem_on_1))
            self.assertEqual(i + 2, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_0, elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_0)
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)
    def testUneven(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]], max_buffer_size=4)

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 10, 2):
            elem_on_1 = multi_device_iterator.get_next(self._devices[1])
            self.assertEqual(i, self.evaluate(elem_on_1))
        for i in range(0, 10, 2):
            elem_on_2 = multi_device_iterator.get_next(self._devices[2])
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)
    def testNotFullyDivisible(self):
        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, [self._devices[1], self._devices[2]])

        self.evaluate(multi_device_iterator.initializer)
        for i in range(0, 8, 2):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.assertEqual(i, self.evaluate(elem_on_1))
            self.assertEqual(i + 1, self.evaluate(elem_on_2))
        elem_on_1 = multi_device_iterator.get_next(self._devices[1])
        self.assertEqual(8, self.evaluate(elem_on_1))
        with self.assertRaises(errors.OutOfRangeError):
            elem_on_1, elem_on_2 = multi_device_iterator.get_next()
            self.evaluate(elem_on_1)
            self.evaluate(elem_on_2)
Example #23
0
 def make_one_shot_iterator(self):
   """Get a one time use iterator for the distributed PerReplicaDataset."""
   # Graph mode with one shot iterator is disabled.
   if not context.executing_eagerly():
     raise ValueError("Cannot create a one shot iterator. Please use "
                      "`make_initializable_iterator()` instead.")
   if self._prefetch_on_device:
     dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
         self._dataset, self._replica_devices)
   else:
     dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
   return PerReplicaDataIterator(
       dataset_iterator,
       self._input_workers,
       self._worker_index,
       prefetch_on_device=self._prefetch_on_device)
Example #24
0
    def testMultipleInitializations(self):
        with ops.device("/cpu:0"):
            epoch = array_ops.placeholder(dtypes.int64, shape=[])
            dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
            dataset2 = dataset_ops.Dataset.range(1000)
            dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()
        init_op = multi_device_iterator.initializer

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            for i in range(1000):
                sess.run(init_op, feed_dict={epoch: i})
                self.assertEqual([(i, 0), (i, 1)],
                                 sess.run([elem_on_1, elem_on_2]))
    def testOneOnSameDevice(self):
        with ops.device("/cpu:0"):
            dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:0", "/cpu:1"])

        config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)
Example #26
0
    def testBasicGpu(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/gpu:0"])
        elem_on_1, elem_on_2 = multi_device_iterator.get_next()

        config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
        with self.test_session(config=config) as sess:
            sess.run(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                self.assertEqual(i, sess.run(elem_on_1))
                self.assertEqual(i + 1, sess.run(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(elem_on_1)
                sess.run(elem_on_2)
    def testUneven(self):
        dataset = dataset_ops.Dataset.range(10)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 10, 2):
                elem_on_1 = multi_device_iterator.get_next("/cpu:1")
                self.assertEqual(i, self.evaluate(elem_on_1))
            for i in range(0, 10, 2):
                elem_on_2 = multi_device_iterator.get_next("/cpu:2")
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)
    def testNotFullyDivisible(self):
        dataset = dataset_ops.Dataset.range(9)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2"])

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config):
            self.evaluate(multi_device_iterator.initializer)
            for i in range(0, 8, 2):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.assertEqual(i, self.evaluate(elem_on_1))
                self.assertEqual(i + 1, self.evaluate(elem_on_2))
            elem_on_1 = multi_device_iterator.get_next("/cpu:1")
            self.assertEqual(8, self.evaluate(elem_on_1))
            with self.assertRaises(errors.OutOfRangeError):
                elem_on_1, elem_on_2 = multi_device_iterator.get_next()
                self.evaluate(elem_on_1)
                self.evaluate(elem_on_2)
Example #29
0
 def testPrefetchWithSlackOption(self):
     """Determines slack_period based on num devices attached to iterator."""
     dataset = dataset_ops.Dataset.range(10)
     dataset = dataset.prefetch(1)
     options = dataset_ops.Options()
     options.experimental_slack = True
     dataset = dataset.with_options(options)
     multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
         dataset, [self._devices[1], self._devices[2]])
     self.evaluate(multi_device_iterator.initializer)
     for i in range(0, 10, 2):
         elem_on_1, elem_on_2 = multi_device_iterator.get_next()
         self.assertEqual(i, self.evaluate(elem_on_1))
         self.assertEqual(i + 1, self.evaluate(elem_on_2))
     with self.assertRaises(errors.OutOfRangeError):
         elem_on_1, elem_on_2 = multi_device_iterator.get_next()
         self.evaluate(elem_on_1)
         self.evaluate(elem_on_2)
Example #30
0
    def testRepeatDevices(self):
        with ops.device("/cpu:0"):
            dataset = dataset_ops.Dataset.range(20)
        multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
            dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
        elements = multi_device_iterator.get_next()
        elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements

        config = config_pb2.ConfigProto(device_count={"CPU": 3})
        with self.test_session(config=config) as sess:
            sess.run(multi_device_iterator.initializer)
            for i in range(0, 20, 4):
                self.assertEqual(i, sess.run(elem_on_1))
                self.assertEqual(i + 1, sess.run(elem_on_2))
                self.assertEqual(i + 2, sess.run(elem_on_3))
                self.assertEqual(i + 3, sess.run(elem_on_4))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(elem_on_1)
                sess.run(elem_on_2)
                sess.run(elem_on_3)
                sess.run(elem_on_4)