Example #1
0
 def _distribute_dataset(self, dataset_fn):
     if self._local_mode:
         return input_lib.PerReplicaDataset(
             self._call_dataset_fn(dataset_fn), self._input_workers, 0)
     else:
         return input_lib.MultiWorkerDataset(
             functools.partial(self._call_dataset_fn, dataset_fn),
             self._input_workers,
             auto_shard=self._auto_shard_dataset)
 def _test_dataset(self, dataset_fn, worker_devices, devices,
                   expected_values):
     device_map = values.ReplicaDeviceMap(devices)
     input_workers = input_lib.InputWorkers(device_map, worker_devices)
     multi_worker_dataset = input_lib.MultiWorkerDataset(
         dataset_fn, input_workers)
     multi_worker_iterator = multi_worker_dataset.make_initializable_iterator(
     )
     with self.cached_session() as sess:
         sess.run(multi_worker_iterator.initializer)
         self._test_iterator(sess, multi_worker_iterator, devices,
                             expected_values)
    def testInitializableIterator(self):
        worker_devices, devices = self._cpu_devices()
        with context.graph_mode(), self.cached_session() as sess:
            dataset_fn = lambda: dataset_ops.Dataset.range(8)
            device_map = values.ReplicaDeviceMap(devices)
            input_workers = input_lib.InputWorkers(device_map, worker_devices)
            multi_worker_dataset = input_lib.MultiWorkerDataset(
                dataset_fn, input_workers)
            multi_worker_iterator = multi_worker_dataset.make_initializable_iterator(
            )

            sess.run(multi_worker_iterator.initializer)
            self._test_iterator(sess, multi_worker_iterator, devices,
                                [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
                                 [5, 5], [6, 6], [7, 7]])

            # After re-initializing the iterator, should be able to iterate again.
            sess.run(multi_worker_iterator.initializer)
            self._test_iterator(sess, multi_worker_iterator, devices,
                                [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
                                 [5, 5], [6, 6], [7, 7]])
Example #4
0
 def _distribute_dataset(self, dataset_fn):
     return input_lib.MultiWorkerDataset(
         functools.partial(self._call_dataset_fn, dataset_fn),
         self._input_workers)