def test_train_image_default(sagemaker_session): sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, train_instance_type=INSTANCE_TYPE, py_version=PYTHON_VERSION) assert _get_full_cpu_image_uri( defaults.SKLEARN_VERSION) in sklearn.train_image()
def test_train_image(sagemaker_session, sklearn_version): container_log_level = '"logging.INFO"' source_dir = 's3://mybucket/source' sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, train_instance_type=INSTANCE_TYPE, framework_version=sklearn_version, container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir) train_image = sklearn.train_image() assert train_image == '246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3'