コード例 #1
0
def test_create_model(sagemaker_session):
    source_dir = "s3://mybucket/source"

    xgboost_model = XGBoostModel(
        model_data=source_dir,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        entry_point=SCRIPT_PATH,
        framework_version=XGBOOST_LATEST_VERSION,
    )
    default_image_uri = _get_full_cpu_image_uri(XGBOOST_LATEST_VERSION)
    model_values = xgboost_model.prepare_container_def(CPU)
    assert model_values["Image"] == default_image_uri
コード例 #2
0
def test_create_model_with_network_isolation(upload, sagemaker_session,
                                             xgboost_framework_version):
    source_dir = "s3://mybucket/source"
    repacked_model_data = "s3://mybucket/prefix/model.tar.gz"

    xgboost_model = XGBoostModel(
        model_data=source_dir,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        entry_point=SCRIPT_PATH,
        framework_version=xgboost_framework_version,
        enable_network_isolation=True,
    )
    xgboost_model.uploaded_code = UploadedCode(s3_prefix=repacked_model_data,
                                               script_name="script")
    xgboost_model.repacked_model_data = repacked_model_data
    model_values = xgboost_model.prepare_container_def(CPU)
    assert model_values["Environment"][
        "SAGEMAKER_SUBMIT_DIRECTORY"] == "/opt/ml/model/code"
    assert model_values["ModelDataUrl"] == repacked_model_data