def test_model( sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version ): model = MXNetModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, framework_version=mxnet_inference_version, py_version=mxnet_inference_py_version, sagemaker_session=sagemaker_session, ) predictor = model.deploy(1, GPU) assert isinstance(predictor, MXNetPredictor) model_package_name = "test-mxnet-register-model" content_types = ["application/json"] response_types = ["application/json"] inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] dummy_metrics_source = MetricsSource( content_type="a", s3_uri="s3://b/c", content_digest="d", ) model_metrics = ModelMetrics( model_statistics=dummy_metrics_source, model_constraints=dummy_metrics_source, model_data_statistics=dummy_metrics_source, model_data_constraints=dummy_metrics_source, bias=dummy_metrics_source, explainability=dummy_metrics_source, ) model.register( content_types, response_types, inference_instances, transform_instances, model_package_name=model_package_name, model_metrics=model_metrics, marketplace_cert=True, approval_status="Approved", description="description", ) expected_create_model_package_request = { "containers": ANY, "content_types": content_types, "response_types": response_types, "inference_instances": inference_instances, "transform_instances": transform_instances, "model_package_name": model_package_name, "model_metrics": model_metrics._to_request_dict(), "marketplace_cert": True, "approval_status": "Approved", "description": "description", } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request )
def test_model_register_all_args( sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version, ): model = MXNetModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, framework_version=mxnet_inference_version, py_version=mxnet_inference_py_version, sagemaker_session=sagemaker_session, ) predictor = model.deploy(1, GPU) assert isinstance(predictor, MXNetPredictor) model_package_name = "test-mxnet-register-model" content_types = ["application/json"] response_types = ["application/json"] inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] dummy_metrics_source = MetricsSource( content_type="a", s3_uri="s3://b/c", content_digest="d", ) dummy_file_source = FileSource( content_type="a", s3_uri="s3://b/c", content_digest="d", ) model_metrics = ModelMetrics( model_statistics=dummy_metrics_source, model_constraints=dummy_metrics_source, model_data_statistics=dummy_metrics_source, model_data_constraints=dummy_metrics_source, bias=dummy_metrics_source, bias_pre_training=dummy_metrics_source, bias_post_training=dummy_metrics_source, explainability=dummy_metrics_source, ) drift_check_baselines = DriftCheckBaselines( model_statistics=dummy_metrics_source, model_constraints=dummy_metrics_source, model_data_statistics=dummy_metrics_source, model_data_constraints=dummy_metrics_source, bias_config_file=dummy_file_source, bias_pre_training_constraints=dummy_metrics_source, bias_post_training_constraints=dummy_metrics_source, explainability_constraints=dummy_metrics_source, explainability_config_file=dummy_file_source, ) metadata_properties = MetadataProperties( commit_id="test-commit-id", repository="test-repository", generated_by="sagemaker-python-sdk-test", project_id="test-project-id", ) model.register( content_types, response_types, inference_instances, transform_instances, model_package_name=model_package_name, model_metrics=model_metrics, metadata_properties=metadata_properties, marketplace_cert=True, approval_status="Approved", description="description", drift_check_baselines=drift_check_baselines, ) expected_create_model_package_request = { "containers": ANY, "content_types": content_types, "response_types": response_types, "inference_instances": inference_instances, "transform_instances": transform_instances, "model_package_name": model_package_name, "model_metrics": model_metrics._to_request_dict(), "metadata_properties": metadata_properties._to_request_dict(), "marketplace_cert": True, "approval_status": "Approved", "description": "description", "drift_check_baselines": drift_check_baselines._to_request_dict(), } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request )