예제 #1
0
def test_model_parameters(data, network_with_params):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=network_with_params,
                     verbose=1,
                     iters=5).fit(data)

    py_model = stm.getPytorchModel()
    assert py_model.fc1 is not None
    assert py_model.fc2 is not None
예제 #2
0
def test_inference(lazy_model, data):
    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=lazy_model,
                     verbose=1,
                     iters=10).fit(data)

    first_res = stm.transform(data).take(1)

    res = stm.getPytorchModel()
    spark_model = create_spark_torch_model(res, 'features', 'predictions')

    res = spark_model.transform(data).take(1)
    assert first_res == res