def test_generic_deploy_vpc_config_override(sagemaker_session):
    vpc_config_a = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
    vpc_config_b = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']}

    e = Estimator(IMAGE_NAME,
                  ROLE,
                  INSTANCE_COUNT,
                  INSTANCE_TYPE,
                  sagemaker_session=sagemaker_session)
    e.fit({'train': 's3://bucket/training-prefix'})
    e.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
    assert sagemaker_session.create_model.call_args_list[0][1][
        'vpc_config'] is None

    e.subnets = vpc_config_a['Subnets']
    e.security_group_ids = vpc_config_a['SecurityGroupIds']
    e.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
    assert sagemaker_session.create_model.call_args_list[1][1][
        'vpc_config'] == vpc_config_a

    e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=vpc_config_b)
    assert sagemaker_session.create_model.call_args_list[2][1][
        'vpc_config'] == vpc_config_b

    e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=None)
    assert sagemaker_session.create_model.call_args_list[3][1][
        'vpc_config'] is None
def test_generic_create_model_vpc_config_override(sagemaker_session):
    vpc_config_a = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
    vpc_config_b = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']}

    e = Estimator(IMAGE_NAME,
                  ROLE,
                  INSTANCE_COUNT,
                  INSTANCE_TYPE,
                  sagemaker_session=sagemaker_session)
    e.fit({'train': 's3://bucket/training-prefix'})
    assert e.get_vpc_config() is None
    assert e.create_model().vpc_config is None
    assert e.create_model(
        vpc_config_override=vpc_config_a).vpc_config == vpc_config_a
    assert e.create_model(vpc_config_override=None).vpc_config is None

    e.subnets = vpc_config_a['Subnets']
    e.security_group_ids = vpc_config_a['SecurityGroupIds']
    assert e.get_vpc_config() == vpc_config_a
    assert e.create_model().vpc_config == vpc_config_a
    assert e.create_model(
        vpc_config_override=vpc_config_b).vpc_config == vpc_config_b
    assert e.create_model(vpc_config_override=None).vpc_config is None

    with pytest.raises(ValueError):
        e.get_vpc_config(vpc_config_override={'invalid'})
    with pytest.raises(ValueError):
        e.create_model(vpc_config_override={'invalid'})