def test_tf_config(self): cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) result = runner.run(fn_that_adds_task_type_in_return_data) job_count_dict = {'worker': 2, 'chief': 1} for data in result: job_count_dict[data] -= 1 self.assertEqual(job_count_dict['worker'], 0) self.assertEqual(job_count_dict['chief'], 0)
def runner(self): if not self._runner: if (_num_total_workers(self.has_chief, self.num_workers) > 1 and self.use_pool_runner): # Need to create the strategy in the initializer so that collectives are # configured before eager context initialization. cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=self.has_chief, num_workers=self.num_workers, num_ps=0, has_eval=False) self._runner = multi_process_runner.MultiProcessPoolRunner( cluster_spec, initializer=self._distribution_fn) return self._runner
def __init__(self, name, distribution_fn, required_gpus=None, required_tpu=False, use_cloud_tpu=False, has_chief=False, num_workers=1, use_pool_runner=False, no_xla=False): """Initialize NamedDistribution. Args: name: Name that will be a part of the name of the test case. distribution_fn: A callable that creates a `tf.distribute.Strategy`. required_gpus: The number of GPUs that the strategy requires. required_tpu: Whether the strategy requires TPU. use_cloud_tpu: Whether the strategy requires cloud TPU. has_chief: Whether the strategy requires a chief worker. num_workers: The number of workers that the strategy requires. use_pool_runner: Whether to use a pool runner so that workers are re-used each time. no_xla: Whether to skip in XLA tests. """ object.__init__(self) self._name = name self._distribution_fn = distribution_fn self.required_gpus = required_gpus self.required_tpu = required_tpu self.use_cloud_tpu = use_cloud_tpu self.has_chief = has_chief self.num_workers = num_workers self.no_xla = no_xla self._runner = None if _num_total_workers(self.has_chief, self.num_workers) > 1: cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=has_chief, num_workers=num_workers, num_ps=0, has_eval=False) if use_pool_runner: # Need to create the strategy in the initializer so that collectives are # configured before eager context initialization. self._runner = multi_process_runner.MultiProcessPoolRunner( cluster_spec, initializer=self._distribution_fn)
def __init__(self, num_processes): cluster_spec_dict = multi_worker_test_base.create_cluster_spec( num_workers=num_processes) self.runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec_dict)
def testCheckHealthInvalidPeer(self): def worker_fn(): enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver()) context.context().check_collective_ops_peer_health( "localhost:12345", timeout_in_ms=1000) cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) mpr.start_single_process("worker", 0) with self.assertRaises(errors.InvalidArgumentError): mpr.join() two_worker_pool_runner = multi_process_runner.MultiProcessPoolRunner( multi_worker_test_base.create_cluster_spec(num_workers=2), initializer=lambda: enable_collective_ops(cluster_resolver_lib. TFConfigClusterResolver())) @combinations.generate( combinations.times( combinations.combine( mode="eager", num_workers=2, runner=two_worker_pool_runner), device_combination)) class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): def testAbortCommunication(self, device, communication): if communication == "NCCL": self.skipTest("b/171358086: cannot test multi worker NCCL") dev0 = "/device:%s:0" % device cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
if multi_process_runner.is_oss(): self.skipTest('Intentionally skipping longer test in OSS.') def fn(): time.sleep(250) raise ValueError('Worker 0 errored') mpr = multi_process_runner.MultiProcessRunner( fn, multi_worker_test_base.create_cluster_spec(num_workers=1)) mpr.start() with self.assertRaisesRegex(ValueError, 'Worker 0 errored'): mpr.join(timeout=None) _global_pool = multi_process_runner.MultiProcessPoolRunner( multi_worker_test_base.create_cluster_spec(num_workers=2)) class MultiProcessPoolRunnerTest(test.TestCase): def test_same_process_across_runs(self): cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) pid = runner.run(fn_that_returns_pid) for _ in range(3): self.assertAllEqual(runner.run(fn_that_returns_pid), pid) def test_exceptions_in_sub_process(self): cluster_spec = multi_worker_test_base.create_cluster_spec( num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
def test_initializer(self): cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner( cluster_spec, initializer=lambda: proc_func_that_sets_global(1)) result = runner.run(proc_func_that_sets_global, args=(2,)) self.assertAllEqual(result, [1, 1])
def test_same_process_across_runs(self): cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) pid = runner.run(proc_func_that_returns_pid) for _ in range(3): self.assertAllEqual(runner.run(proc_func_that_returns_pid), pid)