Beispiel #1
0
    def test_host_hash(self):
        hash = host_hash()
        salted = host_hash('salt')
        empty_salted = host_hash('')

        self.assertNotEqual(salted, hash)
        self.assertEqual(empty_salted, hash)
Beispiel #2
0
def _task_fn(index, driver_addresses, settings):
    task = task_service.SparkTaskService(index, settings.key)
    try:
        driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(), host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        # Tasks ping each other in a circular fashion to determine interfaces reachable within
        # the cluster.
        next_task_index = (index + 1) % settings.num_proc
        next_task_addresses = driver_client.all_task_addresses(next_task_index)
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task_client = \
            task_service.SparkTaskClient(next_task_index, next_task_addresses,
                                         settings.key, settings.verbose,
                                         match_intf=True)
        driver_client.register_task_to_task_addresses(next_task_index, next_task_client.addresses())
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())
        if task_indices_on_this_host[0] == index:
            # Task with first index will execute orted that will run mpirun_exec_fn for all tasks.
            task.wait_for_command_start(settings.timeout)
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()
        return task.fn_result()
    finally:
        task.shutdown()
Beispiel #3
0
def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics,
                                         settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()
        return task.fn_result()
    finally:
        task.shutdown()
Beispiel #4
0
def _task_fn(index, num_hosts, driver_addresses, settings):
    task = task_service.HorovodRunTaskService(index, settings.key,
                                              settings.nics)
    try:
        driver = driver_service.HorovodRunDriverClient(driver_addresses,
                                                       settings.key,
                                                       settings.verbose)
        driver.register_task(index, task.addresses(), host_hash.host_hash())
        task.wait_for_initial_registration(settings.start_timeout)
        # Tasks ping each other in a circular fashion to determine interfaces
        # reachable within the cluster.
        next_task_index = (index + 1) % num_hosts
        next_task_addresses = driver.all_task_addresses(next_task_index)
        # We request interface matching to weed out all the NAT'ed interfaces.
        next_task = task_service.HorovodRunTaskClient(next_task_index,
                                                      next_task_addresses,
                                                      settings.key,
                                                      settings.verbose,
                                                      match_intf=True,
                                                      attempts=10)
        driver.register_task_to_task_addresses(next_task_index,
                                               next_task.addresses())
        # Notify the next task that the address checks are completed.
        next_task.task_to_task_address_check_completed()
        # Wait to get a notification from previous task that its address checks
        # are completed as well.
        task.wait_for_task_to_task_address_check_finish_signal(
            settings.start_timeout)

    finally:
        task.shutdown()
Beispiel #5
0
def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics,
                                         settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(),
                                    host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(
            host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        minimum_lifetime_after_start = None
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            minimum_lifetime_after_start = timeout.Timeout(
                MINIMUM_COMMAND_LIFETIME_S, message='Just measuring runtime')
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        # command terminated, make sure this task service does not shutdown too quickly after
        # the client started the command as it needs some time to connect again
        # to wait for the result after starting the command (see horovod.spark.driver.rsh).
        if minimum_lifetime_after_start is not None:
            time.sleep(minimum_lifetime_after_start.remaining())

        return task.fn_result()
    finally:
        # this has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()
Beispiel #6
0
def host_hash(salt=None):
    """
    Computes this host's host hash by invoking horovod.run.common.util.host_hash.host_hash.

    Consider environment variable CONTAINER_ID which is present when running Spark via YARN.
    A YARN container does not share memory with other containers on the same host,
    so it must be considered a `host` in the sense of the `host_hash`.

    :param salt: extra information to include in the hash, ignores Falsy values
    :return: host hash
    """
    # turn salt into an array of a single string if given
    salt = [str(salt)] if salt else []

    # We would violate resource allocation if we run all tasks of a host in one container.
    # See [issues 1497](https://github.com/horovod/horovod/issues/1497) for details.
    container = os.environ.get("CONTAINER_ID")
    if container is not None:
        salt.append(container)

    return hh.host_hash(salt='-'.join(salt))
Beispiel #7
0
 def test_host_hash(self):
     hash = host_hash()
     # host_hash should consider CONTAINER_ID environment variable
     with override_env({'CONTAINER_ID': 'a container id'}):
         self.assertNotEqual(host_hash(), hash)
     self.assertEqual(host_hash(), hash)