def test_train_image_default(sagemaker_session): xgboost = XGBoost( entry_point=SCRIPT_PATH, role=ROLE, framework_version=XGBOOST_LATEST_VERSION, sagemaker_session=sagemaker_session, train_instance_type=INSTANCE_TYPE, train_instance_count=1, py_version=PYTHON_VERSION, ) assert _get_full_cpu_image_uri(XGBOOST_LATEST_VERSION) in xgboost.train_image()
def test_train_image(sagemaker_session, xgboost_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" xgboost = XGBoost( entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, train_instance_type=INSTANCE_TYPE, train_instance_count=1, framework_version=xgboost_version, container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name="job", source_dir=source_dir, ) train_image = xgboost.train_image() assert ( train_image == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3" )