def test_simple_torch_module(data, general_model): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=general_model, iters=5, verbose=1).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0] assert type(res[0]['predictions']) is float
def test_early_stopping_async(data, general_model): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=general_model, iters=25, verbose=1, earlyStopPatience=2).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0]
def test_lazy(lazy_model, data): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=lazy_model, verbose=1, iters=5).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0] assert type(res[0]['predictions']) is float
def test_validation_pct(data, general_model): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=general_model, iters=10, verbose=1, partitions=2, validationPct=0.25).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0]
def test_barrier(data, general_model): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=general_model, iters=5, verbose=1, partitions=2, useBarrier=True).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0]
def test_simple_hogwild(data, sequential_model): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=sequential_model, verbose=1, mode='hogwild', iters=5).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0] assert type(res[0]['predictions']) is float
def test_mini_batch(data, general_model): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=general_model, iters=10, verbose=1, partitions=2, miniBatch=5, acquireLock=True).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0]
def test_inference(lazy_model, data): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=lazy_model, verbose=1, iters=10).fit(data) first_res = stm.transform(data).take(1) res = stm.getPytorchModel() spark_model = create_spark_torch_model(res, 'features', 'predictions') res = spark_model.transform(data).take(1) assert first_res == res
def test_classification(data): model = serialize_torch_obj(ClassificationNet(), nn.CrossEntropyLoss(), torch.optim.Adam, lr=0.001) stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=model, iters=5, verbose=1, partitions=2).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0]
def test_autoencoder(data): model = serialize_torch_obj(AutoEncoder(), nn.MSELoss(), torch.optim.Adam, lr=0.001) stm = SparkTorch(inputCol='features', predictionCol='predictions', torchObj=model, iters=5, verbose=1, partitions=2, useVectorOut=True).fit(data) res = stm.transform(data).take(1) assert 'predictions' in res[0] assert len(res[0]['predictions']) == 10