Example #1
0
def test_model_parameters(data, network_with_params):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=network_with_params,
                     verbose=1,
                     iters=5).fit(data)

    py_model = stm.getPytorchModel()
    assert py_model.fc1 is not None
    assert py_model.fc2 is not None
Example #2
0
def test_simple_torch_module(data, general_model):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=general_model,
                     iters=5,
                     verbose=1).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
    assert type(res[0]['predictions']) is float
Example #3
0
def test_early_stopping_async(data, general_model):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=general_model,
                     iters=25,
                     verbose=1,
                     earlyStopPatience=2).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
Example #4
0
def test_lazy(lazy_model, data):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=lazy_model,
                     verbose=1,
                     iters=5).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
    assert type(res[0]['predictions']) is float
Example #5
0
def test_validation_pct(data, general_model):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=general_model,
                     iters=10,
                     verbose=1,
                     partitions=2,
                     validationPct=0.25).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
Example #6
0
def test_barrier(data, general_model):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=general_model,
                     iters=5,
                     verbose=1,
                     partitions=2,
                     useBarrier=True).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
Example #7
0
def test_simple_hogwild(data, sequential_model):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=sequential_model,
                     verbose=1,
                     mode='hogwild',
                     iters=5).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
    assert type(res[0]['predictions']) is float
Example #8
0
def test_mini_batch(data, general_model):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=general_model,
                     iters=10,
                     verbose=1,
                     partitions=2,
                     miniBatch=5,
                     acquireLock=True).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
Example #9
0
def test_inference(lazy_model, data):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=lazy_model,
                     verbose=1,
                     iters=10).fit(data)

    first_res = stm.transform(data).take(1)

    res = stm.getPytorchModel()
    spark_model = create_spark_torch_model(res, 'features', 'predictions')

    res = spark_model.transform(data).take(1)
    assert first_res == res
Example #10
0
def test_classification(data):
    model = serialize_torch_obj(ClassificationNet(),
                                nn.CrossEntropyLoss(),
                                torch.optim.Adam,
                                lr=0.001)

    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=model,
                     iters=5,
                     verbose=1,
                     partitions=2).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
Example #11
0
def test_autoencoder(data):
    model = serialize_torch_obj(AutoEncoder(),
                                nn.MSELoss(),
                                torch.optim.Adam,
                                lr=0.001)

    stm = SparkTorch(inputCol='features',
                     predictionCol='predictions',
                     torchObj=model,
                     iters=5,
                     verbose=1,
                     partitions=2,
                     useVectorOut=True).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
    assert len(res[0]['predictions']) == 10