def testLimitedRetracing(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        trace_count = [0]

        @def_function.function
        def f(iterator):
            trace_count[0] += 1
            counter = np.int64(0)
            for _ in range(5):
                elem = next(iterator)
                counter += elem[0]
                counter += elem[1]
            return counter

        dataset = dataset_ops.Dataset.range(10)
        dataset2 = dataset_ops.Dataset.range(20)

        for _ in range(10):
            multi_device_iterator = multi_device_iterator_ops.MultiDeviceIteratorV2(
                dataset, ["/cpu:0", "/gpu:0"])
            self.assertEqual(self.evaluate(f(multi_device_iterator)), 45)
            multi_device_iterator2 = multi_device_iterator_ops.MultiDeviceIteratorV2(
                dataset2, ["/cpu:0", "/gpu:0"])
            self.assertEqual(self.evaluate(f(multi_device_iterator2)), 45)
            self.assertEqual(trace_count[0], 1)
 def fn():
     with ops.device("/cpu:0"):
         dataset = dataset_ops.Dataset.range(10)
     iterator = multi_device_iterator_ops.MultiDeviceIteratorV2(
         dataset, ["/cpu:0", "/gpu:0"])
     for _ in range(5):
         el0, el1 = next(iterator)
         queue.enqueue(el0)
         queue.enqueue(el1)
    def testMultipleInitializations(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            dataset = dataset_ops.Dataset.range(1000)

        for _ in range(5):
            multi_device_iterator = multi_device_iterator_ops.MultiDeviceIteratorV2(
                dataset, ["/cpu:0", "/gpu:0"])
            for i, el in enumerate(multi_device_iterator):
                self.assertEqual([i * 2, i * 2 + 1],
                                 [el[0].numpy(), el[1].numpy()])
 def fn():
     dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn,
                                             finalize_fn)
     iterator = multi_device_iterator_ops.MultiDeviceIteratorV2(
         dataset, ["/cpu:0", "/gpu:0"])
     next(iterator)