Exemplo n.º 1
0
def test_train_image_default(sagemaker_session):
    xgboost = XGBoost(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        framework_version=XGBOOST_LATEST_VERSION,
        sagemaker_session=sagemaker_session,
        train_instance_type=INSTANCE_TYPE,
        train_instance_count=1,
        py_version=PYTHON_VERSION,
    )

    assert _get_full_cpu_image_uri(XGBOOST_LATEST_VERSION) in xgboost.train_image()
Exemplo n.º 2
0
def test_train_image(sagemaker_session, xgboost_version):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    xgboost = XGBoost(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_type=INSTANCE_TYPE,
        train_instance_count=1,
        framework_version=xgboost_version,
        container_log_level=container_log_level,
        py_version=PYTHON_VERSION,
        base_job_name="job",
        source_dir=source_dir,
    )

    train_image = xgboost.train_image()
    assert (
        train_image ==
        "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3"
    )