def test_mpi_run_on_large_cluster(self): if _get_mpi_implementation_flags(False)[0] is None: self.skipTest("MPI is not available") cmd = ['cmd'] settings = copy.copy(self.minimal_settings) settings.num_hosts = large_cluster_threshold run_func = MagicMock(return_value=0) mpi_run(settings, None, {}, cmd, run_func=run_func) mpi_flags, binding_args = _get_mpi_implementation_flags(False) self.assertIsNotNone(mpi_flags) mpi_flags.append('-mca plm_rsh_no_tree_spawn true') mpi_flags.append('-mca plm_rsh_num_concurrent {}'.format( settings.num_hosts)) expected_cmd = ('mpirun ' '--allow-run-as-root --tag-output ' '-np 2 -H host ' '{binding_args} ' '{mpi_flags} ' 'cmd').format(binding_args=' '.join(binding_args), mpi_flags=' '.join(mpi_flags)) expected_env = {} run_func.assert_called_once_with(command=expected_cmd, env=expected_env, stdout=None, stderr=None)
def test_mpi_run_full(self): if _get_mpi_implementation_flags(False)[0] is None: self.skipTest("MPI is not available") cmd = ['cmd', 'arg1', 'arg2'] common_intfs = ['eth0', 'eth1'] env = {'env1': 'val1', 'env2': 'val2'} stdout = '<stdout>' stderr = '<stderr>' tmout = timeout.Timeout(5, message='Timed out waiting for something.') settings = hvd_settings.Settings( verbose=0, ssh_port=1022, extra_mpi_args='>mpi-extra args go here<', binding_args='>binding args go here<', key=secret.make_secret_key(), timeout=tmout, num_hosts=1, num_proc=1, hosts='>host names go here<', output_filename='>output filename goes here<', run_func_mode=True) run_func = MagicMock(return_value=0) mpi_run(settings, common_intfs, env, cmd, stdout=stdout, stderr=stderr, run_func=run_func) mpi_flags, _ = _get_mpi_implementation_flags(False) self.assertIsNotNone(mpi_flags) expected_command = ( 'mpirun ' '--allow-run-as-root --tag-output ' '-np 1 -H >host names go here< ' '>binding args go here< ' '{mpi_flags} ' '-mca plm_rsh_args "-p 1022" ' '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 ' '--output-filename >output filename goes here< ' '-x env1 -x env2 ' '>mpi-extra args go here< ' 'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags)) expected_env = {'env1': 'val1', 'env2': 'val2'} run_func.assert_called_once_with(command=expected_command, env=expected_env, stdout=stdout, stderr=stderr)
def test_mpi_run_with_non_zero_exit(self): if _get_mpi_implementation_flags(False)[0] is None: self.skipTest("MPI is not available") cmd = ['cmd'] settings = self.minimal_settings run_func = MagicMock(return_value=1) with pytest.raises(RuntimeError, match="^mpirun failed with exit code 1$") as e: mpi_run(settings, None, {}, cmd, run_func=run_func)
def test_js_run(self): if _get_mpi_implementation_flags(False)[0] is None: self.skipTest("MPI is not available") cmd = ['cmd', 'arg1', 'arg2'] env = {'env1': 'val1', 'env2': 'val2'} stdout = '<stdout>' stderr = '<stderr>' settings = hvd_settings.Settings( verbose=0, extra_mpi_args='>mpi-extra args go here<', num_hosts=2, num_proc=4, hosts='>host names go here<', output_filename='>output filename goes here<', run_func_mode=True) run_func = MagicMock(return_value=0) js_run(settings, None, env, cmd, stdout=stdout, stderr=stderr, run_func=run_func) mpi_flags, _ = _get_mpi_implementation_flags(False) self.assertIsNotNone(mpi_flags) expected_command = ( 'jsrun ' '--erf_input /tmp/rankfile ' '--stdio_stderr >output filename goes here< ' '--stdio_stdout >output filename goes here< ' '--smpiargs \'{mpi_args} >mpi-extra args go here<\' ' 'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags)) expected_env = {'env1': 'val1', 'env2': 'val2'} run_func.assert_called_once_with(command=expected_command, env=expected_env, stdout=stdout, stderr=stderr)
def test_mpi_run_minimal(self): if _get_mpi_implementation_flags() is None: self.skipTest("MPI is not available") cmd = ['cmd'] settings = self.minimal_settings run_func = MagicMock(return_value=0) mpi_run(settings, None, {}, cmd, run_func=run_func) mpi_flags = _get_mpi_implementation_flags() self.assertIsNotNone(mpi_flags) expected_cmd = ('mpirun ' '--allow-run-as-root --tag-output ' '-np 2 -H host ' '-bind-to none -map-by slot ' '{mpi_flags} ' 'cmd').format(mpi_flags=' '.join(mpi_flags)) expected_env = {} run_func.assert_called_once_with(command=expected_cmd, env=expected_env, stdout=None, stderr=None)
def test_js_run(self): if _get_mpi_implementation_flags(False)[0] is None: self.skipTest("MPI is not available") cmd = ['cmd', 'arg1', 'arg2'] env = {'env1': 'val1', 'env2': 'val2'} stdout = '<stdout>' stderr = '<stderr>' settings = hvd_settings.Settings( verbose=0, extra_mpi_args='>mpi-extra args go here<', num_hosts=2, num_proc=4, hosts='>host names go here<', output_filename='>output filename goes here<', run_func_mode=True) def mpi_impl_flags(tcp): return ["--mock-mpi-impl-flags"], [] with mock.patch("horovod.run.js_run._get_mpi_implementation_flags", side_effect=mpi_impl_flags): with mock.patch("horovod.run.js_run.safe_shell_exec.execute", return_value=0) as execute: js_run(settings, None, env, cmd, stdout=stdout, stderr=stderr) # call the mocked _get_mpi_implementation_flags method mpi_flags, _ = horovod.run.js_run._get_mpi_implementation_flags( False) self.assertIsNotNone(mpi_flags) expected_command = ( 'jsrun ' '--erf_input /tmp/rankfile ' '--stdio_stderr >output filename goes here< ' '--stdio_stdout >output filename goes here< ' '--smpiargs \'{mpi_args} >mpi-extra args go here<\' ' 'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags)) expected_env = {'env1': 'val1', 'env2': 'val2'} execute.assert_called_once_with(expected_command, env=expected_env, stdout=stdout, stderr=stderr)
def do_test_spark_run_func(self, args=(), kwargs={}, num_proc=1, extra_mpi_args=None, env={}, stdout=None, stderr=None, verbose=0, cores=2, expected_np=1, expected_env=''): def fn(): return 1 run_func = MagicMock(return_value=0) with spark_session('test_spark_run_func', cores=cores): with pytest.raises(Exception) as e: # we need to timeout horovod because our mocked run_func will block spark otherwise # this raises above exception, but allows us to catch run_func arguments horovod.spark.run(fn, args=args, kwargs=kwargs, num_proc=num_proc, start_timeout=1, extra_mpi_args=extra_mpi_args, env=env, stdout=stdout, stderr=stderr, verbose=verbose, run_func=run_func) self.assertFalse( str(e.value).startswith( 'Timed out waiting for Spark tasks to start.'), 'Spark timed out before mpi_run was called, test setup is broken.') self.assertEqual(str(e.value), 'Spark job has failed, see the error above.') mpi_flags = _get_mpi_implementation_flags() self.assertIsNotNone(mpi_flags) expected_command = ( 'mpirun ' '--allow-run-as-root --tag-output ' '-np {expected_np} -H [^ ]+ ' '-bind-to none -map-by slot ' '{mpi_flags} ' '-mca btl_tcp_if_include [^ ]+ -x NCCL_SOCKET_IFNAME=[^ ]+ ' '-x _HOROVOD_SECRET_KEY {expected_env}' '{extra_mpi_args} ' '-x NCCL_DEBUG=INFO ' r'-mca plm_rsh_agent "[^"]+python[\d]* -m horovod.spark.driver.mpirun_rsh [^ ]+ [^ ]+" ' r'[^"]+python[\d]* -m horovod.spark.task.mpirun_exec_fn [^ ]+ [^ ]+' .format(expected_np=expected_np, expected_env=expected_env + ' ' if expected_env else '', mpi_flags=' '.join(mpi_flags), extra_mpi_args=extra_mpi_args if extra_mpi_args else '')) run_func.assert_called_once() run_func_args, run_func_kwargs = run_func.call_args actual_command = run_func_kwargs.get('command') actual_env = run_func_kwargs.get('env') actual_stdout = run_func_kwargs.get('stdout') actual_stderr = run_func_kwargs.get('stderr') actual_secret = actual_env.pop('_HOROVOD_SECRET_KEY', None) # for better comparison replace sections in actual_command that change across runs / hosts for replacement in ( '-H [^ ]+', '-mca btl_tcp_if_include [^ ]+', '-x NCCL_SOCKET_IFNAME=[^ ]+', r'"[^"]+python[\d]*', r' [^"]+python[\d]*', '-m horovod.spark.driver.mpirun_rsh [^ ]+ [^ ]+"', '-m horovod.spark.task.mpirun_exec_fn [^ ]+ [^ ]+'): actual_command = re.sub(replacement, replacement, actual_command, 1) self.assertEqual(run_func_args, ()) self.assertEqual(actual_command, expected_command) if env: self.assertEqual(actual_env, env) else: self.assertIsNotNone(actual_env) self.assertIsNotNone(actual_secret) self.assertTrue(len(actual_secret) > 0) self.assertEqual(actual_stdout, stdout) self.assertEqual(actual_stderr, stderr)
def js_run(settings, common_intfs, env, command, stdout=None, stderr=None, run_func=safe_shell_exec.execute): """ Runs Horovod with jsrun. Args: settings: Settings for running jsrun. Note: settings.num_proc and settings.hosts must not be None. common_intfs: Interfaces to include by jsrun. env: Environment dictionary to use for running jsrun. command: Command and arguments to run as a list of string. stdout: Stdout of the mpi process. Only used when settings.run_func_mode is True. stderr: Stderr of the mpi process. Only used when settings.run_func_mode is True. run_func: Run function to use. Must have arguments 'command' and 'env'. Only used when settings.run_func_mode is True. Defaults to safe_shell_exec.execute. """ mpi_impl_flags, _ = _get_mpi_implementation_flags(settings.tcp_flag) if mpi_impl_flags is None: raise Exception(_MPI_NOT_FOUND_ERROR_MSG) if not is_jsrun_installed(): raise Exception( 'horovodrun convenience script does not find the jsrun command.\n\n' 'Please, make sure you are running on a cluster with jsrun installed or ' 'use one of the other launchers.') if common_intfs and 'NCCL_SOCKET_IFNAME' not in env: env['NCCL_SOCKET_IFNAME'] = ','.join(common_intfs) smpiargs = ' '.join(mpi_impl_flags) if settings.extra_mpi_args: smpiargs += ' ' + settings.extra_mpi_args if settings.binding_args: binding_args = settings.binding_args else: rf = generate_jsrun_rankfile(settings) if settings.verbose >= 2: safe_shell_exec.execute('cat {rf}'.format(rf=rf)) binding_args = '--erf_input {rf}'.format(rf=rf) jsrun_command = ( 'jsrun {binding_args} ' '{output_filename_arg} ' '{smpiargs} ' '{command}'.format( binding_args=binding_args, output_filename_arg='--stdio_stderr {file} --stdio_stdout {file}'. format(file=settings.output_filename) if settings.output_filename else '', smpiargs='--smpiargs {args}'.format( args=quote(smpiargs)) if smpiargs else '', command=' '.join(quote(par) for par in command))) if settings.verbose >= 2: print(jsrun_command) # Execute the jsrun command. if settings.run_func_mode: exit_code = run_func(command=jsrun_command, env=env, stdout=stdout, stderr=stderr) if exit_code != 0: raise RuntimeError( "jsrun failed with exit code {exit_code}".format( exit_code=exit_code)) else: os.execve('/bin/sh', ['/bin/sh', '-c', jsrun_command], env)