Beispiel #1
0
def test_model(
    sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version
):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        sagemaker_session=sagemaker_session,
    )
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)

    model_package_name = "test-mxnet-register-model"
    content_types = ["application/json"]
    response_types = ["application/json"]
    inference_instances = ["ml.m4.xlarge"]
    transform_instances = ["ml.m4.xlarget"]

    dummy_metrics_source = MetricsSource(
        content_type="a",
        s3_uri="s3://b/c",
        content_digest="d",
    )
    model_metrics = ModelMetrics(
        model_statistics=dummy_metrics_source,
        model_constraints=dummy_metrics_source,
        model_data_statistics=dummy_metrics_source,
        model_data_constraints=dummy_metrics_source,
        bias=dummy_metrics_source,
        explainability=dummy_metrics_source,
    )
    model.register(
        content_types,
        response_types,
        inference_instances,
        transform_instances,
        model_package_name=model_package_name,
        model_metrics=model_metrics,
        marketplace_cert=True,
        approval_status="Approved",
        description="description",
    )
    expected_create_model_package_request = {
        "containers": ANY,
        "content_types": content_types,
        "response_types": response_types,
        "inference_instances": inference_instances,
        "transform_instances": transform_instances,
        "model_package_name": model_package_name,
        "model_metrics": model_metrics._to_request_dict(),
        "marketplace_cert": True,
        "approval_status": "Approved",
        "description": "description",
    }
    sagemaker_session.create_model_package_from_containers.assert_called_with(
        **expected_create_model_package_request
    )
Beispiel #2
0
def test_model_register_all_args(
    sagemaker_session,
    mxnet_inference_version,
    mxnet_inference_py_version,
    skip_if_mms_version,
):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        sagemaker_session=sagemaker_session,
    )
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)

    model_package_name = "test-mxnet-register-model"
    content_types = ["application/json"]
    response_types = ["application/json"]
    inference_instances = ["ml.m4.xlarge"]
    transform_instances = ["ml.m4.xlarget"]

    dummy_metrics_source = MetricsSource(
        content_type="a",
        s3_uri="s3://b/c",
        content_digest="d",
    )
    dummy_file_source = FileSource(
        content_type="a",
        s3_uri="s3://b/c",
        content_digest="d",
    )
    model_metrics = ModelMetrics(
        model_statistics=dummy_metrics_source,
        model_constraints=dummy_metrics_source,
        model_data_statistics=dummy_metrics_source,
        model_data_constraints=dummy_metrics_source,
        bias=dummy_metrics_source,
        bias_pre_training=dummy_metrics_source,
        bias_post_training=dummy_metrics_source,
        explainability=dummy_metrics_source,
    )
    drift_check_baselines = DriftCheckBaselines(
        model_statistics=dummy_metrics_source,
        model_constraints=dummy_metrics_source,
        model_data_statistics=dummy_metrics_source,
        model_data_constraints=dummy_metrics_source,
        bias_config_file=dummy_file_source,
        bias_pre_training_constraints=dummy_metrics_source,
        bias_post_training_constraints=dummy_metrics_source,
        explainability_constraints=dummy_metrics_source,
        explainability_config_file=dummy_file_source,
    )
    metadata_properties = MetadataProperties(
        commit_id="test-commit-id",
        repository="test-repository",
        generated_by="sagemaker-python-sdk-test",
        project_id="test-project-id",
    )
    model.register(
        content_types,
        response_types,
        inference_instances,
        transform_instances,
        model_package_name=model_package_name,
        model_metrics=model_metrics,
        metadata_properties=metadata_properties,
        marketplace_cert=True,
        approval_status="Approved",
        description="description",
        drift_check_baselines=drift_check_baselines,
    )
    expected_create_model_package_request = {
        "containers": ANY,
        "content_types": content_types,
        "response_types": response_types,
        "inference_instances": inference_instances,
        "transform_instances": transform_instances,
        "model_package_name": model_package_name,
        "model_metrics": model_metrics._to_request_dict(),
        "metadata_properties": metadata_properties._to_request_dict(),
        "marketplace_cert": True,
        "approval_status": "Approved",
        "description": "description",
        "drift_check_baselines": drift_check_baselines._to_request_dict(),
    }
    sagemaker_session.create_model_package_from_containers.assert_called_with(
        **expected_create_model_package_request
    )