def test_model_parameters(data, network_with_params): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=network_with_params, verbose=1, iters=5).fit(data) py_model = stm.getPytorchModel() assert py_model.fc1 is not None assert py_model.fc2 is not None
def test_inference(lazy_model, data): stm = SparkTorch(inputCol='features', labelCol='label', predictionCol='predictions', torchObj=lazy_model, verbose=1, iters=10).fit(data) first_res = stm.transform(data).take(1) res = stm.getPytorchModel() spark_model = create_spark_torch_model(res, 'features', 'predictions') res = spark_model.transform(data).take(1) assert first_res == res