Exemple #1
0
    def __init__(self, ray_ctx, verbose=None, start_timeout=None):

        self.cores_per_node = ray_ctx.ray_node_cpu_cores
        self.num_nodes = ray_ctx.num_ray_nodes
        self.worker_class = make_horovod_worker(self.cores_per_node)
        self.remote_workers = [self.worker_class.remote() for i in range(0, self.num_nodes)]

        hosts = ray.get([worker.hostname.remote() for worker in self.remote_workers])
        hosts_spec, name_rank_to_id, host_to_size = _hosts_to_hosts_spec(hosts)
        self.host_alloc_plan = _allocate(",".join(hosts_spec), self.num_nodes)
        global_rendezv = RendezvousServer(True)
        global_rendezv_port = global_rendezv.start_server(self.host_alloc_plan)

        if start_timeout is None:
            start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

        tmout = timeout.Timeout(start_timeout,
                                message='Timed out waiting for {activity}. Please '
                                        'check connectivity between servers. You '
                                        'may need to increase the --start-timeout '
                                        'parameter if you have too many servers.')

        all_host_names = [k for k in host_to_size]

        settings = hvd_settings.Settings(verbose=2 if verbose else 0,
                                         key=secret.make_secret_key(),
                                         timeout=tmout,
                                         num_hosts=len(all_host_names),
                                         num_proc=self.num_nodes,
                                         hosts=",".join(hosts_spec))

        common_intfs = _find_common_network_interface(host_to_size, name_rank_to_id,
                                                      self.remote_workers, settings)
        iface = list(common_intfs)[0]
        driver_ip = _get_driver_ip([iface])

        common_envs = {
            "HOROVOD_GLOO_RENDEZVOUS_ADDR": driver_ip,
            "HOROVOD_GLOO_RENDEZVOUS_PORT": str(global_rendezv_port),
            "HOROVOD_CONTROLLER": "gloo",
            "HOROVOD_CPU_OPERATIONS": "gloo",
            "HOROVOD_GLOO_IFACE": iface,
            "PYTHONUNBUFFERED": '1',
        }

        for key in os.environ:
            if key.startswith("HOROVOD"):
                common_envs[key] = os.environ[key]

        # todo support other Horovod envs
        self.per_worker_envs = [common_envs.copy() for _ in range(self.num_nodes)]
        for alloc_info in self.host_alloc_plan:
            key = (alloc_info.hostname, alloc_info.local_rank)
            local_envs = self.per_worker_envs[name_rank_to_id[key]]
            local_envs["HOROVOD_RANK"] = str(alloc_info.rank)
            local_envs["HOROVOD_SIZE"] = str(alloc_info.size)
            local_envs["HOROVOD_LOCAL_RANK"] = str(alloc_info.local_rank)
            local_envs["HOROVOD_LOCAL_SIZE"] = str(alloc_info.local_size)
            local_envs["HOROVOD_CROSS_RANK"] = str(alloc_info.cross_rank)
            local_envs["HOROVOD_CROSS_SIZE"] = str(alloc_info.cross_size)
Exemple #2
0
    def test_mpi_run_full(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        nics = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            start_timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp, env=None):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags) as impl:
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, nics, env, cmd, stdout=stdout, stderr=stderr)

                # assert call on _get_mpi_implementation_flags
                impl.assert_called_once_with(None, env=env)

                # call the mocked _get_mpi_implementation_flags method ourselves
                mpi_flags, _ = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'mpirun '
                    '--allow-run-as-root --tag-output '
                    '-np 1 -H >host names go here< '
                    '>binding args go here< '
                    '{mpi_flags} '
                    '-mca plm_rsh_args "-p 1022" '
                    '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
                    '--output-filename >output filename goes here< '
                    '-x env1 -x env2 '
                    '>mpi-extra args go here< '
                    'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
                expected_env = {
                    'env1': 'val1',
                    'env2': 'val2',
                    'PATH': os.environ.get('PATH')
                }
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)
Exemple #3
0
    def test_generate_jsrun_rankfile(self):
        settings = hvd_settings.Settings(
            num_proc=5,
            hosts='host1:4,host2:4,host3:4',
        )

        with temppath() as rankfile_path:
            rankfile_path = generate_jsrun_rankfile(settings, rankfile_path)

            with open(rankfile_path, 'r') as file:
                gen_rankfile = file.read()

            expected_rankfile = (
"""overlapping_rs: allow
cpu_index_using: logical

rank: 0: { hostname: host1; cpu: {0-3} ; gpu: * ; mem: * }
rank: 1: { hostname: host1; cpu: {4-7} ; gpu: * ; mem: * }
rank: 2: { hostname: host1; cpu: {8-11} ; gpu: * ; mem: * }
rank: 3: { hostname: host1; cpu: {12-15} ; gpu: * ; mem: * }

rank: 4: { hostname: host2; cpu: {0-3} ; gpu: * ; mem: * }
""")

            self.assertMultiLineEqual(gen_rankfile, expected_rankfile)
Exemple #4
0
    def test_mpi_run_full(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        common_intfs = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        mpi_run(settings,
                common_intfs,
                env,
                cmd,
                stdout=stdout,
                stderr=stderr,
                run_func=run_func)

        mpi_flags, _ = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'mpirun '
            '--allow-run-as-root --tag-output '
            '-np 1 -H >host names go here< '
            '>binding args go here< '
            '{mpi_flags} '
            '-mca plm_rsh_args "-p 1022" '
            '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
            '--output-filename >output filename goes here< '
            '-x env1 -x env2 '
            '>mpi-extra args go here< '
            'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)
Exemple #5
0
    def test_js_run(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        settings = hvd_settings.Settings(
            verbose=0,
            extra_mpi_args='>mpi-extra args go here<',
            num_hosts=2,
            num_proc=4,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.js_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.js_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                js_run(settings, None, env, cmd, stdout=stdout, stderr=stderr)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, _ = horovod.run.js_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'jsrun '
                    '--erf_input /tmp/rankfile '
                    '--stdio_stderr >output filename goes here< '
                    '--stdio_stdout >output filename goes here< '
                    '--smpiargs \'{mpi_args} >mpi-extra args go here<\' '
                    'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags))
                expected_env = {'env1': 'val1', 'env2': 'val2'}
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)
Exemple #6
0
    def test_js_run(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        settings = hvd_settings.Settings(
            verbose=0,
            extra_mpi_args='>mpi-extra args go here<',
            num_hosts=2,
            num_proc=4,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        js_run(settings,
               None,
               env,
               cmd,
               stdout=stdout,
               stderr=stderr,
               run_func=run_func)

        mpi_flags, _ = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'jsrun '
            '--erf_input /tmp/rankfile '
            '--stdio_stderr >output filename goes here< '
            '--stdio_stdout >output filename goes here< '
            '--smpiargs \'{mpi_args} >mpi-extra args go here<\' '
            'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)
Exemple #7
0
class RunTests(unittest.TestCase):
    """
    Tests for horovod.run.
    """
    def __init__(self, *args, **kwargs):
        super(RunTests, self).__init__(*args, **kwargs)
        warnings.simplefilter('module')

    def test_params_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '10', '--cycle-time-ms', '20', '--cache-capacity',
                           '512', '--hierarchical-allreduce',
                           '--hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_FUSION_THRESHOLD),
                             str(10 * 1024 * 1024))
            self.assertEqual(env.get(config_parser.HOROVOD_CYCLE_TIME), '20.0')
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '512')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '1')

    def test_autotune_args(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--autotune-log-file', '/tmp/autotune.txt',
                           '--autotune-warmup-samples', '1',
                           '--autotune-steps-per-sample', '5',
                           '--autotune-bayes-opt-max-samples', '10',
                           '--autotune-gaussian-process-noise', '0.2'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE_LOG),
                             '/tmp/autotune.txt')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_WARMUP_SAMPLES), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE), '5')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES),
                '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE),
                '0.2')

    def test_autotuning_with_fixed_param(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--cache-capacity', '1024',
                           '--no-hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_FUSION_THRESHOLD, env)
            self.assertNotIn(config_parser.HOROVOD_CYCLE_TIME, env)
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '1024')
            self.assertNotIn(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '0')

    def test_timeline_args(self):
        with override_args('horovodrun', '-np', '2', '--timeline-filename',
                           '/tmp/timeline.json', '--timeline-mark-cycles'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_TIMELINE),
                             '/tmp/timeline.json')
            self.assertEqual(
                env.get(config_parser.HOROVOD_TIMELINE_MARK_CYCLES), '1')

    def test_stall_check_args(self):
        with override_args('horovodrun', '-np', '2', '--no-stall-check'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_DISABLE), '1')

        with override_args('horovodrun', '-np', '2',
                           '--stall-check-warning-time-seconds', '10',
                           '--stall-check-shutdown-time-seconds', '20'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_STALL_CHECK_DISABLE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_TIME_SECONDS), '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_SHUTDOWN_TIME_SECONDS),
                '20')

    def test_library_args(self):
        with override_args('horovodrun', '-np', '2', '--mpi-threads-disable',
                           '--num-nccl-streams', '2', '--ccl-bgt-affinity',
                           '1', '--gloo-timeout-seconds', '60'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_MPI_THREADS_DISABLE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_NUM_NCCL_STREAMS),
                             '2')
            self.assertEqual(env.get(config_parser.HOROVOD_CCL_BGT_AFFINITY),
                             '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_GLOO_TIMEOUT_SECONDS), '60')

    def test_logging_args(self):
        with override_args('horovodrun', '-np', '2', '--log-level', 'INFO',
                           '--log-hide-timestamp'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_LOG_LEVEL), 'INFO')
            self.assertEqual(env.get(config_parser.HOROVOD_LOG_HIDE_TIME), '1')

    def test_config_file(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args('horovodrun', '-np', '2', '--config-file',
                           config_filename):
            args = parse_args()

            self.assertTrue(args.use_gloo)

            # Params
            self.assertEqual(args.fusion_threshold_mb, 32)
            self.assertEqual(args.cycle_time_ms, 10)
            self.assertEqual(args.cache_capacity, 2048)
            self.assertTrue(args.hierarchical_allreduce)
            self.assertTrue(args.hierarchical_allgather)

            # Autotune
            self.assertTrue(args.autotune)
            self.assertEqual(args.autotune_log_file,
                             'horovod_autotune_log.txt')
            self.assertEqual(args.autotune_warmup_samples, 5)
            self.assertEqual(args.autotune_steps_per_sample, 20)
            self.assertEqual(args.autotune_bayes_opt_max_samples, 50)
            self.assertEqual(args.autotune_gaussian_process_noise, 0.9)

            # Timeline
            self.assertEqual(args.timeline_filename, 'horovod_timeline.json')
            self.assertTrue(args.timeline_mark_cycles)

            # Stall Check
            self.assertFalse(args.no_stall_check)
            self.assertEqual(args.stall_check_warning_time_seconds, 120)
            self.assertEqual(args.stall_check_shutdown_time_seconds, 240)

            # Library Options
            self.assertTrue(args.mpi_threads_disable)
            self.assertEqual(args.num_nccl_streams, 2)
            self.assertEqual(args.ccl_bgt_affinity, 1)
            self.assertEqual(args.gloo_timeout_seconds, 60)

            # Logging
            self.assertEqual(args.log_level, 'INFO')
            self.assertTrue(args.log_hide_timestamp)

    def test_config_file_override_args(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args(
                'horovodrun',
                '-np',
                '2',
                '--fusion-threshold-mb',
                '128',
                '--config-file',
                config_filename,
                '--cycle-time-ms',
                '20',
        ):
            args = parse_args()
            self.assertEqual(args.fusion_threshold_mb, 128)
            self.assertEqual(args.cycle_time_ms, 20)

    def test_validate_config_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '-1'):
            with pytest.raises(ValueError):
                parse_args()

    # test_on_event tests in_thread as well, but it does not test args
    def test_in_thread_args(self):
        fn = mock.Mock()
        thread = in_thread(fn, args=(1, ))
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once_with(1)

        fn = mock.Mock()
        thread = in_thread(fn, args=(1, 2))
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once_with(1, 2)

        fn = mock.Mock()
        thread = in_thread(fn, args=(1, 2), silent=True)
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once_with(1, 2)

        fn = mock.Mock()
        with pytest.raises(
                ValueError,
                match="^args must be a tuple, not <(class|type) 'int'>, "
                "for a single argument use \\(arg,\\)$"):
            in_thread(fn, args=1)
        fn.assert_not_called()

    def test_on_event(self):
        # a happy run without args and stop event
        event = threading.Event()
        fn = mock.Mock()
        thread = on_event(event, fn)
        fn.assert_not_called()
        event.set()
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once()

        # a happy run with args but without stop event
        event = threading.Event()
        fn = mock.Mock()
        thread = on_event(event, fn, ('a', 1))
        fn.assert_not_called()
        event.set()
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once()
        fn.assert_called_once_with('a', 1)

        # a happy run with stop event but unused
        event = threading.Event()
        stop = threading.Event()
        fn = mock.Mock()
        thread = on_event(event, fn, stop=stop, check_interval_seconds=0.01)
        fn.assert_not_called()
        event.set()
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once()
        stop.set()
        time.sleep(0.1)
        fn.assert_called_once()

        # stop the thread before we set the event
        event = threading.Event()
        stop = threading.Event()
        fn = mock.Mock()
        thread = on_event(event, fn, stop=stop, check_interval_seconds=0.01)
        fn.assert_not_called()
        stop.set()
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_not_called()
        event.set()
        time.sleep(0.1)
        fn.assert_not_called()

        # test with exception
        def exception():
            raise Exception("Test Exception")

        event = threading.Event()
        fn = mock.Mock(side_effect=exception)
        thread = on_event(event, fn)
        fn.assert_not_called()
        event.set()
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once()

        # test with exception but silent
        event = threading.Event()
        fn = mock.Mock(side_effect=exception)
        thread = on_event(event, fn)
        fn.assert_not_called()
        event.set()
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once()

        # test non-tuple args
        event = threading.Event()
        fn = mock.Mock()
        with pytest.raises(
                ValueError,
                match="^args must be a tuple, not <(class|type) 'int'>, "
                "for a single argument use \\(arg,\\)$"):
            on_event(event, fn, args=1)
        fn.assert_not_called()

    def test_safe_shell_exec_captures_stdout(self):
        self.do_test_safe_shell_exec('echo hello', 0, 'hello\n', '')

    def test_safe_shell_exec_captures_stderr(self):
        self.do_test_safe_shell_exec('echo hello >&2', 0, '', 'hello\n')

    def test_safe_shell_exec_captures_last_line_wo_eol(self):
        cmd = 'bash -c "echo -e -n \\"hello\nstdout\\"; echo -e -n \\"hello\nstderr\\" >&2"'
        self.do_test_safe_shell_exec(cmd, 0, 'hello\nstdout', 'hello\nstderr')

    def test_safe_shell_exec_returns_exit_code(self):
        self.do_test_safe_shell_exec('false', 1, '', '')

    def test_safe_shell_exec_interrupts_on_event(self):
        # interrupt execute in one second
        interrupt = threading.Event()
        delay(lambda: interrupt.set(), 1.0)

        sleep = 10
        start = time.time()
        self.do_test_safe_shell_exec('sleep {}'.format(sleep), 143, '', None,
                                     interrupt)
        duration = time.time() - start

        self.assertGreaterEqual(duration, 1.0)
        self.assertLess(duration,
                        2.0 + safe_shell_exec.GRACEFUL_TERMINATION_TIME_S,
                        'sleep should not finish')
        self.assertGreater(
            sleep, 2.0 + safe_shell_exec.GRACEFUL_TERMINATION_TIME_S,
            'sleep should allow for GRACEFUL_TERMINATION_TIME_S')

    def test_safe_shell_exec_interrupts_on_parent_shutdown(self):
        sleep = 20
        parent_script = os.path.join(os.path.dirname(__file__),
                                     'data/run_safe_shell_exec.py')
        child_script = os.path.join(os.path.dirname(__file__), 'data/sleep.py')

        def get_pid(logfile):
            # Wait until the script has written its PID to the logfile
            wait(lambda: os.path.exists(logfile), timeout=5)
            with open(logfile, 'r') as f:
                return int(f.read())

        with temppath() as parent_logfile, temppath() as child_logfile:
            # It's important that this executes in an entirely different interpreter with as little shared
            # state as possible, to avoid issues with the semaphore tracker.
            cmd = ' '.join([
                sys.executable, parent_script, parent_logfile, child_script,
                str(sleep), child_logfile
            ])
            p = subprocess.Popen(cmd, shell=True)

            parent = psutil.Process(get_pid(parent_logfile))
            child = psutil.Process(get_pid(child_logfile))

            self.assertTrue(parent.is_running())
            self.assertTrue(child.is_running())

            # Hard kill the parent process
            parent.kill()
            parent.wait(timeout=safe_shell_exec.GRACEFUL_TERMINATION_TIME_S)
            p.wait()

            # Child process will exit when pipe breaks
            child.wait(
                timeout=2 * safe_shell_exec.GRACEFUL_TERMINATION_TIME_S + 1)

            self.assertFalse(parent.is_running())
            self.assertFalse(child.is_running())

    def do_test_safe_shell_exec(self,
                                cmd,
                                expected_exit_code,
                                expected_stdout,
                                expected_stderr,
                                event=None):
        stdout = io.StringIO()
        stderr = io.StringIO()
        res = safe_shell_exec.execute(cmd,
                                      stdout=stdout,
                                      stderr=stderr,
                                      events=[event])
        self.assertEqual(expected_exit_code, res)
        if expected_stdout is not None:
            self.assertEqual(expected_stdout, stdout.getvalue())
        if expected_stderr is not None:
            self.assertEqual(expected_stderr, stderr.getvalue())

    def test_hash(self):
        hash = _hash("test string")
        self.assertEqual(hash, '6f8db599de986fab7a21625b7916589c')

    def test_host_hash(self):
        hash = host_hash()
        # host_hash should consider CONTAINER_ID environment variable
        with override_env({'CONTAINER_ID': 'a container id'}):
            self.assertNotEqual(host_hash(), hash)
        self.assertEqual(host_hash(), hash)

    def test_settings_dump_drops_key(self):
        settings = hvd_settings.Settings(verbose=2, key="a secret key")
        clone = codec.loads_base64(codec.dumps_base64(settings))
        self.assertEqual(settings.verbose, clone.verbose)
        self.assertIsNotNone(settings.key)
        self.assertIsNone(clone.key)

    def test_get_mpi_implementation(self):
        def test(output, expected, exit_code=0):
            ret = (output, exit_code) if output is not None else None
            with mock.patch("horovod.run.mpi_run.tiny_shell_exec.execute",
                            return_value=ret):
                implementation = _get_mpi_implementation()
                self.assertEqual(expected, implementation)

        test(("mpirun (Open MPI) 2.1.1\n"
              "Report bugs to http://www.open-mpi.org/community/help/\n"),
             _OMPI_IMPL)

        test("OpenRTE", _OMPI_IMPL)

        test("IBM Spectrum MPI", _SMPI_IMPL)

        test(("HYDRA build details:\n"
              "    Version:           3.3a2\n"
              "    Configure options: 'MPICHLIB_CFLAGS=-g -O2'\n"),
             _MPICH_IMPL)

        test("Unknown MPI v1.00", _UNKNOWN_IMPL)

        test("output", exit_code=1, expected=_MISSING_IMPL)

        test(None, _MISSING_IMPL)

    def test_run_controller(self):
        def test(use_gloo, use_mpi, use_js, gloo_is_built, mpi_is_built,
                 lsf_exists, jsrun_installed, expected, exception):
            gloo_run = MagicMock()
            mpi_run = MagicMock()
            js_run = MagicMock()

            with is_built(gloo_is_built, mpi_is_built):
                with lsf_and_jsrun(lsf_exists, jsrun_installed):
                    if exception is not None:
                        with pytest.raises(ValueError, match=exception) as e:
                            run_controller(use_gloo,
                                           gloo_run,
                                           use_mpi,
                                           mpi_run,
                                           use_js,
                                           js_run,
                                           verbosity=2)
                        return
                    run_controller(use_gloo,
                                   gloo_run,
                                   use_mpi,
                                   mpi_run,
                                   use_js,
                                   js_run,
                                   verbosity=2)

            if expected == "gloo":
                gloo_run.assert_called_once()
                mpi_run.assert_not_called()
                js_run.assert_not_called()
            elif expected == "mpi":
                gloo_run.assert_not_called()
                mpi_run.assert_called_once()
                js_run.assert_not_called()
            elif expected == "js":
                gloo_run.assert_not_called()
                mpi_run.assert_not_called()
                js_run.assert_called_once()
            else:
                raise ValueError("unsupported framework: {}".format(expected))

        bool_values = [False, True]
        bool_values_and_none = [None, False, True]

        for use_gloo, use_mpi, use_js, \
            gloo_is_built, mpi_is_built, \
            lsf_exists, jsrun_installed in \
            itertools.product(bool_values_and_none, bool_values_and_none, bool_values_and_none,
                              bool_values, bool_values,
                              bool_values, bool_values):

            expected = exception = None
            if use_gloo:
                if gloo_is_built:
                    expected = 'gloo'
                else:
                    exception = r'^Gloo support has not been built\.  If this is not expected, ensure CMake is installed ' \
                                r'and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error\.$'
            elif use_mpi:
                if mpi_is_built:
                    expected = 'mpi'
                else:
                    exception = r'^MPI support has not been built\.  If this is not expected, ensure MPI is installed ' \
                                r'and reinstall Horovod with HOROVOD_WITH_MPI=1 to debug the build error\.$'
            elif use_js:
                if mpi_is_built:
                    if lsf_exists:
                        expected = 'js'
                    else:
                        exception = 'Horovod did not detect an LSF job.  The jsrun launcher can only be used in that environment. ' \
                                    'Please, pick a different launcher for other environments.'
                else:
                    exception = r'^MPI support has not been built\.  If this is not expected, ensure MPI is installed ' \
                                r'and reinstall Horovod with HOROVOD_WITH_MPI=1 to debug the build error\.$'
            elif mpi_is_built:
                if lsf_exists and jsrun_installed:
                    expected = 'js'
                else:
                    expected = 'mpi'
            elif gloo_is_built:
                expected = 'gloo'
            else:
                exception = r'Neither MPI nor Gloo support has been built\. Try reinstalling Horovod ensuring that ' \
                            r'either MPI is installed \(MPI\) or CMake is installed \(Gloo\)\.'

            test(use_gloo, use_mpi, use_js, gloo_is_built, mpi_is_built,
                 lsf_exists, jsrun_installed, expected, exception)

    """
    Minimal mpi_run settings for tests.
    """
    minimal_settings = hvd_settings.Settings(verbose=0,
                                             num_hosts=1,
                                             num_proc=2,
                                             hosts='host',
                                             run_func_mode=True)
    """
    Tests mpi_run with minimal settings.
    """

    def test_mpi_run_minimal(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], ["--mock-mpi-binding-args"]

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, None, {}, cmd)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, binding_args = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_cmd = ('mpirun '
                                '--allow-run-as-root --tag-output '
                                '-np 2 -H host '
                                '{binding_args} '
                                '{mpi_flags}       '
                                'cmd').format(
                                    binding_args=' '.join(binding_args),
                                    mpi_flags=' '.join(mpi_flags))
                expected_env = {'PATH': os.environ.get('PATH')}
                execute.assert_called_once_with(expected_cmd,
                                                env=expected_env,
                                                stdout=None,
                                                stderr=None)

    """
    Tests mpi_run on a large cluster.
    """

    def test_mpi_run_on_large_cluster(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = copy.copy(self.minimal_settings)
        settings.num_hosts = large_cluster_threshold

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], ["--mock-mpi-binding-args"]

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, None, {}, cmd)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, binding_args = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                mpi_flags.append('-mca plm_rsh_no_tree_spawn true')
                mpi_flags.append('-mca plm_rsh_num_concurrent {}'.format(
                    settings.num_hosts))
                expected_cmd = ('mpirun '
                                '--allow-run-as-root --tag-output '
                                '-np 2 -H host '
                                '{binding_args} '
                                '{mpi_flags}       '
                                'cmd').format(
                                    binding_args=' '.join(binding_args),
                                    mpi_flags=' '.join(mpi_flags))
                expected_env = {'PATH': os.environ.get('PATH')}
                execute.assert_called_once_with(expected_cmd,
                                                env=expected_env,
                                                stdout=None,
                                                stderr=None)

    """
    Tests mpi_run with full settings.
    """

    def test_mpi_run_full(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        nics = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            start_timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, nics, env, cmd, stdout=stdout, stderr=stderr)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, _ = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'mpirun '
                    '--allow-run-as-root --tag-output '
                    '-np 1 -H >host names go here< '
                    '>binding args go here< '
                    '{mpi_flags} '
                    '-mca plm_rsh_args "-p 1022" '
                    '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
                    '--output-filename >output filename goes here< '
                    '-x env1 -x env2 '
                    '>mpi-extra args go here< '
                    'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
                expected_env = {
                    'env1': 'val1',
                    'env2': 'val2',
                    'PATH': os.environ.get('PATH')
                }
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)

    def test_mpi_run_with_non_zero_exit(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings

        def mpi_impl_flags(tcp):
            return [], []

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=1):
                with pytest.raises(RuntimeError,
                                   match="^mpirun failed with exit code 1$"):
                    mpi_run(settings, None, {}, cmd)

    def test_horovodrun_hostfile(self):
        with temppath() as host_filename:
            with open(host_filename, 'w+') as fp:
                fp.write('172.31.32.7 slots=8\n')
                fp.write('172.31.33.9 slots=8\n')

            hosts = parse_host_files(host_filename)
            self.assertEqual(hosts, '172.31.32.7:8,172.31.33.9:8')

    """
    Tests js_run.
    """

    @mock.patch('horovod.run.js_run.is_jsrun_installed',
                MagicMock(return_value=True))
    @mock.patch('horovod.run.js_run.generate_jsrun_rankfile',
                MagicMock(return_value='/tmp/rankfile'))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_gpus',
                MagicMock(return_value=2))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_cores',
                MagicMock(return_value=2))
    def test_js_run(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        settings = hvd_settings.Settings(
            verbose=0,
            extra_mpi_args='>mpi-extra args go here<',
            num_hosts=2,
            num_proc=4,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.js_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.js_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                js_run(settings, None, env, cmd, stdout=stdout, stderr=stderr)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, _ = horovod.run.js_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'jsrun '
                    '--erf_input /tmp/rankfile '
                    '--stdio_stderr >output filename goes here< '
                    '--stdio_stdout >output filename goes here< '
                    '--smpiargs \'{mpi_args} >mpi-extra args go here<\' '
                    'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags))
                expected_env = {'env1': 'val1', 'env2': 'val2'}
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)

    """
    Tests generate_jsrun_rankfile.
    """

    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_gpus',
                MagicMock(return_value=4))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_cores',
                MagicMock(return_value=4))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_threads',
                MagicMock(return_value=4))
    def test_generate_jsrun_rankfile(self):
        settings = hvd_settings.Settings(
            num_proc=5,
            hosts='host1:4,host2:4,host3:4',
        )

        with temppath() as rankfile_path:
            rankfile_path = generate_jsrun_rankfile(settings, rankfile_path)

            with open(rankfile_path, 'r') as file:
                gen_rankfile = file.read()

            expected_rankfile = ("""overlapping_rs: allow
cpu_index_using: logical

rank: 0: { hostname: host1; cpu: {0-3} ; gpu: * ; mem: * }
rank: 1: { hostname: host1; cpu: {4-7} ; gpu: * ; mem: * }
rank: 2: { hostname: host1; cpu: {8-11} ; gpu: * ; mem: * }
rank: 3: { hostname: host1; cpu: {12-15} ; gpu: * ; mem: * }

rank: 4: { hostname: host2; cpu: {0-3} ; gpu: * ; mem: * }
""")

            self.assertMultiLineEqual(gen_rankfile, expected_rankfile)

    """
    Tests horovod.run.runner._run with jsrun
    """

    @mock.patch('horovod.run.util.lsf.LSFUtils.using_lsf',
                MagicMock(return_value=True))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_compute_hosts',
                MagicMock(return_value=['host1', 'host2']))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_gpus',
                MagicMock(return_value=2))
    @mock.patch('horovod.run.util.network.filter_local_addresses',
                MagicMock(return_value=['host1', 'host2']))
    @mock.patch('horovod.run.runner._check_all_hosts_ssh_successful',
                MagicMock())
    @mock.patch('horovod.run.runner.run_controller')
    def test_run_with_jsrun(self, mocked_run_controller):
        hargs = HorovodArgs()
        _run(hargs)
        mocked_run_controller.assert_called_once()
Exemple #8
0
 def test_settings_dump_drops_key(self):
     settings = hvd_settings.Settings(verbose=2, key="a secret key")
     clone = codec.loads_base64(codec.dumps_base64(settings))
     self.assertEqual(settings.verbose, clone.verbose)
     self.assertIsNotNone(settings.key)
     self.assertIsNone(clone.key)
Exemple #9
0
def _run(args):
    if args.check_build:
        check_build(args.verbose)

    # if hosts are not specified, either parse from hostfile, or default as
    # localhost
    if not args.hosts:
        if args.hostfile:
            args.hosts = parse_host_files(args.hostfile)
        else:
            # Set hosts to localhost if not specified
            args.hosts = 'localhost:{np}'.format(np=args.np)

    host_list = args.hosts.split(',')
    all_host_names = []
    pattern = re.compile(r'^[\w.-]+:\d+$')
    for host in host_list:
        if not pattern.match(host.strip()):
            raise ValueError('Invalid host input, please make sure it has '
                             'format as : worker-0:2,worker-1:2.')
        all_host_names.append(host.strip().split(':')[0])

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     ssh_ports=args.ssh_ports,
                                     extra_mpi_args=args.mpi_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     num_hosts=len(all_host_names),
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     output_filename=args.output_filename,
                                     run_func_mode=args.run_func is not None,
                                     nic=args.nic)

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        elif args.ssh_ports:
            params += str(args.ssh_ports)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        if args.ssh_ports:
            ssh_ports = [
                port for host, port in zip(all_host_names,
                                           args.ssh_ports.split(","))
                if host in set(remote_host_names)
            ]
            ssh_ports = ",".join(ssh_ports)
        else:
            ssh_ports = None
        _check_all_hosts_ssh_successful(remote_host_names,
                                        ssh_port=args.ssh_port,
                                        ssh_ports=ssh_ports,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Testing interfaces on all the hosts.')

        local_host_names = set(all_host_names) - set(remote_host_names)
        # Find the set of common, routed interfaces on all the hosts (remote
        # and local) and specify it in the args to be used by NCCL. It is
        # expected that the following function will find at least one interface
        # otherwise, it will raise an exception.
        common_intfs = _driver_fn(all_host_names,
                                  local_host_names,
                                  settings,
                                  fn_cache=fn_cache)

        if settings.verbose >= 2:
            print('Interfaces on all the hosts were successfully checked.')
            print('Common interface found: ' + ' '.join(common_intfs))

    else:
        if settings.verbose >= 2:
            print('All hosts are local, finding the interfaces '
                  'with address 127.0.0.1')
        # If all the given hosts are local, find the interfaces with address
        # 127.0.0.1
        common_intfs = set()
        for iface, addrs in net_if_addrs().items():
            if settings.nic and iface != settings.nic:
                continue
            for addr in addrs:
                if addr.family == AF_INET and addr.address == '127.0.0.1':
                    common_intfs.add(iface)
                    break

        if len(common_intfs) == 0:
            raise ValueError('No interface is found for address 127.0.0.1.')

        if settings.verbose >= 2:
            print('Local interface found ' + ' '.join(common_intfs))

    # get the driver IPv4 address
    driver_ip = _get_driver_ip(common_intfs)

    if args.run_func:
        run_func_server = KVStoreServer(verbose=settings.verbose)
        run_func_server_port = run_func_server.start_server()
        pickled_exec_func = cloudpickle.dumps(args.run_func)
        put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc',
                              'func', pickled_exec_func)

        command = [
            sys.executable, '-m', 'horovod.run.run_task',
            str(driver_ip),
            str(run_func_server_port)
        ]

        try:
            _launch_job(args, remote_host_names, settings, common_intfs,
                        command)
            results = [None] * args.np
            # TODO: make it parallel to improve performance
            for i in range(args.np):
                pickled_result = read_data_from_kvstore(
                    driver_ip, run_func_server_port, 'runfunc_result', str(i))
                results[i] = cloudpickle.loads(pickled_result)
            return results
        finally:
            run_func_server.shutdown_server()
    else:
        command = args.command
        _launch_job(args, remote_host_names, settings, common_intfs, command)
        return None
Exemple #10
0
def _run_static(args):
    all_host_names, _ = parse_hosts_and_slots(args.hosts)

    nics_set = set(args.nics.split(',')) if args.nics else None

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     extra_mpi_args=args.mpi_args,
                                     tcp_flag=args.tcp_flag,
                                     binding_args=args.binding_args,
                                     key=secret.make_secret_key(),
                                     start_timeout=tmout,
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     num_hosts=len(all_host_names),
                                     output_filename=args.output_filename,
                                     run_func_mode=args.run_func is not None,
                                     nics=nics_set)

    # This cache stores the results of checks performed by horovod
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        _check_all_hosts_ssh_successful(remote_host_names,
                                        args.ssh_port,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    nics = driver_service.get_common_interfaces(settings, all_host_names,
                                                remote_host_names, fn_cache)

    if args.run_func:
        # get the driver IPv4 address
        driver_ip = network.get_driver_ip(nics)
        run_func_server = KVStoreServer(verbose=settings.verbose)
        run_func_server_port = run_func_server.start_server()
        put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc',
                              'func', args.run_func)

        command = [
            sys.executable, '-m', 'horovod.run.run_task',
            str(driver_ip),
            str(run_func_server_port)
        ]

        try:
            _launch_job(args, settings, nics, command)
            results = [None] * args.np
            # TODO: make it parallel to improve performance
            for i in range(args.np):
                results[i] = read_data_from_kvstore(driver_ip,
                                                    run_func_server_port,
                                                    'runfunc_result', str(i))
            return results
        finally:
            run_func_server.shutdown_server()
    else:
        command = args.command
        _launch_job(args, settings, nics, command)
        return None
Exemple #11
0
def run():
    args = parse_args()

    if args.check_build:
        check_build(args.verbose)

    # if hosts are not specified, either parse from hostfile, or default as
    # localhost
    if not args.hosts:
        if args.hostfile:
            args.hosts = parse_host_files(args.hostfile)
        else:
            # Set hosts to localhost if not specified
            args.hosts = 'localhost:{np}'.format(np=args.np)

    host_list = args.hosts.split(',')
    all_host_names = []
    pattern = re.compile(r'^[\w.-]+:\d+$')
    for host in host_list:
        if not pattern.match(host.strip()):
            raise ValueError('Invalid host input, please make sure it has '
                             'format as : worker-0:2,worker-1:2.')
        all_host_names.append(host.strip().split(':')[0])

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     num_hosts=len(all_host_names),
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     command=args.command)

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        _check_all_hosts_ssh_successful(remote_host_names,
                                        args.ssh_port,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Testing interfaces on all the hosts.')

        local_host_names = set(all_host_names) - set(remote_host_names)
        # Find the set of common, routed interfaces on all the hosts (remote
        # and local) and specify it in the args to be used by NCCL. It is
        # expected that the following function will find at least one interface
        # otherwise, it will raise an exception.
        common_intfs = _driver_fn(all_host_names,
                                  local_host_names,
                                  settings,
                                  fn_cache=fn_cache)

        if settings.verbose >= 2:
            print('Interfaces on all the hosts were successfully checked.')
            print('Common interface found: ' + ' '.join(common_intfs))

    else:
        if settings.verbose >= 2:
            print('All hosts are local, finding the interfaces '
                  'with address 127.0.0.1')
        # If all the given hosts are local, find the interfaces with address
        # 127.0.0.1
        common_intfs = set()
        for iface, addrs in net_if_addrs().items():
            for addr in addrs:
                if addr.family == AF_INET and addr.address == '127.0.0.1':
                    common_intfs.add(iface)
                    break

        if len(common_intfs) == 0:
            raise ValueError('No interface is found for address 127.0.0.1.')

        if settings.verbose >= 2:
            print('Local interface found ' + ' '.join(common_intfs))

    env = os.environ.copy()
    config_parser.set_env_from_args(env, args)

    if args.use_gloo:
        if not gloo_built(verbose=(settings.verbose >= 2)):
            raise ValueError(
                'Gloo support has not been built.  If this is not expected, ensure CMake is installed '
                'and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.'
            )
        gloo_run(settings, remote_host_names, common_intfs, env)
    elif args.use_mpi:
        if not mpi_built(verbose=(settings.verbose >= 2)):
            raise ValueError(
                'MPI support has not been built.  If this is not expected, ensure MPI is installed '
                'and reinstall Horovod with HOROVOD_WITH_MPI=1 to debug the build error.'
            )
        mpi_run(settings, common_intfs, env)
    else:
        if mpi_built(verbose=(settings.verbose >= 2)):
            mpi_run(settings, common_intfs, env)
        elif gloo_built(verbose=(settings.verbose >= 2)):
            gloo_run(settings, remote_host_names, common_intfs, env)
        else:
            raise ValueError(
                'Neither MPI nor Gloo support has been built. Try reinstalling Horovod ensuring that '
                'either MPI is installed (MPI) or CMake is installed (Gloo).')
Exemple #12
0
def run():
    args = parse_args()

    if args.version:
        print(horovod.__version__)
        exit(0)

    if args.host:
        all_host_names = [
            x for x in [y.split(':')[0] for y in args.host.split(',')]
        ]
    elif args.hostfile:
        all_host_names = [
            x for x in [line.split()[0] for line in open(args.hostfile)]
        ]
    else:
        all_host_names = []

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     num_hosts=len(all_host_names),
                                     num_proc=args.np)

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.host:
            params += str(args.host) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    remote_host_names = []
    if args.host or args.hostfile:
        if settings.verbose >= 2:
            print("Filtering local host names.")
        remote_host_names = network.filter_local_addresses(all_host_names)

        if len(remote_host_names) > 0:
            if settings.verbose >= 2:
                print("Checking ssh on all remote hosts.")
            # Check if we can ssh into all remote hosts successfully.
            _check_all_hosts_ssh_successful(remote_host_names,
                                            args.ssh_port,
                                            fn_cache=fn_cache)
            if settings.verbose >= 2:
                print("SSH was successful into all the remote hosts.")

        if args.host:
            hosts_arg = "-H {hosts}".format(hosts=args.host)
        else:
            hosts_arg = "-hostfile {hostfile}".format(hostfile=args.hostfile)
    else:
        # if none of --host  of --hostfile is specified, localhost will be
        # used by default
        # There is no need to specify localhost.
        hosts_arg = ""

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print("Testing interfaces on all the hosts.")

        local_host_names = set(all_host_names) - set(remote_host_names)
        # Find the set of common, routed interfaces on all the hosts (remote
        # and local) and specify it in the args to be used by NCCL. It is
        # expected that the following function will find at least one interface
        # otherwise, it will raise an exception.
        common_intfs = _driver_fn(all_host_names,
                                  local_host_names,
                                  settings,
                                  fn_cache=fn_cache)

        tcp_intf_arg = "-mca btl_tcp_if_include {common_intfs}".format(
            common_intfs=','.join(common_intfs))
        nccl_socket_intf_arg = "-x NCCL_SOCKET_IFNAME={common_intfs}".format(
            common_intfs=','.join(common_intfs))

        if settings.verbose >= 2:
            print("Interfaces on all the hosts were successfully checked.")
    else:
        # If all the given hosts are local, no need to specify the interfaces
        # because MPI does not use network for local execution.
        tcp_intf_arg = ""
        nccl_socket_intf_arg = ""

    # Pass all the env variables to the mpirun command.
    env = os.environ.copy()

    # Pass secret key through the environment variables.
    env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)

    if not _is_open_mpi_installed():
        raise Exception(
            'horovodrun convenience script currently only supports '
            'Open MPI.\n\n'
            'Choose one of:\n'
            '1. Install Open MPI 4.0.0+ and re-install Horovod '
            '(use --no-cache-dir pip option).\n'
            '2. Run distributed '
            'training script using the standard way provided by your'
            ' MPI distribution (usually mpirun, srun, or jsrun).')

    if args.ssh_port:
        ssh_port_arg = "-mca plm_rsh_args \"-p {ssh_port}\"".format(
            ssh_port=args.ssh_port)
    else:
        ssh_port_arg = ""

    mpirun_command = (
        'mpirun --allow-run-as-root --tag-output '
        '-np {num_proc} {hosts_arg} '
        '-bind-to none -map-by slot '
        '-mca pml ob1 -mca btl ^openib '
        '{ssh_port_arg} '
        '{tcp_intf_arg} '
        '-x NCCL_DEBUG=INFO '
        '{nccl_socket_intf_arg} '
        '{env} {command}'  # expect a lot of environment variables
        .format(num_proc=settings.num_proc,
                hosts_arg=hosts_arg,
                tcp_intf_arg=tcp_intf_arg,
                nccl_socket_intf_arg=nccl_socket_intf_arg,
                ssh_port_arg=ssh_port_arg,
                env=' '.join('-x %s' % key for key in env.keys()
                             if env_util.is_exportable(key)),
                command=' '.join(quote(par) for par in args.command)))

    if settings.verbose >= 2:
        print(mpirun_command)
    # Execute the mpirun command.
    os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)
Exemple #13
0
def _run(args):
    if args.check_build:
        check_build(args.verbose)

    # If LSF is used, use default values from job config
    if lsf.LSFUtils.using_lsf():
        if not args.np:
            args.np = lsf.LSFUtils.get_num_processes()
        if not args.hosts and not args.hostfile:
            args.hosts = ','.join(
                '{host}:{np}'.format(host=host, np=lsf.LSFUtils.get_num_gpus())
                for host in lsf.LSFUtils.get_compute_hosts())

    # if hosts are not specified, either parse from hostfile, or default as
    # localhost
    if not args.hosts:
        if args.hostfile:
            args.hosts = parse_host_files(args.hostfile)
        else:
            # Set hosts to localhost if not specified
            args.hosts = 'localhost:{np}'.format(np=args.np)

    host_list = args.hosts.split(',')
    all_host_names = []
    pattern = re.compile(r'^[\w.-]+:\d+$')
    for host in host_list:
        if not pattern.match(host.strip()):
            raise ValueError('Invalid host input, please make sure it has '
                             'format as : worker-0:2,worker-1:2.')
        all_host_names.append(host.strip().split(':')[0])

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     extra_mpi_args=args.mpi_args,
                                     tcp_flag=args.tcp_flag,
                                     binding_args=args.binding_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     num_hosts=len(all_host_names),
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     output_filename=args.output_filename,
                                     run_func_mode=args.run_func is not None,
                                     nic=args.nic)

    # This cache stores the results of checks performed by horovodrun
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        _check_all_hosts_ssh_successful(remote_host_names,
                                        args.ssh_port,
                                        fn_cache=fn_cache)
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    common_intfs = driver_service._get_common_interfaces(
        settings, all_host_names, remote_host_names, fn_cache)
    if args.run_func:
        # get the driver IPv4 address
        driver_ip = network._get_driver_ip(common_intfs)
        run_func_server = KVStoreServer(verbose=settings.verbose)
        run_func_server_port = run_func_server.start_server()
        pickled_exec_func = cloudpickle.dumps(args.run_func)
        put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc',
                              'func', pickled_exec_func)

        command = [
            sys.executable, '-m', 'horovod.run.run_task',
            str(driver_ip),
            str(run_func_server_port)
        ]

        try:
            _launch_job(args, remote_host_names, settings, common_intfs,
                        command)
            results = [None] * args.np
            # TODO: make it parallel to improve performance
            for i in range(args.np):
                pickled_result = read_data_from_kvstore(
                    driver_ip, run_func_server_port, 'runfunc_result', str(i))
                results[i] = cloudpickle.loads(pickled_result)
            return results
        finally:
            run_func_server.shutdown_server()
    else:
        command = args.command
        _launch_job(args, remote_host_names, settings, common_intfs, command)
        return None
Exemple #14
0
class RunTests(unittest.TestCase):
    """
    Tests for horovod.run.
    """
    def __init__(self, *args, **kwargs):
        super(RunTests, self).__init__(*args, **kwargs)
        warnings.simplefilter('module')

    def test_params_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '10', '--cycle-time-ms', '20', '--cache-capacity',
                           '512', '--hierarchical-allreduce',
                           '--hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_FUSION_THRESHOLD),
                             str(10 * 1024 * 1024))
            self.assertEqual(env.get(config_parser.HOROVOD_CYCLE_TIME), '20.0')
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '512')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '1')

    def test_autotune_args(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--autotune-log-file', '/tmp/autotune.txt',
                           '--autotune-warmup-samples', '1',
                           '--autotune-steps-per-sample', '5',
                           '--autotune-bayes-opt-max-samples', '10',
                           '--autotune-gaussian-process-noise', '0.2'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE_LOG),
                             '/tmp/autotune.txt')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_WARMUP_SAMPLES), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE), '5')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES),
                '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE),
                '0.2')

    def test_autotuning_with_fixed_param(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--cache-capacity', '1024',
                           '--no-hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_FUSION_THRESHOLD, env)
            self.assertNotIn(config_parser.HOROVOD_CYCLE_TIME, env)
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '1024')
            self.assertNotIn(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '0')

    def test_timeline_args(self):
        with override_args('horovodrun', '-np', '2', '--timeline-filename',
                           '/tmp/timeline.json', '--timeline-mark-cycles'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_TIMELINE),
                             '/tmp/timeline.json')
            self.assertEqual(
                env.get(config_parser.HOROVOD_TIMELINE_MARK_CYCLES), '1')

    def test_stall_check_args(self):
        with override_args('horovodrun', '-np', '2', '--no-stall-check'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_DISABLE), '1')

        with override_args('horovodrun', '-np', '2',
                           '--stall-check-warning-time-seconds', '10',
                           '--stall-check-shutdown-time-seconds', '20'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_STALL_CHECK_DISABLE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_TIME_SECONDS), '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_SHUTDOWN_TIME_SECONDS),
                '20')

    def test_library_args(self):
        with override_args('horovodrun', '-np', '2', '--mpi-threads-disable',
                           '--num-nccl-streams', '2', '--mlsl-bgt-affinity',
                           '1', '--gloo-timeout-seconds', '60'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_MPI_THREADS_DISABLE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_NUM_NCCL_STREAMS),
                             '2')
            self.assertEqual(env.get(config_parser.HOROVOD_MLSL_BGT_AFFINITY),
                             '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_GLOO_TIMEOUT_SECONDS), '60')

    def test_logging_args(self):
        with override_args('horovodrun', '-np', '2', '--log-level', 'INFO',
                           '--log-hide-timestamp'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_LOG_LEVEL), 'INFO')
            self.assertEqual(env.get(config_parser.HOROVOD_LOG_HIDE_TIME), '1')

    def test_config_file(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args('horovodrun', '-np', '2', '--config-file',
                           config_filename):
            args = parse_args()

            self.assertTrue(args.use_gloo)

            # Params
            self.assertEqual(args.fusion_threshold_mb, 32)
            self.assertEqual(args.cycle_time_ms, 10)
            self.assertEqual(args.cache_capacity, 2048)
            self.assertTrue(args.hierarchical_allreduce)
            self.assertTrue(args.hierarchical_allgather)

            # Autotune
            self.assertTrue(args.autotune)
            self.assertEqual(args.autotune_log_file,
                             'horovod_autotune_log.txt')
            self.assertEqual(args.autotune_warmup_samples, 5)
            self.assertEqual(args.autotune_steps_per_sample, 20)
            self.assertEqual(args.autotune_bayes_opt_max_samples, 50)
            self.assertEqual(args.autotune_gaussian_process_noise, 0.9)

            # Timeline
            self.assertEqual(args.timeline_filename, 'horovod_timeline.json')
            self.assertTrue(args.timeline_mark_cycles)

            # Stall Check
            self.assertFalse(args.no_stall_check)
            self.assertEqual(args.stall_check_warning_time_seconds, 120)
            self.assertEqual(args.stall_check_shutdown_time_seconds, 240)

            # Library Options
            self.assertTrue(args.mpi_threads_disable)
            self.assertEqual(args.num_nccl_streams, 2)
            self.assertEqual(args.mlsl_bgt_affinity, 1)
            self.assertEqual(args.gloo_timeout_seconds, 60)

            # Logging
            self.assertEqual(args.log_level, 'INFO')
            self.assertTrue(args.log_hide_timestamp)

    def test_config_file_override_args(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args(
                'horovodrun',
                '-np',
                '2',
                '--fusion-threshold-mb',
                '128',
                '--config-file',
                config_filename,
                '--cycle-time-ms',
                '20',
        ):
            args = parse_args()
            self.assertEqual(args.fusion_threshold_mb, 128)
            self.assertEqual(args.cycle_time_ms, 20)

    def test_validate_config_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '-1'):
            with pytest.raises(ValueError):
                parse_args()

    """
    Minimal mpi_run settings for tests.
    """
    minimal_settings = hvd_settings.Settings(verbose=0,
                                             num_hosts=1,
                                             num_proc=2,
                                             hosts='host',
                                             run_func_mode=True)
    """
    Tests mpi_run with minimal settings.
    """

    def test_mpi_run_minimal(self):
        if _get_mpi_implementation_flags() is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings
        run_func = MagicMock(return_value=0)

        mpi_run(settings, None, {}, cmd, run_func=run_func)

        mpi_flags = _get_mpi_implementation_flags()
        self.assertIsNotNone(mpi_flags)
        expected_cmd = ('mpirun '
                        '--allow-run-as-root --tag-output '
                        '-np 2 -H host '
                        '-bind-to none -map-by slot '
                        '{mpi_flags}       '
                        'cmd').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {}
        run_func.assert_called_once_with(command=expected_cmd,
                                         env=expected_env,
                                         stdout=None,
                                         stderr=None)

    """
    Tests mpi_run on a large cluster.
    """

    def test_mpi_run_on_large_cluster(self):
        if _get_mpi_implementation_flags() is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = copy.copy(self.minimal_settings)
        settings.num_hosts = large_cluster_threshold
        run_func = MagicMock(return_value=0)

        mpi_run(settings, None, {}, cmd, run_func=run_func)

        mpi_flags = _get_mpi_implementation_flags()
        self.assertIsNotNone(mpi_flags)
        mpi_flags.append('-mca plm_rsh_no_tree_spawn true')
        mpi_flags.append('-mca plm_rsh_num_concurrent 2')
        expected_cmd = ('mpirun '
                        '--allow-run-as-root --tag-output '
                        '-np 2 -H host '
                        '-bind-to none -map-by slot '
                        '{mpi_flags}       '
                        'cmd').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {}
        run_func.assert_called_once_with(command=expected_cmd,
                                         env=expected_env,
                                         stdout=None,
                                         stderr=None)

    """
    Tests mpi_run with full settings.
    """

    def test_mpi_run_full(self):
        if _get_mpi_implementation_flags() is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        common_intfs = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            key=secret.make_secret_key(),
            timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        mpi_run(settings,
                common_intfs,
                env,
                cmd,
                stdout=stdout,
                stderr=stderr,
                run_func=run_func)

        mpi_flags = _get_mpi_implementation_flags()
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'mpirun '
            '--allow-run-as-root --tag-output '
            '-np 1 -H >host names go here< '
            '-bind-to none -map-by slot '
            '{mpi_flags} '
            '-mca plm_rsh_args "-p 1022" '
            '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
            '--output-filename >output filename goes here< '
            '-x env1 -x env2 '
            '>mpi-extra args go here< '
            'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)

    def test_mpi_run_with_non_zero_exit(self):
        if _get_mpi_implementation_flags() is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings
        run_func = MagicMock(return_value=1)

        with pytest.raises(RuntimeError,
                           match="^mpirun failed with exit code 1$") as e:
            mpi_run(settings, None, {}, cmd, run_func=run_func)
Exemple #15
0
def run(fn,
        args=(),
        kwargs={},
        num_proc=None,
        start_timeout=None,
        use_mpi=None,
        use_gloo=None,
        extra_mpi_args=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1,
        nics=None):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     extra_mpi_args=extra_mpi_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     nics=nics,
                                     run_func_mode=True)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print(
                'Running %d processes (inferred from spark.default.parallelism)...'
                % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    # start Spark driver service and launch settings.num_proc Spark tasks
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args,
                                               kwargs, settings.key,
                                               settings.nics)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings, use_gloo)
    try:
        # wait for all tasks to register and notify them
        driver.wait_for_initial_registration(settings.timeout)
        if settings.verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.SparkTaskClient(
                index, driver.task_addresses_for_driver(index), settings.key,
                settings.verbose) for index in range(settings.num_proc)
        ]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(settings.timeout)
        if settings.verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        settings.hosts = ','.join(
            '%s:%d' %
            (host_hash, len(driver.task_host_hash_indices()[host_hash]))
            for host_hash in host_hashes)

        # Determine the ranks to indicies
        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        # Run the job
        _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]
Exemple #16
0
class RunTests(unittest.TestCase):
    """
    Tests for horovod.run.
    """
    def __init__(self, *args, **kwargs):
        super(RunTests, self).__init__(*args, **kwargs)
        warnings.simplefilter('module')

    def test_params_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '10', '--cycle-time-ms', '20', '--cache-capacity',
                           '512', '--hierarchical-allreduce',
                           '--hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_FUSION_THRESHOLD),
                             str(10 * 1024 * 1024))
            self.assertEqual(env.get(config_parser.HOROVOD_CYCLE_TIME), '20.0')
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '512')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '1')

    def test_autotune_args(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--autotune-log-file', '/tmp/autotune.txt',
                           '--autotune-warmup-samples', '1',
                           '--autotune-steps-per-sample', '5',
                           '--autotune-bayes-opt-max-samples', '10',
                           '--autotune-gaussian-process-noise', '0.2'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE_LOG),
                             '/tmp/autotune.txt')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_WARMUP_SAMPLES), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE), '5')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES),
                '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE),
                '0.2')

    def test_autotuning_with_fixed_param(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--cache-capacity', '1024',
                           '--no-hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_FUSION_THRESHOLD, env)
            self.assertNotIn(config_parser.HOROVOD_CYCLE_TIME, env)
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '1024')
            self.assertNotIn(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '0')

    def test_timeline_args(self):
        with override_args('horovodrun', '-np', '2', '--timeline-filename',
                           '/tmp/timeline.json', '--timeline-mark-cycles'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_TIMELINE),
                             '/tmp/timeline.json')
            self.assertEqual(
                env.get(config_parser.HOROVOD_TIMELINE_MARK_CYCLES), '1')

    def test_stall_check_args(self):
        with override_args('horovodrun', '-np', '2', '--no-stall-check'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_DISABLE), '1')

        with override_args('horovodrun', '-np', '2',
                           '--stall-check-warning-time-seconds', '10',
                           '--stall-check-shutdown-time-seconds', '20'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_STALL_CHECK_DISABLE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_TIME_SECONDS), '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_SHUTDOWN_TIME_SECONDS),
                '20')

    def test_library_args(self):
        with override_args('horovodrun', '-np', '2', '--mpi-threads-disable',
                           '--num-nccl-streams', '2', '--ccl-bgt-affinity',
                           '1', '--gloo-timeout-seconds', '60'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_MPI_THREADS_DISABLE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_NUM_NCCL_STREAMS),
                             '2')
            self.assertEqual(env.get(config_parser.HOROVOD_CCL_BGT_AFFINITY),
                             '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_GLOO_TIMEOUT_SECONDS), '60')

    def test_logging_args(self):
        with override_args('horovodrun', '-np', '2', '--log-level', 'INFO',
                           '--log-hide-timestamp'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_LOG_LEVEL), 'INFO')
            self.assertEqual(env.get(config_parser.HOROVOD_LOG_HIDE_TIME), '1')

    def test_config_file(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args('horovodrun', '-np', '2', '--config-file',
                           config_filename):
            args = parse_args()

            self.assertTrue(args.use_gloo)

            # Params
            self.assertEqual(args.fusion_threshold_mb, 32)
            self.assertEqual(args.cycle_time_ms, 10)
            self.assertEqual(args.cache_capacity, 2048)
            self.assertTrue(args.hierarchical_allreduce)
            self.assertTrue(args.hierarchical_allgather)

            # Autotune
            self.assertTrue(args.autotune)
            self.assertEqual(args.autotune_log_file,
                             'horovod_autotune_log.txt')
            self.assertEqual(args.autotune_warmup_samples, 5)
            self.assertEqual(args.autotune_steps_per_sample, 20)
            self.assertEqual(args.autotune_bayes_opt_max_samples, 50)
            self.assertEqual(args.autotune_gaussian_process_noise, 0.9)

            # Timeline
            self.assertEqual(args.timeline_filename, 'horovod_timeline.json')
            self.assertTrue(args.timeline_mark_cycles)

            # Stall Check
            self.assertFalse(args.no_stall_check)
            self.assertEqual(args.stall_check_warning_time_seconds, 120)
            self.assertEqual(args.stall_check_shutdown_time_seconds, 240)

            # Library Options
            self.assertTrue(args.mpi_threads_disable)
            self.assertEqual(args.num_nccl_streams, 2)
            self.assertEqual(args.ccl_bgt_affinity, 1)
            self.assertEqual(args.gloo_timeout_seconds, 60)

            # Logging
            self.assertEqual(args.log_level, 'INFO')
            self.assertTrue(args.log_hide_timestamp)

    def test_config_file_override_args(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args(
                'horovodrun',
                '-np',
                '2',
                '--fusion-threshold-mb',
                '128',
                '--config-file',
                config_filename,
                '--cycle-time-ms',
                '20',
        ):
            args = parse_args()
            self.assertEqual(args.fusion_threshold_mb, 128)
            self.assertEqual(args.cycle_time_ms, 20)

    def test_validate_config_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '-1'):
            with pytest.raises(ValueError):
                parse_args()

    def test_hash(self):
        hash = _hash("test string")
        self.assertEqual(hash, '6f8db599de986fab7a21625b7916589c')

    def test_host_hash(self):
        hash = host_hash()
        # host_hash should consider CONTAINER_ID environment variable
        with override_env({'CONTAINER_ID': 'a container id'}):
            self.assertNotEqual(host_hash(), hash)
        self.assertEqual(host_hash(), hash)

    """
    Minimal mpi_run settings for tests.
    """
    minimal_settings = hvd_settings.Settings(verbose=0,
                                             num_hosts=1,
                                             num_proc=2,
                                             hosts='host',
                                             run_func_mode=True)
    """
    Tests mpi_run with minimal settings.
    """

    def test_mpi_run_minimal(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings
        run_func = MagicMock(return_value=0)

        mpi_run(settings, None, {}, cmd, run_func=run_func)

        mpi_flags, binding_args = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_cmd = ('mpirun '
                        '--allow-run-as-root --tag-output '
                        '-np 2 -H host '
                        '{binding_args} '
                        '{mpi_flags}       '
                        'cmd').format(binding_args=' '.join(binding_args),
                                      mpi_flags=' '.join(mpi_flags))
        expected_env = {}
        run_func.assert_called_once_with(command=expected_cmd,
                                         env=expected_env,
                                         stdout=None,
                                         stderr=None)

    """
    Tests mpi_run on a large cluster.
    """

    def test_mpi_run_on_large_cluster(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = copy.copy(self.minimal_settings)
        settings.num_hosts = large_cluster_threshold
        run_func = MagicMock(return_value=0)

        mpi_run(settings, None, {}, cmd, run_func=run_func)

        mpi_flags, binding_args = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        mpi_flags.append('-mca plm_rsh_no_tree_spawn true')
        mpi_flags.append('-mca plm_rsh_num_concurrent {}'.format(
            settings.num_hosts))
        expected_cmd = ('mpirun '
                        '--allow-run-as-root --tag-output '
                        '-np 2 -H host '
                        '{binding_args} '
                        '{mpi_flags}       '
                        'cmd').format(binding_args=' '.join(binding_args),
                                      mpi_flags=' '.join(mpi_flags))
        expected_env = {}
        run_func.assert_called_once_with(command=expected_cmd,
                                         env=expected_env,
                                         stdout=None,
                                         stderr=None)

    """
    Tests mpi_run with full settings.
    """

    def test_mpi_run_full(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        common_intfs = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        mpi_run(settings,
                common_intfs,
                env,
                cmd,
                stdout=stdout,
                stderr=stderr,
                run_func=run_func)

        mpi_flags, _ = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'mpirun '
            '--allow-run-as-root --tag-output '
            '-np 1 -H >host names go here< '
            '>binding args go here< '
            '{mpi_flags} '
            '-mca plm_rsh_args "-p 1022" '
            '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
            '--output-filename >output filename goes here< '
            '-x env1 -x env2 '
            '>mpi-extra args go here< '
            'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)

    def test_mpi_run_with_non_zero_exit(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings
        run_func = MagicMock(return_value=1)

        with pytest.raises(RuntimeError,
                           match="^mpirun failed with exit code 1$") as e:
            mpi_run(settings, None, {}, cmd, run_func=run_func)

    def test_horovodrun_hostfile(self):
        with temppath() as host_filename:
            with open(host_filename, 'w+') as fp:
                fp.write('172.31.32.7 slots=8\n')
                fp.write('172.31.33.9 slots=8\n')

            hosts = parse_host_files(host_filename)
            self.assertEqual(hosts, '172.31.32.7:8,172.31.33.9:8')

    """
    Tests js_run.
    """

    @patch('horovod.run.js_run.is_jsrun_installed',
           MagicMock(return_value=True))
    @patch('horovod.run.js_run.generate_jsrun_rankfile',
           MagicMock(return_value='/tmp/rankfile'))
    @patch('horovod.run.util.lsf.LSFUtils.get_num_gpus',
           MagicMock(return_value=2))
    @patch('horovod.run.util.lsf.LSFUtils.get_num_cores',
           MagicMock(return_value=2))
    def test_js_run(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        settings = hvd_settings.Settings(
            verbose=0,
            extra_mpi_args='>mpi-extra args go here<',
            num_hosts=2,
            num_proc=4,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)
        run_func = MagicMock(return_value=0)

        js_run(settings,
               None,
               env,
               cmd,
               stdout=stdout,
               stderr=stderr,
               run_func=run_func)

        mpi_flags, _ = _get_mpi_implementation_flags(False)
        self.assertIsNotNone(mpi_flags)
        expected_command = (
            'jsrun '
            '--erf_input /tmp/rankfile '
            '--stdio_stderr >output filename goes here< '
            '--stdio_stdout >output filename goes here< '
            '--smpiargs \'{mpi_args} >mpi-extra args go here<\' '
            'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags))
        expected_env = {'env1': 'val1', 'env2': 'val2'}
        run_func.assert_called_once_with(command=expected_command,
                                         env=expected_env,
                                         stdout=stdout,
                                         stderr=stderr)

    """
    Tests generate_jsrun_rankfile.
    """

    @patch('horovod.run.util.lsf.LSFUtils.get_num_gpus',
           MagicMock(return_value=4))
    @patch('horovod.run.util.lsf.LSFUtils.get_num_cores',
           MagicMock(return_value=4))
    @patch('horovod.run.util.lsf.LSFUtils.get_num_threads',
           MagicMock(return_value=4))
    def test_generate_jsrun_rankfile(self):
        settings = hvd_settings.Settings(
            num_proc=5,
            hosts='host1:4,host2:4,host3:4',
        )

        with temppath() as rankfile_path:
            rankfile_path = generate_jsrun_rankfile(settings, rankfile_path)

            with open(rankfile_path, 'r') as file:
                gen_rankfile = file.read()

            expected_rankfile = ("""overlapping_rs: allow
cpu_index_using: logical

rank: 0: { hostname: host1; cpu: {0-3} ; gpu: * ; mem: * }
rank: 1: { hostname: host1; cpu: {4-7} ; gpu: * ; mem: * }
rank: 2: { hostname: host1; cpu: {8-11} ; gpu: * ; mem: * }
rank: 3: { hostname: host1; cpu: {12-15} ; gpu: * ; mem: * }

rank: 4: { hostname: host2; cpu: {0-3} ; gpu: * ; mem: * }
""")

            self.assertMultiLineEqual(gen_rankfile, expected_rankfile)
Exemple #17
0
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, env=None,
        stdout=None, stderr=None, verbose=1):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please check that you have '
                                    'enough resources to run all Horovod processes. Each Horovod '
                                    'process runs in a Spark task. You may need to increase the '
                                    'start_timeout parameter to a larger value if your Spark resources '
                                    'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     key=secret.make_secret_key(),
                                     timeout=tmout)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print('Running %d processes (inferred from spark.default.parallelism)...' % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs,
                                               settings.key)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings)
    try:
        driver.wait_for_initial_registration(settings.timeout)
        if settings.verbose >= 2:
            print('Initial Spark task registration is complete.')
        task_clients = [
            task_service.SparkTaskClient(index,
                                         driver.task_addresses_for_driver(index),
                                         settings.key, settings.verbose)
            for index in range(settings.num_proc)]
        for task_client in task_clients:
            task_client.notify_initial_registration_complete()
        driver.wait_for_task_to_task_address_updates(settings.timeout)
        if settings.verbose >= 2:
            print('Spark task-to-task address registration is complete.')

        # Determine a set of common interfaces for task-to-task communication.
        common_intfs = set(driver.task_addresses_for_tasks(0).keys())
        for index in range(1, settings.num_proc):
            common_intfs.intersection_update(driver.task_addresses_for_tasks(index).keys())
        if not common_intfs:
            raise Exception('Unable to find a set of common task-to-task communication interfaces: %s'
                            % [(index, driver.task_addresses_for_tasks(index)) for index in range(settings.num_proc)])

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        if env is None:
            env = os.environ.copy()

        # Pass secret key through the environment variables.
        env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)

        mpirun_command = (
            'mpirun --allow-run-as-root --tag-output '
            '-np {num_proc} -H {hosts} '
            '-bind-to none -map-by slot '
            '-mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include {common_intfs} '
            '-x NCCL_DEBUG=INFO -x NCCL_SOCKET_IFNAME={common_intfs} '
            '{env} '  # expect a lot of environment variables
            '-mca plm_rsh_agent "{python} -m horovod.spark.driver.mpirun_rsh {encoded_driver_addresses} {settings}" '
            '{python} -m horovod.spark.task.mpirun_exec_fn {encoded_driver_addresses} {settings}'
                .format(num_proc=settings.num_proc,
                        hosts=','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash]))
                                       for host_hash in host_hashes),
                        common_intfs=','.join(common_intfs),
                        env=' '.join('-x %s' % key for key in env.keys() if env_util.is_exportable(key)),
                        python=sys.executable,
                        encoded_driver_addresses=codec.dumps_base64(driver.addresses()),
                        settings=codec.dumps_base64(settings)))
        if settings.verbose >= 2:
            print('+ %s' % mpirun_command)
        exit_code = safe_shell_exec.execute(mpirun_command, env, stdout, stderr)
        if exit_code != 0:
            raise Exception('mpirun exited with code %d, see the error above.' % exit_code)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]
Exemple #18
0
class RunTests(unittest.TestCase):
    """
    Tests for horovod.run.
    """
    def __init__(self, *args, **kwargs):
        super(RunTests, self).__init__(*args, **kwargs)
        warnings.simplefilter('module')

    def test_params_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '10', '--cycle-time-ms', '20', '--cache-capacity',
                           '512', '--hierarchical-allreduce',
                           '--hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_FUSION_THRESHOLD),
                             str(10 * 1024 * 1024))
            self.assertEqual(env.get(config_parser.HOROVOD_CYCLE_TIME), '20.0')
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '512')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '1')

    def test_autotune_args(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--autotune-log-file', '/tmp/autotune.txt',
                           '--autotune-warmup-samples', '1',
                           '--autotune-steps-per-sample', '5',
                           '--autotune-bayes-opt-max-samples', '10',
                           '--autotune-gaussian-process-noise', '0.2'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_AUTOTUNE_LOG),
                             '/tmp/autotune.txt')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_WARMUP_SAMPLES), '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_STEPS_PER_SAMPLE), '5')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_BAYES_OPT_MAX_SAMPLES),
                '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_AUTOTUNE_GAUSSIAN_PROCESS_NOISE),
                '0.2')

    def test_autotuning_with_fixed_param(self):
        with override_args('horovodrun', '-np', '2', '--autotune',
                           '--cache-capacity', '1024',
                           '--no-hierarchical-allgather'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_FUSION_THRESHOLD, env)
            self.assertNotIn(config_parser.HOROVOD_CYCLE_TIME, env)
            self.assertEqual(env.get(config_parser.HOROVOD_CACHE_CAPACITY),
                             '1024')
            self.assertNotIn(config_parser.HOROVOD_HIERARCHICAL_ALLREDUCE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_HIERARCHICAL_ALLGATHER), '0')

    def test_timeline_args(self):
        with override_args('horovodrun', '-np', '2', '--timeline-filename',
                           '/tmp/timeline.json', '--timeline-mark-cycles'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_TIMELINE),
                             '/tmp/timeline.json')
            self.assertEqual(
                env.get(config_parser.HOROVOD_TIMELINE_MARK_CYCLES), '1')

    def test_stall_check_args(self):
        with override_args('horovodrun', '-np', '2', '--no-stall-check'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_DISABLE), '1')

        with override_args('horovodrun', '-np', '2',
                           '--stall-check-warning-time-seconds', '10',
                           '--stall-check-shutdown-time-seconds', '20'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertNotIn(config_parser.HOROVOD_STALL_CHECK_DISABLE, env)
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_CHECK_TIME_SECONDS), '10')
            self.assertEqual(
                env.get(config_parser.HOROVOD_STALL_SHUTDOWN_TIME_SECONDS),
                '20')

    def test_library_args(self):
        with override_args('horovodrun', '-np', '2', '--mpi-threads-disable',
                           '--num-nccl-streams', '2', '--ccl-bgt-affinity',
                           '1', '--gloo-timeout-seconds', '60'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(
                env.get(config_parser.HOROVOD_MPI_THREADS_DISABLE), '1')
            self.assertEqual(env.get(config_parser.HOROVOD_NUM_NCCL_STREAMS),
                             '2')
            self.assertEqual(env.get(config_parser.HOROVOD_CCL_BGT_AFFINITY),
                             '1')
            self.assertEqual(
                env.get(config_parser.HOROVOD_GLOO_TIMEOUT_SECONDS), '60')

    def test_logging_args(self):
        with override_args('horovodrun', '-np', '2', '--log-level', 'INFO',
                           '--log-hide-timestamp'):
            args = parse_args()
            env = {}
            config_parser.set_env_from_args(env, args)

            self.assertEqual(env.get(config_parser.HOROVOD_LOG_LEVEL), 'INFO')
            self.assertEqual(env.get(config_parser.HOROVOD_LOG_HIDE_TIME), '1')

    def test_config_file(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args('horovodrun', '-np', '2', '--config-file',
                           config_filename):
            args = parse_args()

            self.assertTrue(args.use_gloo)

            # Params
            self.assertEqual(args.fusion_threshold_mb, 32)
            self.assertEqual(args.cycle_time_ms, 10)
            self.assertEqual(args.cache_capacity, 2048)
            self.assertTrue(args.hierarchical_allreduce)
            self.assertTrue(args.hierarchical_allgather)

            # Autotune
            self.assertTrue(args.autotune)
            self.assertEqual(args.autotune_log_file,
                             'horovod_autotune_log.txt')
            self.assertEqual(args.autotune_warmup_samples, 5)
            self.assertEqual(args.autotune_steps_per_sample, 20)
            self.assertEqual(args.autotune_bayes_opt_max_samples, 50)
            self.assertEqual(args.autotune_gaussian_process_noise, 0.9)

            # Timeline
            self.assertEqual(args.timeline_filename, 'horovod_timeline.json')
            self.assertTrue(args.timeline_mark_cycles)

            # Stall Check
            self.assertFalse(args.no_stall_check)
            self.assertEqual(args.stall_check_warning_time_seconds, 120)
            self.assertEqual(args.stall_check_shutdown_time_seconds, 240)

            # Library Options
            self.assertTrue(args.mpi_threads_disable)
            self.assertEqual(args.num_nccl_streams, 2)
            self.assertEqual(args.ccl_bgt_affinity, 1)
            self.assertEqual(args.gloo_timeout_seconds, 60)

            # Logging
            self.assertEqual(args.log_level, 'INFO')
            self.assertTrue(args.log_hide_timestamp)

    def test_config_file_override_args(self):
        config_filename = os.path.join(os.path.dirname(__file__),
                                       'data/config.test.yaml')
        with override_args(
                'horovodrun',
                '-np',
                '2',
                '--fusion-threshold-mb',
                '128',
                '--config-file',
                config_filename,
                '--cycle-time-ms',
                '20',
        ):
            args = parse_args()
            self.assertEqual(args.fusion_threshold_mb, 128)
            self.assertEqual(args.cycle_time_ms, 20)

    def test_validate_config_args(self):
        with override_args('horovodrun', '-np', '2', '--fusion-threshold-mb',
                           '-1'):
            with pytest.raises(ValueError):
                parse_args()

    def test_hash(self):
        hash = _hash("test string")
        self.assertEqual(hash, '6f8db599de986fab7a21625b7916589c')

    def test_host_hash(self):
        hash = host_hash()
        # host_hash should consider CONTAINER_ID environment variable
        with override_env({'CONTAINER_ID': 'a container id'}):
            self.assertNotEqual(host_hash(), hash)
        self.assertEqual(host_hash(), hash)

    def test_settings_dump_drops_key(self):
        settings = hvd_settings.Settings(verbose=2, key="a secret key")
        clone = codec.loads_base64(codec.dumps_base64(settings))
        self.assertEqual(settings.verbose, clone.verbose)
        self.assertIsNotNone(settings.key)
        self.assertIsNone(clone.key)

    def test_get_mpi_implementation(self):
        def test(output, expected, exit_code=0):
            ret = (output, exit_code) if output is not None else None
            with mock.patch("horovod.run.mpi_run.tiny_shell_exec.execute",
                            return_value=ret):
                implementation = _get_mpi_implementation()
                self.assertEqual(expected, implementation)

        test(("mpirun (Open MPI) 2.1.1\n"
              "Report bugs to http://www.open-mpi.org/community/help/\n"),
             _OMPI_IMPL)

        test("OpenRTE", _OMPI_IMPL)

        test("IBM Spectrum MPI", _SMPI_IMPL)

        test(("HYDRA build details:\n"
              "    Version:           3.3a2\n"
              "    Configure options: 'MPICHLIB_CFLAGS=-g -O2'\n"),
             _MPICH_IMPL)

        test("Unknown MPI v1.00", _UNKNOWN_IMPL)

        test("output", exit_code=1, expected=_MISSING_IMPL)

        test(None, _MISSING_IMPL)

    def test_run_controller(self):
        def test(use_gloo, use_mpi, use_js, gloo_is_built, mpi_is_built,
                 lsf_exists, jsrun_installed, expected, exception):
            print('testing run controller with gloo={gloo} mpi={mpi} js={js} '
                  'gloo_built={gloo_is_built} mpi_built={mpi_is_built} '
                  'lsf_exists={lsf} js_installed={js_is_installed} '
                  'expected={expected} exception={exception}'.format(
                      gloo=use_gloo,
                      mpi=use_mpi,
                      js=use_js,
                      gloo_is_built=gloo_is_built,
                      mpi_is_built=mpi_is_built,
                      lsf=lsf_exists,
                      js_is_installed=jsrun_installed,
                      expected=expected,
                      exception=exception))

            gloo_run = MagicMock()
            mpi_run = MagicMock()
            js_run = MagicMock()

            with is_built(gloo_is_built, mpi_is_built):
                with lsf_and_jsrun(lsf_exists, jsrun_installed):
                    if exception is not None:
                        with pytest.raises(ValueError, match=exception) as e:
                            run_controller(use_gloo,
                                           gloo_run,
                                           use_mpi,
                                           mpi_run,
                                           use_js,
                                           js_run,
                                           verbosity=2)
                        return
                    run_controller(use_gloo,
                                   gloo_run,
                                   use_mpi,
                                   mpi_run,
                                   use_js,
                                   js_run,
                                   verbosity=2)

            if expected == "gloo":
                gloo_run.assert_called_once()
                mpi_run.assert_not_called()
                js_run.assert_not_called()
            elif expected == "mpi":
                gloo_run.assert_not_called()
                mpi_run.assert_called_once()
                js_run.assert_not_called()
            elif expected == "js":
                gloo_run.assert_not_called()
                mpi_run.assert_not_called()
                js_run.assert_called_once()
            else:
                raise ValueError("unsupported framework: {}".format(expected))

        bool_values = [False, True]
        bool_values_and_none = [None, False, True]

        for use_gloo, use_mpi, use_js, \
            gloo_is_built, mpi_is_built, \
            lsf_exists, jsrun_installed in \
            itertools.product(bool_values_and_none, bool_values_and_none, bool_values_and_none,
                              bool_values, bool_values,
                              bool_values, bool_values):

            expected = exception = None
            if use_gloo:
                if gloo_is_built:
                    expected = 'gloo'
                else:
                    exception = '^Gloo support has not been built\.  If this is not expected, ensure CMake is installed ' \
                                'and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error\.$'
            elif use_mpi:
                if mpi_is_built:
                    expected = 'mpi'
                else:
                    exception = '^MPI support has not been built\.  If this is not expected, ensure MPI is installed ' \
                                'and reinstall Horovod with HOROVOD_WITH_MPI=1 to debug the build error\.$'
            elif use_js:
                if mpi_is_built:
                    if lsf_exists:
                        expected = 'js'
                    else:
                        exception = 'Horovod did not detect an LSF job.  The jsrun launcher can only be used in that environment. ' \
                                    'Please, pick a different launcher for other environments.'
                else:
                    exception = '^MPI support has not been built\.  If this is not expected, ensure MPI is installed ' \
                                'and reinstall Horovod with HOROVOD_WITH_MPI=1 to debug the build error\.$'
            elif mpi_is_built:
                if lsf_exists and jsrun_installed:
                    expected = 'js'
                else:
                    expected = 'mpi'
            elif gloo_is_built:
                expected = 'gloo'
            else:
                exception = 'Neither MPI nor Gloo support has been built\. Try reinstalling Horovod ensuring that ' \
                            'either MPI is installed \(MPI\) or CMake is installed \(Gloo\)\.'

            test(use_gloo, use_mpi, use_js, gloo_is_built, mpi_is_built,
                 lsf_exists, jsrun_installed, expected, exception)

    """
    Minimal mpi_run settings for tests.
    """
    minimal_settings = hvd_settings.Settings(verbose=0,
                                             num_hosts=1,
                                             num_proc=2,
                                             hosts='host',
                                             run_func_mode=True)
    """
    Tests mpi_run with minimal settings.
    """

    def test_mpi_run_minimal(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], ["--mock-mpi-binding-args"]

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, None, {}, cmd)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, binding_args = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_cmd = ('mpirun '
                                '--allow-run-as-root --tag-output '
                                '-np 2 -H host '
                                '{binding_args} '
                                '{mpi_flags}       '
                                'cmd').format(
                                    binding_args=' '.join(binding_args),
                                    mpi_flags=' '.join(mpi_flags))
                expected_env = {}
                execute.assert_called_once_with(expected_cmd,
                                                env=expected_env,
                                                stdout=None,
                                                stderr=None)

    """
    Tests mpi_run on a large cluster.
    """

    def test_mpi_run_on_large_cluster(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = copy.copy(self.minimal_settings)
        settings.num_hosts = large_cluster_threshold

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], ["--mock-mpi-binding-args"]

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, None, {}, cmd)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, binding_args = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                mpi_flags.append('-mca plm_rsh_no_tree_spawn true')
                mpi_flags.append('-mca plm_rsh_num_concurrent {}'.format(
                    settings.num_hosts))
                expected_cmd = ('mpirun '
                                '--allow-run-as-root --tag-output '
                                '-np 2 -H host '
                                '{binding_args} '
                                '{mpi_flags}       '
                                'cmd').format(
                                    binding_args=' '.join(binding_args),
                                    mpi_flags=' '.join(mpi_flags))
                expected_env = {}
                execute.assert_called_once_with(expected_cmd,
                                                env=expected_env,
                                                stdout=None,
                                                stderr=None)

    """
    Tests mpi_run with full settings.
    """

    def test_mpi_run_full(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        nics = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            timeout=tmout,
            num_hosts=1,
            num_proc=1,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                mpi_run(settings, nics, env, cmd, stdout=stdout, stderr=stderr)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, _ = horovod.run.mpi_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'mpirun '
                    '--allow-run-as-root --tag-output '
                    '-np 1 -H >host names go here< '
                    '>binding args go here< '
                    '{mpi_flags} '
                    '-mca plm_rsh_args "-p 1022" '
                    '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
                    '--output-filename >output filename goes here< '
                    '-x env1 -x env2 '
                    '>mpi-extra args go here< '
                    'cmd arg1 arg2').format(mpi_flags=' '.join(mpi_flags))
                expected_env = {'env1': 'val1', 'env2': 'val2'}
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)

    def test_mpi_run_with_non_zero_exit(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd']
        settings = self.minimal_settings

        def mpi_impl_flags(tcp):
            return [], []

        with mock.patch("horovod.run.mpi_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.mpi_run.safe_shell_exec.execute",
                            return_value=1):
                with pytest.raises(RuntimeError,
                                   match="^mpirun failed with exit code 1$"):
                    mpi_run(settings, None, {}, cmd)

    def test_horovodrun_hostfile(self):
        with temppath() as host_filename:
            with open(host_filename, 'w+') as fp:
                fp.write('172.31.32.7 slots=8\n')
                fp.write('172.31.33.9 slots=8\n')

            hosts = parse_host_files(host_filename)
            self.assertEqual(hosts, '172.31.32.7:8,172.31.33.9:8')

    """
    Tests js_run.
    """

    @mock.patch('horovod.run.js_run.is_jsrun_installed',
                MagicMock(return_value=True))
    @mock.patch('horovod.run.js_run.generate_jsrun_rankfile',
                MagicMock(return_value='/tmp/rankfile'))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_gpus',
                MagicMock(return_value=2))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_cores',
                MagicMock(return_value=2))
    def test_js_run(self):
        if _get_mpi_implementation_flags(False)[0] is None:
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        settings = hvd_settings.Settings(
            verbose=0,
            extra_mpi_args='>mpi-extra args go here<',
            num_hosts=2,
            num_proc=4,
            hosts='>host names go here<',
            output_filename='>output filename goes here<',
            run_func_mode=True)

        def mpi_impl_flags(tcp):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.run.js_run._get_mpi_implementation_flags",
                        side_effect=mpi_impl_flags):
            with mock.patch("horovod.run.js_run.safe_shell_exec.execute",
                            return_value=0) as execute:
                js_run(settings, None, env, cmd, stdout=stdout, stderr=stderr)

                # call the mocked _get_mpi_implementation_flags method
                mpi_flags, _ = horovod.run.js_run._get_mpi_implementation_flags(
                    False)
                self.assertIsNotNone(mpi_flags)
                expected_command = (
                    'jsrun '
                    '--erf_input /tmp/rankfile '
                    '--stdio_stderr >output filename goes here< '
                    '--stdio_stdout >output filename goes here< '
                    '--smpiargs \'{mpi_args} >mpi-extra args go here<\' '
                    'cmd arg1 arg2').format(mpi_args=' '.join(mpi_flags))
                expected_env = {'env1': 'val1', 'env2': 'val2'}
                execute.assert_called_once_with(expected_command,
                                                env=expected_env,
                                                stdout=stdout,
                                                stderr=stderr)

    """
    Tests generate_jsrun_rankfile.
    """

    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_gpus',
                MagicMock(return_value=4))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_cores',
                MagicMock(return_value=4))
    @mock.patch('horovod.run.util.lsf.LSFUtils.get_num_threads',
                MagicMock(return_value=4))
    def test_generate_jsrun_rankfile(self):
        settings = hvd_settings.Settings(
            num_proc=5,
            hosts='host1:4,host2:4,host3:4',
        )

        with temppath() as rankfile_path:
            rankfile_path = generate_jsrun_rankfile(settings, rankfile_path)

            with open(rankfile_path, 'r') as file:
                gen_rankfile = file.read()

            expected_rankfile = ("""overlapping_rs: allow
cpu_index_using: logical

rank: 0: { hostname: host1; cpu: {0-3} ; gpu: * ; mem: * }
rank: 1: { hostname: host1; cpu: {4-7} ; gpu: * ; mem: * }
rank: 2: { hostname: host1; cpu: {8-11} ; gpu: * ; mem: * }
rank: 3: { hostname: host1; cpu: {12-15} ; gpu: * ; mem: * }

rank: 4: { hostname: host2; cpu: {0-3} ; gpu: * ; mem: * }
""")

            self.assertMultiLineEqual(gen_rankfile, expected_rankfile)