Beispiel #1
0
def test_jumpstart_pytorch_script_uri(patched_get_model_specs):

    patched_get_model_specs.side_effect = get_prototype_model_spec

    # inference
    uri = script_uris.retrieve(
        region="us-west-2",
        script_scope="inference",
        model_id="pytorch-eqa-bert-base-cased",
        model_version="*",
    )
    assert (
        uri ==
        "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
        "inference/eqa/v1.0.0/sourcedir.tar.gz")

    # training
    uri = script_uris.retrieve(
        region="us-west-2",
        script_scope="training",
        model_id="pytorch-eqa-bert-base-cased",
        model_version="*",
    )
    assert (
        uri ==
        "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
        "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz")
def test_jumpstart_inference_model_class(setup):

    model_id, model_version = "catboost-classification-model", "1.0.0"
    instance_type, instance_count = "ml.m5.xlarge", 1

    print("Starting inference...")

    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    script_uri = script_uris.retrieve(model_id=model_id,
                                      model_version=model_version,
                                      script_scope="inference")

    model_uri = model_uris.retrieve(model_id=model_id,
                                    model_version=model_version,
                                    model_scope="inference")

    model = Model(
        image_uri=image_uri,
        model_data=model_uri,
        source_dir=script_uri,
        entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME,
        role=get_sm_session().get_caller_identity_arn(),
        sagemaker_session=get_sm_session(),
        enable_network_isolation=True,
    )

    model.deploy(
        initial_instance_count=instance_count,
        instance_type=instance_type,
        tags=[{
            "Key": JUMPSTART_TAG,
            "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
        }],
    )

    endpoint_invoker = EndpointInvoker(endpoint_name=model.endpoint_name, )

    download_inference_assets()
    ground_truth_label, features = get_tabular_data(
        InferenceTabularDataname.MULTICLASS)

    response = endpoint_invoker.invoke_tabular_endpoint(features)

    assert response is not None
def test_jumpstart_inference_retrieve_functions(setup):

    model_id, model_version = "catboost-classification-model", "1.0.0"
    instance_type = "ml.m5.xlarge"

    print("Starting inference...")

    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    script_uri = script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="inference"
    )

    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="inference"
    )

    environment_vars = environment_variables.retrieve_default(
        model_id=model_id, model_version=model_version
    )

    inference_job = InferenceJobLauncher(
        image_uri=image_uri,
        script_uri=script_uri,
        model_uri=model_uri,
        instance_type=instance_type,
        base_name="catboost",
        environment_variables=environment_vars,
    )

    inference_job.launch_inference_job()
    inference_job.wait_until_endpoint_in_service()

    endpoint_invoker = EndpointInvoker(
        endpoint_name=inference_job.endpoint_name,
    )

    download_inference_assets()
    ground_truth_label, features = get_tabular_data(InferenceTabularDataname.MULTICLASS)

    response = endpoint_invoker.invoke_tabular_endpoint(features)

    assert response is not None
Beispiel #4
0
def test_jumpstart_transfer_learning_estimator_class(setup):

    model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0"
    training_instance_type = "ml.p3.2xlarge"
    inference_instance_type = "ml.p2.xlarge"
    instance_count = 1

    print("Starting training...")

    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="training",
        model_id=model_id,
        model_version=model_version,
        instance_type=training_instance_type,
    )

    script_uri = script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="training"
    )

    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="training"
    )

    default_hyperparameters = hyperparameters.retrieve_default(
        model_id=model_id,
        model_version=model_version,
    )

    default_hyperparameters["epochs"] = "1"

    estimator = Estimator(
        image_uri=image_uri,
        source_dir=script_uri,
        model_uri=model_uri,
        entry_point=TRAINING_ENTRY_POINT_SCRIPT_NAME,
        role=get_sm_session().get_caller_identity_arn(),
        sagemaker_session=get_sm_session(),
        enable_network_isolation=True,
        hyperparameters=default_hyperparameters,
        tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
        instance_count=instance_count,
        instance_type=training_instance_type,
    )

    estimator.fit(
        {
            "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
            f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
        }
    )

    print("Starting inference...")

    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=inference_instance_type,
    )

    script_uri = script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="inference"
    )

    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="inference"
    )

    endpoint_name = name_from_base(f"{model_id}-transfer-learning")

    estimator.deploy(
        initial_instance_count=instance_count,
        instance_type=inference_instance_type,
        entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME,
        image_uri=image_uri,
        source_dir=script_uri,
        endpoint_name=endpoint_name,
    )

    endpoint_invoker = EndpointInvoker(
        endpoint_name=endpoint_name,
    )

    response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"])

    assert response is not None
def test_jumpstart_transfer_learning_retrieve_functions(setup):

    model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0"
    training_instance_type = "ml.p3.2xlarge"
    inference_instance_type = "ml.p2.xlarge"

    # training
    print("Starting training...")
    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="training",
        model_id=model_id,
        model_version=model_version,
        instance_type=training_instance_type,
    )

    script_uri = script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="training"
    )

    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="training"
    )

    default_hyperparameters = hyperparameters.retrieve_default(
        model_id=model_id, model_version=model_version, include_container_hyperparameters=True
    )

    default_hyperparameters["epochs"] = "1"

    training_job = TrainingJobLauncher(
        image_uri=image_uri,
        script_uri=script_uri,
        model_uri=model_uri,
        hyperparameters=default_hyperparameters,
        instance_type=training_instance_type,
        training_dataset_s3_key=get_training_dataset_for_model_and_version(model_id, model_version),
        base_name="huggingface",
    )

    training_job.create_training_job()
    training_job.wait_until_training_job_complete()

    # inference
    print("Starting inference...")
    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=inference_instance_type,
    )

    script_uri = script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="inference"
    )

    environment_vars = environment_variables.retrieve_default(
        model_id=model_id, model_version=model_version
    )

    inference_job = InferenceJobLauncher(
        image_uri=image_uri,
        script_uri=script_uri,
        model_uri=get_model_tarball_full_uri_from_base_uri(
            training_job.output_tarball_base_path, training_job.training_job_name
        ),
        instance_type=inference_instance_type,
        base_name="huggingface",
        environment_variables=environment_vars,
    )

    inference_job.launch_inference_job()
    inference_job.wait_until_endpoint_in_service()

    endpoint_invoker = EndpointInvoker(
        endpoint_name=inference_job.endpoint_name,
    )

    response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"])

    assert response is not None
def test_jumpstart_common_script_uri(
        patched_get_model_specs, patched_verify_model_region_and_return_specs):

    patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
    patched_get_model_specs.side_effect = get_spec_from_base_spec

    script_uris.retrieve(
        script_scope="training",
        model_id="pytorch-ic-mobilenet-v2",
        model_version="*",
    )
    patched_get_model_specs.assert_called_once_with(
        region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
        model_id="pytorch-ic-mobilenet-v2",
        version="*",
    )
    patched_verify_model_region_and_return_specs.assert_called_once()

    patched_get_model_specs.reset_mock()
    patched_verify_model_region_and_return_specs.reset_mock()

    script_uris.retrieve(
        script_scope="inference",
        model_id="pytorch-ic-mobilenet-v2",
        model_version="1.*",
    )
    patched_get_model_specs.assert_called_once_with(
        region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME,
        model_id="pytorch-ic-mobilenet-v2",
        version="1.*",
    )
    patched_verify_model_region_and_return_specs.assert_called_once()

    patched_get_model_specs.reset_mock()
    patched_verify_model_region_and_return_specs.reset_mock()

    script_uris.retrieve(
        region="us-west-2",
        script_scope="training",
        model_id="pytorch-ic-mobilenet-v2",
        model_version="*",
    )
    patched_get_model_specs.assert_called_once_with(
        region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*")
    patched_verify_model_region_and_return_specs.assert_called_once()

    patched_get_model_specs.reset_mock()
    patched_verify_model_region_and_return_specs.reset_mock()

    script_uris.retrieve(
        region="us-west-2",
        script_scope="inference",
        model_id="pytorch-ic-mobilenet-v2",
        model_version="1.*",
    )
    patched_get_model_specs.assert_called_once_with(
        region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*")
    patched_verify_model_region_and_return_specs.assert_called_once()

    with pytest.raises(NotImplementedError):
        script_uris.retrieve(
            region="us-west-2",
            script_scope="BAD_SCOPE",
            model_id="pytorch-ic-mobilenet-v2",
            model_version="*",
        )

    with pytest.raises(KeyError):
        script_uris.retrieve(
            region="us-west-2",
            script_scope="training",
            model_id="blah",
            model_version="*",
        )

    with pytest.raises(ValueError):
        script_uris.retrieve(
            region="mars-south-1",
            script_scope="training",
            model_id="pytorch-ic-mobilenet-v2",
            model_version="*",
        )

    with pytest.raises(ValueError):
        script_uris.retrieve(
            model_id="pytorch-ic-mobilenet-v2",
            model_version="*",
        )

    with pytest.raises(ValueError):
        script_uris.retrieve(
            script_scope="training",
            model_version="*",
        )

    with pytest.raises(ValueError):
        script_uris.retrieve(
            script_scope="training",
            model_id="pytorch-ic-mobilenet-v2",
        )