Example #1
0
    def test_restore_from_checkpoint(self):

        model = create_xor_model()

        with spark_session('test_restore_from_checkpoint') as spark:
            df = create_noisy_xor_data(spark)

            ctx = CallbackBackend()

            run_id = 'run01'
            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    backend=ctx,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    run_id=run_id)

                torch_estimator._read_checkpoint = Mock(
                    side_effect=torch_estimator._read_checkpoint)

                ckpt_path = store.get_checkpoint_path(run_id)
                assert not store.exists(ckpt_path)
                torch_estimator._read_checkpoint.assert_not_called()
                torch_estimator.fit(df)

                assert store.exists(ckpt_path)
                torch_estimator.fit(df)
                torch_estimator._read_checkpoint.assert_called()
Example #2
0
    def test_model_override_trainer_args(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)
            model = create_xor_model()

            with tempdir() as dir:

                with local_store() as store:
                    torch_estimator = hvd_spark.TorchEstimator(
                        num_proc=2,
                        store=store,
                        model=model,
                        input_shapes=[[-1, 2]],
                        feature_cols=['features'],
                        label_cols=['y'],
                        validation=0.2,
                        batch_size=4,
                        epochs=2,
                        verbose=2,
                        trainer_args={'stochastic_weight_avg': True})

                    torch_model = torch_estimator.fit(df)

                    # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                    trained_model = torch_model.getModel()
                    pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                    assert len(pred) == 1
                    assert pred.dtype == torch.float32
Example #3
0
    def test_legacy_fit_model(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        model = create_legacy_xor_model()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        loss = F.binary_cross_entropy

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    optimizer=optimizer,
                    loss=loss,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    sample_weight_col='weight')

                torch_model = torch_estimator.fit(df)

                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #4
0
    def test_direct_parquet_train(self):
        with spark_session('test_direct_parquet_train') as spark:
            df = create_noisy_xor_data(spark)

            backend = CallbackBackend()
            with local_store() as store:
                store.get_train_data_path = lambda v=None: store._train_path
                store.get_val_data_path = lambda v=None: store._val_path

                with util.prepare_data(backend.num_processes(),
                                       store,
                                       df,
                                       feature_columns=['features'],
                                       label_columns=['y'],
                                       validation=0.2):
                    model = create_xor_model()

                    for inmemory_cache_all in [False, True]:
                        est = hvd_spark.TorchEstimator(
                            backend=backend,
                            store=store,
                            model=model,
                            input_shapes=[[-1, 2]],
                            feature_cols=['features'],
                            label_cols=['y'],
                            validation=0.2,
                            batch_size=1,
                            epochs=3,
                            verbose=2,
                            inmemory_cache_all=inmemory_cache_all)

                        transformer = est.fit_on_parquet()
                        predictions = transformer.transform(df)
                        assert predictions.count() == df.count()
Example #5
0
    def test_early_stop_callback(self):
        from pytorch_lightning.callbacks.early_stopping import EarlyStopping

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)
            model = create_xor_model()

            early_stop_callback = EarlyStopping(monitor='val_loss',
                                                min_delta=0.00,
                                                patience=3,
                                                verbose=True,
                                                mode='max')
            callbacks = [early_stop_callback]

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    callbacks=callbacks)

                torch_model = torch_estimator.fit(df)

                # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #6
0
    def test_train_with_inmemory_cache_all(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)
            model = create_xor_model()

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=
                    1,  # Normally inmem dataloader is for single worker training with small data
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    inmemory_cache_all=True)

                torch_model = torch_estimator.fit(df)

                # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #7
0
    def test_legacy_fit_model(self):
        model = create_legacy_xor_model()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        loss = F.binary_cross_entropy

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    optimizer=optimizer,
                    loss=loss,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    sample_weight_col='weight')

                torch_model = torch_estimator.fit(df)

                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #8
0
    def test_model_checkpoint_callback(self):
        from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)
            model = create_xor_model()

            with tempdir() as dir:
                checkpoint_callback = ModelCheckpoint(dirpath=dir)
                callbacks = [checkpoint_callback]

                with local_store() as store:
                    torch_estimator = hvd_spark.TorchEstimator(
                        num_proc=2,
                        store=store,
                        model=model,
                        input_shapes=[[-1, 2]],
                        feature_cols=['features'],
                        label_cols=['y'],
                        validation=0.2,
                        batch_size=4,
                        epochs=2,
                        verbose=2,
                        callbacks=callbacks)

                    torch_model = torch_estimator.fit(df)

                    # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                    trained_model = torch_model.getModel()
                    pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                    assert len(pred) == 1
                    assert pred.dtype == torch.float32
Example #9
0
    def test_fit_model(self):
        model = create_xor_model()

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2)

                torch_model = torch_estimator.fit(df)

                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #10
0
    def test_train_with_pytorch_infinite_async_data_loader(self):
        from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInfiniteAsyncDataLoader

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)
            model = create_xor_model()

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    data_loader_class=PytorchInfiniteAsyncDataLoader)

                torch_model = torch_estimator.fit(df)

                # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #11
0
    def test_fit_model(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        model = create_xor_model()

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    random_seed=1,
                    verbose=2)

                torch_model = torch_estimator.fit(df)

                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #12
0
    def test_dummy_callback(self):
        from pytorch_lightning.callbacks import Callback
        model = create_xor_model()

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)

            for num_proc in [1, 2]:
                for epochs in [2, 3]:

                    class MyDummyCallback(Callback):
                        def __init__(self):
                            self.epcoh_end_counter = 0
                            self.train_epcoh_end_counter = 0

                        def on_init_start(self, trainer):
                            print('Starting to init trainer!')

                        def on_init_end(self, trainer):
                            print('Trainer is initialized.')

                        def on_epoch_end(self, trainer, model):
                            print('A epoch ended.')
                            self.epcoh_end_counter += 1

                        def on_train_epoch_end(self, trainer, model, unused=None):
                            print('A train epoch ended.')
                            self.train_epcoh_end_counter += 1

                        def on_train_end(self, trainer, model):
                            print('Training ends')
                            assert self.train_epcoh_end_counter == epochs

                    dm_callback = MyDummyCallback()
                    callbacks = [dm_callback]

                    with local_store() as store:
                        torch_estimator = hvd_spark.TorchEstimator(
                            num_proc=num_proc,
                            store=store,
                            model=model,
                            input_shapes=[[-1, 2]],
                            feature_cols=['features'],
                            label_cols=['y'],
                            validation=0.2,
                            batch_size=4,
                            epochs=epochs,
                            verbose=2,
                            callbacks=callbacks)

                        torch_model = torch_estimator.fit(df)

                        # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                        trained_model = torch_model.getModel()
                        pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                        assert len(pred) == 1
                        assert pred.dtype == torch.float32
Example #13
0
    def test_lr_scheduler_callback(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        from pytorch_lightning.callbacks import LearningRateMonitor

        class LRTestingModel(XOR):
            def configure_optimizers(self):
                optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

                def lambda_func(epoch):
                    return epoch // 30

                lr_scheduler = {
                    'scheduler':
                    torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda_func),
                    'name':
                    'my_logging_name'
                }
                return [optimizer], [lr_scheduler]

        model = LRTestingModel()

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)

            lr_monitor = LearningRateMonitor(logging_interval='step')
            callbacks = [lr_monitor]

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    callbacks=callbacks)

                torch_model = torch_estimator.fit(df)

                # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #14
0
    def test_direct_parquet_train(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        with spark_session('test_direct_parquet_train') as spark:
            df = create_noisy_xor_data_with_val(spark)

            backend = CallbackBackend()
            with local_store() as store:
                store.get_train_data_path = lambda v=None: store._train_path
                store.get_val_data_path = lambda v=None: store._val_path

                # Make sure to cover val dataloader cases
                for validation in [None, 'val']:
                    with util.prepare_data(backend.num_processes(),
                                           store,
                                           df,
                                           feature_columns=['features'],
                                           label_columns=['y'],
                                           validation=validation):
                        model = create_xor_model()

                        for inmemory_cache_all in [False, True]:
                            for reader_pool_type in ['process', 'thread']:
                                est = hvd_spark.TorchEstimator(
                                    backend=backend,
                                    store=store,
                                    model=model,
                                    input_shapes=[[-1, 2]],
                                    feature_cols=['features'],
                                    label_cols=['y'],
                                    validation=validation,
                                    batch_size=1,
                                    epochs=3,
                                    verbose=2,
                                    inmemory_cache_all=inmemory_cache_all,
                                    reader_pool_type=reader_pool_type)

                                transformer = est.fit_on_parquet()
                                predictions = transformer.transform(df)
                                assert predictions.count() == df.count()
Example #15
0
    def test_legacy_restore_from_checkpoint(self):
        self.skipTest(
            'There is a bug in current lightning version for checkpoint'
            'call back. Will add this test back when it is solved.')

        model = create_legacy_xor_model()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        loss = nn.BCELoss()

        with spark_session('test_restore_from_checkpoint') as spark:
            df = create_noisy_xor_data(spark)

            ctx = CallbackBackend()

            run_id = 'run01'
            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    backend=ctx,
                    store=store,
                    model=model,
                    optimizer=optimizer,
                    loss=loss,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    run_id=run_id)

                torch_estimator._read_checkpoint = mock.Mock(
                    side_effect=torch_estimator._read_checkpoint)

                ckpt_path = store.get_checkpoint_path(run_id)
                assert not store.exists(ckpt_path)
                torch_estimator._read_checkpoint.assert_not_called()
                torch_estimator.fit(df)

                assert store.exists(ckpt_path)
                torch_estimator.fit(df)
                torch_estimator._read_checkpoint.assert_called()
Example #16
0
    def test_terminate_on_nan_flag(self):
        model = create_xor_model()

        with spark_session('test_terminate_on_nan_flag') as spark:
            df = create_noisy_xor_data(spark)

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    verbose=2,
                    terminate_on_nan=True,
                    profiler="pytorch")
                assert torch_estimator.getTerminateOnNan() == True
def train_model(args):
    # do not run this test for pytorch lightning below min supported verson
    import pytorch_lightning as pl
    if LooseVersion(pl.__version__) < LooseVersion(MIN_PL_VERSION):
        print("Skip test for pytorch_ligthning=={}, min support version is {}".format(pl.__version__, MIN_PL_VERSION))
        return

    # Initialize SparkSession
    conf = SparkConf().setAppName('pytorch_spark_mnist').set('spark.sql.shuffle.partitions', '16')
    if args.master:
        conf.setMaster(args.master)
    elif args.num_proc:
        conf.setMaster('local[{}]'.format(args.num_proc))
    spark = SparkSession.builder.config(conf=conf).getOrCreate()

    # Setup our store for intermediate data
    store = Store.create(args.work_dir)

    # Download MNIST dataset
    data_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2'
    libsvm_path = os.path.join(args.data_dir, 'mnist.bz2')
    if not os.path.exists(libsvm_path):
        subprocess.check_output(['wget', data_url, '-O', libsvm_path])

    # Load dataset into a Spark DataFrame
    df = spark.read.format('libsvm') \
        .option('numFeatures', '784') \
        .load(libsvm_path)

    # One-hot encode labels into SparseVectors
    encoder = OneHotEncoder(inputCols=['label'],
                            outputCols=['label_vec'],
                            dropLast=False)
    model = encoder.fit(df)
    train_df = model.transform(df)

    # Train/test split
    train_df, test_df = train_df.randomSplit([0.9, 0.1])

    # Define the PyTorch model without any Horovod-specific parameters
    class Net(LightningModule):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(320, 50)
            self.fc2 = nn.Linear(50, 10)

        def forward(self, x):
            x = x.float().reshape((-1, 1, 28, 28))
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 320)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            x = self.fc2(x)
            return F.log_softmax(x, -1)

        def configure_optimizers(self):
            return optim.SGD(self.parameters(), lr=0.01, momentum=0.5)

        def training_step(self, batch, batch_idx):
            if batch_idx == 0:
                print(f"training data batch size: {batch['label'].shape}")
            x, y = batch['features'], batch['label']
            y_hat = self(x)
            loss = F.nll_loss(y_hat, y.long())
            self.log('train_loss', loss)
            return loss

        def validation_step(self, batch, batch_idx):
            if batch_idx == 0:
                print(f"validation data batch size: {batch['label'].shape}")
            x, y = batch['features'], batch['label']
            y_hat = self(x)
            loss = F.nll_loss(y_hat, y.long())
            self.log('val_loss', loss)

        def validation_epoch_end(self, outputs):
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() if len(outputs) > 0 else float('inf')
            self.log('avg_val_loss', avg_loss)

    model = Net()

    # Train a Horovod Spark Estimator on the DataFrame
    backend = SparkBackend(num_proc=args.num_proc,
                           stdout=sys.stdout, stderr=sys.stderr,
                           prefix_output_with_timestamp=True)

    from pytorch_lightning.callbacks import Callback

    epochs = args.epochs

    class MyDummyCallback(Callback):
        def __init__(self):
            self.epcoh_end_counter = 0
            self.train_epcoh_end_counter = 0
            self.validation_epoch_end_counter = 0

        def on_init_start(self, trainer):
            print('Starting to init trainer!')

        def on_init_end(self, trainer):
            print('Trainer is initialized.')

        def on_epoch_end(self, trainer, model):
            print('A train or eval epoch ended.')
            self.epcoh_end_counter += 1

        def on_train_epoch_end(self, trainer, model, unused=None):
            print('A train epoch ended.')
            self.train_epcoh_end_counter += 1

        def on_validation_epoch_end(self, trainer, model, unused=None):
            print('A val epoch ended.')
            self.validation_epoch_end_counter += 1

        def on_train_end(self, trainer, model):
            print("Training ends:"
                  f"epcoh_end_counter={self.epcoh_end_counter}, "
                  f"train_epcoh_end_counter={self.train_epcoh_end_counter}, "
                  f"validation_epoch_end_counter={self.validation_epoch_end_counter} \n")
            assert self.train_epcoh_end_counter <= epochs
            assert self.epcoh_end_counter == self.train_epcoh_end_counter + self.validation_epoch_end_counter

    callbacks = [MyDummyCallback()]

    # added EarlyStopping and ModelCheckpoint
    from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
    callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min",
                                     save_top_k=1, verbose=True))

    from pytorch_lightning.callbacks.early_stopping import EarlyStopping
    callbacks.append(EarlyStopping(monitor='val_loss',
                                   min_delta=0.001,
                                   patience=3,
                                   verbose=True,
                                   mode='min'))

    torch_estimator = hvd.TorchEstimator(backend=backend,
                                         store=store,
                                         model=model,
                                         input_shapes=[[-1, 1, 28, 28]],
                                         feature_cols=['features'],
                                         label_cols=['label'],
                                         batch_size=args.batch_size,
                                         epochs=args.epochs,
                                         validation=0.1,
                                         verbose=1,
                                         callbacks=callbacks,
                                         profiler="simple" if args.enable_profiler else None)

    torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob'])

    # Evaluate the model on the held-out test DataFrame
    pred_df = torch_model.transform(test_df)

    argmax = udf(lambda v: float(np.argmax(v)), returnType=T.DoubleType())
    pred_df = pred_df.withColumn('label_pred', argmax(pred_df.label_prob))
    evaluator = MulticlassClassificationEvaluator(predictionCol='label_pred', labelCol='label', metricName='accuracy')
    print('Test accuracy:', evaluator.evaluate(pred_df))

    spark.stop()
Example #18
0
    def test_train_with_custom_data_module(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        from horovod.spark.data_loaders.pytorch_data_loaders import PytorchAsyncDataLoader

        class CustomDataModule(pl.LightningDataModule):
            """Custom DataModule for Lightning Estimator, using PytorchAsyncDataLoader"""
            def __init__(self,
                         train_dir: str,
                         val_dir: str,
                         has_val: bool = True,
                         train_batch_size: int = 32,
                         val_batch_size: int = 32,
                         shuffle_size: int = 100,
                         num_reader_epochs=None,
                         cur_shard: int = 0,
                         shard_count: int = 1,
                         schema_fields=None,
                         storage_options=None,
                         steps_per_epoch_train: int = 1,
                         steps_per_epoch_val: int = 1,
                         verbose=True,
                         **kwargs):
                super().__init__()
                self.train_dir = train_dir
                self.val_dir = val_dir
                self.has_val = has_val
                self.train_batch_size = train_batch_size
                self.val_batch_size = val_batch_size
                self.shuffle_size = shuffle_size
                self.num_reader_epochs = num_reader_epochs
                self.cur_shard = cur_shard
                self.shard_count = shard_count
                self.schema_fields = schema_fields
                self.storage_options = storage_options
                self.steps_per_epoch_train = steps_per_epoch_train
                self.steps_per_epoch_val = steps_per_epoch_val
                self.verbose = verbose

            def setup(self, stage=None):
                # Assign train/val datasets for use in dataloaders
                from petastorm import make_batch_reader
                if stage == 'fit' or stage is None:
                    self.train_reader = make_batch_reader(
                        self.train_dir,
                        num_epochs=self.num_reader_epochs,
                        cur_shard=self.cur_shard,
                        shard_count=self.shard_count,
                        hdfs_driver='libhdfs',
                        schema_fields=self.schema_fields,
                        storage_options=self.storage_options)
                    if self.has_val:
                        self.val_reader = make_batch_reader(
                            self.val_dir,
                            num_epochs=self.num_reader_epochs,
                            cur_shard=self.cur_shard,
                            shard_count=self.shard_count,
                            hdfs_driver='libhdfs',
                            schema_fields=self.schema_fields,
                            storage_options=self.storage_options)

            def teardown(self, stage=None):
                if stage == "fit" or stage is None:
                    if self.verbose:
                        print("Tear down petastorm readers")
                    self.train_reader.stop()
                    self.train_reader.join()
                    if self.has_val:
                        self.val_reader.stop()
                        self.val_reader.join()

            def train_dataloader(self):
                if self.verbose:
                    print("Setup train dataloader")
                kwargs = dict(reader=self.train_reader,
                              batch_size=self.train_batch_size,
                              name="train dataloader",
                              shuffling_queue_capacity=self.shuffle_size,
                              limit_step_per_epoch=self.steps_per_epoch_train,
                              verbose=self.verbose)
                return PytorchAsyncDataLoader(**kwargs)

            def val_dataloader(self):
                if not self.has_val:
                    return None
                if self.verbose:
                    print("setup val dataloader")
                kwargs = dict(reader=self.val_reader,
                              batch_size=self.val_batch_size,
                              name="val dataloader",
                              shuffling_queue_capacity=0,
                              limit_step_per_epoch=self.steps_per_epoch_val,
                              verbose=self.verbose)
                return PytorchAsyncDataLoader(**kwargs)

        with spark_session('test_fit_model') as spark:
            df = create_noisy_xor_data(spark)
            model = create_xor_model()

            with local_store() as store:
                torch_estimator = hvd_spark.TorchEstimator(
                    num_proc=2,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    validation=0.2,
                    batch_size=4,
                    epochs=2,
                    data_module=CustomDataModule,
                    verbose=2)

                torch_model = torch_estimator.fit(df)

                # TODO: Find a way to pass log metrics from remote, and assert base on the logger.
                trained_model = torch_model.getModel()
                pred = trained_model(torch.ones([1, 2], dtype=torch.int32))
                assert len(pred) == 1
                assert pred.dtype == torch.float32
Example #19
0
def train_model(args):
    # do not run this test for pytorch lightning below min supported verson
    import pytorch_lightning as pl
    if LooseVersion(pl.__version__) < LooseVersion(MIN_PL_VERSION):
        print("Skip test for pytorch_ligthning=={}, min support version is {}".
              format(pl.__version__, MIN_PL_VERSION))
        return

    # Initialize SparkSession
    conf = SparkConf().setAppName('keras_spark_mnist').set(
        'spark.sql.shuffle.partitions', '16')
    if args.master:
        conf.setMaster(args.master)
    elif args.num_proc:
        conf.setMaster('local[{}]'.format(args.num_proc))
    spark = SparkSession.builder.config(conf=conf).getOrCreate()

    # Setup our store for intermediate data
    store = Store.create(args.work_dir)

    # Download MNIST dataset
    data_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2'
    libsvm_path = os.path.join(args.data_dir, 'mnist.bz2')
    if not os.path.exists(libsvm_path):
        subprocess.check_output(['wget', data_url, '-O', libsvm_path])

    # Load dataset into a Spark DataFrame
    df = spark.read.format('libsvm') \
        .option('numFeatures', '784') \
        .load(libsvm_path)

    # One-hot encode labels into SparseVectors
    encoder = OneHotEncoderEstimator(inputCols=['label'],
                                     outputCols=['label_vec'],
                                     dropLast=False)
    model = encoder.fit(df)
    train_df = model.transform(df)

    # Train/test split
    train_df, test_df = train_df.randomSplit([0.9, 0.1])

    # Define the PyTorch model without any Horovod-specific parameters
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(320, 50)
            self.fc2 = nn.Linear(50, 10)

        def forward(self, x):
            x = x.float()
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 320)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            x = self.fc2(x)
            return F.log_softmax(x)

    model = Net()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    loss = nn.NLLLoss()

    # Train a Horovod Spark Estimator on the DataFrame
    torch_estimator = hvd.TorchEstimator(
        num_proc=args.num_proc,
        store=store,
        model=model,
        optimizer=optimizer,
        loss=lambda input, target: loss(input, target.long()),
        input_shapes=[[-1, 1, 28, 28]],
        feature_cols=['features'],
        label_cols=['label'],
        batch_size=args.batch_size,
        epochs=args.epochs,
        verbose=1)

    torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob'])

    # Evaluate the model on the held-out test DataFrame
    pred_df = torch_model.transform(test_df)
    argmax = udf(lambda v: float(np.argmax(v)), returnType=T.DoubleType())
    pred_df = pred_df.withColumn('label_pred', argmax(pred_df.label_prob))
    evaluator = MulticlassClassificationEvaluator(predictionCol='label_pred',
                                                  labelCol='label',
                                                  metricName='accuracy')
    print('Test accuracy:', evaluator.evaluate(pred_df))

    spark.stop()
Example #20
0
    def test_direct_parquet_train_with_no_val_column(self):
        if skip_lightning_tests:
            self.skipTest(
                'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
                'https://github.com/horovod/horovod/pull/3263')

        with spark_session(
                'test_direct_parquet_train_with_no_val_column') as spark:
            df_train = create_noisy_xor_data(spark)
            df_val = create_noisy_xor_data(spark)

            def to_petastorm(df):
                metadata = None
                if util._has_vector_column(df):
                    to_petastorm = util.to_petastorm_fn(["features", "y"],
                                                        metadata)
                    df = df.rdd.map(to_petastorm).toDF()
                return df

            df_train = to_petastorm(df_train)
            df_val = to_petastorm(df_val)

            df_train.show(1)
            print(df_train.count())
            df_val.show(1)
            print(df_val.count())

            backend = CallbackBackend()
            with local_store() as store:
                store.get_train_data_path = lambda v=None: store._train_path
                store.get_val_data_path = lambda v=None: store._val_path

                print(store.get_train_data_path())
                print(store.get_val_data_path())

                df_train \
                    .coalesce(4) \
                    .write \
                    .mode('overwrite') \
                    .parquet(store.get_train_data_path())

                df_val \
                    .coalesce(4) \
                    .write \
                    .mode('overwrite') \
                    .parquet(store.get_val_data_path())

                model = create_xor_model()

                inmemory_cache_all = True
                reader_pool_type = 'process'
                est = hvd_spark.TorchEstimator(
                    backend=backend,
                    store=store,
                    model=model,
                    input_shapes=[[-1, 2]],
                    feature_cols=['features'],
                    label_cols=['y'],
                    batch_size=64,
                    epochs=2,
                    verbose=2,
                    inmemory_cache_all=inmemory_cache_all,
                    reader_pool_type=reader_pool_type)

                # set validation to any random strings would work.
                est.setValidation("True")

                transformer = est.fit_on_parquet()

                predictions = transformer.transform(df_train)
                assert predictions.count() == df_train.count()