def test_init_sagemaker_deployment_client_with_iam_role_arn_but_no_region_name_raises_exception():
    match = "A region name must be provided when the target_uri contains a role ARN."
    with pytest.raises(MlflowException, match=match) as exc:
        mfs.SageMakerDeploymentClient(
            "sagemaker:/arn:aws:iam::123456789012:role/dummy.company.com/assumed_role"
        )

    assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
def test_initialize_sagemaker_deployment_client_with_region_name_and_iam_role_arn():
    plugin = mfs.SageMakerDeploymentClient(
        "sagemaker:/us-east-1/////////arn:aws:iam::123456789012:role/dummy.company.com/assumed_role"
    )

    assert plugin.region_name == "us-east-1"
    assert (
        plugin.assumed_role_arn == "arn:aws:iam::123456789012:role/dummy.company.com/assumed_role"
    )
def test_get_deployment_successful(pretrained_model, sagemaker_client):
    name = "test-app"
    region_name = sagemaker_client.meta.region_name
    sagemaker_deployment_client = mfs.SageMakerDeploymentClient(f"sagemaker:/{region_name}")
    sagemaker_deployment_client.create_deployment(
        name=name, model_uri=pretrained_model.model_uri, config=dict(region_name=region_name)
    )

    endpoint_description = sagemaker_deployment_client.get_deployment(name)

    expected_description = sagemaker_client.describe_endpoint(EndpointName=name)
    assert endpoint_description == expected_description
def test_create_deployment_with_non_existent_assume_role_arn_raises_exception(pretrained_model):

    plugin = mfs.SageMakerDeploymentClient(
        "sagemaker:/us-west-2/arn:aws:iam::123456789012:role/non-existent-role-arn"
    )
    match = (
        r"An error occurred \(NoSuchEntity\) when calling the GetRole "
        r"operation: Role non-existent-role-arn not found"
    )
    with pytest.raises(botocore.exceptions.ClientError, match=match):
        plugin.create_deployment(
            name="bad_assume_role_arn",
            model_uri=pretrained_model.model_uri,
        )
def test_list_deployments_returns_all_endpoints(pretrained_model, sagemaker_client):
    region_name = sagemaker_client.meta.region_name
    sagemaker_deployment_client = mfs.SageMakerDeploymentClient(f"sagemaker:/{region_name}")
    sagemaker_deployment_client.create_deployment(
        name="test-app-1",
        model_uri=pretrained_model.model_uri,
        config=dict(region_name=region_name),
    )
    sagemaker_deployment_client.create_deployment(
        name="test-app-2",
        model_uri=pretrained_model.model_uri,
        config=dict(region_name=region_name),
    )

    endpoints = sagemaker_deployment_client.list_deployments()

    assert len(endpoints) == 2
    assert endpoints[0]["EndpointName"] == "test-app-1"
    assert endpoints[1]["EndpointName"] == "test-app-2"
def sagemaker_deployment_client():
    return mfs.SageMakerDeploymentClient(
        "sagemaker:/us-west-2/arn:aws:iam::123456789012:role/assumed_role"
    )
def test_initialize_sagemaker_deployment_client_with_region_name():
    plugin = mfs.SageMakerDeploymentClient("sagemaker:/us-east-1")

    assert plugin.region_name == "us-east-1"
    assert plugin.assumed_role_arn is None
def test_initialize_sagemaker_deployment_client_with_empty_path():
    plugin = mfs.SageMakerDeploymentClient("sagemaker:/")

    assert plugin.region_name == mfs.DEFAULT_REGION_NAME
    assert plugin.assumed_role_arn is None
def test_get_deployment_non_existent_deployment():
    sagemaker_deployment_client = mfs.SageMakerDeploymentClient("sagemaker:/us-west-2")

    with pytest.raises(MlflowException, match="There was an error while"):
        sagemaker_deployment_client.get_deployment("non-existent app")