def test_restore_from_checkpoint(self): model = create_xor_model() with spark_session('test_restore_from_checkpoint') as spark: df = create_noisy_xor_data(spark) ctx = CallbackBackend() run_id = 'run01' with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( backend=ctx, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, run_id=run_id) torch_estimator._read_checkpoint = Mock( side_effect=torch_estimator._read_checkpoint) ckpt_path = store.get_checkpoint_path(run_id) assert not store.exists(ckpt_path) torch_estimator._read_checkpoint.assert_not_called() torch_estimator.fit(df) assert store.exists(ckpt_path) torch_estimator.fit(df) torch_estimator._read_checkpoint.assert_called()
def test_model_override_trainer_args(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) model = create_xor_model() with tempdir() as dir: with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, trainer_args={'stochastic_weight_avg': True}) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_legacy_fit_model(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') model = create_legacy_xor_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) loss = F.binary_cross_entropy with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, optimizer=optimizer, loss=loss, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], batch_size=4, epochs=2, verbose=2, sample_weight_col='weight') torch_model = torch_estimator.fit(df) trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_direct_parquet_train(self): with spark_session('test_direct_parquet_train') as spark: df = create_noisy_xor_data(spark) backend = CallbackBackend() with local_store() as store: store.get_train_data_path = lambda v=None: store._train_path store.get_val_data_path = lambda v=None: store._val_path with util.prepare_data(backend.num_processes(), store, df, feature_columns=['features'], label_columns=['y'], validation=0.2): model = create_xor_model() for inmemory_cache_all in [False, True]: est = hvd_spark.TorchEstimator( backend=backend, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=1, epochs=3, verbose=2, inmemory_cache_all=inmemory_cache_all) transformer = est.fit_on_parquet() predictions = transformer.transform(df) assert predictions.count() == df.count()
def test_early_stop_callback(self): from pytorch_lightning.callbacks.early_stopping import EarlyStopping with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) model = create_xor_model() early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, patience=3, verbose=True, mode='max') callbacks = [early_stop_callback] with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, callbacks=callbacks) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_train_with_inmemory_cache_all(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) model = create_xor_model() with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc= 1, # Normally inmem dataloader is for single worker training with small data store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, inmemory_cache_all=True) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_legacy_fit_model(self): model = create_legacy_xor_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) loss = F.binary_cross_entropy with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, optimizer=optimizer, loss=loss, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], batch_size=4, epochs=2, verbose=2, sample_weight_col='weight') torch_model = torch_estimator.fit(df) trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_model_checkpoint_callback(self): from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) model = create_xor_model() with tempdir() as dir: checkpoint_callback = ModelCheckpoint(dirpath=dir) callbacks = [checkpoint_callback] with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, callbacks=callbacks) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_fit_model(self): model = create_xor_model() with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2) torch_model = torch_estimator.fit(df) trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_train_with_pytorch_infinite_async_data_loader(self): from horovod.spark.data_loaders.pytorch_data_loaders import PytorchInfiniteAsyncDataLoader with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) model = create_xor_model() with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, data_loader_class=PytorchInfiniteAsyncDataLoader) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_fit_model(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') model = create_xor_model() with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, random_seed=1, verbose=2) torch_model = torch_estimator.fit(df) trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_dummy_callback(self): from pytorch_lightning.callbacks import Callback model = create_xor_model() with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) for num_proc in [1, 2]: for epochs in [2, 3]: class MyDummyCallback(Callback): def __init__(self): self.epcoh_end_counter = 0 self.train_epcoh_end_counter = 0 def on_init_start(self, trainer): print('Starting to init trainer!') def on_init_end(self, trainer): print('Trainer is initialized.') def on_epoch_end(self, trainer, model): print('A epoch ended.') self.epcoh_end_counter += 1 def on_train_epoch_end(self, trainer, model, unused=None): print('A train epoch ended.') self.train_epcoh_end_counter += 1 def on_train_end(self, trainer, model): print('Training ends') assert self.train_epcoh_end_counter == epochs dm_callback = MyDummyCallback() callbacks = [dm_callback] with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=num_proc, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=epochs, verbose=2, callbacks=callbacks) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_lr_scheduler_callback(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') from pytorch_lightning.callbacks import LearningRateMonitor class LRTestingModel(XOR): def configure_optimizers(self): optimizer = torch.optim.Adam(model.parameters(), lr=0.02) def lambda_func(epoch): return epoch // 30 lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func), 'name': 'my_logging_name' } return [optimizer], [lr_scheduler] model = LRTestingModel() with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) lr_monitor = LearningRateMonitor(logging_interval='step') callbacks = [lr_monitor] with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, callbacks=callbacks) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def test_direct_parquet_train(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') with spark_session('test_direct_parquet_train') as spark: df = create_noisy_xor_data_with_val(spark) backend = CallbackBackend() with local_store() as store: store.get_train_data_path = lambda v=None: store._train_path store.get_val_data_path = lambda v=None: store._val_path # Make sure to cover val dataloader cases for validation in [None, 'val']: with util.prepare_data(backend.num_processes(), store, df, feature_columns=['features'], label_columns=['y'], validation=validation): model = create_xor_model() for inmemory_cache_all in [False, True]: for reader_pool_type in ['process', 'thread']: est = hvd_spark.TorchEstimator( backend=backend, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=validation, batch_size=1, epochs=3, verbose=2, inmemory_cache_all=inmemory_cache_all, reader_pool_type=reader_pool_type) transformer = est.fit_on_parquet() predictions = transformer.transform(df) assert predictions.count() == df.count()
def test_legacy_restore_from_checkpoint(self): self.skipTest( 'There is a bug in current lightning version for checkpoint' 'call back. Will add this test back when it is solved.') model = create_legacy_xor_model() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) loss = nn.BCELoss() with spark_session('test_restore_from_checkpoint') as spark: df = create_noisy_xor_data(spark) ctx = CallbackBackend() run_id = 'run01' with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( backend=ctx, store=store, model=model, optimizer=optimizer, loss=loss, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, run_id=run_id) torch_estimator._read_checkpoint = mock.Mock( side_effect=torch_estimator._read_checkpoint) ckpt_path = store.get_checkpoint_path(run_id) assert not store.exists(ckpt_path) torch_estimator._read_checkpoint.assert_not_called() torch_estimator.fit(df) assert store.exists(ckpt_path) torch_estimator.fit(df) torch_estimator._read_checkpoint.assert_called()
def test_terminate_on_nan_flag(self): model = create_xor_model() with spark_session('test_terminate_on_nan_flag') as spark: df = create_noisy_xor_data(spark) with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, verbose=2, terminate_on_nan=True, profiler="pytorch") assert torch_estimator.getTerminateOnNan() == True
def train_model(args): # do not run this test for pytorch lightning below min supported verson import pytorch_lightning as pl if LooseVersion(pl.__version__) < LooseVersion(MIN_PL_VERSION): print("Skip test for pytorch_ligthning=={}, min support version is {}".format(pl.__version__, MIN_PL_VERSION)) return # Initialize SparkSession conf = SparkConf().setAppName('pytorch_spark_mnist').set('spark.sql.shuffle.partitions', '16') if args.master: conf.setMaster(args.master) elif args.num_proc: conf.setMaster('local[{}]'.format(args.num_proc)) spark = SparkSession.builder.config(conf=conf).getOrCreate() # Setup our store for intermediate data store = Store.create(args.work_dir) # Download MNIST dataset data_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2' libsvm_path = os.path.join(args.data_dir, 'mnist.bz2') if not os.path.exists(libsvm_path): subprocess.check_output(['wget', data_url, '-O', libsvm_path]) # Load dataset into a Spark DataFrame df = spark.read.format('libsvm') \ .option('numFeatures', '784') \ .load(libsvm_path) # One-hot encode labels into SparseVectors encoder = OneHotEncoder(inputCols=['label'], outputCols=['label_vec'], dropLast=False) model = encoder.fit(df) train_df = model.transform(df) # Train/test split train_df, test_df = train_df.randomSplit([0.9, 0.1]) # Define the PyTorch model without any Horovod-specific parameters class Net(LightningModule): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = x.float().reshape((-1, 1, 28, 28)) x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, -1) def configure_optimizers(self): return optim.SGD(self.parameters(), lr=0.01, momentum=0.5) def training_step(self, batch, batch_idx): if batch_idx == 0: print(f"training data batch size: {batch['label'].shape}") x, y = batch['features'], batch['label'] y_hat = self(x) loss = F.nll_loss(y_hat, y.long()) self.log('train_loss', loss) return loss def validation_step(self, batch, batch_idx): if batch_idx == 0: print(f"validation data batch size: {batch['label'].shape}") x, y = batch['features'], batch['label'] y_hat = self(x) loss = F.nll_loss(y_hat, y.long()) self.log('val_loss', loss) def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() if len(outputs) > 0 else float('inf') self.log('avg_val_loss', avg_loss) model = Net() # Train a Horovod Spark Estimator on the DataFrame backend = SparkBackend(num_proc=args.num_proc, stdout=sys.stdout, stderr=sys.stderr, prefix_output_with_timestamp=True) from pytorch_lightning.callbacks import Callback epochs = args.epochs class MyDummyCallback(Callback): def __init__(self): self.epcoh_end_counter = 0 self.train_epcoh_end_counter = 0 self.validation_epoch_end_counter = 0 def on_init_start(self, trainer): print('Starting to init trainer!') def on_init_end(self, trainer): print('Trainer is initialized.') def on_epoch_end(self, trainer, model): print('A train or eval epoch ended.') self.epcoh_end_counter += 1 def on_train_epoch_end(self, trainer, model, unused=None): print('A train epoch ended.') self.train_epcoh_end_counter += 1 def on_validation_epoch_end(self, trainer, model, unused=None): print('A val epoch ended.') self.validation_epoch_end_counter += 1 def on_train_end(self, trainer, model): print("Training ends:" f"epcoh_end_counter={self.epcoh_end_counter}, " f"train_epcoh_end_counter={self.train_epcoh_end_counter}, " f"validation_epoch_end_counter={self.validation_epoch_end_counter} \n") assert self.train_epcoh_end_counter <= epochs assert self.epcoh_end_counter == self.train_epcoh_end_counter + self.validation_epoch_end_counter callbacks = [MyDummyCallback()] # added EarlyStopping and ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min", save_top_k=1, verbose=True)) from pytorch_lightning.callbacks.early_stopping import EarlyStopping callbacks.append(EarlyStopping(monitor='val_loss', min_delta=0.001, patience=3, verbose=True, mode='min')) torch_estimator = hvd.TorchEstimator(backend=backend, store=store, model=model, input_shapes=[[-1, 1, 28, 28]], feature_cols=['features'], label_cols=['label'], batch_size=args.batch_size, epochs=args.epochs, validation=0.1, verbose=1, callbacks=callbacks, profiler="simple" if args.enable_profiler else None) torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob']) # Evaluate the model on the held-out test DataFrame pred_df = torch_model.transform(test_df) argmax = udf(lambda v: float(np.argmax(v)), returnType=T.DoubleType()) pred_df = pred_df.withColumn('label_pred', argmax(pred_df.label_prob)) evaluator = MulticlassClassificationEvaluator(predictionCol='label_pred', labelCol='label', metricName='accuracy') print('Test accuracy:', evaluator.evaluate(pred_df)) spark.stop()
def test_train_with_custom_data_module(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') from horovod.spark.data_loaders.pytorch_data_loaders import PytorchAsyncDataLoader class CustomDataModule(pl.LightningDataModule): """Custom DataModule for Lightning Estimator, using PytorchAsyncDataLoader""" def __init__(self, train_dir: str, val_dir: str, has_val: bool = True, train_batch_size: int = 32, val_batch_size: int = 32, shuffle_size: int = 100, num_reader_epochs=None, cur_shard: int = 0, shard_count: int = 1, schema_fields=None, storage_options=None, steps_per_epoch_train: int = 1, steps_per_epoch_val: int = 1, verbose=True, **kwargs): super().__init__() self.train_dir = train_dir self.val_dir = val_dir self.has_val = has_val self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.shuffle_size = shuffle_size self.num_reader_epochs = num_reader_epochs self.cur_shard = cur_shard self.shard_count = shard_count self.schema_fields = schema_fields self.storage_options = storage_options self.steps_per_epoch_train = steps_per_epoch_train self.steps_per_epoch_val = steps_per_epoch_val self.verbose = verbose def setup(self, stage=None): # Assign train/val datasets for use in dataloaders from petastorm import make_batch_reader if stage == 'fit' or stage is None: self.train_reader = make_batch_reader( self.train_dir, num_epochs=self.num_reader_epochs, cur_shard=self.cur_shard, shard_count=self.shard_count, hdfs_driver='libhdfs', schema_fields=self.schema_fields, storage_options=self.storage_options) if self.has_val: self.val_reader = make_batch_reader( self.val_dir, num_epochs=self.num_reader_epochs, cur_shard=self.cur_shard, shard_count=self.shard_count, hdfs_driver='libhdfs', schema_fields=self.schema_fields, storage_options=self.storage_options) def teardown(self, stage=None): if stage == "fit" or stage is None: if self.verbose: print("Tear down petastorm readers") self.train_reader.stop() self.train_reader.join() if self.has_val: self.val_reader.stop() self.val_reader.join() def train_dataloader(self): if self.verbose: print("Setup train dataloader") kwargs = dict(reader=self.train_reader, batch_size=self.train_batch_size, name="train dataloader", shuffling_queue_capacity=self.shuffle_size, limit_step_per_epoch=self.steps_per_epoch_train, verbose=self.verbose) return PytorchAsyncDataLoader(**kwargs) def val_dataloader(self): if not self.has_val: return None if self.verbose: print("setup val dataloader") kwargs = dict(reader=self.val_reader, batch_size=self.val_batch_size, name="val dataloader", shuffling_queue_capacity=0, limit_step_per_epoch=self.steps_per_epoch_val, verbose=self.verbose) return PytorchAsyncDataLoader(**kwargs) with spark_session('test_fit_model') as spark: df = create_noisy_xor_data(spark) model = create_xor_model() with local_store() as store: torch_estimator = hvd_spark.TorchEstimator( num_proc=2, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], validation=0.2, batch_size=4, epochs=2, data_module=CustomDataModule, verbose=2) torch_model = torch_estimator.fit(df) # TODO: Find a way to pass log metrics from remote, and assert base on the logger. trained_model = torch_model.getModel() pred = trained_model(torch.ones([1, 2], dtype=torch.int32)) assert len(pred) == 1 assert pred.dtype == torch.float32
def train_model(args): # do not run this test for pytorch lightning below min supported verson import pytorch_lightning as pl if LooseVersion(pl.__version__) < LooseVersion(MIN_PL_VERSION): print("Skip test for pytorch_ligthning=={}, min support version is {}". format(pl.__version__, MIN_PL_VERSION)) return # Initialize SparkSession conf = SparkConf().setAppName('keras_spark_mnist').set( 'spark.sql.shuffle.partitions', '16') if args.master: conf.setMaster(args.master) elif args.num_proc: conf.setMaster('local[{}]'.format(args.num_proc)) spark = SparkSession.builder.config(conf=conf).getOrCreate() # Setup our store for intermediate data store = Store.create(args.work_dir) # Download MNIST dataset data_url = 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2' libsvm_path = os.path.join(args.data_dir, 'mnist.bz2') if not os.path.exists(libsvm_path): subprocess.check_output(['wget', data_url, '-O', libsvm_path]) # Load dataset into a Spark DataFrame df = spark.read.format('libsvm') \ .option('numFeatures', '784') \ .load(libsvm_path) # One-hot encode labels into SparseVectors encoder = OneHotEncoderEstimator(inputCols=['label'], outputCols=['label_vec'], dropLast=False) model = encoder.fit(df) train_df = model.transform(df) # Train/test split train_df, test_df = train_df.randomSplit([0.9, 0.1]) # Define the PyTorch model without any Horovod-specific parameters class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = x.float() x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) loss = nn.NLLLoss() # Train a Horovod Spark Estimator on the DataFrame torch_estimator = hvd.TorchEstimator( num_proc=args.num_proc, store=store, model=model, optimizer=optimizer, loss=lambda input, target: loss(input, target.long()), input_shapes=[[-1, 1, 28, 28]], feature_cols=['features'], label_cols=['label'], batch_size=args.batch_size, epochs=args.epochs, verbose=1) torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob']) # Evaluate the model on the held-out test DataFrame pred_df = torch_model.transform(test_df) argmax = udf(lambda v: float(np.argmax(v)), returnType=T.DoubleType()) pred_df = pred_df.withColumn('label_pred', argmax(pred_df.label_prob)) evaluator = MulticlassClassificationEvaluator(predictionCol='label_pred', labelCol='label', metricName='accuracy') print('Test accuracy:', evaluator.evaluate(pred_df)) spark.stop()
def test_direct_parquet_train_with_no_val_column(self): if skip_lightning_tests: self.skipTest( 'Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: ' 'https://github.com/horovod/horovod/pull/3263') with spark_session( 'test_direct_parquet_train_with_no_val_column') as spark: df_train = create_noisy_xor_data(spark) df_val = create_noisy_xor_data(spark) def to_petastorm(df): metadata = None if util._has_vector_column(df): to_petastorm = util.to_petastorm_fn(["features", "y"], metadata) df = df.rdd.map(to_petastorm).toDF() return df df_train = to_petastorm(df_train) df_val = to_petastorm(df_val) df_train.show(1) print(df_train.count()) df_val.show(1) print(df_val.count()) backend = CallbackBackend() with local_store() as store: store.get_train_data_path = lambda v=None: store._train_path store.get_val_data_path = lambda v=None: store._val_path print(store.get_train_data_path()) print(store.get_val_data_path()) df_train \ .coalesce(4) \ .write \ .mode('overwrite') \ .parquet(store.get_train_data_path()) df_val \ .coalesce(4) \ .write \ .mode('overwrite') \ .parquet(store.get_val_data_path()) model = create_xor_model() inmemory_cache_all = True reader_pool_type = 'process' est = hvd_spark.TorchEstimator( backend=backend, store=store, model=model, input_shapes=[[-1, 2]], feature_cols=['features'], label_cols=['y'], batch_size=64, epochs=2, verbose=2, inmemory_cache_all=inmemory_cache_all, reader_pool_type=reader_pool_type) # set validation to any random strings would work. est.setValidation("True") transformer = est.fit_on_parquet() predictions = transformer.transform(df_train) assert predictions.count() == df_train.count()