def test_transformer_creation_without_endpoint_type(create_model, sagemaker_session): model = Mock() create_model.return_value = model tf = TensorFlow( entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, ) tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name") tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE) create_model.assert_called_with( endpoint_type=None, model_server_workers=None, role=ROLE, vpc_config_override="VPC_CONFIG_DEFAULT", entry_point=None, ) model.transformer.assert_called_with( INSTANCE_COUNT, INSTANCE_TYPE, accept=None, assemble_with=None, env=None, max_concurrent_transforms=None, max_payload=None, output_kms_key=None, output_path=None, strategy=None, tags=None, volume_kms_key=None, )
def test_transformer_creation_without_optional_args( name_from_base, create_model, sagemaker_session, tensorflow_inference_version, tensorflow_inference_py_version, ): if version.Version(tensorflow_inference_version) < version.Version("1.11"): pytest.skip( "Legacy TF version requires explicit image URI, and " "this logic is tested in test_create_model_with_custom_image.") model_name = "generated-model-name" name_from_base.return_value = model_name model = Mock() create_model.return_value = model base_job_name = "tensorflow" tf = TensorFlow( entry_point=SCRIPT_PATH, framework_version=tensorflow_inference_version, py_version=tensorflow_inference_py_version, role=ROLE, sagemaker_session=sagemaker_session, instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE, base_job_name=base_job_name, ) tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name") tf.transformer(INSTANCE_COUNT, INSTANCE_TYPE) name_from_base.assert_called_with(base_job_name) create_model.assert_called_with( role=ROLE, vpc_config_override="VPC_CONFIG_DEFAULT", entry_point=None, enable_network_isolation=False, name=model_name, ) model.transformer.assert_called_with( INSTANCE_COUNT, INSTANCE_TYPE, accept=None, assemble_with=None, env=None, max_concurrent_transforms=None, max_payload=None, output_kms_key=None, output_path=None, strategy=None, tags=None, volume_kms_key=None, )
def test_transformer_creation_with_optional_args( create_model, sagemaker_session, tensorflow_inference_version, tensorflow_inference_py_version): if version.Version(tensorflow_inference_version) < version.Version("1.11"): pytest.skip( "Legacy TF version requires explicit image URI, and " "this logic is tested in test_create_model_with_custom_image.") model = Mock() create_model.return_value = model tf = TensorFlow( entry_point=SCRIPT_PATH, framework_version=tensorflow_inference_version, py_version=tensorflow_inference_py_version, role=ROLE, sagemaker_session=sagemaker_session, instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE, ) tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name") strategy = "SingleRecord" assemble_with = "Line" output_path = "s3://{}/batch-output".format(BUCKET_NAME) kms_key = "kms" accept_type = "text/bytes" env = {"foo": "bar"} max_concurrent_transforms = 3 max_payload = 100 tags = {"Key": "foo", "Value": "bar"} new_role = "role" vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]} model_name = "model-name" tf.transformer( INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, output_path=output_path, output_kms_key=kms_key, accept=accept_type, env=env, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, tags=tags, role=new_role, volume_kms_key=kms_key, entry_point=SERVING_SCRIPT_FILE, vpc_config_override=vpc_config, enable_network_isolation=True, model_name=model_name, ) create_model.assert_called_with( role=new_role, vpc_config_override=vpc_config, entry_point=SERVING_SCRIPT_FILE, enable_network_isolation=True, name=model_name, ) model.transformer.assert_called_with( INSTANCE_COUNT, INSTANCE_TYPE, accept=accept_type, assemble_with=assemble_with, env=env, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, output_kms_key=kms_key, output_path=output_path, strategy=strategy, tags=tags, volume_kms_key=kms_key, )
def test_transformer_creation_with_optional_args(create_model, sagemaker_session): model = Mock() create_model.return_value = model tf = TensorFlow( entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, ) tf.latest_training_job = _TrainingJob(sagemaker_session, "some-job-name") strategy = "SingleRecord" assemble_with = "Line" output_path = "s3://{}/batch-output".format(BUCKET_NAME) kms_key = "kms" accept_type = "text/bytes" env = {"foo": "bar"} max_concurrent_transforms = 3 max_payload = 100 tags = {"Key": "foo", "Value": "bar"} new_role = "role" model_server_workers = 2 vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]} tf.transformer( INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, output_path=output_path, output_kms_key=kms_key, accept=accept_type, env=env, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, tags=tags, role=new_role, model_server_workers=model_server_workers, volume_kms_key=kms_key, endpoint_type="tensorflow-serving", entry_point=SERVING_SCRIPT_FILE, vpc_config_override=vpc_config, enable_network_isolation=True, ) create_model.assert_called_with( model_server_workers=model_server_workers, role=new_role, vpc_config_override=vpc_config, endpoint_type="tensorflow-serving", entry_point=SERVING_SCRIPT_FILE, enable_network_isolation=True, ) model.transformer.assert_called_with( INSTANCE_COUNT, INSTANCE_TYPE, accept=accept_type, assemble_with=assemble_with, env=env, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, output_kms_key=kms_key, output_path=output_path, strategy=strategy, tags=tags, volume_kms_key=kms_key, )