Example #1
0
    def test_mpi_run_on_large_cluster(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = copy.copy(self.minimal_settings)
        settings.num_hosts = large_cluster_threshold
        run_func = MagicMock(return_value=0)

        mpi_run(settings, None, {}, cmd, run_func=run_func)

        mpi_flags, binding_args = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        mpi_flags.append('-mca plm_rsh_no_tree_spawn true')
        mpi_flags.append('-mca plm_rsh_num_concurrent {}'.format(
            settings.num_hosts))
        expected_cmd = ('mpirun '
                        '--allow-run-as-root --tag-output '
                        '-np 2 -H host '
                        '{binding_args} '
                        '{mpi_flags}       '
                        'cmd').format(binding_args=' '.join(binding_args),
                                      mpi_flags=' '.join(mpi_flags))
        expected_env = {}
        run_func.assert_called_once_with(command=expected_cmd,
                                         env=expected_env,
                                         stdout=None,
                                         stderr=None)
Example #2
0
    def test_mpi_run_full(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        common_intfs = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        mpi_run(settings,
                common_intfs,
                env,
                cmd,
                stdout=stdout,
                stderr=stderr,
                run_func=run_func)

        mpi_flags, _ = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'mpirun '
            '--allow-run-as-root --tag-output '
            '-np 1 -H >host names go here< '
            '>binding args go here< '
            '{mpi_flags} '
            '-mca plm_rsh_args "-p 1022" '
            '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
            '--output-filename >output filename goes here< '
            '-x env1 -x env2 '
            '>mpi-extra args go here< '
            'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)
Example #3
0
    def test_mpi_run_with_non_zero_exit(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings
        run_func = MagicMock(return_value=1)

        with pytest.raises(RuntimeError,
                           match="^mpirun failed with exit code 1$") as e:
            mpi_run(settings, None, {}, cmd, run_func=run_func)
Example #4
0
    def test_js_run(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        settings = hvd_settings.Settings(
            verbose=0,
            extra_mpi_args='>mpi-extra args go here<',
            num_hosts=2,
            num_proc=4,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        js_run(settings,
               None,
               env,
               cmd,
               stdout=stdout,
               stderr=stderr,
               run_func=run_func)

        mpi_flags, _ = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'jsrun '
            '--erf_input /tmp/rankfile '
            '--stdio_stderr >output filename goes here< '
            '--stdio_stdout >output filename goes here< '
            '--smpiargs \'{mpi_args} >mpi-extra args go here<\' '
            'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)
Example #5
0
    def test_mpi_run_minimal(self):
        if _get_mpi_implementation_flags() is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings
        run_func = MagicMock(return_value=0)

        mpi_run(settings, None, {}, cmd, run_func=run_func)

        mpi_flags = _get_mpi_implementation_flags()
        self.assertIsNotNone(mpi_flags)
        expected_cmd = ('mpirun '
                        '--allow-run-as-root --tag-output '
                        '-np 2 -H host '
                        '-bind-to none -map-by slot '
                        '{mpi_flags}       '
                        'cmd').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {}
        run_func.assert_called_once_with(command=expected_cmd,
                                         env=expected_env,
                                         stdout=None,
                                         stderr=None)
Example #6
0
    def test_js_run(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        settings = hvd_settings.Settings(
            verbose=0,
            extra_mpi_args='>mpi-extra args go here<',
            num_hosts=2,
            num_proc=4,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.js_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.js_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                js_run(settings, None, env, cmd, stdout=stdout, stderr=stderr)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, _ = horovod.run.js_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'jsrun '
                    '--erf_input /tmp/rankfile '
                    '--stdio_stderr >output filename goes here< '
                    '--stdio_stdout >output filename goes here< '
                    '--smpiargs \'{mpi_args} >mpi-extra args go here<\' '
                    'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags))
                expected_env = {'env1': 'val1', 'env2': 'val2'}
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)
Example #7
0
    def do_test_spark_run_func(self,
                               args=(),
                               kwargs={},
                               num_proc=1,
                               extra_mpi_args=None,
                               env={},
                               stdout=None,
                               stderr=None,
                               verbose=0,
                               cores=2,
                               expected_np=1,
                               expected_env=''):
        def fn():
            return 1

        run_func = MagicMock(return_value=0)

        with spark_session('test_spark_run_func', cores=cores):
            with pytest.raises(Exception) as e:
                # we need to timeout horovod because our mocked run_func will block spark otherwise
                # this raises above exception, but allows us to catch run_func arguments
                horovod.spark.run(fn,
                                  args=args,
                                  kwargs=kwargs,
                                  num_proc=num_proc,
                                  start_timeout=1,
                                  extra_mpi_args=extra_mpi_args,
                                  env=env,
                                  stdout=stdout,
                                  stderr=stderr,
                                  verbose=verbose,
                                  run_func=run_func)

        self.assertFalse(
            str(e.value).startswith(
                'Timed out waiting for Spark tasks to start.'),
            'Spark timed out before mpi_run was called, test setup is broken.')
        self.assertEqual(str(e.value),
                         'Spark job has failed, see the error above.')

        mpi_flags = _get_mpi_implementation_flags()
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'mpirun '
            '--allow-run-as-root --tag-output '
            '-np {expected_np} -H [^ ]+ '
            '-bind-to none -map-by slot '
            '{mpi_flags}  '
            '-mca btl_tcp_if_include [^ ]+ -x NCCL_SOCKET_IFNAME=[^ ]+  '
            '-x _HOROVOD_SECRET_KEY {expected_env}'
            '{extra_mpi_args} '
            '-x NCCL_DEBUG=INFO '
            r'-mca plm_rsh_agent "[^"]+python[\d]* -m horovod.spark.driver.mpirun_rsh [^ ]+ [^ ]+" '
            r'[^"]+python[\d]* -m horovod.spark.task.mpirun_exec_fn [^ ]+ [^ ]+'
            .format(expected_np=expected_np,
                    expected_env=expected_env + ' ' if expected_env else '',
                    mpi_flags=' '.join(mpi_flags),
                    extra_mpi_args=extra_mpi_args if extra_mpi_args else ''))

        run_func.assert_called_once()
        run_func_args, run_func_kwargs = run_func.call_args
        actual_command = run_func_kwargs.get('command')
        actual_env = run_func_kwargs.get('env')
        actual_stdout = run_func_kwargs.get('stdout')
        actual_stderr = run_func_kwargs.get('stderr')
        actual_secret = actual_env.pop('_HOROVOD_SECRET_KEY', None)

        # for better comparison replace sections in actual_command that change across runs / hosts
        for replacement in (
                '-H [^ ]+', '-mca btl_tcp_if_include [^ ]+',
                '-x NCCL_SOCKET_IFNAME=[^ ]+', r'"[^"]+python[\d]*',
                r' [^"]+python[\d]*',
                '-m horovod.spark.driver.mpirun_rsh [^ ]+ [^ ]+"',
                '-m horovod.spark.task.mpirun_exec_fn [^ ]+ [^ ]+'):
            actual_command = re.sub(replacement, replacement, actual_command,
                                    1)

        self.assertEqual(run_func_args, ())
        self.assertEqual(actual_command, expected_command)
        if env:
            self.assertEqual(actual_env, env)
        else:
            self.assertIsNotNone(actual_env)
        self.assertIsNotNone(actual_secret)
        self.assertTrue(len(actual_secret) > 0)
        self.assertEqual(actual_stdout, stdout)
        self.assertEqual(actual_stderr, stderr)
Example #8
0
def js_run(settings,
           common_intfs,
           env,
           command,
           stdout=None,
           stderr=None,
           run_func=safe_shell_exec.execute):
    """
    Runs Horovod with jsrun.

    Args:
        settings: Settings for running jsrun.
                  Note: settings.num_proc and settings.hosts must not be None.
        common_intfs: Interfaces to include by jsrun.
        env: Environment dictionary to use for running jsrun.
        command: Command and arguments to run as a list of string.
        stdout: Stdout of the mpi process.
                Only used when settings.run_func_mode is True.
        stderr: Stderr of the mpi process.
                Only used when settings.run_func_mode is True.
        run_func: Run function to use. Must have arguments 'command' and 'env'.
                  Only used when settings.run_func_mode is True.
                  Defaults to safe_shell_exec.execute.
    """
    mpi_impl_flags, _ = _get_mpi_implementation_flags(settings.tcp_flag)
    if mpi_impl_flags is None:
        raise Exception(_MPI_NOT_FOUND_ERROR_MSG)

    if not is_jsrun_installed():
        raise Exception(
            'horovodrun convenience script does not find the jsrun command.\n\n'
            'Please, make sure you are running on a cluster with jsrun installed or '
            'use one of the other launchers.')

    if common_intfs and 'NCCL_SOCKET_IFNAME' not in env:
        env['NCCL_SOCKET_IFNAME'] = ','.join(common_intfs)

    smpiargs = ' '.join(mpi_impl_flags)
    if settings.extra_mpi_args:
        smpiargs += ' ' + settings.extra_mpi_args

    if settings.binding_args:
        binding_args = settings.binding_args
    else:
        rf = generate_jsrun_rankfile(settings)
        if settings.verbose >= 2:
            safe_shell_exec.execute('cat {rf}'.format(rf=rf))
        binding_args = '--erf_input {rf}'.format(rf=rf)

    jsrun_command = (
        'jsrun {binding_args} '
        '{output_filename_arg} '
        '{smpiargs} '
        '{command}'.format(
            binding_args=binding_args,
            output_filename_arg='--stdio_stderr {file} --stdio_stdout {file}'.
            format(file=settings.output_filename)
            if settings.output_filename else '',
            smpiargs='--smpiargs {args}'.format(
                args=quote(smpiargs)) if smpiargs else '',
            command=' '.join(quote(par) for par in command)))

    if settings.verbose >= 2:
        print(jsrun_command)

    # Execute the jsrun command.
    if settings.run_func_mode:
        exit_code = run_func(command=jsrun_command,
                             env=env,
                             stdout=stdout,
                             stderr=stderr)
        if exit_code != 0:
            raise RuntimeError(
                "jsrun failed with exit code {exit_code}".format(
                    exit_code=exit_code))
    else:
        os.execve('/bin/sh', ['/bin/sh', '-c', jsrun_command], env)