Exemple #1
0
def test_serialization_transformer_and_predict(spark_context,
                                               classification_model,
                                               mnist_data):
    _, _, x_test, y_test = mnist_data
    df = to_data_frame(spark_context, x_test, y_test, categorical=True)
    transformer = ElephasTransformer(
        weights=classification_model.get_weights(),
        model_type=ModelType.CLASSIFICATION)
    transformer.set_keras_model_config(classification_model.to_yaml())
    transformer.save("test.h5")
    loaded_transformer = load_ml_transformer("test.h5")
    loaded_transformer.transform(df)
Exemple #2
0
def test_serialization_transformer(classification_model):
    transformer = ElephasTransformer()
    transformer.set_keras_model_config(classification_model.to_yaml())
    transformer.save("test.h5")
    loaded_model = load_ml_transformer("test.h5")
    assert loaded_model.get_model().to_yaml() == classification_model.to_yaml()
def test_serialization_transformer():
    transformer = ElephasTransformer()
    transformer.set_keras_model_config(model.to_yaml())
    transformer.save("test.h5")
    load_ml_transformer("test.h5")