示例#1
0
    def test_happy_run(self):
        from pyspark import SparkConf
        from pyspark.sql import SparkSession
        conf = SparkConf().setAppName("test_happy_run").setMaster("local[2]")
        spark = SparkSession \
            .builder \
            .config(conf=conf) \
            .getOrCreate()

        def fn():
            hvd.init()
            res = hvd.allgather(torch.tensor([hvd.rank()])).tolist()
            return res, hvd.rank()

        try:
            res = horovod.spark.run(fn, env={'PATH': os.environ.get('PATH')})
            self.assertListEqual([([0, 1], 0), ([0, 1], 1)], res)
        finally:
            spark.stop()
示例#2
0
    def test_timeout(self):
        from pyspark import SparkConf
        from pyspark.sql import SparkSession
        conf = SparkConf().setAppName("test_happy_run").setMaster("local[2]")
        spark = SparkSession \
            .builder \
            .config(conf=conf) \
            .getOrCreate()

        try:
            horovod.spark.run(None, num_proc=4, start_timeout=5,
                              env={'PATH': os.environ.get('PATH')})
            self.fail("Timeout expected")
        except Exception as e:
            print('Caught exception:')
            traceback.print_exc()
            self.assertIn("Timed out waiting for Spark tasks to start", str(e))
        finally:
            spark.stop()
示例#3
0
    def test_mpirun_not_found(self):
        from pyspark import SparkConf
        from pyspark.sql import SparkSession
        conf = SparkConf().setAppName("test_happy_run").setMaster("local[2]")
        spark = SparkSession \
            .builder \
            .config(conf=conf) \
            .getOrCreate()

        start = time.time()
        try:
            horovod.spark.run(None, env={'PATH': '/nonexistent'})
            self.fail("Failure expected")
        except Exception as e:
            print('Caught exception:')
            traceback.print_exc()
            self.assertIn("mpirun exited with code", str(e))
            self.assertLessEqual(time.time() - start, 10, "Failure propagation took too long")
        finally:
            spark.stop()
示例#4
0
    conf = set_gpu_conf(conf)
    spark = SparkSession.builder.config(conf=conf).getOrCreate()

    # Horovod: run training.
    history, best_model_bytes = \
        horovod.spark.run(train_fn, args=(model_bytes,), num_proc=args.num_proc, verbose=2)[0]

    best_val_rmspe = min(history['val_exp_rmspe'])
    print('Best RMSPE: %f' % best_val_rmspe)

    # Write checkpoint.
    with open(args.local_checkpoint_file, 'wb') as f:
        f.write(best_model_bytes)
    print('Written checkpoint to %s' % args.local_checkpoint_file)

    spark.stop()

    # ================ #
    # FINAL PREDICTION #
    # ================ #

    print('================')
    print('Final prediction')
    print('================')

    # Create Spark session for prediction.
    conf = SparkConf().setAppName('prediction') \
        .setExecutorEnv('LD_LIBRARY_PATH', os.environ.get('LD_LIBRARY_PATH')) \
        .setExecutorEnv('PATH', os.environ.get('PATH'))

    if GPU_INFERENCE_ENABLED: