コード例 #1
0
def test_train_image_default(sagemaker_session):
    sklearn = SKLearn(entry_point=SCRIPT_PATH,
                      role=ROLE,
                      sagemaker_session=sagemaker_session,
                      train_instance_type=INSTANCE_TYPE,
                      py_version=PYTHON_VERSION)

    assert _get_full_cpu_image_uri(
        defaults.SKLEARN_VERSION) in sklearn.train_image()
コード例 #2
0
def test_train_image(sagemaker_session, sklearn_version):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    sklearn = SKLearn(entry_point=SCRIPT_PATH,
                      role=ROLE,
                      sagemaker_session=sagemaker_session,
                      train_instance_type=INSTANCE_TYPE,
                      framework_version=sklearn_version,
                      container_log_level=container_log_level,
                      py_version=PYTHON_VERSION,
                      base_job_name='job',
                      source_dir=source_dir)

    train_image = sklearn.train_image()
    assert train_image == '246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3'