示例#1
0
    def test_fit_model(self):
        if sys.version_info < (3, 0, 0) and is_gloo_used():
            self.skipTest(
                'Horovod on Spark over Gloo only supported on Python3')

        model = create_xor_model()
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        loss = 'binary_crossentropy'

        with spark_session('test_fit_model') as spark:
            df = create_xor_data(spark)

            with local_store() as store:
                keras_estimator = hvd.KerasEstimator(num_proc=2,
                                                     store=store,
                                                     model=model,
                                                     optimizer=optimizer,
                                                     loss=loss,
                                                     feature_cols=['features'],
                                                     label_cols=['y'],
                                                     batch_size=1,
                                                     epochs=3,
                                                     verbose=2)

                keras_model = keras_estimator.fit(df)

                trained_model = keras_model.getModel()
                pred = trained_model.predict(
                    [np.ones([1, 2], dtype=np.float32)])
                assert len(pred) == 1
                assert pred.dtype == np.float32
示例#2
0
    def test_fit_model(self):
        if sys.version_info < (3, 0, 0) and is_gloo_used():
            self.skipTest(
                'Horovod on Spark over Gloo only supported on Python3')

        model = create_xor_model()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        loss = F.binary_cross_entropy

        with spark_session('test_fit_model') as spark:
            df = create_xor_data(spark)

            with local_store() as store:
                torch_estimator = hvd.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    optimizer=optimizer,
                    loss=loss,
                    input_shapes=[[2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    batch_size=1,
                    epochs=3,
                    verbose=2,
                    sample_weight_col='weight')

                torch_model = torch_estimator.fit(df)

                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
示例#3
0
    def test_fit_model_multiclass(self):
        if sys.version_info < (3, 0, 0) and is_gloo_used():
            self.skipTest(
                'Horovod on Spark over Gloo only supported on Python3')

        model = create_mnist_model()
        optimizer = tf.keras.optimizers.Adadelta(1.0)
        loss = tf.keras.losses.categorical_crossentropy

        for num_cores in [2, constants.TOTAL_BUFFER_MEMORY_CAP_GIB + 1]:
            with spark_session('test_fit_model_multiclass',
                               cores=num_cores) as spark:
                df = create_mnist_data(spark)

                with local_store() as store:
                    keras_estimator = hvd.KerasEstimator(
                        num_proc=num_cores,
                        store=store,
                        model=model,
                        optimizer=optimizer,
                        loss=loss,
                        metrics=['accuracy'],
                        feature_cols=['features'],
                        label_cols=['label_vec'],
                        batch_size=2,
                        epochs=2,
                        verbose=2)

                    keras_model = keras_estimator.fit(df).setOutputCols(
                        ['label_prob'])
                    pred_df = keras_model.transform(df)

                    argmax = udf(lambda v: float(np.argmax(v)),
                                 returnType=T.DoubleType())
                    pred_df = pred_df.withColumn('label_pred',
                                                 argmax(pred_df.label_prob))

                    preds = pred_df.collect()
                    assert len(preds) == df.count()

                    row = preds[0]
                    label_prob = row.label_prob.toArray().tolist()
                    assert label_prob[int(row.label_pred)] == max(label_prob)
示例#4
0
def run(fn,
        args=(),
        kwargs={},
        num_proc=None,
        start_timeout=None,
        use_mpi=None,
        use_gloo=None,
        extra_mpi_args=None,
        env=None,
        stdout=None,
        stderr=None,
        verbose=1,
        nics=None):
    """
    Runs Horovod in Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
        env: Environment dictionary to use in Horovod run.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    tmout = timeout.Timeout(
        start_timeout,
        message='Timed out waiting for {activity}. Please check that you have '
        'enough resources to run all Horovod processes. Each Horovod '
        'process runs in a Spark task. You may need to increase the '
        'start_timeout parameter to a larger value if your Spark resources '
        'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     extra_mpi_args=extra_mpi_args,
                                     key=secret.make_secret_key(),
                                     timeout=tmout,
                                     nics=nics,
                                     run_func_mode=True)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            print(
                'Running %d processes (inferred from spark.default.parallelism)...'
                % num_proc)
    else:
        if settings.verbose >= 1:
            print('Running %d processes...' % num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    # start Spark driver service and launch settings.num_proc Spark tasks
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, fn, args,
                                               kwargs, settings.key,
                                               settings.nics)
    gloo_is_used = is_gloo_used(use_gloo=use_gloo,
                                use_mpi=use_mpi,
                                use_jsrun=False)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings, gloo_is_used)
    try:
        # wait for all tasks to register, notify them and initiate task-to-task address registration
        _notify_and_register_task_addresses(driver, settings)

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        settings.hosts = ','.join(
            '%s:%d' %
            (host_hash, len(driver.task_host_hash_indices()[host_hash]))
            for host_hash in host_hashes)

        # Determine the ranks to indicies
        ranks_to_indices = []
        for host_hash in host_hashes:
            ranks_to_indices += driver.task_host_hash_indices()[host_hash]
        driver.set_ranks_to_indices(ranks_to_indices)

        # Run the job
        _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in ranks_to_indices]