def test_nested_pool(self):
        if multi_process_runner.is_oss():
            self.skipTest('TODO(b/170360740): Failing in OSS')

        def fn():
            # This runs in sub processes, so they are each using their own
            # MultiProcessPoolRunner.
            _global_pool.run(fn_that_does_nothing)

        _global_pool.run(fn)
Пример #2
0
    def test_timeout_none(self):

        if multi_process_runner.is_oss():
            self.skipTest('Intentionally skipping longer test in OSS.')

        def fn():
            time.sleep(250)
            raise ValueError('Worker 0 errored')

        mpr = multi_process_runner.MultiProcessRunner(
            fn, multi_worker_test_base.create_cluster_spec(num_workers=1))

        mpr.start()
        with self.assertRaisesRegex(ValueError, 'Worker 0 errored'):
            mpr.join(timeout=None)
Пример #3
0
    def testCheckHealthPeerDown(self):

        if multi_process_runner.is_oss():
            self.skipTest("TODO(b/170838845): Failing in OSS")

        def worker_fn():
            enable_collective_ops(
                cluster_resolver_lib.TFConfigClusterResolver())
            context.context().check_collective_ops_peer_health(
                "/job:worker/replica:0/task:1", )

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
        mpr.start_single_process("worker", 0)
        with self.assertRaises(errors.UnavailableError):
            mpr.join()
Пример #4
0
    def test_seg_fault_raises_error(self):

        if multi_process_runner.is_oss() or sys.version_info >= (3, 7):
            self.skipTest('TODO(b/171004637): Failing in OSS and Python 3.7+')

        def fn_expected_to_seg_fault():
            ctypes.string_at(0)  # Intentionally made seg fault.

        with self.assertRaises(
                multi_process_runner.UnexpectedSubprocessExitError) as cm:
            multi_process_runner.run(
                fn_expected_to_seg_fault,
                multi_worker_test_base.create_cluster_spec(num_workers=1),
                return_output=True)
        self.assertIn('Subprocess worker-0 exited with exit code',
                      str(cm.exception))
        list_to_assert = cm.exception.mpr_result.stdout
        self.assertTrue(
            any('Segmentation fault' in line for line in list_to_assert))
        def proc_model_checkpoint_works_with_same_file_path(
                test_obj, saving_filepath):
            if multi_process_runner.is_oss():
                test_obj.skipTest('TODO(b/170838633): Failing in OSS')
            model, _, train_ds, steps = _model_setup(test_obj, file_format='')
            num_epoch = 4

            # The saving_filepath shouldn't exist at the beginning (as it's unique).
            test_obj.assertFalse(file_io.file_exists_v2(saving_filepath))
            bar_dir = os.path.join(os.path.dirname(saving_filepath), 'backup')

            try:
                model.fit(
                    x=train_ds,
                    epochs=num_epoch,
                    steps_per_epoch=steps,
                    callbacks=[
                        callbacks.ModelCheckpoint(filepath=saving_filepath),
                        callbacks.BackupAndRestore(backup_dir=bar_dir),
                        InterruptingCallback()
                    ])
            except RuntimeError as e:
                if 'Interrupting!' not in str(e):
                    raise

            multi_process_runner.get_barrier().wait()
            backup_filepath = os.path.join(bar_dir, 'chief', 'checkpoint')
            test_obj.assertTrue(file_io.file_exists_v2(backup_filepath))
            test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))

            model.fit(x=train_ds,
                      epochs=num_epoch,
                      steps_per_epoch=steps,
                      callbacks=[
                          callbacks.ModelCheckpoint(filepath=saving_filepath),
                          callbacks.BackupAndRestore(backup_dir=bar_dir),
                          AssertCallback()
                      ])
            multi_process_runner.get_barrier().wait()
            test_obj.assertFalse(file_io.file_exists_v2(backup_filepath))
            test_obj.assertTrue(file_io.file_exists_v2(saving_filepath))
Пример #6
0
    def test_seg_fault_in_chief_raises_error(self):

        if multi_process_runner.is_oss():
            self.skipTest('TODO(b/171004637): Failing in OSS')

        def fn_expected_to_seg_fault():
            if multi_worker_test_base.get_task_type() == 'worker':
                time.sleep(10000)
            ctypes.string_at(0)  # Intentionally made seg fault.

        with self.assertRaises(
                multi_process_runner.UnexpectedSubprocessExitError) as cm:
            multi_process_runner.run(
                fn_expected_to_seg_fault,
                multi_worker_test_base.create_cluster_spec(has_chief=True,
                                                           num_workers=1),
                return_output=True)
        self.assertIn('Subprocess chief-0 exited with exit code',
                      str(cm.exception))
        list_to_assert = cm.exception.mpr_result.stdout
        self.assertTrue(
            any('Segmentation fault' in line for line in list_to_assert))
Пример #7
0
        parameterized.TestCase):
    @ds_combinations.generate(
        combinations.combine(mode=['graph'], required_gpus=[2, 4]))
    def testComplexModel(self, required_gpus):
        self._test_complex_model(None, None, required_gpus)

    @ds_combinations.generate(
        combinations.combine(mode=['graph'], required_gpus=[2, 4]))
    @testing_utils.enable_v2_dtype_behavior
    def testMixedPrecision(self, required_gpus):
        with policy.policy_scope('mixed_float16'):
            self._test_mixed_precision(None, None, required_gpus)


# TODO(b/170360740): Timeout in OSS
if not multi_process_runner.is_oss():

    @ds_combinations.generate(
        combinations.combine(strategy=[
            strategy_combinations.multi_worker_mirrored_2x1_cpu,
            strategy_combinations.multi_worker_mirrored_2x1_gpu,
        ],
                             mode=['eager']))
    class DistributedCollectiveAllReduceStrategyEagerTest(
            test.TestCase, parameterized.TestCase):
        def testFitWithoutStepsPerEpochPartialBatch(self, strategy):
            def _model_fn():
                x = layers.Input(shape=(1, ), name='input')
                y = layers.Dense(1, name='dense')(x)
                model = training.Model(x, y)
                return model
 def test_global_pool(self):
     if multi_process_runner.is_oss():
         self.skipTest('TODO(b/170360740): Failing in OSS')
     _global_pool.run(fn_that_does_nothing)