def test_fit_model(self): if sys.version_info < (3, 0, 0) and is_gloo_used(): self.skipTest( 'Horovod on Spark over Gloo only supported on Python3') model = create_xor_model() optimizer = tf.keras.optimizers.SGD(lr=0.1) loss = 'binary_crossentropy' with spark_session('test_fit_model') as spark: df = create_xor_data(spark) with local_store() as store: keras_estimator = hvd.KerasEstimator(num_proc=2, store=store, model=model, optimizer=optimizer, loss=loss, feature_cols=['features'], label_cols=['y'], batch_size=1, epochs=3, verbose=2) keras_model = keras_estimator.fit(df) trained_model = keras_model.getModel() pred = trained_model.predict( [np.ones([1, 2], dtype=np.float32)]) assert len(pred) == 1 assert pred.dtype == np.float32
def test_fit_model(self): if sys.version_info < (3, 0, 0) and is_gloo_used(): self.skipTest( 'Horovod on Spark over Gloo only supported on Python3') model = create_xor_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) loss = F.binary_cross_entropy with spark_session('test_fit_model') as spark: df = create_xor_data(spark) with local_store() as store: torch_estimator = hvd.TorchEstimator( num_proc=2, store=store, model=model, optimizer=optimizer, loss=loss, input_shapes=[[2]], feature_cols=['features'], label_cols=['y'], batch_size=1, epochs=3, verbose=2, sample_weight_col='weight') torch_model = torch_estimator.fit(df) trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_fit_model_multiclass(self): if sys.version_info < (3, 0, 0) and is_gloo_used(): self.skipTest( 'Horovod on Spark over Gloo only supported on Python3') model = create_mnist_model() optimizer = tf.keras.optimizers.Adadelta(1.0) loss = tf.keras.losses.categorical_crossentropy for num_cores in [2, constants.TOTAL_BUFFER_MEMORY_CAP_GIB + 1]: with spark_session('test_fit_model_multiclass', cores=num_cores) as spark: df = create_mnist_data(spark) with local_store() as store: keras_estimator = hvd.KerasEstimator( num_proc=num_cores, store=store, model=model, optimizer=optimizer, loss=loss, metrics=['accuracy'], feature_cols=['features'], label_cols=['label_vec'], batch_size=2, epochs=2, verbose=2) keras_model = keras_estimator.fit(df).setOutputCols( ['label_prob']) pred_df = keras_model.transform(df) argmax = udf(lambda v: float(np.argmax(v)), returnType=T.DoubleType()) pred_df = pred_df.withColumn('label_pred', argmax(pred_df.label_prob)) preds = pred_df.collect() assert len(preds) == df.count() row = preds[0] label_prob = row.label_prob.toArray().tolist() assert label_prob[int(row.label_pred)] == max(label_prob)
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, use_mpi=None, use_gloo=None, extra_mpi_args=None, env=None, stdout=None, stderr=None, verbose=1, nics=None): """ Runs Horovod in Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args. env: Environment dictionary to use in Horovod run. stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout. stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr. verbose: Debug output verbosity (0-2). Defaults to 1. nics: List of NICs for tcp network communication. Returns: List of results returned by running `fn` on each rank. """ if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) # nics needs to be a set if nics and not isinstance(nics, set): nics = set(nics) tmout = timeout.Timeout( start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_settings.Settings(verbose=verbose, extra_mpi_args=extra_mpi_args, key=secret.make_secret_key(), timeout=tmout, nics=nics, run_func_mode=True) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if num_proc is None: num_proc = spark_context.defaultParallelism if settings.verbose >= 1: print( 'Running %d processes (inferred from spark.default.parallelism)...' % num_proc) else: if settings.verbose >= 1: print('Running %d processes...' % num_proc) settings.num_proc = num_proc result_queue = queue.Queue(1) # start Spark driver service and launch settings.num_proc Spark tasks spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs, settings.key, settings.nics) gloo_is_used = is_gloo_used(use_gloo=use_gloo, use_mpi=use_mpi, use_jsrun=False) spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings, gloo_is_used) try: # wait for all tasks to register, notify them and initiate task-to-task address registration _notify_and_register_task_addresses(driver, settings) # Determine the index grouping based on host hashes. # Barrel shift until index 0 is in the first host. host_hashes = list(driver.task_host_hash_indices().keys()) host_hashes.sort() while 0 not in driver.task_host_hash_indices()[host_hashes[0]]: host_hashes = host_hashes[1:] + host_hashes[:1] settings.hosts = ','.join( '%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash])) for host_hash in host_hashes) # Determine the ranks to indicies ranks_to_indices = [] for host_hash in host_hashes: ranks_to_indices += driver.task_host_hash_indices()[host_hash] driver.set_ranks_to_indices(ranks_to_indices) # Run the job _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in ranks_to_indices]