def test_serialization_transformer_and_predict(spark_context, classification_model, mnist_data): _, _, x_test, y_test = mnist_data df = to_data_frame(spark_context, x_test, y_test, categorical=True) transformer = ElephasTransformer( weights=classification_model.get_weights(), model_type=ModelType.CLASSIFICATION) transformer.set_keras_model_config(classification_model.to_yaml()) transformer.save("test.h5") loaded_transformer = load_ml_transformer("test.h5") loaded_transformer.transform(df)
def test_serialization_transformer(classification_model): transformer = ElephasTransformer() transformer.set_keras_model_config(classification_model.to_yaml()) transformer.save("test.h5") loaded_model = load_ml_transformer("test.h5") assert loaded_model.get_model().to_yaml() == classification_model.to_yaml()
def test_serialization_transformer(): transformer = ElephasTransformer() transformer.set_keras_model_config(model.to_yaml()) transformer.save("test.h5") load_ml_transformer("test.h5")