def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session):
    model = Mock()
    create_model.return_value = model

    tf = TensorFlow(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
    )
    tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")
    tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

    create_model.assert_called_with(
        endpoint_type=None,
        model_server_workers=None,
        role=ROLE,
        vpc_config_override="VPC_CONFIG_DEFAULT",
        entry_point=None,
    )
    model.transformer.assert_called_with(
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        accept=None,
        assemble_with=None,
        env=None,
        max_concurrent_transforms=None,
        max_payload=None,
        output_kms_key=None,
        output_path=None,
        strategy=None,
        tags=None,
        volume_kms_key=None,
    )
def test_transformer_creation_without_optional_args(
    name_from_base,
    create_model,
    sagemaker_session,
    tensorflow_inference_version,
    tensorflow_inference_py_version,
):
    if version.Version(tensorflow_inference_version) < version.Version("1.11"):
        pytest.skip(
            "Legacy TF version requires explicit image URI, and "
            "this logic is tested in test_create_model_with_custom_image.")

    model_name = "generated-model-name"
    name_from_base.return_value = model_name

    model = Mock()
    create_model.return_value = model

    base_job_name = "tensorflow"
    tf = TensorFlow(
        entry_point=SCRIPT_PATH,
        framework_version=tensorflow_inference_version,
        py_version=tensorflow_inference_py_version,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        base_job_name=base_job_name,
    )
    tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")
    tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE)

    name_from_base.assert_called_with(base_job_name)
    create_model.assert_called_with(
        role=ROLE,
        vpc_config_override="VPC_CONFIG_DEFAULT",
        entry_point=None,
        enable_network_isolation=False,
        name=model_name,
    )
    model.transformer.assert_called_with(
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        accept=None,
        assemble_with=None,
        env=None,
        max_concurrent_transforms=None,
        max_payload=None,
        output_kms_key=None,
        output_path=None,
        strategy=None,
        tags=None,
        volume_kms_key=None,
    )
def test_transformer_creation_with_optional_args(
        create_model, sagemaker_session, tensorflow_inference_version,
        tensorflow_inference_py_version):
    if version.Version(tensorflow_inference_version) < version.Version("1.11"):
        pytest.skip(
            "Legacy TF version requires explicit image URI, and "
            "this logic is tested in test_create_model_with_custom_image.")

    model = Mock()
    create_model.return_value = model

    tf = TensorFlow(
        entry_point=SCRIPT_PATH,
        framework_version=tensorflow_inference_version,
        py_version=tensorflow_inference_py_version,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
    )
    tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")

    strategy = "SingleRecord"
    assemble_with = "Line"
    output_path = "s3://{}/batch-output".format(BUCKET_NAME)
    kms_key = "kms"
    accept_type = "text/bytes"
    env = {"foo": "bar"}
    max_concurrent_transforms = 3
    max_payload = 100
    tags = {"Key": "foo", "Value": "bar"}
    new_role = "role"
    vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]}
    model_name = "model-name"

    tf.transformer(
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        strategy=strategy,
        assemble_with=assemble_with,
        output_path=output_path,
        output_kms_key=kms_key,
        accept=accept_type,
        env=env,
        max_concurrent_transforms=max_concurrent_transforms,
        max_payload=max_payload,
        tags=tags,
        role=new_role,
        volume_kms_key=kms_key,
        entry_point=SERVING_SCRIPT_FILE,
        vpc_config_override=vpc_config,
        enable_network_isolation=True,
        model_name=model_name,
    )

    create_model.assert_called_with(
        role=new_role,
        vpc_config_override=vpc_config,
        entry_point=SERVING_SCRIPT_FILE,
        enable_network_isolation=True,
        name=model_name,
    )
    model.transformer.assert_called_with(
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        accept=accept_type,
        assemble_with=assemble_with,
        env=env,
        max_concurrent_transforms=max_concurrent_transforms,
        max_payload=max_payload,
        output_kms_key=kms_key,
        output_path=output_path,
        strategy=strategy,
        tags=tags,
        volume_kms_key=kms_key,
    )
예제 #4
0
def test_transformer_creation_with_optional_args(create_model, sagemaker_session):
    model = Mock()
    create_model.return_value = model

    tf = TensorFlow(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
    )
    tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name")

    strategy = "SingleRecord"
    assemble_with = "Line"
    output_path = "s3://{}/batch-output".format(BUCKET_NAME)
    kms_key = "kms"
    accept_type = "text/bytes"
    env = {"foo": "bar"}
    max_concurrent_transforms = 3
    max_payload = 100
    tags = {"Key": "foo", "Value": "bar"}
    new_role = "role"
    model_server_workers = 2
    vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]}

    tf.transformer(
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        strategy=strategy,
        assemble_with=assemble_with,
        output_path=output_path,
        output_kms_key=kms_key,
        accept=accept_type,
        env=env,
        max_concurrent_transforms=max_concurrent_transforms,
        max_payload=max_payload,
        tags=tags,
        role=new_role,
        model_server_workers=model_server_workers,
        volume_kms_key=kms_key,
        endpoint_type="tensorflow-serving",
        entry_point=SERVING_SCRIPT_FILE,
        vpc_config_override=vpc_config,
        enable_network_isolation=True,
    )

    create_model.assert_called_with(
        model_server_workers=model_server_workers,
        role=new_role,
        vpc_config_override=vpc_config,
        endpoint_type="tensorflow-serving",
        entry_point=SERVING_SCRIPT_FILE,
        enable_network_isolation=True,
    )
    model.transformer.assert_called_with(
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        accept=accept_type,
        assemble_with=assemble_with,
        env=env,
        max_concurrent_transforms=max_concurrent_transforms,
        max_payload=max_payload,
        output_kms_key=kms_key,
        output_path=output_path,
        strategy=strategy,
        tags=tags,
        volume_kms_key=kms_key,
    )