def test_create_sagemaker_model_optional_model_params(base_name_from_image, name_from_base, prepare_container_def, sagemaker_session): container_def = { "Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA } prepare_container_def.return_value = container_def vpc_config = {"Subnets": ["123"], "SecurityGroupIds": ["456", "789"]} model = Model( MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, vpc_config=vpc_config, enable_network_isolation=True, sagemaker_session=sagemaker_session, ) model._create_sagemaker_model(INSTANCE_TYPE) base_name_from_image.assert_not_called() name_from_base.assert_not_called() sagemaker_session.create_model.assert_called_with( MODEL_NAME, ROLE, container_def, vpc_config=vpc_config, enable_network_isolation=True, tags=None, )
def test_create_sagemaker_model_generates_model_name(base_name_from_image, name_from_base, prepare_container_def, sagemaker_session): container_def = { "Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA } prepare_container_def.return_value = container_def model = Model( MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, ) model._create_sagemaker_model(INSTANCE_TYPE) base_name_from_image.assert_called_with(MODEL_IMAGE) name_from_base.assert_called_with(base_name_from_image.return_value) sagemaker_session.create_model.assert_called_with( MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None, )
def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_session): model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
def test_create_sagemaker_model(name_from_image, prepare_container_def, sagemaker_session): name_from_image.return_value = MODEL_NAME container_def = { "Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA } prepare_container_def.return_value = container_def model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None) name_from_image.assert_called_with(MODEL_IMAGE) sagemaker_session.create_model.assert_called_with( MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None)
def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemaker_session): model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session) accelerator_type = "ml.eia.medium" model._create_sagemaker_model(INSTANCE_TYPE, accelerator_type=accelerator_type) prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=accelerator_type)
def test_create_sagemaker_model_creates_correct_session(local_session, session): model = Model(MODEL_IMAGE, MODEL_DATA) model._create_sagemaker_model("local") assert model.sagemaker_session == local_session.return_value model = Model(MODEL_IMAGE, MODEL_DATA) model._create_sagemaker_model("ml.m5.xlarge") assert model.sagemaker_session == session.return_value
def test_create_sagemaker_model_tags(prepare_container_def, sagemaker_session): container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} prepare_container_def.return_value = container_def model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session) tags = {"Key": "foo", "Value": "bar"} model._create_sagemaker_model(INSTANCE_TYPE, tags=tags) sagemaker_session.create_model.assert_called_with( MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=tags )
def test_create_sagemaker_model_generates_model_name_each_time( base_name_from_image, name_from_base, prepare_container_def, sagemaker_session ): container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} prepare_container_def.return_value = container_def model = Model( MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, ) model._create_sagemaker_model(INSTANCE_TYPE) model._create_sagemaker_model(INSTANCE_TYPE) base_name_from_image.assert_called_once_with(MODEL_IMAGE) name_from_base.assert_called_with(base_name_from_image.return_value) assert 2 == name_from_base.call_count
def test_create_sagemaker_model(prepare_container_def, sagemaker_session): container_def = { "Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA } prepare_container_def.return_value = container_def model = Model(MODEL_DATA, MODEL_IMAGE, name=MODEL_NAME, sagemaker_session=sagemaker_session) model._create_sagemaker_model() prepare_container_def.assert_called_with(None, accelerator_type=None, serverless_inference_config=None) sagemaker_session.create_model.assert_called_with( MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None)