def test_deploy_multi_data_model(sagemaker_session):
    model = MultiDataModel(
        name=MODEL_NAME,
        model_data_prefix=VALID_MULTI_MODEL_DATA_PREFIX,
        image=IMAGE,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        env={"EXTRA_ENV_MOCK": "MockValue"},
    )
    model.deploy(
        initial_instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        endpoint_name=MULTI_MODEL_ENDPOINT_NAME,
    )

    sagemaker_session.create_model.assert_called_with(
        MODEL_NAME,
        ROLE,
        model.prepare_container_def(INSTANCE_TYPE),
        vpc_config=None,
        enable_network_isolation=False,
        tags=None,
    )
    sagemaker_session.endpoint_from_production_variants.assert_called_with(
        name=MULTI_MODEL_ENDPOINT_NAME,
        wait=True,
        tags=None,
        kms_key=None,
        data_capture_config_dict=None,
        production_variants=EXPECTED_PROD_VARIANT,
    )
def test_deploy_multi_data_framework_model(sagemaker_session, mxnet_model):
    model = MultiDataModel(
        name=MODEL_NAME,
        model_data_prefix=VALID_MULTI_MODEL_DATA_PREFIX,
        sagemaker_session=sagemaker_session,
        model=mxnet_model,
    )

    predictor = model.deploy(
        initial_instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        endpoint_name=MULTI_MODEL_ENDPOINT_NAME,
    )

    # Assert if this is called with mxnet_model parameters
    sagemaker_session.create_model.assert_called_with(
        MODEL_NAME,
        MXNET_ROLE,
        model.prepare_container_def(INSTANCE_TYPE),
        vpc_config=None,
        enable_network_isolation=True,
        tags=None,
    )
    sagemaker_session.endpoint_from_production_variants.assert_called_with(
        name=MULTI_MODEL_ENDPOINT_NAME,
        wait=True,
        tags=None,
        kms_key=None,
        data_capture_config_dict=None,
        production_variants=EXPECTED_PROD_VARIANT,
    )
    sagemaker_session.create_endpoint_config.assert_not_called()
    assert isinstance(predictor, MXNetPredictor)
def test_prepare_container_def_mxnet(sagemaker_session, mxnet_model):
    expected_container_env_keys = [
        "SAGEMAKER_CONTAINER_LOG_LEVEL",
        "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS",
        "SAGEMAKER_PROGRAM",
        "SAGEMAKER_REGION",
        "SAGEMAKER_SUBMIT_DIRECTORY",
    ]
    model = MultiDataModel(
        name=MODEL_NAME,
        model_data_prefix=VALID_MULTI_MODEL_DATA_PREFIX,
        sagemaker_session=sagemaker_session,
        model=mxnet_model,
    )

    container_def = model.prepare_container_def(INSTANCE_TYPE)

    assert container_def["Image"] == MXNET_IMAGE
    assert container_def["ModelDataUrl"] == VALID_MULTI_MODEL_DATA_PREFIX
    assert container_def["Mode"] == MULTI_MODEL_CONTAINER_MODE
    # Check if the environment variables defined only for MXNetModel
    # are part of the MultiDataModel container definition
    assert set(container_def["Environment"].keys()) == set(expected_container_env_keys)