def test_jumpstart_inference_model_class(setup):

    model_id, model_version = "catboost-classification-model", "1.0.0"
    instance_type, instance_count = "ml.m5.xlarge", 1

    print("Starting inference...")

    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=instance_type,
    )

    script_uri = script_uris.retrieve(model_id=model_id,
                                      model_version=model_version,
                                      script_scope="inference")

    model_uri = model_uris.retrieve(model_id=model_id,
                                    model_version=model_version,
                                    model_scope="inference")

    model = Model(
        image_uri=image_uri,
        model_data=model_uri,
        source_dir=script_uri,
        entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME,
        role=get_sm_session().get_caller_identity_arn(),
        sagemaker_session=get_sm_session(),
        enable_network_isolation=True,
    )

    model.deploy(
        initial_instance_count=instance_count,
        instance_type=instance_type,
        tags=[{
            "Key": JUMPSTART_TAG,
            "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
        }],
    )

    endpoint_invoker = EndpointInvoker(endpoint_name=model.endpoint_name, )

    download_inference_assets()
    ground_truth_label, features = get_tabular_data(
        InferenceTabularDataname.MULTICLASS)

    response = endpoint_invoker.invoke_tabular_endpoint(features)

    assert response is not None
Exemplo n.º 2
0
    def __init__(
        self,
        image_uri,
        script_uri,
        model_uri,
        hyperparameters,
        instance_type,
        training_dataset_s3_key,
        suffix=time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime()),
        region=JUMPSTART_DEFAULT_REGION_NAME,
        boto_config=Config(retries={
            "max_attempts": 10,
            "mode": "standard"
        }),
        base_name="jumpstart-training-job",
        execution_role=None,
    ) -> None:

        self.account_id = boto3.client("sts").get_caller_identity()["Account"]
        self.suffix = suffix
        self.test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
        self.region = region
        self.config = boto_config
        self.base_name = base_name
        self.execution_role = execution_role or get_sm_session(
        ).get_caller_identity_arn()
        self.image_uri = image_uri
        self.script_uri = script_uri
        self.model_uri = model_uri
        self.hyperparameters = hyperparameters
        self.instance_type = instance_type
        self.training_dataset_s3_key = training_dataset_s3_key
        self.sagemaker_client = self.get_sagemaker_client()
Exemplo n.º 3
0
    def package_artifacts(self):

        self.model_name = self.get_model_name()

        cache_bucket_uri = f"s3://{get_test_artifact_bucket()}"
        repacked_model_uri = "/".join([
            cache_bucket_uri,
            self.test_suite_id,
            "inference_model_tarballs",
            self.model_name,
            "repacked_model.tar.gz",
        ])

        repack_model(
            inference_script="inference.py",
            source_directory=self.script_uri,
            dependencies=None,
            model_uri=self.model_uri,
            repacked_model_uri=repacked_model_uri,
            sagemaker_session=get_sm_session(),
            kms_key=None,
        )

        return repacked_model_uri
Exemplo n.º 4
0
def test_jumpstart_transfer_learning_estimator_class(setup):

    model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0"
    training_instance_type = "ml.p3.2xlarge"
    inference_instance_type = "ml.p2.xlarge"
    instance_count = 1

    print("Starting training...")

    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="training",
        model_id=model_id,
        model_version=model_version,
        instance_type=training_instance_type,
    )

    script_uri = script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="training"
    )

    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="training"
    )

    default_hyperparameters = hyperparameters.retrieve_default(
        model_id=model_id,
        model_version=model_version,
    )

    default_hyperparameters["epochs"] = "1"

    estimator = Estimator(
        image_uri=image_uri,
        source_dir=script_uri,
        model_uri=model_uri,
        entry_point=TRAINING_ENTRY_POINT_SCRIPT_NAME,
        role=get_sm_session().get_caller_identity_arn(),
        sagemaker_session=get_sm_session(),
        enable_network_isolation=True,
        hyperparameters=default_hyperparameters,
        tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
        instance_count=instance_count,
        instance_type=training_instance_type,
    )

    estimator.fit(
        {
            "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
            f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
        }
    )

    print("Starting inference...")

    image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=inference_instance_type,
    )

    script_uri = script_uris.retrieve(
        model_id=model_id, model_version=model_version, script_scope="inference"
    )

    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="inference"
    )

    endpoint_name = name_from_base(f"{model_id}-transfer-learning")

    estimator.deploy(
        initial_instance_count=instance_count,
        instance_type=inference_instance_type,
        entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME,
        image_uri=image_uri,
        source_dir=script_uri,
        endpoint_name=endpoint_name,
    )

    endpoint_invoker = EndpointInvoker(
        endpoint_name=endpoint_name,
    )

    response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"])

    assert response is not None