def test_create_model(name_from_base, sagemaker_session, chainer_version, chainer_py_version):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    base_job_name = "job"

    chainer = Chainer(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        framework_version=chainer_version,
        container_log_level=container_log_level,
        py_version=chainer_py_version,
        base_job_name=base_job_name,
        source_dir=source_dir,
    )

    chainer.fit(inputs="s3://mybucket/train", job_name="new_name")

    model_name = "model_name"
    name_from_base.return_value = model_name
    model = chainer.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.framework_version == chainer_version
    assert model.py_version == chainer.py_version
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == model_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
    assert model.vpc_config is None

    name_from_base.assert_called_with(base_job_name)
Beispiel #2
0
def test_create_model(sagemaker_session, chainer_version):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    enable_cloudwatch_metrics = 'true'
    chainer = Chainer(entry_point=SCRIPT_PATH,
                      role=ROLE,
                      sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT,
                      train_instance_type=INSTANCE_TYPE,
                      framework_version=chainer_version,
                      container_log_level=container_log_level,
                      py_version=PYTHON_VERSION,
                      base_job_name='job',
                      source_dir=source_dir,
                      enable_cloudwatch_metrics=enable_cloudwatch_metrics)

    job_name = 'new_name'
    chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
    model = chainer.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.framework_version == chainer_version
    assert model.py_version == chainer.py_version
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == job_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
    assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
Beispiel #3
0
def test_create_model_with_optional_params(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    enable_cloudwatch_metrics = 'true'
    chainer = Chainer(entry_point=SCRIPT_PATH,
                      role=ROLE,
                      sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT,
                      train_instance_type=INSTANCE_TYPE,
                      container_log_level=container_log_level,
                      py_version=PYTHON_VERSION,
                      base_job_name='job',
                      source_dir=source_dir,
                      enable_cloudwatch_metrics=enable_cloudwatch_metrics)

    chainer.fit(inputs='s3://mybucket/train', job_name='new_name')

    new_role = 'role'
    model_server_workers = 2
    vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
    model = chainer.create_model(role=new_role,
                                 model_server_workers=model_server_workers,
                                 vpc_config_override=vpc_config)

    assert model.role == new_role
    assert model.model_server_workers == model_server_workers
    assert model.vpc_config == vpc_config
def test_create_model(sagemaker_session, chainer_version):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    chainer = Chainer(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        framework_version=chainer_version,
        container_log_level=container_log_level,
        py_version=PYTHON_VERSION,
        base_job_name="job",
        source_dir=source_dir,
    )

    job_name = "new_name"
    chainer.fit(inputs="s3://mybucket/train", job_name=job_name)
    model = chainer.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.framework_version == chainer_version
    assert model.py_version == chainer.py_version
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == job_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
    assert model.vpc_config is None
def test_create_model_with_optional_params(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    enable_cloudwatch_metrics = "true"
    chainer = Chainer(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        container_log_level=container_log_level,
        py_version=PYTHON_VERSION,
        base_job_name="job",
        source_dir=source_dir,
        enable_cloudwatch_metrics=enable_cloudwatch_metrics,
    )

    chainer.fit(inputs="s3://mybucket/train", job_name="new_name")

    new_role = "role"
    model_server_workers = 2
    vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
    model = chainer.create_model(
        role=new_role,
        model_server_workers=model_server_workers,
        vpc_config_override=vpc_config,
        entry_point=SERVING_SCRIPT_FILE,
    )

    assert model.role == new_role
    assert model.model_server_workers == model_server_workers
    assert model.vpc_config == vpc_config
    assert model.entry_point == SERVING_SCRIPT_FILE
Beispiel #6
0
def _test_mnist_train(sagemaker_session, ecr_image, instance_type,
                      instance_count, script):
    source_dir = 'test/resources/mnist'

    with timeout(minutes=15):
        data_path = 'test/resources/mnist/data'

        chainer = Chainer(entry_point=script,
                          source_dir=source_dir,
                          role='SageMakerRole',
                          train_instance_count=instance_count,
                          train_instance_type=instance_type,
                          sagemaker_session=sagemaker_session,
                          image_name=ecr_image,
                          hyperparameters={
                              'batch-size': 10000,
                              'epochs': 1
                          })

        prefix = 'chainer_mnist/{}'.format(sagemaker_timestamp())

        train_data_path = os.path.join(data_path, 'train')

        key_prefix = prefix + '/train'
        train_input = sagemaker_session.upload_data(path=train_data_path,
                                                    key_prefix=key_prefix)

        test_path = os.path.join(data_path, 'test')
        test_input = sagemaker_session.upload_data(path=test_path,
                                                   key_prefix=prefix + '/test')

        chainer.fit({'train': train_input, 'test': test_input})
Beispiel #7
0
def test_chainer(strftime, time, sagemaker_session, chainer_version,
                 chainer_py_version):
    chainer = Chainer(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        framework_version=chainer_version,
        py_version=chainer_py_version,
    )

    inputs = "s3://mybucket/train"

    chainer.fit(inputs=inputs)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ["train", "logs_for_job"]
    boto_call_names = [
        c[0] for c in sagemaker_session.boto_session.method_calls
    ]
    assert boto_call_names == ["resource"]

    expected_train_args = _create_train_job(chainer_version,
                                            chainer_py_version)
    expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
        "S3Uri"] = inputs

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = chainer.create_model()

    expected_image_base = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-gpu-{}"
    assert {
        "Environment": {
            "SAGEMAKER_SUBMIT_DIRECTORY":
            "s3://mybucket/sagemaker-chainer-{}/source/sourcedir.tar.gz".
            format(TIMESTAMP),
            "SAGEMAKER_PROGRAM":
            "dummy_script.py",
            "SAGEMAKER_REGION":
            "us-west-2",
            "SAGEMAKER_CONTAINER_LOG_LEVEL":
            "20",
        },
        "Image": expected_image_base.format(chainer_version,
                                            chainer_py_version),
        "ModelDataUrl": "s3://m/m.tar.gz",
    } == model.prepare_container_def(GPU)

    assert "cpu" in model.prepare_container_def(CPU)["Image"]
    predictor = chainer.deploy(1, GPU)
    assert isinstance(predictor, ChainerPredictor)
Beispiel #8
0
def test_chainer_mnist_distributed(docker_image, sagemaker_local_session,
                                   instance_type, customer_script, tmpdir):
    if instance_type == 'local_gpu':
        pytest.skip('Local Mode does not support distributed GPU training.')

    # pure_nccl communicator hangs when only one gpu is available.
    cluster_size = 2
    hyperparameters = {
        'sagemaker_process_slots_per_host': 1,
        'sagemaker_num_processes': cluster_size,
        'batch-size': 10000,
        'epochs': 1,
        'communicator': 'hierarchical'
    }

    estimator = Chainer(entry_point=customer_script,
                        source_dir=mnist_path,
                        role=role,
                        image_name=docker_image,
                        train_instance_count=cluster_size,
                        train_instance_type=instance_type,
                        sagemaker_session=sagemaker_local_session,
                        hyperparameters=hyperparameters,
                        output_path='file://{}'.format(tmpdir))

    estimator.fit({
        'train': 'file://{}'.format(os.path.join(data_dir, 'train')),
        'test': 'file://{}'.format(os.path.join(data_dir, 'test'))
    })

    success_files = {
        'model': ['model.npz'],
        'output': [
            'success', 'data/accuracy.png', 'data/cg.dot', 'data/log',
            'data/loss.png'
        ],
    }

    test_utils.files_exist(str(tmpdir), success_files)

    request_data = np.zeros((100, 784), dtype='float32')

    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type)
    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type,
                                                  json_serializer,
                                                  json_deserializer,
                                                  'application/json')
    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type,
                                                  csv_serializer,
                                                  csv_deserializer, 'text/csv')
Beispiel #9
0
def test_chainer_mnist_single_machine(docker_image, sagemaker_local_session,
                                      instance_type, tmpdir):
    customer_script = 'single_machine_customer_script.py'
    hyperparameters = {'batch-size': 10000, 'epochs': 1}

    estimator = Chainer(entry_point=customer_script,
                        source_dir=mnist_path,
                        role=role,
                        image_name=docker_image,
                        train_instance_count=1,
                        train_instance_type=instance_type,
                        sagemaker_session=sagemaker_local_session,
                        hyperparameters=hyperparameters,
                        output_path='file://{}'.format(tmpdir))

    estimator.fit({
        'train': 'file://{}'.format(os.path.join(data_dir, 'train')),
        'test': 'file://{}'.format(os.path.join(data_dir, 'test'))
    })

    success_files = {
        'model': ['model.npz'],
        'output': [
            'success', 'data/accuracy.png', 'data/cg.dot', 'data/log',
            'data/loss.png'
        ],
    }
    test_utils.files_exist(str(tmpdir), success_files)

    request_data = np.zeros((100, 784), dtype='float32')

    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type)
    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type,
                                                  csv_serializer,
                                                  csv_deserializer, 'text/csv')

    test_arrays = [
        np.zeros((100, 784), dtype='float32'),
        np.zeros((100, 1, 28, 28), dtype='float32'),
        np.zeros((100, 28, 28), dtype='float32')
    ]

    with test_utils.local_mode_lock():
        try:
            predictor = _json_predictor(estimator, instance_type)
            for array in test_arrays:
                response = predictor.predict(array)
                assert len(response) == len(array)
        finally:
            predictor.delete_endpoint()
Beispiel #10
0
def test_create_model_with_custom_image(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    custom_image = 'ubuntu:latest'
    chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
                      image_name=custom_image, container_log_level=container_log_level,
                      py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir)

    chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
    model = chainer.create_model()

    assert model.image == custom_image
def test_create_model_with_custom_image(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    custom_image = 'ubuntu:latest'
    chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
                      image_name=custom_image, container_log_level=container_log_level,
                      py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir)

    chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
    model = chainer.create_model()

    assert model.image == custom_image
Beispiel #12
0
def test_chainer(strftime, sagemaker_session, chainer_version):
    chainer = Chainer(entry_point=SCRIPT_PATH,
                      role=ROLE,
                      sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT,
                      train_instance_type=INSTANCE_TYPE,
                      py_version=PYTHON_VERSION,
                      framework_version=chainer_version)

    inputs = 's3://mybucket/train'

    chainer.fit(inputs=inputs)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ['train', 'logs_for_job']
    boto_call_names = [
        c[0] for c in sagemaker_session.boto_session.method_calls
    ]
    assert boto_call_names == ['resource']

    expected_train_args = _create_train_job(chainer_version)
    expected_train_args['input_config'][0]['DataSource']['S3DataSource'][
        'S3Uri'] = inputs

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = chainer.create_model()

    expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-gpu-{}'
    assert {
        'Environment': {
            'SAGEMAKER_SUBMIT_DIRECTORY':
            's3://mybucket/sagemaker-chainer-{}/source/sourcedir.tar.gz'.
            format(TIMESTAMP),
            'SAGEMAKER_PROGRAM':
            'dummy_script.py',
            'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS':
            'false',
            'SAGEMAKER_REGION':
            'us-west-2',
            'SAGEMAKER_CONTAINER_LOG_LEVEL':
            '20'
        },
        'Image': expected_image_base.format(chainer_version, PYTHON_VERSION),
        'ModelDataUrl': 's3://m/m.tar.gz'
    } == model.prepare_container_def(GPU)

    assert 'cpu' in model.prepare_container_def(CPU)['Image']
    predictor = chainer.deploy(1, GPU)
    assert isinstance(predictor, ChainerPredictor)
def test_single_machine_failure(docker_image, instance_type, sagemaker_local_session, tmpdir):
    customer_script = 'failure_script.py'
    estimator = Chainer(entry_point=customer_script,
                        source_dir=resource_path,
                        role=role,
                        image_name=docker_image,
                        train_instance_count=1,
                        train_instance_type=instance_type,
                        sagemaker_session=sagemaker_local_session,
                        output_path='file://{}'.format(tmpdir))

    with pytest.raises(RuntimeError):
        estimator.fit()

    failure_files = {'output': ['failure', os.path.join('data', 'this_file_is_expected')]}
    test_utils.files_exist(str(tmpdir), failure_files)
def test_create_model_with_optional_params(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    enable_cloudwatch_metrics = 'true'
    chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
                      container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name='job',
                      source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)

    chainer.fit(inputs='s3://mybucket/train', job_name='new_name')

    new_role = 'role'
    model_server_workers = 2
    model = chainer.create_model(role=new_role, model_server_workers=model_server_workers)

    assert model.role == new_role
    assert model.model_server_workers == model_server_workers
def test_create_model_with_custom_image(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    custom_image = "ubuntu:latest"
    chainer = Chainer(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        image_uri=custom_image,
        container_log_level=container_log_level,
        base_job_name="job",
        source_dir=source_dir,
    )

    chainer.fit(inputs="s3://mybucket/train", job_name="new_name")
    model = chainer.create_model()

    assert model.image_uri == custom_image
def test_create_model(sagemaker_session, chainer_version):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
                      framework_version=chainer_version, container_log_level=container_log_level,
                      py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir)

    job_name = 'new_name'
    chainer.fit(inputs='s3://mybucket/train', job_name=job_name)
    model = chainer.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.framework_version == chainer_version
    assert model.py_version == chainer.py_version
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == job_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
def _test_mnist(sagemaker_session, ecr_image, instance_type, instance_count,
                script):
    source_dir = 'test/resources/mnist'

    with timeout(minutes=15):
        data_path = 'test/resources/mnist/data'

        chainer = Chainer(entry_point=script,
                          source_dir=source_dir,
                          role='SageMakerRole',
                          train_instance_count=instance_count,
                          train_instance_type=instance_type,
                          sagemaker_session=sagemaker_session,
                          image_name=ecr_image,
                          hyperparameters={
                              'batch-size': 10000,
                              'epochs': 1
                          })

        prefix = 'chainer_mnist/{}'.format(sagemaker_timestamp())

        train_data_path = os.path.join(data_path, 'train')

        key_prefix = prefix + '/train'
        train_input = sagemaker_session.upload_data(path=train_data_path,
                                                    key_prefix=key_prefix)

        test_path = os.path.join(data_path, 'test')
        test_input = sagemaker_session.upload_data(path=test_path,
                                                   key_prefix=prefix + '/test')

        chainer.fit({'train': train_input, 'test': test_input})

    with timeout_and_delete_endpoint(estimator=chainer, minutes=30):
        predictor = chainer.deploy(initial_instance_count=1,
                                   instance_type=instance_type)

        batch_size = 100
        data = np.zeros(shape=(batch_size, 1, 28, 28), dtype='float32')
        output = predictor.predict(data)
        assert len(output) == batch_size
Beispiel #18
0
def test_chainer_mnist_custom_loop(docker_image, sagemaker_local_session,
                                   instance_type, tmpdir):
    customer_script = 'single_machine_custom_loop.py'
    hyperparameters = {'batch-size': 10000, 'epochs': 1}

    estimator = Chainer(entry_point=customer_script,
                        source_dir=mnist_path,
                        role=role,
                        image_name=docker_image,
                        train_instance_count=1,
                        train_instance_type=instance_type,
                        sagemaker_session=sagemaker_local_session,
                        hyperparameters=hyperparameters,
                        output_path='file://{}'.format(tmpdir))

    estimator.fit({
        'train': 'file://{}'.format(os.path.join(data_dir, 'train')),
        'test': 'file://{}'.format(os.path.join(data_dir, 'test'))
    })

    success_files = {
        'model': ['model.npz'],
        'output': ['success'],
    }

    test_utils.files_exist(str(tmpdir), success_files)

    request_data = np.zeros((100, 784), dtype='float32')

    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type)
    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type,
                                                  json_serializer,
                                                  json_deserializer,
                                                  'application/json')
    test_utils.predict_and_assert_response_length(estimator, request_data,
                                                  instance_type,
                                                  csv_serializer,
                                                  csv_deserializer, 'text/csv')
def test_all_processes_finish_with_mpi(docker_image, sagemaker_local_session, tmpdir):
    """
    This test validates that all training processes finish before containers are shut down.
    """
    customer_script = 'all_processes_finish_customer_script.py'
    hyperparameters = {'sagemaker_use_mpi': True, 'sagemaker_process_slots_per_host': 2,
                       'sagemaker_num_processes': 4}

    estimator = Chainer(entry_point=customer_script,
                        source_dir=resource_path,
                        role=role,
                        image_name=docker_image,
                        train_instance_count=2,
                        train_instance_type='local',
                        sagemaker_session=sagemaker_local_session,
                        hyperparameters=hyperparameters,
                        output_path='file://{}'.format(tmpdir))

    estimator.fit()

    completion_file = {'output': [os.path.join('data', 'algo-2', 'process_could_complete')]}
    test_utils.files_exist(str(tmpdir), completion_file)
def test_distributed_failure(docker_image, sagemaker_local_session, tmpdir):
    customer_script = 'failure_script.py'
    cluster_size = 2
    failure_node = 1
    hyperparameters = {'sagemaker_process_slots_per_host': 1,
                       'sagemaker_num_processes': cluster_size, 'node_to_fail': failure_node}

    estimator = Chainer(entry_point=customer_script,
                        source_dir=resource_path,
                        role=role,
                        image_name=docker_image,
                        train_instance_count=cluster_size,
                        train_instance_type='local',
                        sagemaker_session=sagemaker_local_session,
                        hyperparameters=hyperparameters,
                        output_path='file://{}'.format(tmpdir))

    with pytest.raises(RuntimeError):
        estimator.fit()

    node_failure_file = os.path.join('data', 'file_from_node_{}'.format(failure_node))
    failure_files = {'output': ['failure', node_failure_file]}
    test_utils.files_exist(str(tmpdir), failure_files)
def test_chainer(strftime, sagemaker_session, chainer_version):
    chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
                      train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, py_version=PYTHON_VERSION,
                      framework_version=chainer_version)

    inputs = 's3://mybucket/train'

    chainer.fit(inputs=inputs)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ['train', 'logs_for_job']
    boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
    assert boto_call_names == ['resource']

    expected_train_args = _create_train_job(chainer_version)
    expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = chainer.create_model()

    expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-gpu-{}'
    assert {'Environment':
            {'SAGEMAKER_SUBMIT_DIRECTORY':
             's3://mybucket/sagemaker-chainer-{}/source/sourcedir.tar.gz'.format(TIMESTAMP),
             'SAGEMAKER_PROGRAM': 'dummy_script.py',
             'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
             'SAGEMAKER_REGION': 'us-west-2',
             'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'},
            'Image': expected_image_base.format(chainer_version, PYTHON_VERSION),
            'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(GPU)

    assert 'cpu' in model.prepare_container_def(CPU)['Image']
    predictor = chainer.deploy(1, GPU)
    assert isinstance(predictor, ChainerPredictor)
def test_training_jobs_do_not_stall(docker_image, sagemaker_local_session, tmpdir):
    """
    This test validates that training does not stall.
    https://github.com/chainer/chainermn/issues/236
    """
    customer_script = 'training_jobs_do_not_stall_customer_script.py'
    hyperparameters = {'sagemaker_use_mpi': True, 'sagemaker_process_slots_per_host': 1,
                       'sagemaker_num_processes': 2}

    estimator = Chainer(entry_point=customer_script,
                        source_dir=resource_path,
                        role=role,
                        image_name=docker_image,
                        train_instance_count=2,
                        train_instance_type='local',
                        sagemaker_session=sagemaker_local_session,
                        hyperparameters=hyperparameters,
                        output_path='file://{}'.format(tmpdir))

    with pytest.raises(RuntimeError):
        estimator.fit()

    failure_files = {'output': ['failure', os.path.join('data', 'this_file_is_expected')]}
    test_utils.files_exist(str(tmpdir), failure_files)