def test_sparkml_model_deploy(sagemaker_session, cpu_instance_type): # Uploads an MLeap serialized MLeap model to S3 and use that to deploy a SparkML model to perform inference data_path = os.path.join(DATA_DIR, "sparkml_model") endpoint_name = "test-sparkml-deploy-{}".format(sagemaker_timestamp()) model_data = sagemaker_session.upload_data( path=os.path.join(data_path, "mleap_model.tar.gz"), key_prefix="integ-test-data/sparkml/model", ) schema = json.dumps( { "input": [ {"name": "Pclass", "type": "float"}, {"name": "Embarked", "type": "string"}, {"name": "Age", "type": "float"}, {"name": "Fare", "type": "float"}, {"name": "SibSp", "type": "float"}, {"name": "Sex", "type": "string"}, ], "output": {"name": "features", "struct": "vector", "type": "double"}, } ) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): model = SparkMLModel( model_data=model_data, role="SageMakerRole", sagemaker_session=sagemaker_session, env={"SAGEMAKER_SPARKML_SCHEMA": schema}, ) predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) valid_data = "1.0,C,38.0,71.5,1.0,female" assert predictor.predict(valid_data) == "1.0,0.0,38.0,1.0,71.5,0.0,1.0" invalid_data = "1.0,28.0,C,38.0,71.5,1.0" assert predictor.predict(invalid_data) is None
from pathlib import Path import boto3 import sagemaker from cfn_tools import load_yaml from sagemaker.sparkml.model import SparkMLModel account = boto3.client('sts').get_caller_identity()['Account'] with open(Path(__file__).parents[2]/'setup/stack.yaml') as f: env = load_yaml(f) file_name = "model.tar.gz" bucket_name = env['Parameters']['PreBucket']['Default'] role = f'arn:aws:iam::{account}:role/sagemaker-role' endpoint_name = env['Mappings']['TaskMap']['predict']['name'] model_name = "model" instance_type = "ml.t2.medium" if __name__ == "__main__": sparkml_model = SparkMLModel(model_data=f"s3://{bucket_name}/{file_name}", role=role, sagemaker_session=sagemaker.Session( boto3.session.Session()), name=model_name) sparkml_model.deploy(initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name)