def test_jumpstart_pytorch_script_uri(patched_get_model_specs): patched_get_model_specs.side_effect = get_prototype_model_spec # inference uri = script_uris.retrieve( region="us-west-2", script_scope="inference", model_id="pytorch-eqa-bert-base-cased", model_version="*", ) assert ( uri == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" "inference/eqa/v1.0.0/sourcedir.tar.gz") # training uri = script_uris.retrieve( region="us-west-2", script_scope="training", model_id="pytorch-eqa-bert-base-cased", model_version="*", ) assert ( uri == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz")
def test_jumpstart_inference_model_class(setup): model_id, model_version = "catboost-classification-model", "1.0.0" instance_type, instance_count = "ml.m5.xlarge", 1 print("Starting inference...") image_uri = image_uris.retrieve( region=None, framework=None, image_scope="inference", model_id=model_id, model_version=model_version, instance_type=instance_type, ) script_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope="inference") model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope="inference") model = Model( image_uri=image_uri, model_data=model_uri, source_dir=script_uri, entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME, role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), enable_network_isolation=True, ) model.deploy( initial_instance_count=instance_count, instance_type=instance_type, tags=[{ "Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] }], ) endpoint_invoker = EndpointInvoker(endpoint_name=model.endpoint_name, ) download_inference_assets() ground_truth_label, features = get_tabular_data( InferenceTabularDataname.MULTICLASS) response = endpoint_invoker.invoke_tabular_endpoint(features) assert response is not None
def test_jumpstart_inference_retrieve_functions(setup): model_id, model_version = "catboost-classification-model", "1.0.0" instance_type = "ml.m5.xlarge" print("Starting inference...") image_uri = image_uris.retrieve( region=None, framework=None, image_scope="inference", model_id=model_id, model_version=model_version, instance_type=instance_type, ) script_uri = script_uris.retrieve( model_id=model_id, model_version=model_version, script_scope="inference" ) model_uri = model_uris.retrieve( model_id=model_id, model_version=model_version, model_scope="inference" ) environment_vars = environment_variables.retrieve_default( model_id=model_id, model_version=model_version ) inference_job = InferenceJobLauncher( image_uri=image_uri, script_uri=script_uri, model_uri=model_uri, instance_type=instance_type, base_name="catboost", environment_variables=environment_vars, ) inference_job.launch_inference_job() inference_job.wait_until_endpoint_in_service() endpoint_invoker = EndpointInvoker( endpoint_name=inference_job.endpoint_name, ) download_inference_assets() ground_truth_label, features = get_tabular_data(InferenceTabularDataname.MULTICLASS) response = endpoint_invoker.invoke_tabular_endpoint(features) assert response is not None
def test_jumpstart_transfer_learning_estimator_class(setup): model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0" training_instance_type = "ml.p3.2xlarge" inference_instance_type = "ml.p2.xlarge" instance_count = 1 print("Starting training...") image_uri = image_uris.retrieve( region=None, framework=None, image_scope="training", model_id=model_id, model_version=model_version, instance_type=training_instance_type, ) script_uri = script_uris.retrieve( model_id=model_id, model_version=model_version, script_scope="training" ) model_uri = model_uris.retrieve( model_id=model_id, model_version=model_version, model_scope="training" ) default_hyperparameters = hyperparameters.retrieve_default( model_id=model_id, model_version=model_version, ) default_hyperparameters["epochs"] = "1" estimator = Estimator( image_uri=image_uri, source_dir=script_uri, model_uri=model_uri, entry_point=TRAINING_ENTRY_POINT_SCRIPT_NAME, role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), enable_network_isolation=True, hyperparameters=default_hyperparameters, tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], instance_count=instance_count, instance_type=training_instance_type, ) estimator.fit( { "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" f"{get_training_dataset_for_model_and_version(model_id, model_version)}", } ) print("Starting inference...") image_uri = image_uris.retrieve( region=None, framework=None, image_scope="inference", model_id=model_id, model_version=model_version, instance_type=inference_instance_type, ) script_uri = script_uris.retrieve( model_id=model_id, model_version=model_version, script_scope="inference" ) model_uri = model_uris.retrieve( model_id=model_id, model_version=model_version, model_scope="inference" ) endpoint_name = name_from_base(f"{model_id}-transfer-learning") estimator.deploy( initial_instance_count=instance_count, instance_type=inference_instance_type, entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME, image_uri=image_uri, source_dir=script_uri, endpoint_name=endpoint_name, ) endpoint_invoker = EndpointInvoker( endpoint_name=endpoint_name, ) response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"]) assert response is not None
def test_jumpstart_transfer_learning_retrieve_functions(setup): model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0" training_instance_type = "ml.p3.2xlarge" inference_instance_type = "ml.p2.xlarge" # training print("Starting training...") image_uri = image_uris.retrieve( region=None, framework=None, image_scope="training", model_id=model_id, model_version=model_version, instance_type=training_instance_type, ) script_uri = script_uris.retrieve( model_id=model_id, model_version=model_version, script_scope="training" ) model_uri = model_uris.retrieve( model_id=model_id, model_version=model_version, model_scope="training" ) default_hyperparameters = hyperparameters.retrieve_default( model_id=model_id, model_version=model_version, include_container_hyperparameters=True ) default_hyperparameters["epochs"] = "1" training_job = TrainingJobLauncher( image_uri=image_uri, script_uri=script_uri, model_uri=model_uri, hyperparameters=default_hyperparameters, instance_type=training_instance_type, training_dataset_s3_key=get_training_dataset_for_model_and_version(model_id, model_version), base_name="huggingface", ) training_job.create_training_job() training_job.wait_until_training_job_complete() # inference print("Starting inference...") image_uri = image_uris.retrieve( region=None, framework=None, image_scope="inference", model_id=model_id, model_version=model_version, instance_type=inference_instance_type, ) script_uri = script_uris.retrieve( model_id=model_id, model_version=model_version, script_scope="inference" ) environment_vars = environment_variables.retrieve_default( model_id=model_id, model_version=model_version ) inference_job = InferenceJobLauncher( image_uri=image_uri, script_uri=script_uri, model_uri=get_model_tarball_full_uri_from_base_uri( training_job.output_tarball_base_path, training_job.training_job_name ), instance_type=inference_instance_type, base_name="huggingface", environment_variables=environment_vars, ) inference_job.launch_inference_job() inference_job.wait_until_endpoint_in_service() endpoint_invoker = EndpointInvoker( endpoint_name=inference_job.endpoint_name, ) response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"]) assert response is not None
def test_jumpstart_common_script_uri( patched_get_model_specs, patched_verify_model_region_and_return_specs): patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec script_uris.retrieve( script_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", ) patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( script_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", ) patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", script_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", ) patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*") patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", script_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", ) patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*") patched_verify_model_region_and_return_specs.assert_called_once() with pytest.raises(NotImplementedError): script_uris.retrieve( region="us-west-2", script_scope="BAD_SCOPE", model_id="pytorch-ic-mobilenet-v2", model_version="*", ) with pytest.raises(KeyError): script_uris.retrieve( region="us-west-2", script_scope="training", model_id="blah", model_version="*", ) with pytest.raises(ValueError): script_uris.retrieve( region="mars-south-1", script_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", ) with pytest.raises(ValueError): script_uris.retrieve( model_id="pytorch-ic-mobilenet-v2", model_version="*", ) with pytest.raises(ValueError): script_uris.retrieve( script_scope="training", model_version="*", ) with pytest.raises(ValueError): script_uris.retrieve( script_scope="training", model_id="pytorch-ic-mobilenet-v2", )