def test_deploy_multi_data_model(sagemaker_session): model = MultiDataModel( name=MODEL_NAME, model_data_prefix=VALID_MULTI_MODEL_DATA_PREFIX, image=IMAGE, role=ROLE, sagemaker_session=sagemaker_session, env={"EXTRA_ENV_MOCK": "MockValue"}, ) model.deploy( initial_instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE, endpoint_name=MULTI_MODEL_ENDPOINT_NAME, ) sagemaker_session.create_model.assert_called_with( MODEL_NAME, ROLE, model.prepare_container_def(INSTANCE_TYPE), vpc_config=None, enable_network_isolation=False, tags=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=MULTI_MODEL_ENDPOINT_NAME, wait=True, tags=None, kms_key=None, data_capture_config_dict=None, production_variants=EXPECTED_PROD_VARIANT, )
def test_deploy_multi_data_framework_model(sagemaker_session, mxnet_model): model = MultiDataModel( name=MODEL_NAME, model_data_prefix=VALID_MULTI_MODEL_DATA_PREFIX, sagemaker_session=sagemaker_session, model=mxnet_model, ) predictor = model.deploy( initial_instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE, endpoint_name=MULTI_MODEL_ENDPOINT_NAME, ) # Assert if this is called with mxnet_model parameters sagemaker_session.create_model.assert_called_with( MODEL_NAME, MXNET_ROLE, model.prepare_container_def(INSTANCE_TYPE), vpc_config=None, enable_network_isolation=True, tags=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=MULTI_MODEL_ENDPOINT_NAME, wait=True, tags=None, kms_key=None, data_capture_config_dict=None, production_variants=EXPECTED_PROD_VARIANT, ) sagemaker_session.create_endpoint_config.assert_not_called() assert isinstance(predictor, MXNetPredictor)
def test_prepare_container_def_mxnet(sagemaker_session, mxnet_model): expected_container_env_keys = [ "SAGEMAKER_CONTAINER_LOG_LEVEL", "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS", "SAGEMAKER_PROGRAM", "SAGEMAKER_REGION", "SAGEMAKER_SUBMIT_DIRECTORY", ] model = MultiDataModel( name=MODEL_NAME, model_data_prefix=VALID_MULTI_MODEL_DATA_PREFIX, sagemaker_session=sagemaker_session, model=mxnet_model, ) container_def = model.prepare_container_def(INSTANCE_TYPE) assert container_def["Image"] == MXNET_IMAGE assert container_def["ModelDataUrl"] == VALID_MULTI_MODEL_DATA_PREFIX assert container_def["Mode"] == MULTI_MODEL_CONTAINER_MODE # Check if the environment variables defined only for MXNetModel # are part of the MultiDataModel container definition assert set(container_def["Environment"].keys()) == set(expected_container_env_keys)