示例#1
0
    def test_shutdown_during_request_basic(self):
        sleep = 2.0
        key = secret.make_secret_key()
        service = TestSleepService(key, duration=sleep)
        try:
            client = TestSleepClient(service.addresses(), key, attempts=1)
            start = time.time()
            threads = list([
                in_thread(client.sleep,
                          name='request {}'.format(i + 1),
                          daemon=False) for i in range(5)
            ])
            time.sleep(sleep / 2.0)
        finally:
            service.shutdown()

        duration = time.time() - start
        print('shutdown completed in {} seconds'.format(duration))
        self.assertGreaterEqual(duration, sleep,
                                'sleep requests should have been completed')
        self.assertLess(duration, sleep + 1.0,
                        'sleep requests should have been concurrent')

        for thread in threads:
            thread.join(0.1)
            self.assertFalse(thread.is_alive(),
                             'thread should have terminated by now')
    def test_invalid_dispatcher_ids(self):
        key = secret.make_secret_key()
        service = ComputeService(2, 4, key, nics=None)
        try:
            client = ComputeClient(service.addresses(), key, verbose=2)

            with self.assertRaises(IndexError):
                client.register_dispatcher(-1, 'grpc://localhost:10000')
            with self.assertRaises(IndexError):
                client.register_dispatcher(2, 'grpc://localhost:10000')

            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_registration(-1, 0.1)
            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_registration(2, 0.1)

            with self.assertRaises(IndexError):
                client.register_worker_for_dispatcher(-1, 0)
            with self.assertRaises(IndexError):
                client.register_worker_for_dispatcher(2, 0)

            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_worker_registration(-1, 0.1)
            with self.assertRaises(IndexError):
                client.wait_for_dispatcher_worker_registration(2, 0.1)
        finally:
            service.shutdown()
示例#3
0
    def test_shutdown_during_request_basic_task(self):
        result_queue = queue.Queue(1)

        def wait_for_exit_code(client, queue):
            queue.put(client.wait_for_command_exit_code())

        key = secret.make_secret_key()
        service_name = 'test-service'
        service = BasicTaskService(service_name, key, nics=None, verbose=2)
        client = BasicTaskClient(service_name, service.addresses(), key, verbose=2, attempts=1)
        thread = threading.Thread(target=wait_for_exit_code, args=(client, result_queue))

        start = time.time()
        thread.start()  # wait for command exit code
        client.run_command('sleep 2', {})  # execute command
        time.sleep(0.5)  # give the thread some time to connect before shutdown
        service.shutdown()  # shutdown should wait on request to finish
        duration = time.time() - start
        self.assertGreaterEqual(duration, 2)

        # we cannot call after shutdown
        with pytest.raises(Exception, match=r'^(\[[Ee]rrno 104\] Connection reset by peer)'
                                            r'|(\[[Ee]rrno 111\] Connection refused)$'):
            client.command_result()

        # but still our long running request succeeded
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
示例#4
0
文件: worker.py 项目: zw0610/horovod
    def init(self,
             rendezvous_addr=None,
             rendezvous_port=None,
             nic=None,
             hostname=None,
             local_rank=None):
        with self._lock:
            if self._service:
                return

            rendezvous_addr = rendezvous_addr or os.environ.get(
                HOROVOD_GLOO_RENDEZVOUS_ADDR)
            if not rendezvous_addr:
                return

            rendezvous_port = rendezvous_port if rendezvous_port is not None else \
                int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
            nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
            hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
            local_rank = local_rank if local_rank is not None else \
                int(os.environ.get(HOROVOD_LOCAL_RANK))

            secret_key = secret.make_secret_key()
            self._service = WorkerNotificationService(secret_key, nic, self)

            value = (self._service.addresses(), secret_key)
            put_data_into_kvstore(rendezvous_addr, rendezvous_port,
                                  PUT_WORKER_ADDRESSES,
                                  self._create_id(hostname, local_rank), value)
示例#5
0
def _run_elastic(args):
    # construct host discovery component
    if args.host_discovery_script:
        discover_hosts = discovery.HostDiscoveryScript(
            args.host_discovery_script, args.slots)
    elif args.hosts:
        _, available_host_slots = hosts.parse_hosts_and_slots(args.hosts)
        if len(available_host_slots) < 2:
            raise ValueError(
                'Cannot run in fault tolerance mode with fewer than 2 hosts.')
        discover_hosts = discovery.FixedHosts(available_host_slots)
    else:
        raise ValueError(
            'One of --host-discovery-script, --hosts, or --hostnames must be provided'
        )

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = elastic_settings.ElasticSettings(
        discovery=discover_hosts,
        min_num_proc=args.min_num_proc or args.num_proc,
        max_num_proc=args.max_num_proc,
        elastic_timeout=args.elastic_timeout,
        reset_limit=args.reset_limit,
        cooldown_range=args.cooldown_range,
        num_proc=args.num_proc,
        verbose=2 if args.verbose else 0,
        ssh_port=args.ssh_port,
        ssh_identity_file=args.ssh_identity_file,
        extra_mpi_args=args.mpi_args,
        key=secret.make_secret_key(),
        start_timeout=tmout,
        output_filename=args.output_filename,
        run_func_mode=args.run_func is not None,
        nics=args.nics,
        prefix_output_with_timestamp=args.prefix_output_with_timestamp)

    if not gloo_built(verbose=(settings.verbose >= 2)):
        raise ValueError(
            'Gloo support is required to use elastic training, but has not been built.  Ensure CMake is '
            'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.'
        )

    env = os.environ.copy()
    config_parser.set_env_from_args(env, args)
    executable = args.executable or sys.executable
    return gloo_run_elastic(settings, env,
                            args.run_func if args.run_func else args.command,
                            executable)
示例#6
0
文件: horovod.py 项目: stjordanis/ray
    def __post_init__(self):
        if self.ssh_str and not os.path.exists(self.ssh_identity_file):
            with open(self.ssh_identity_file, "w") as f:
                os.chmod(self.ssh_identity_file, 0o600)
                f.write(self.ssh_str)

        if self.key is None:
            self.key = secret.make_secret_key()
示例#7
0
    def test_mpi_run_full(self):
        if not mpi_available():
            self.skipTest("MPI is not available")

        cmd = ['cmd', 'arg1', 'arg2']
        nics = ['eth0', 'eth1']
        env = {'env1': 'val1', 'env2': 'val2'}
        stdout = '<stdout>'
        stderr = '<stderr>'
        tmout = timeout.Timeout(5, message='Timed out waiting for something.')
        settings = hvd_settings.Settings(
            verbose=0,
            ssh_port=1022,
            extra_mpi_args='>mpi-extra args go here<',
            binding_args='>binding args go here<',
            key=secret.make_secret_key(),
            start_timeout=tmout,
            num_proc=1,
            hosts='localhost:1',
            output_filename='>output filename goes here<',
            run_func_mode=True
        )

        def mpi_impl_flags(tcp, env=None):
            return ["--mock-mpi-impl-flags"], []

        with mock.patch("horovod.runner.mpi_run._get_mpi_implementation_flags", side_effect=mpi_impl_flags) as impl:
            with mock.patch("horovod.runner.mpi_run.safe_shell_exec.execute", return_value=0) as execute:
                mpi_run(settings, nics, env, cmd, stdout=stdout, stderr=stderr)

                # assert call on _get_mpi_implementation_flags
                impl.assert_called_once_with(None, env=env)

                # call the mocked _get_mpi_implementation_flags method ourselves
                mpi_flags, _ = horovod.runner.mpi_run._get_mpi_implementation_flags(False)
                self.assertIsNotNone(mpi_flags)
                expected_command = ('mpirun '
                                    '--allow-run-as-root --tag-output '
                                    '-np 1 -H {hosts} '
                                    '>binding args go here< '
                                    '{mpi_flags} '
                                    '-mca plm_rsh_args "-p 1022" '
                                    '-mca btl_tcp_if_include eth0,eth1 -x NCCL_SOCKET_IFNAME=eth0,eth1 '
                                    '--output-filename >output filename goes here< '
                                    '-x env1 -x env2 '
                                    '>mpi-extra args go here< '
                                    'cmd arg1 arg2').format(hosts=settings.hosts,
                                                            mpi_flags=' '.join(mpi_flags))

                # remove PYTHONPATH from execute's env
                # we cannot know the exact value of that env variable
                # we test right handling of PYTHONPATH in test_mpi_run_*pythonpath* below
                self.assertIn('env', execute.call_args.kwargs)
                if 'PYTHONPATH' in execute.call_args.kwargs['env']:
                    execute.call_args.kwargs['env'].pop('PYTHONPATH')

                expected_env = {'env1': 'val1', 'env2': 'val2', 'PATH': os.environ.get('PATH')}
                execute.assert_called_once_with(expected_command, env=expected_env, stdout=stdout, stderr=stderr)
示例#8
0
    def test_exit_code(self):
        """test non-zero exit code"""
        key = secret.make_secret_key()
        service_name = 'test-service'
        service = BasicTaskService(service_name, key, nics=None, verbose=2)
        client = BasicTaskClient(service_name, service.addresses(), key, verbose=2, attempts=1)

        client.run_command('false', {})
        res = client.wait_for_command_exit_code()
        self.assertEqual(1, res)
示例#9
0
def spark_task_service(index, key=None, nics=None, match_intf=False,
                       minimum_command_lifetime_s=0, verbose=2):
    key = key or secret.make_secret_key()
    task = SparkTaskService(index, key, nics, minimum_command_lifetime_s, verbose)
    client = SparkTaskClient(index, task.addresses(), key, verbose, match_intf)

    try:
        yield task, client, key
    finally:
        task.shutdown()
示例#10
0
    def create_settings(min_np: int = 1,
                        max_np: int = None,
                        reset_limit: int = None,
                        elastic_timeout: int = 600,
                        timeout_s: int = 30,
                        ssh_identity_file: str = None,
                        nics: str = None,
                        **kwargs):
        """Returns a Settings object for ElasticRayExecutor.

        Note that the `discovery` property will be set at runtime.

        Args:
            min_np (int): Minimum number of processes running for
                training to continue. If number of available processes dips
                below this threshold, then training will wait for
                more instances to become available.
            max_np (int): Maximum number of training processes,
                beyond which no additional processes will be created.
                If not specified, then will be unbounded.
            reset_limit (int): Maximum number of times that the training
                job can scale up or down the number of workers after
                which the job is terminated.
            elastic_timeout (int): Timeout for elastic initialisation after
                re-scaling the cluster. The default value is 600 seconds.
                Alternatively, the environment variable
                HOROVOD_ELASTIC_TIMEOUT can also be used.'
            timeout_s (int): Horovod performs all the checks and starts the
                processes before the specified timeout.
                The default value is 30 seconds.
            ssh_identity_file (str): File on the driver from which
                the identity (private key) is read.
            nics (set): Network interfaces that can be used for communication.
        """
        start_timeout = timeout.Timeout(
            timeout_s,
            message="Timed out waiting for {activity}. Please "
            "check connectivity between servers. You "
            "may need to increase the --start-timeout "
            "parameter if you have too many servers.")
        ssh_identity_file = ssh_identity_file or os.path.expanduser(
            "~/ray_bootstrap_key.pem")
        settings = ElasticSettings(
            discovery=None,
            min_np=min_np,
            max_np=max_np,
            elastic_timeout=elastic_timeout,
            reset_limit=reset_limit,
            num_proc=min_np,
            ssh_identity_file=ssh_identity_file,
            nics=nics,
            start_timeout=start_timeout,
            key=secret.make_secret_key() if secret else None,
            **kwargs)
        return settings
示例#11
0
def spark_driver_service(num_proc, initial_np=None, fn=fn, args=(), kwargs={},
                         key=None, nics=None, verbose=2):
    initial_np = initial_np or num_proc
    key = key or secret.make_secret_key()
    driver = SparkDriverService(initial_np, num_proc, fn, args, kwargs, key, nics)
    client = SparkDriverClient(driver.addresses(), key, verbose)

    try:
        yield driver, client, key
    finally:
        driver.shutdown()
示例#12
0
文件: horovod.py 项目: RuofanKong/ray
    def __post_init__(self):
        if Coordinator is None:
            raise ValueError(
                "`horovod[ray]` is not installed. "
                "Please install 'horovod[ray]' to use this backend.")

        if self.ssh_str and not os.path.exists(self.ssh_identity_file):
            with open(self.ssh_identity_file, "w") as f:
                os.chmod(self.ssh_identity_file, 0o600)
                f.write(self.ssh_str)

        if self.key is None:
            self.key = secret.make_secret_key()
示例#13
0
 def test_register_dispatcher_duplicate(self):
     key = secret.make_secret_key()
     service = ComputeService(2, 1, key, nics=None)
     try:
         client = ComputeClient(service.addresses(), key, verbose=2)
         client.register_dispatcher(0, 'grpc://localhost:10000')
         with self.assertRaisesRegex(
                 ValueError,
                 'Dispatcher with id 0 has already been registered under '
                 'different address grpc://localhost:10000: grpc://localhost:10001'
         ):
             client.register_dispatcher(0, 'grpc://localhost:10001')
     finally:
         service.shutdown()
示例#14
0
 def test_register_dispatcher_worker_timeout(self):
     key = secret.make_secret_key()
     service = ComputeService(1, 1, key, nics=None)
     try:
         client = ComputeClient(service.addresses(), key, verbose=2)
         with self.assertRaisesRegex(
                 TimeoutException,
                 expected_regex=
                 'Timed out waiting for workers for dispatcher 0 to register. '
                 'Try to find out what takes the workers so long '
                 'to register or increase timeout. Timeout after 0.1 seconds.'
         ):
             client.wait_for_dispatcher_worker_registration(0, timeout=0.1)
     finally:
         service.shutdown()
示例#15
0
    def test_concurrent_requests_basic(self):
        sleep = 2.0
        key = secret.make_secret_key()
        service = TestSleepService(key, duration=sleep)
        client = TestSleepClient(service.addresses(), key, attempts=1)

        start = time.time()
        threads = list([in_thread(client.sleep, daemon=False) for _ in range(1)])
        for thread in threads:
            thread.join(sleep + 1.0)
            self.assertFalse(thread.is_alive(), 'thread should have terminated by now')
        duration = time.time() - start
        print('concurrent requests completed in {} seconds'.format(duration))

        self.assertGreaterEqual(duration, sleep, 'sleep requests should have been completed')
        self.assertLess(duration, sleep + 1.0, 'sleep requests should have been concurrent')
示例#16
0
    def test_stream(self):
        sleep = 2.0
        key = secret.make_secret_key()
        service = TestStreamService(key, duration=sleep)
        try:
            client = TestStreamClient(service.addresses(), key, attempts=1)

            start = time.time()
            stream = io.StringIO()
            client.sleep(stream)
            duration = time.time() - start

            self.assertEqual(f'slept {sleep}', stream.getvalue())
            self.assertGreaterEqual(duration, 2)
        finally:
            service.shutdown()
示例#17
0
    def test_register_dispatcher_worker_replay(self):
        key = secret.make_secret_key()
        service = ComputeService(1, 2, key, nics=None)
        try:
            client = ComputeClient(service.addresses(), key, verbose=2)
            client.register_dispatcher(0, 'grpc://localhost:10000')
            client.register_worker_for_dispatcher(0, 0)

            # register the same worker again should not complete the registration
            client.register_worker_for_dispatcher(0, 0)
            with self.assertRaises(TimeoutException):
                client.wait_for_dispatcher_worker_registration(0, timeout=2)

            # registering the second dispatcher completes registration
            client.register_worker_for_dispatcher(0, 1)
            client.wait_for_dispatcher_worker_registration(0, timeout=2)
        finally:
            service.shutdown()
示例#18
0
class MiniSettings:
    """Minimal settings necessary for Ray to work.

    Can be replaced with a proper Horovod Settings object.
    """
    nics: set = None
    verbose: int = 1
    key: str = secret.make_secret_key() if secret else None
    ssh_port: int = None
    ssh_identity_file: str = None
    timeout_s: int = 300

    @property
    def start_timeout(self):
        return timeout.Timeout(
            self.timeout_s,
            message="Timed out waiting for {activity}. Please "
            "check connectivity between servers. You "
            "may need to increase the --start-timeout "
            "parameter if you have too many servers.")
示例#19
0
def main(dispatchers: int, dispatcher_side: str, configfile: str,
         timeout: int):
    hvd.init()
    rank, size = hvd.rank(), hvd.size()

    if size % dispatchers:
        raise ValueError(
            f'Number of processes ({size}) must be a multiple of number of dispatchers ({dispatchers}).'
        )
    workers_per_dispatcher = size // dispatchers

    # start the compute service on rank 0
    compute = None
    try:
        compute_config = None

        if rank == 0:
            key = secret.make_secret_key()
            compute = ComputeService(dispatchers,
                                     workers_per_dispatcher,
                                     key=key)

            compute_config = TfDataServiceConfig(
                dispatchers=dispatchers,
                workers_per_dispatcher=workers_per_dispatcher,
                dispatcher_side=dispatcher_side,
                addresses=compute.addresses(),
                key=key,
                timeout=timeout)
            compute_config.write(configfile)

        # broadcast this config to all ranks via CPU ops
        with tf.device(f'/cpu:0'):
            compute_config = hvd.broadcast_object(compute_config,
                                                  name='TfDataServiceConfig')

        # start all compute workers
        compute_worker_fn(compute_config)
    finally:
        if compute is not None:
            compute.shutdown()
示例#20
0
    def test_good_path(self):
        for dispatchers_num, workers_per_dispatcher in [(1, 1), (1, 2), (1, 4),
                                                        (2, 1), (2, 2), (2, 4),
                                                        (32, 16), (1, 512)]:
            with self.subTest(dispatchers=dispatchers_num,
                              workers_per_dispatcher=workers_per_dispatcher):
                key = secret.make_secret_key()
                service = ComputeService(dispatchers_num,
                                         workers_per_dispatcher,
                                         key,
                                         nics=None)
                try:
                    client = ComputeClient(service.addresses(), key, verbose=2)

                    # create thread waiting for shutdown
                    shutdown = Queue()
                    shutdown_thread = self.wait_for_shutdown(client, shutdown)

                    # dispatcher registration
                    # start threads that wait for dispatchers
                    threads = []
                    dispatchers = Queue()
                    for id in range(dispatchers_num):
                        threads.append(
                            self.wait_for_dispatcher(client, id, dispatchers))

                    # register dispatchers
                    for id in range(dispatchers_num):
                        client.register_dispatcher(
                            id, f'grpc://localhost:{10000+id}')

                    # check threads terminate
                    for thread in threads:
                        thread.join(10)
                        self.assertFalse(
                            thread.is_alive(),
                            msg=
                            "threads waiting for dispatchers did not terminate"
                        )

                    # check reported dispatcher addresses
                    self.assertEqual([(id, f'grpc://localhost:{10000+id}')
                                      for id in range(dispatchers_num)],
                                     sorted(self.get_all(dispatchers)))

                    # worker registration
                    # start threads to wait for dispatcher worker registration
                    threads = []
                    dispatchers = Queue()
                    for id in range(dispatchers_num):
                        threads.append(
                            self.wait_for_dispatcher_workers(
                                client, id, dispatchers))

                    # register dispatcher workers
                    for id in range(dispatchers_num * workers_per_dispatcher):
                        client.register_worker_for_dispatcher(
                            dispatcher_id=id // workers_per_dispatcher,
                            worker_id=id)

                    # check threads terminate
                    for thread in threads:
                        thread.join(10)
                        self.assertFalse(
                            thread.is_alive(),
                            msg=
                            "threads waiting for dispatchers' workers did not terminate"
                        )

                    # check reported dispatcher success
                    self.assertEqual(sorted(range(dispatchers_num)),
                                     sorted(self.get_all(dispatchers)))

                    # shutdown and wait for shutdown
                    self.assertTrue(
                        shutdown_thread.is_alive(),
                        msg="thread waiting for shutdown, terminated early")
                    client.shutdown()
                    shutdown_thread.join(10)
                    self.assertFalse(
                        shutdown_thread.is_alive(),
                        msg="thread waiting for shutdown did not terminate")
                    self.assertEqual([True], list(self.get_all(shutdown)))
                finally:
                    service.shutdown()
示例#21
0
    parser.add_argument("--dispatcher-side", required=False, default='compute', type=str,
                        help=f"Where do the dispatcher run? On 'compute' side or 'training' side.",
                        dest="dispatcher_side")

    parsed_args = parser.parse_args()

    spark = SparkSession.builder.getOrCreate()
    spark_context = spark.sparkContext
    workers = spark_context.defaultParallelism

    if workers % parsed_args.dispatchers:
        raise ValueError(f'Number of processes ({workers}) must be '
                         f'a multiple of number of dispatchers ({parsed_args.dispatchers}).')
    workers_per_dispatcher = workers // parsed_args.dispatchers

    key = secret.make_secret_key()
    compute = ComputeService(parsed_args.dispatchers, workers_per_dispatcher, key=key)

    compute_config = TfDataServiceConfig(
        dispatchers=parsed_args.dispatchers,
        workers_per_dispatcher=workers_per_dispatcher,
        dispatcher_side=parsed_args.dispatcher_side,
        addresses=compute.addresses(),
        key=key
    )
    compute_config.write(parsed_args.configfile)

    def _exit_gracefully():
        logging.info('Spark driver receiving SIGTERM. Exiting gracefully')
        spark_context.stop()
示例#22
0
def run_elastic(
        fn,
        args=(),
        kwargs={},
        num_proc=None,
        min_num_proc=None,
        max_num_proc=None,
        start_timeout=None,
        elastic_timeout=None,
        reset_limit=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1,
        nics=None,
        prefix_output_with_timestamp=False,
        # np is deprecated, use min_num_proc instead
        min_np=None,
        # max_num_proc is deprecated, use max_num_proc instead
        max_np=None):
    """
    Runs Elastic Horovod on Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        min_num_proc: Minimum number of processes running for training to continue.
                      If number of available processes dips below this threshold,
                      then training will wait for more instances to become available.
        max_num_proc: Maximum number of training processes,
                      beyond which no additional processes will be created.
                      If not specified, then will be unbounded.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        elastic_timeout: Timeout for elastic initialisation after re-scaling the cluster.
                       If not set, falls back to `HOROVOD_ELASTIC_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        reset_limit: Maximum number of resets after which the job is terminated.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        stdout: Horovod stdout is redirected to this stream.
        stderr: Horovod stderr is redirected to this stream.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.
        prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver

    Returns:
        List of results returned by running `fn` on each rank.
    """
    if min_np is not None:
        min_num_proc = min_np
        warnings.warn('min_np is deprecated, use min_num_proc instead',
                      DeprecationWarning)
    if max_np is not None:
        max_num_proc = max_np
        warnings.warn('max_np is deprecated, use max_num_proc instead',
                      DeprecationWarning)

    if not gloo_built(verbose=(verbose >= 2)):
        raise ValueError(
            'Gloo support is required to use elastic training, but has not been built.  Ensure CMake is '
            'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.'
        )

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    if num_proc is None:
        # TODO: #2023 try spark.dynamicAllocation.initialExecutors
        num_proc = spark_context.defaultParallelism
        if verbose >= 1:
            logging.info(
                'Running %d processes (inferred from spark.default.parallelism)...',
                num_proc)
    else:
        if verbose >= 1:
            logging.info('Running %d processes...', num_proc)

    if min_num_proc is None:
        # TODO: #2023 try spark.dynamicAllocation.minExecutors
        min_num_proc = num_proc
    if max_num_proc is None:
        # TODO: #2023 try spark.dynamicAllocation.maxExecutors
        max_num_proc = num_proc

    # start Spark driver service and launch settings.num_proc Spark tasks
    key = secret.make_secret_key()
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(num_proc, max_num_proc, fn,
                                               args, kwargs, key, nics)

    discovery = host_discovery.SparkDriverHostDiscovery(driver)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_elastic_settings.ElasticSettings(
        discovery=discovery,
        min_num_proc=min_num_proc,
        max_num_proc=max_num_proc,
        elastic_timeout=elastic_timeout,
        reset_limit=reset_limit,
        num_proc=num_proc,
        verbose=verbose,
        key=key,
        start_timeout=tmout,
        nics=nics,
        run_func_mode=True,
        prefix_output_with_timestamp=prefix_output_with_timestamp)

    result_queue = queue.Queue(1)

    # launch settings.num_proc / settings.max_num_proc Spark tasks
    spark_thread = _make_spark_thread(spark_context,
                                      spark_job_group,
                                      driver,
                                      result_queue,
                                      settings,
                                      use_gloo=True,
                                      is_elastic=True)
    try:
        # Register task addresses of initial num_proc tasks
        _register_task_addresses(driver, settings)

        # Run the job
        gloo_run_elastic(settings, driver, env, stdout, stderr)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # get ranks from driver
    indices_in_rank_order = _get_indices_in_rank_order(driver)

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in indices_in_rank_order]
示例#23
0
def _run_static(args):
    nics_set = set(args.nics.split(',')) if args.nics else None

    # horovodrun has to finish all the checks before this timeout runs out.
    if args.start_timeout:
        start_timeout = args.start_timeout
    else:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_START_TIMEOUT', '30'))

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please '
                            'check connectivity between servers. You '
                            'may need to increase the --start-timeout '
                            'parameter if you have too many servers.')
    settings = hvd_settings.Settings(verbose=2 if args.verbose else 0,
                                     ssh_port=args.ssh_port,
                                     ssh_identity_file=args.ssh_identity_file,
                                     extra_mpi_args=args.mpi_args,
                                     tcp_flag=args.tcp_flag,
                                     binding_args=args.binding_args,
                                     key=secret.make_secret_key(),
                                     start_timeout=tmout,
                                     num_proc=args.np,
                                     hosts=args.hosts,
                                     output_filename=args.output_filename,
                                     run_func_mode=args.run_func is not None,
                                     nics=nics_set)

    # This cache stores the results of checks performed by horovod
    # during the initialization step. It can be disabled by setting
    # --disable-cache flag.
    fn_cache = None
    if not args.disable_cache:
        params = ''
        if args.np:
            params += str(args.np) + ' '
        if args.hosts:
            params += str(args.hosts) + ' '
        if args.ssh_port:
            params += str(args.ssh_port)
        if args.ssh_identity_file:
            params += args.ssh_identity_file
        parameters_hash = hashlib.md5(params.encode('utf-8')).hexdigest()
        fn_cache = cache.Cache(CACHE_FOLDER, CACHE_STALENESS_THRESHOLD_MINUTES,
                               parameters_hash)

    all_host_names, _ = hosts.parse_hosts_and_slots(args.hosts)
    if settings.verbose >= 2:
        print('Filtering local host names.')
    remote_host_names = network.filter_local_addresses(all_host_names)
    if settings.verbose >= 2:
        print('Remote host found: ' + ' '.join(remote_host_names))

    if len(remote_host_names) > 0:
        if settings.verbose >= 2:
            print('Checking ssh on all remote hosts.')
        # Check if we can ssh into all remote hosts successfully.
        if not _check_all_hosts_ssh_successful(remote_host_names,
                                               args.ssh_port,
                                               args.ssh_identity_file,
                                               fn_cache=fn_cache):
            raise RuntimeError('could not connect to some hosts via ssh')
        if settings.verbose >= 2:
            print('SSH was successful into all the remote hosts.')

    nics = driver_service.get_common_interfaces(settings, all_host_names,
                                                remote_host_names, fn_cache)

    if args.run_func:
        # get the driver IPv4 address
        driver_ip = network.get_driver_ip(nics)
        run_func_server = KVStoreServer(verbose=settings.verbose)
        run_func_server_port = run_func_server.start_server()
        put_data_into_kvstore(driver_ip, run_func_server_port, 'runfunc',
                              'func', args.run_func)

        command = [
            sys.executable, '-m', 'horovod.runner.run_task',
            str(driver_ip),
            str(run_func_server_port)
        ]

        try:
            _launch_job(args, settings, nics, command)
            results = [None] * args.np
            # TODO: make it parallel to improve performance
            for i in range(args.np):
                results[i] = read_data_from_kvstore(driver_ip,
                                                    run_func_server_port,
                                                    'runfunc_result', str(i))
            return results
        finally:
            run_func_server.shutdown_server()
    else:
        command = args.command
        _launch_job(args, settings, nics, command)
        return None
示例#24
0
def run(fn,
        args=(),
        kwargs={},
        num_proc=None,
        start_timeout=None,
        use_mpi=None,
        use_gloo=None,
        extra_mpi_args=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1,
        nics=None,
        prefix_output_with_timestamp=False,
        executable=None):
    """
    Runs Horovod on Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
        env: Environment dictionary to use in Horovod run.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout when used with MPI.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr when used with MPI.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.
        prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver
        executable: Optional executable to run when launching the workers. Defaults to `sys.executable`.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_settings.Settings(
        verbose=verbose,
        extra_mpi_args=extra_mpi_args,
        key=secret.make_secret_key(),
        start_timeout=tmout,
        nics=nics,
        run_func_mode=True,
        prefix_output_with_timestamp=prefix_output_with_timestamp)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            logging.info(
                'Running %d processes (inferred from spark.default.parallelism)...',
                num_proc)
    else:
        if settings.verbose >= 1:
            logging.info('Running %d processes...', num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    # start Spark driver service and launch settings.num_proc Spark tasks
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc,
                                               settings.num_proc, fn, args,
                                               kwargs, settings.key,
                                               settings.nics)
    gloo_is_used = is_gloo_used(use_gloo=use_gloo,
                                use_mpi=use_mpi,
                                use_jsrun=False)
    spark_thread = _make_spark_thread(spark_context,
                                      spark_job_group,
                                      driver,
                                      result_queue,
                                      settings,
                                      use_gloo=gloo_is_used,
                                      is_elastic=False)
    try:
        # wait for all tasks to register, notify them and initiate task-to-task address registration
        _notify_and_register_task_addresses(driver, settings)

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        settings.hosts = ','.join(
            '%s:%d' %
            (host_hash, len(driver.task_host_hash_indices()[host_hash]))
            for host_hash in host_hashes)

        # Run the job
        _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr,
                    executable)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # get ranks from driver
    indices_in_rank_order = _get_indices_in_rank_order(driver)

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in indices_in_rank_order]