def test_generic_deploy_vpc_config_override(sagemaker_session): vpc_config_a = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} vpc_config_b = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']} e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session) e.fit({'train': 's3://bucket/training-prefix'}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) assert sagemaker_session.create_model.call_args_list[0][1][ 'vpc_config'] is None e.subnets = vpc_config_a['Subnets'] e.security_group_ids = vpc_config_a['SecurityGroupIds'] e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) assert sagemaker_session.create_model.call_args_list[1][1][ 'vpc_config'] == vpc_config_a e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=vpc_config_b) assert sagemaker_session.create_model.call_args_list[2][1][ 'vpc_config'] == vpc_config_b e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=None) assert sagemaker_session.create_model.call_args_list[3][1][ 'vpc_config'] is None
def test_generic_create_model_vpc_config_override(sagemaker_session): vpc_config_a = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} vpc_config_b = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']} e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session) e.fit({'train': 's3://bucket/training-prefix'}) assert e.get_vpc_config() is None assert e.create_model().vpc_config is None assert e.create_model( vpc_config_override=vpc_config_a).vpc_config == vpc_config_a assert e.create_model(vpc_config_override=None).vpc_config is None e.subnets = vpc_config_a['Subnets'] e.security_group_ids = vpc_config_a['SecurityGroupIds'] assert e.get_vpc_config() == vpc_config_a assert e.create_model().vpc_config == vpc_config_a assert e.create_model( vpc_config_override=vpc_config_b).vpc_config == vpc_config_b assert e.create_model(vpc_config_override=None).vpc_config is None with pytest.raises(ValueError): e.get_vpc_config(vpc_config_override={'invalid'}) with pytest.raises(ValueError): e.create_model(vpc_config_override={'invalid'})