Пример #1
0
def test_local_mode_serving_from_s3_model(sagemaker_local_session,
                                          mxnet_model):
    local_mode_lock_fd = open(LOCK_PATH, 'w')
    local_mode_lock = local_mode_lock_fd.fileno()

    model_data = mxnet_model.model_data
    boto_session = sagemaker_local_session.boto_session
    default_bucket = sagemaker_local_session.default_bucket()
    uploaded_data = tar_and_upload_dir(boto_session, default_bucket,
                                       'test_mxnet_local_mode', '', model_data)

    s3_model = MXNetModel(model_data=uploaded_data.s3_prefix,
                          role='SageMakerRole',
                          entry_point=mxnet_model.entry_point,
                          image=mxnet_model.image,
                          sagemaker_session=sagemaker_local_session)

    predictor = None
    try:
        # Since Local Mode uses the same port for serving, we need a lock in order
        # to allow concurrent test execution. The serving test is really fast so it still
        # makes sense to allow this behavior.
        fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
        predictor = s3_model.deploy(initial_instance_count=1,
                                    instance_type='local')
        data = numpy.zeros(shape=(1, 1, 28, 28))
        predictor.predict(data)
    finally:
        if predictor:
            predictor.delete_endpoint()
            time.sleep(5)
        fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
Пример #2
0
def _predictor(image, framework_version, sagemaker_local_session,
               instance_type):
    model_dir = os.path.join(RESOURCE_PATH, 'model')
    source_dir = os.path.join(RESOURCE_PATH, 'scripts')

    versions_map = {
        # container version -> autogluon version
        '0.3.2': '0.3.1',
    }
    ag_framework_version = versions_map.get(framework_version,
                                            framework_version)
    model = MXNetModel(
        model_data=f"file://{model_dir}/model_{ag_framework_version}.tar.gz",
        role=ROLE,
        image_uri=image,
        sagemaker_session=sagemaker_local_session,
        source_dir=source_dir,
        entry_point="tabular_serve.py",
        framework_version="1.9.0")
    with local_mode_utils.lock():
        try:
            predictor = model.deploy(1, instance_type)
            yield predictor
        finally:
            predictor.delete_endpoint()
def test_deploy_elastic_inference_with_pretrained_model(
        pretrained_model_data, ecr_image, sagemaker_session, instance_type,
        accelerator_type):
    default_handler_path = os.path.join(RESOURCE_PATH, 'default_handlers')
    endpoint_name = 'test-mxnet-ei-deploy-model-{}'.format(
        sagemaker_timestamp())

    with timeout_and_delete_endpoint_by_name(
            endpoint_name=endpoint_name,
            sagemaker_session=sagemaker_session,
            minutes=20):
        model = MXNetModel(model_data=pretrained_model_data,
                           entry_point=os.path.join(default_handler_path,
                                                    'code', 'empty_module.py'),
                           role='SageMakerRole',
                           image=ecr_image,
                           sagemaker_session=sagemaker_session)

        logger.info('deploying model to endpoint: {}'.format(endpoint_name))
        predictor = model.deploy(initial_instance_count=1,
                                 instance_type=instance_type,
                                 accelerator_type=accelerator_type,
                                 endpoint_name=endpoint_name)

        random_input = np.zeros(shape=(1, 3, 224, 224))

        predict_response = predictor.predict(random_input.tolist())
        assert predict_response
Пример #4
0
def test_model(sagemaker_session):
    model = MXNetModel("s3://some/data.tar.gz",
                       role=ROLE,
                       entry_point=SCRIPT_PATH,
                       sagemaker_session=sagemaker_session)
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model):
    local_mode_lock_fd = open(LOCK_PATH, 'w')
    local_mode_lock = local_mode_lock_fd.fileno()

    model_data = mxnet_model.model_data
    boto_session = sagemaker_local_session.boto_session
    default_bucket = sagemaker_local_session.default_bucket()
    uploaded_data = tar_and_upload_dir(boto_session, default_bucket,
                                       'test_mxnet_local_mode', '', model_data)

    s3_model = MXNetModel(model_data=uploaded_data.s3_prefix, role='SageMakerRole',
                          entry_point=mxnet_model.entry_point, image=mxnet_model.image,
                          sagemaker_session=sagemaker_local_session)

    predictor = None
    try:
        # Since Local Mode uses the same port for serving, we need a lock in order
        # to allow concurrent test execution. The serving test is really fast so it still
        # makes sense to allow this behavior.
        fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
        predictor = s3_model.deploy(initial_instance_count=1, instance_type='local')
        data = numpy.zeros(shape=(1, 1, 28, 28))
        predictor.predict(data)
    finally:
        if predictor:
            predictor.delete_endpoint()
            time.sleep(5)
        fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
Пример #6
0
def test_model_register(
    sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version
):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        sagemaker_session=sagemaker_session,
    )
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)

    model_package_name = "test-mxnet-register-model"
    content_types = ["application/json"]
    response_types = ["application/json"]
    inference_instances = ["ml.m4.xlarge"]
    transform_instances = ["ml.m4.xlarget"]
    model.register(
        content_types,
        response_types,
        inference_instances,
        transform_instances,
        model_package_name=model_package_name,
    )
    sagemaker_session.create_model_package_from_containers.assert_called()
Пример #7
0
def test_elastic_inference(ecr_image, sagemaker_session, instance_type,
                           accelerator_type, framework_version):
    entry_point = DEFAULT_SCRIPT_PATH
    image_framework, image_framework_version = get_framework_and_version_from_tag(
        ecr_image)
    if image_framework_version == "1.5.1":
        entry_point = os.path.join(DEFAULT_HANDLER_PATH, 'model', 'code',
                                   'empty_module.py')

    endpoint_name = utils.unique_name_from_base('test-mxnet-ei')

    with timeout_and_delete_endpoint_by_name(
            endpoint_name=endpoint_name,
            sagemaker_session=sagemaker_session,
            minutes=20):
        prefix = 'mxnet-serving/default-handlers'
        model_data = sagemaker_session.upload_data(path=MODEL_PATH,
                                                   key_prefix=prefix)
        model = MXNetModel(model_data=model_data,
                           entry_point=entry_point,
                           role='SageMakerRole',
                           image_uri=ecr_image,
                           framework_version=framework_version,
                           sagemaker_session=sagemaker_session)

        predictor = model.deploy(initial_instance_count=1,
                                 instance_type=instance_type,
                                 accelerator_type=accelerator_type,
                                 endpoint_name=endpoint_name)

        output = predictor.predict([[1, 2]])
        assert [[4.9999918937683105]] == output
Пример #8
0
def test_elastic_inference(ecr_image, sagemaker_session, instance_type,
                           accelerator_type, framework_version):
    endpoint_name = utils.unique_name_from_base('test-mxnet-ei')

    with timeout_and_delete_endpoint_by_name(
            endpoint_name=endpoint_name,
            sagemaker_session=sagemaker_session,
            minutes=20):
        prefix = 'mxnet-serving/default-handlers'
        model_data = sagemaker_session.upload_data(path=MODEL_PATH,
                                                   key_prefix=prefix)
        model = MXNetModel(model_data=model_data,
                           entry_point=SCRIPT_PATH,
                           role='SageMakerRole',
                           image=ecr_image,
                           framework_version=framework_version,
                           sagemaker_session=sagemaker_session)

        predictor = model.deploy(initial_instance_count=1,
                                 instance_type=instance_type,
                                 accelerator_type=accelerator_type,
                                 endpoint_name=endpoint_name)

        output = predictor.predict([[1, 2]])
        assert [[4.9999918937683105]] == output
def _test_sm_trained_model(sagemaker_session, ecr_image, instance_type, framework_version):
    model_dir = os.path.join(RESOURCE_PATH, 'model')
    source_dir = os.path.join(RESOURCE_PATH, 'scripts')

    endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-autogluon-serving-trained-model")
    ag_framework_version = '0.3.1' if framework_version == '0.3.2' else framework_version
    model_data = sagemaker_session.upload_data(path=os.path.join(model_dir, f'model_{ag_framework_version}.tar.gz'), key_prefix='sagemaker-autogluon-serving-trained-model/models')

    model = MXNetModel(
        model_data=model_data,
        role='SageMakerRole',
        image_uri=ecr_image,
        sagemaker_session=sagemaker_session,
        source_dir=source_dir,
        entry_point="tabular_serve.py",
        framework_version="1.8.0"
    )

    with timeout_and_delete_endpoint(endpoint_name, sagemaker_session, minutes=30):
        predictor = model.deploy(
            initial_instance_count=1,
            instance_type=instance_type,
            endpoint_name=endpoint_name,
        )
        predictor.serializer = CSVSerializer()
        predictor.deserializer = JSONDeserializer()

        data_path = os.path.join(RESOURCE_PATH, 'data')
        data = pd.read_csv(f'{data_path}/data.csv')
        assert 3 == len(data)

        preds = predictor.predict(data.values)
        assert preds == [' <=50K', ' <=50K', ' <=50K']
Пример #10
0
def test_model(sagemaker_session):
    model = MXNetModel(MODEL_DATA,
                       role=ROLE,
                       entry_point=SCRIPT_PATH,
                       sagemaker_session=sagemaker_session)
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
Пример #11
0
def test_model_mms_version(repack_model, sagemaker_session):
    model_kms_key = "kms-key"
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=MXNetModel._LOWEST_MMS_VERSION,
        sagemaker_session=sagemaker_session,
        name="test-mxnet-model",
        model_kms_key=model_kms_key,
    )
    predictor = model.deploy(1, GPU)

    repack_model.assert_called_once_with(
        inference_script=SCRIPT_PATH,
        source_directory=None,
        dependencies=[],
        model_uri=MODEL_DATA,
        repacked_model_uri="s3://mybucket/test-mxnet-model/model.tar.gz",
        sagemaker_session=sagemaker_session,
        kms_key=model_kms_key,
    )

    assert model.model_data == MODEL_DATA
    assert model.repacked_model_data == "s3://mybucket/test-mxnet-model/model.tar.gz"
    assert model.uploaded_code == UploadedCode(
        s3_prefix="s3://mybucket/test-mxnet-model/model.tar.gz",
        script_name=os.path.basename(SCRIPT_PATH),
    )
    assert isinstance(predictor, MXNetPredictor)
Пример #12
0
def test_model(
    sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version
):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        sagemaker_session=sagemaker_session,
    )
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)

    model_package_name = "test-mxnet-register-model"
    content_types = ["application/json"]
    response_types = ["application/json"]
    inference_instances = ["ml.m4.xlarge"]
    transform_instances = ["ml.m4.xlarget"]

    dummy_metrics_source = MetricsSource(
        content_type="a",
        s3_uri="s3://b/c",
        content_digest="d",
    )
    model_metrics = ModelMetrics(
        model_statistics=dummy_metrics_source,
        model_constraints=dummy_metrics_source,
        model_data_statistics=dummy_metrics_source,
        model_data_constraints=dummy_metrics_source,
        bias=dummy_metrics_source,
        explainability=dummy_metrics_source,
    )
    model.register(
        content_types,
        response_types,
        inference_instances,
        transform_instances,
        model_package_name=model_package_name,
        model_metrics=model_metrics,
        marketplace_cert=True,
        approval_status="Approved",
        description="description",
    )
    expected_create_model_package_request = {
        "containers": ANY,
        "content_types": content_types,
        "response_types": response_types,
        "inference_instances": inference_instances,
        "transform_instances": transform_instances,
        "model_package_name": model_package_name,
        "model_metrics": model_metrics._to_request_dict(),
        "marketplace_cert": True,
        "approval_status": "Approved",
        "description": "description",
    }
    sagemaker_session.create_model_package_from_containers.assert_called_with(
        **expected_create_model_package_request
    )
Пример #13
0
def test_model_prepare_container_def_no_instance_type_or_image():
    model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH)

    with pytest.raises(ValueError) as e:
        model.prepare_container_def()

    expected_msg = "Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
    assert expected_msg in str(e)
Пример #14
0
def test_model_image_accelerator(sagemaker_session):
    model = MXNetModel(MODEL_DATA,
                       role=ROLE,
                       entry_point=SCRIPT_PATH,
                       sagemaker_session=sagemaker_session)
    container_def = model.prepare_container_def(
        INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
    assert container_def['Image'] == _get_full_image_uri_with_ei(
        defaults.MXNET_VERSION)
Пример #15
0
def test_model_image_accelerator_mms_version(sagemaker_session):
    model = MXNetModel(MODEL_DATA,
                       role=ROLE,
                       entry_point=SCRIPT_PATH,
                       framework_version=MXNetModel._LOWEST_MMS_VERSION,
                       sagemaker_session=sagemaker_session)
    container_def = model.prepare_container_def(
        INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
    assert container_def['Image'] == _get_full_image_uri_with_ei(
        MXNetModel._LOWEST_MMS_VERSION, IMAGE_REPO_SERVING_NAME)
Пример #16
0
def test_model(sagemaker_session, mxnet_inference_version,
               mxnet_inference_py_version, skip_if_mms_version):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        sagemaker_session=sagemaker_session,
    )
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
def test_onnx_import(docker_image, sagemaker_local_session, local_instance_type):
    model_path = 'file://{}'.format(os.path.join(ONNX_PATH, 'onnx_model'))
    m = MXNetModel(model_path, 'SageMakerRole', SCRIPT_PATH, image=docker_image,
                   sagemaker_session=sagemaker_local_session,
                   model_server_workers=NUM_MODEL_SERVER_WORKERS)

    input = numpy.zeros(shape=(1, 1, 28, 28))

    with local_mode_utils.lock():
        try:
            predictor = m.deploy(1, local_instance_type)
            output = predictor.predict(input)
        finally:
            sagemaker_local_session.delete_endpoint(m.endpoint_name)

    # Check that there is a probability for each possible class in the prediction
    assert len(output[0]) == 10
def mxnet_model(sagemaker_session):
    return MXNetModel(
        MXNET_MODEL_DATA,
        role=MXNET_ROLE,
        entry_point=ENTRY_POINT,
        sagemaker_session=sagemaker_session,
        name=MXNET_MODEL_NAME,
        enable_network_isolation=True,
    )
def test_model_mms_version(repack_model, sagemaker_session):
    model = MXNetModel(MODEL_DATA,
                       role=ROLE,
                       entry_point=SCRIPT_PATH,
                       framework_version=MXNetModel._LOWEST_MMS_VERSION,
                       sagemaker_session=sagemaker_session)
    predictor = model.deploy(1, GPU)

    repack_model.assert_called_once_with(inference_script=SCRIPT_PATH,
                                         source_directory=None,
                                         model_uri=MODEL_DATA,
                                         sagemaker_session=sagemaker_session)

    assert model.model_data == MODEL_DATA
    assert model.repacked_model_data == REPACKED_MODEL_DATA
    assert model.uploaded_code == UploadedCode(
        s3_prefix=REPACKED_MODEL_DATA,
        script_name=os.path.basename(SCRIPT_PATH))
    assert isinstance(predictor, MXNetPredictor)
Пример #20
0
def test_model_py2_warning(warning, sagemaker_session):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        sagemaker_session=sagemaker_session,
        py_version="py2",
    )
    assert model.py_version == "py2"
    warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
Пример #21
0
def test_onnx_import(docker_image, sagemaker_local_session,
                     local_instance_type):
    model = MXNetModel('file://{}'.format(MODEL_PATH),
                       'SageMakerRole',
                       SCRIPT_PATH,
                       image=docker_image,
                       sagemaker_session=sagemaker_local_session)

    input = numpy.zeros(shape=(1, 1, 28, 28))

    with local_mode_utils.lock():
        try:
            predictor = model.deploy(1, local_instance_type)
            output = predictor.predict(input)
        finally:
            predictor.delete_endpoint()

    # Check that there is a probability for each possible class in the prediction
    assert len(output[0]) == 10
Пример #22
0
def test_model_empty_framework_version(warning, sagemaker_session):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        sagemaker_session=sagemaker_session,
        framework_version=None,
    )
    assert model.framework_version == defaults.MXNET_VERSION
    warning.assert_called_with(defaults.MXNET_VERSION, defaults.LATEST_VERSION)
Пример #23
0
def _predictor(image, framework_version, sagemaker_local_session, instance_type):
    model_dir = os.path.join(RESOURCE_PATH, 'model')
    source_dir = os.path.join(RESOURCE_PATH, 'scripts')

    model = MXNetModel(
        model_data=f"file://{model_dir}/model.tar.gz",
        role=ROLE,
        image_uri=image,
        sagemaker_session=sagemaker_local_session,
        source_dir=source_dir,
        entry_point="tabular_serve.py",
        framework_version="1.8.0"
    )
    with local_mode_utils.lock():
        try:
            predictor = model.deploy(1, instance_type)
            yield predictor
        finally:
            predictor.delete_endpoint()
Пример #24
0
def test_model_image_accelerator(
    retrieve_image_uri,
    repack_model,
    tar_and_upload,
    sagemaker_session,
    mxnet_eia_version,
    mxnet_eia_py_version,
):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_eia_version,
        py_version=mxnet_eia_py_version,
        sagemaker_session=sagemaker_session,
    )
    container_def = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
    assert container_def["Image"] == IMAGE
    assert _is_mms_version(mxnet_eia_version) ^ (tar_and_upload.called and not repack_model.called)
Пример #25
0
def test_elastic_inference():
    endpoint_name = utils.unique_name_from_base('mx-p3-8x-resnet')
    instance_type = 'ml.p3.8xlarge'
    framework_version = '1.4.1'

    maeve_client = boto3.client(
        "maeve",
        "us-west-2",
        endpoint_url="https://maeve.loadtest.us-west-2.ml-platform.aws.a2z.com"
    )
    runtime_client = boto3.client(
        "sagemaker-runtime",
        "us-west-2",
        endpoint_url=
        "https://maeveruntime.loadtest.us-west-2.ml-platform.aws.a2z.com")

    sagemaker_session = session.Session(
        sagemaker_client=maeve_client, sagemaker_runtime_client=runtime_client)

    with timeout_and_delete_endpoint_by_name(
            endpoint_name=endpoint_name,
            sagemaker_session=sagemaker_session,
            minutes=20):
        prefix = 'mxnet-serving/default-handlers'
        model_data = sagemaker_session.upload_data(path=MODEL_PATH,
                                                   key_prefix=prefix)
        model = MXNetModel(
            model_data=model_data,
            entry_point=SCRIPT_PATH,
            role='arn:aws:iam::841569659894:role/sagemaker-access-role',
            image=
            '763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-gpu-py36-cu100-ubuntu16.04',
            framework_version=framework_version,
            py_version='py3',
            sagemaker_session=sagemaker_session)

        predictor = model.deploy(initial_instance_count=1,
                                 instance_type=instance_type,
                                 endpoint_name=endpoint_name)

        output = predictor.predict([[1, 2]])
        assert [[4.9999918937683105]] == output
def mxnet_model(sagemaker_session):
    return MXNetModel(
        MXNET_MODEL_DATA,
        entry_point=ENTRY_POINT,
        framework_version=MXNET_FRAMEWORK_VERSION,
        py_version=MXNET_PY_VERSION,
        role=MXNET_ROLE,
        sagemaker_session=sagemaker_session,
        name=MXNET_MODEL_NAME,
        enable_network_isolation=True,
    )
Пример #27
0
def test_model_custom_serialization(
    sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version
):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        sagemaker_session=sagemaker_session,
    )
    custom_serializer = Mock()
    custom_deserializer = Mock()
    predictor = model.deploy(
        1,
        CPU,
        serializer=custom_serializer,
        deserializer=custom_deserializer,
    )
    assert isinstance(predictor, MXNetPredictor)
    assert predictor.serializer is custom_serializer
    assert predictor.deserializer is custom_deserializer
Пример #28
0
def test_model(sagemaker_session):
    model = MXNetModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH,
                       sagemaker_session=sagemaker_session)
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
Пример #29
0
from sagemaker.session import Session
from sagemaker.mxnet import MXNetModel
from mxnet.gluon.data.vision import transforms

mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx')

with tarfile.open('onnx_model.tar.gz', mode='w:gz') as archive:
    archive.add('resnet50v2.onnx')

model_data = Session().upload_data(path='onnx_model.tar.gz', key_prefix='model')
role = 'arn:aws:iam::841569659894:role/sagemaker-access-role'

mxnet_model = MXNetModel(model_data=model_data,
                         entry_point='resnet50.py',
                         role=role,
                         image='763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-gpu-py36-cu100-ubuntu16.04',
                         py_version='py3',
                         framework_version='1.4.1')

predictor = mxnet_model.deploy(initial_instance_count=1, instance_type='ml.p3.8xlarge')

def do_pred():
    data = np.random.rand(1, 3, 224, 224)
    start_time = time.time()
    scores = predictor.predict(data)
    end_time = time.time()
    
    return end_time-start_time

costtime = do_pred()
print("this run cost {}s".format(costtime))
Пример #30
0
def test_model_register_all_args(
    sagemaker_session,
    mxnet_inference_version,
    mxnet_inference_py_version,
    skip_if_mms_version,
):
    model = MXNetModel(
        MODEL_DATA,
        role=ROLE,
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        sagemaker_session=sagemaker_session,
    )
    predictor = model.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)

    model_package_name = "test-mxnet-register-model"
    content_types = ["application/json"]
    response_types = ["application/json"]
    inference_instances = ["ml.m4.xlarge"]
    transform_instances = ["ml.m4.xlarget"]

    dummy_metrics_source = MetricsSource(
        content_type="a",
        s3_uri="s3://b/c",
        content_digest="d",
    )
    dummy_file_source = FileSource(
        content_type="a",
        s3_uri="s3://b/c",
        content_digest="d",
    )
    model_metrics = ModelMetrics(
        model_statistics=dummy_metrics_source,
        model_constraints=dummy_metrics_source,
        model_data_statistics=dummy_metrics_source,
        model_data_constraints=dummy_metrics_source,
        bias=dummy_metrics_source,
        bias_pre_training=dummy_metrics_source,
        bias_post_training=dummy_metrics_source,
        explainability=dummy_metrics_source,
    )
    drift_check_baselines = DriftCheckBaselines(
        model_statistics=dummy_metrics_source,
        model_constraints=dummy_metrics_source,
        model_data_statistics=dummy_metrics_source,
        model_data_constraints=dummy_metrics_source,
        bias_config_file=dummy_file_source,
        bias_pre_training_constraints=dummy_metrics_source,
        bias_post_training_constraints=dummy_metrics_source,
        explainability_constraints=dummy_metrics_source,
        explainability_config_file=dummy_file_source,
    )
    metadata_properties = MetadataProperties(
        commit_id="test-commit-id",
        repository="test-repository",
        generated_by="sagemaker-python-sdk-test",
        project_id="test-project-id",
    )
    model.register(
        content_types,
        response_types,
        inference_instances,
        transform_instances,
        model_package_name=model_package_name,
        model_metrics=model_metrics,
        metadata_properties=metadata_properties,
        marketplace_cert=True,
        approval_status="Approved",
        description="description",
        drift_check_baselines=drift_check_baselines,
    )
    expected_create_model_package_request = {
        "containers": ANY,
        "content_types": content_types,
        "response_types": response_types,
        "inference_instances": inference_instances,
        "transform_instances": transform_instances,
        "model_package_name": model_package_name,
        "model_metrics": model_metrics._to_request_dict(),
        "metadata_properties": metadata_properties._to_request_dict(),
        "marketplace_cert": True,
        "approval_status": "Approved",
        "description": "description",
        "drift_check_baselines": drift_check_baselines._to_request_dict(),
    }
    sagemaker_session.create_model_package_from_containers.assert_called_with(
        **expected_create_model_package_request
    )