Ejemplo n.º 1
0
def test_classification(data):
    model = serialize_torch_obj(ClassificationNet(),
                                nn.CrossEntropyLoss(),
                                torch.optim.Adam,
                                lr=0.001)

    stm = SparkTorch(inputCol='features',
                     labelCol='label',
                     predictionCol='predictions',
                     torchObj=model,
                     iters=5,
                     verbose=1,
                     partitions=2).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
Ejemplo n.º 2
0
def test_autoencoder(data):
    model = serialize_torch_obj(AutoEncoder(),
                                nn.MSELoss(),
                                torch.optim.Adam,
                                lr=0.001)

    stm = SparkTorch(inputCol='features',
                     predictionCol='predictions',
                     torchObj=model,
                     iters=5,
                     verbose=1,
                     partitions=2,
                     useVectorOut=True).fit(data)

    res = stm.transform(data).take(1)
    assert 'predictions' in res[0]
    assert len(res[0]['predictions']) == 10
Ejemplo n.º 3
0
def general_model():
    model = serialize_torch_obj(Net(),
                                nn.MSELoss(),
                                torch.optim.Adam,
                                lr=0.001)
    return model
Ejemplo n.º 4
0
def sequential_model():
    model = torch.nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 1))
    return serialize_torch_obj(model, nn.MSELoss(), torch.optim.Adam, lr=0.001)