Exemplo n.º 1
0
    def test_tf_config(self):
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=True, num_workers=2)
        runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
        result = runner.run(fn_that_adds_task_type_in_return_data)

        job_count_dict = {'worker': 2, 'chief': 1}
        for data in result:
            job_count_dict[data] -= 1

        self.assertEqual(job_count_dict['worker'], 0)
        self.assertEqual(job_count_dict['chief'], 0)
Exemplo n.º 2
0
 def runner(self):
     if not self._runner:
         if (_num_total_workers(self.has_chief, self.num_workers) > 1
                 and self.use_pool_runner):
             # Need to create the strategy in the initializer so that collectives are
             # configured before eager context initialization.
             cluster_spec = multi_worker_test_base.create_cluster_spec(
                 has_chief=self.has_chief,
                 num_workers=self.num_workers,
                 num_ps=0,
                 has_eval=False)
             self._runner = multi_process_runner.MultiProcessPoolRunner(
                 cluster_spec, initializer=self._distribution_fn)
     return self._runner
Exemplo n.º 3
0
    def __init__(self,
                 name,
                 distribution_fn,
                 required_gpus=None,
                 required_tpu=False,
                 use_cloud_tpu=False,
                 has_chief=False,
                 num_workers=1,
                 use_pool_runner=False,
                 no_xla=False):
        """Initialize NamedDistribution.

    Args:
      name: Name that will be a part of the name of the test case.
      distribution_fn: A callable that creates a `tf.distribute.Strategy`.
      required_gpus: The number of GPUs that the strategy requires.
      required_tpu: Whether the strategy requires TPU.
      use_cloud_tpu: Whether the strategy requires cloud TPU.
      has_chief: Whether the strategy requires a chief worker.
      num_workers: The number of workers that the strategy requires.
      use_pool_runner: Whether to use a pool runner so that workers are re-used
        each time.
      no_xla: Whether to skip in XLA tests.
    """
        object.__init__(self)
        self._name = name
        self._distribution_fn = distribution_fn
        self.required_gpus = required_gpus
        self.required_tpu = required_tpu
        self.use_cloud_tpu = use_cloud_tpu
        self.has_chief = has_chief
        self.num_workers = num_workers
        self.no_xla = no_xla
        self._runner = None

        if _num_total_workers(self.has_chief, self.num_workers) > 1:
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                has_chief=has_chief,
                num_workers=num_workers,
                num_ps=0,
                has_eval=False)
            if use_pool_runner:
                # Need to create the strategy in the initializer so that collectives are
                # configured before eager context initialization.
                self._runner = multi_process_runner.MultiProcessPoolRunner(
                    cluster_spec, initializer=self._distribution_fn)
 def __init__(self, num_processes):
   cluster_spec_dict = multi_worker_test_base.create_cluster_spec(
       num_workers=num_processes)
   self.runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec_dict)
  def testCheckHealthInvalidPeer(self):

    def worker_fn():
      enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
      context.context().check_collective_ops_peer_health(
          "localhost:12345", timeout_in_ms=1000)

    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
    mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
    mpr.start_single_process("worker", 0)
    with self.assertRaises(errors.InvalidArgumentError):
      mpr.join()


two_worker_pool_runner = multi_process_runner.MultiProcessPoolRunner(
    multi_worker_test_base.create_cluster_spec(num_workers=2),
    initializer=lambda: enable_collective_ops(cluster_resolver_lib.
                                              TFConfigClusterResolver()))


@combinations.generate(
    combinations.times(
        combinations.combine(
            mode="eager", num_workers=2, runner=two_worker_pool_runner),
        device_combination))
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):

  def testAbortCommunication(self, device, communication):
    if communication == "NCCL":
      self.skipTest("b/171358086: cannot test multi worker NCCL")
    dev0 = "/device:%s:0" % device
    cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
Exemplo n.º 6
0
        if multi_process_runner.is_oss():
            self.skipTest('Intentionally skipping longer test in OSS.')

        def fn():
            time.sleep(250)
            raise ValueError('Worker 0 errored')

        mpr = multi_process_runner.MultiProcessRunner(
            fn, multi_worker_test_base.create_cluster_spec(num_workers=1))

        mpr.start()
        with self.assertRaisesRegex(ValueError, 'Worker 0 errored'):
            mpr.join(timeout=None)


_global_pool = multi_process_runner.MultiProcessPoolRunner(
    multi_worker_test_base.create_cluster_spec(num_workers=2))


class MultiProcessPoolRunnerTest(test.TestCase):
    def test_same_process_across_runs(self):
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
        pid = runner.run(fn_that_returns_pid)
        for _ in range(3):
            self.assertAllEqual(runner.run(fn_that_returns_pid), pid)

    def test_exceptions_in_sub_process(self):
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
Exemplo n.º 7
0
 def test_initializer(self):
   cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
   runner = multi_process_runner.MultiProcessPoolRunner(
       cluster_spec, initializer=lambda: proc_func_that_sets_global(1))
   result = runner.run(proc_func_that_sets_global, args=(2,))
   self.assertAllEqual(result, [1, 1])
Exemplo n.º 8
0
 def test_same_process_across_runs(self):
   cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
   runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec)
   pid = runner.run(proc_func_that_returns_pid)
   for _ in range(3):
     self.assertAllEqual(runner.run(proc_func_that_returns_pid), pid)