def _distribute_dataset(self, dataset_fn): if self._local_mode: return input_lib.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._input_workers, 0) else: return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers, auto_shard=self._auto_shard_dataset)
def _test_dataset(self, dataset_fn, worker_devices, devices, expected_values): device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_devices) multi_worker_dataset = input_lib.MultiWorkerDataset( dataset_fn, input_workers) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator( ) with self.cached_session() as sess: sess.run(multi_worker_iterator.initializer) self._test_iterator(sess, multi_worker_iterator, devices, expected_values)
def testInitializableIterator(self): worker_devices, devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: dataset_fn = lambda: dataset_ops.Dataset.range(8) device_map = values.ReplicaDeviceMap(devices) input_workers = input_lib.InputWorkers(device_map, worker_devices) multi_worker_dataset = input_lib.MultiWorkerDataset( dataset_fn, input_workers) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator( ) sess.run(multi_worker_iterator.initializer) self._test_iterator(sess, multi_worker_iterator, devices, [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]]) # After re-initializing the iterator, should be able to iterate again. sess.run(multi_worker_iterator.initializer) self._test_iterator(sess, multi_worker_iterator, devices, [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]])
def _distribute_dataset(self, dataset_fn): return input_lib.MultiWorkerDataset( functools.partial(self._call_dataset_fn, dataset_fn), self._input_workers)