def test(use_gloo, use_mpi, use_js, gloo_is_built, mpi_is_built, lsf_exists, jsrun_installed, expected, exception): gloo_run = MagicMock() mpi_run = MagicMock() js_run = MagicMock() with is_built(gloo_is_built, mpi_is_built): with lsf_and_jsrun(lsf_exists, jsrun_installed): if exception is not None: with pytest.raises(ValueError, match=exception) as e: run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_js, js_run, verbosity=2) return run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_js, js_run, verbosity=2) if expected == "gloo": gloo_run.assert_called_once() mpi_run.assert_not_called() js_run.assert_not_called() elif expected == "mpi": gloo_run.assert_not_called() mpi_run.assert_called_once() js_run.assert_not_called() elif expected == "js": gloo_run.assert_not_called() mpi_run.assert_not_called() js_run.assert_called_once() else: raise ValueError("unsupported framework: {}".format(expected))
def _launch_job(use_mpi, use_gloo, settings, driver, env, stdout=None, stderr=None): nics = driver.get_common_interfaces() run_controller( use_gloo, lambda: gloo_run(settings, nics, driver, env), use_mpi, lambda: mpi_run(settings, nics, driver, env, stdout, stderr), False, lambda: None, settings.verbose)