Exemple #1
0
def test_create_model(sagemaker_session, mxnet_version):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        framework_version=mxnet_version,
        container_log_level=container_log_level,
        base_job_name="job",
        source_dir=source_dir,
    )

    job_name = "new_name"
    mx.fit(inputs="s3://mybucket/train", job_name=job_name)
    model = mx.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.framework_version == mxnet_version
    assert model.py_version == mx.py_version
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == job_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
    assert model.image is None
    assert model.vpc_config is None
Exemple #2
0
    def _create_model(output_path):
        script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
        data_path = os.path.join(DATA_DIR, "mxnet_mnist")

        mx = MXNet(
            entry_point=script_path,
            role="SageMakerRole",
            instance_count=1,
            instance_type="local",
            output_path=output_path,
            framework_version=mxnet_inference_latest_version,
            py_version=mxnet_inference_latest_py_version,
            sagemaker_session=sagemaker_local_session,
        )

        train_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "train"),
            key_prefix="integ-test-data/mxnet_mnist/train")
        test_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "test"),
            key_prefix="integ-test-data/mxnet_mnist/test")

        mx.fit({"train": train_input, "test": test_input})
        model = mx.create_model(1)
        return model
Exemple #3
0
def _mxnet_training_job(
    sagemaker_session, container_image, mxnet_version, py_version, cpu_instance_type, learning_rate
):
    with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
        script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
        data_path = os.path.join(DATA_DIR, "mxnet_mnist")

        mx = MXNet(
            entry_point=script_path,
            role=ROLE,
            framework_version=mxnet_version,
            py_version=py_version,
            instance_count=1,
            instance_type=cpu_instance_type,
            sagemaker_session=sagemaker_session,
            hyperparameters={"learning-rate": learning_rate},
        )

        train_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
        )
        test_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
        )

        mx.fit({"train": train_input, "test": test_input})

        # Replace the container image value for now since the frameworks do not support
        # multi-model container image yet.
        return mx.create_model(image_uri=container_image)
def test_create_model_with_custom_image(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    custom_image = 'mxnet:2.0'
    mx = MXNet(entry_point=SCRIPT_PATH,
               role=ROLE,
               sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT,
               train_instance_type=INSTANCE_TYPE,
               image_name=custom_image,
               container_log_level=container_log_level,
               base_job_name='job',
               source_dir=source_dir)

    job_name = 'new_name'
    mx.fit(inputs='s3://mybucket/train', job_name='new_name')
    model = mx.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.image == custom_image
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == job_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
def test_create_model_with_optional_params(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    enable_cloudwatch_metrics = 'true'
    mx = MXNet(entry_point=SCRIPT_PATH,
               role=ROLE,
               sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT,
               train_instance_type=INSTANCE_TYPE,
               container_log_level=container_log_level,
               base_job_name='job',
               source_dir=source_dir,
               enable_cloudwatch_metrics=enable_cloudwatch_metrics)

    mx.fit(inputs='s3://mybucket/train', job_name='new_name')

    new_role = 'role'
    model_server_workers = 2
    vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
    model = mx.create_model(role=new_role,
                            model_server_workers=model_server_workers,
                            vpc_config_override=vpc_config)

    assert model.role == new_role
    assert model.model_server_workers == model_server_workers
    assert model.vpc_config == vpc_config
def test_create_model(sagemaker_session, mxnet_version):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    enable_cloudwatch_metrics = 'true'
    mx = MXNet(entry_point=SCRIPT_PATH,
               role=ROLE,
               sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT,
               train_instance_type=INSTANCE_TYPE,
               framework_version=mxnet_version,
               container_log_level=container_log_level,
               base_job_name='job',
               source_dir=source_dir,
               enable_cloudwatch_metrics=enable_cloudwatch_metrics)

    job_name = 'new_name'
    mx.fit(inputs='s3://mybucket/train', job_name='new_name')
    model = mx.create_model()
    mx.container_log_level

    assert model.sagemaker_session == sagemaker_session
    assert model.framework_version == mxnet_version
    assert model.py_version == mx.py_version
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == job_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
    assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
Exemple #7
0
def test_create_model_with_custom_image(name_from_base, sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    custom_image = "mxnet:2.0"
    base_job_name = "job"

    mx = MXNet(
        entry_point=SCRIPT_NAME,
        source_dir=source_dir,
        framework_version="2.0",
        py_version="py3",
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        image_uri=custom_image,
        container_log_level=container_log_level,
        base_job_name=base_job_name,
    )

    mx.fit(inputs="s3://mybucket/train", job_name="new_name")

    model_name = "model_name"
    name_from_base.return_value = model_name
    model = mx.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.image_uri == custom_image
    assert model.entry_point == SCRIPT_NAME
    assert model.role == ROLE
    assert model.name == model_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir

    name_from_base.assert_called_with(base_job_name)
Exemple #8
0
def test_create_model_with_optional_params(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    enable_cloudwatch_metrics = "true"
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        container_log_level=container_log_level,
        base_job_name="job",
        source_dir=source_dir,
        enable_cloudwatch_metrics=enable_cloudwatch_metrics,
    )

    mx.fit(inputs="s3://mybucket/train", job_name="new_name")

    new_role = "role"
    model_server_workers = 2
    vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
    model = mx.create_model(
        role=new_role,
        model_server_workers=model_server_workers,
        vpc_config_override=vpc_config,
        entry_point=SERVING_SCRIPT_FILE,
    )

    assert model.role == new_role
    assert model.model_server_workers == model_server_workers
    assert model.vpc_config == vpc_config
    assert model.entry_point == SERVING_SCRIPT_FILE
Exemple #9
0
def test_mnist_training_and_serving(docker_image, sagemaker_local_session,
                                    local_instance_type, framework_version,
                                    tmpdir):
    mx = MXNet(entry_point=SCRIPT_PATH,
               role='SageMakerRole',
               train_instance_count=1,
               train_instance_type=local_instance_type,
               sagemaker_session=sagemaker_local_session,
               image_name=docker_image,
               framework_version=framework_version,
               output_path='file://{}'.format(tmpdir))

    _train_and_assert_success(mx, str(tmpdir))

    with local_mode_utils.lock():
        try:
            model = mx.create_model(
                model_server_workers=NUM_MODEL_SERVER_WORKERS)
            predictor = _csv_predictor(model, local_instance_type)
            data = numpy.zeros(shape=(1, 1, 28, 28))
            prediction = predictor.predict(data)
        finally:
            mx.delete_endpoint()

    # Check that there is a probability for each possible class in the prediction
    prediction_values = prediction.decode('utf-8').split(',')
    assert len(prediction_values) == 10
Exemple #10
0
def test_mxnet(
    retrieve_image_uri,
    time,
    strftime,
    repack_model,
    create_tar_file,
    sagemaker_session,
    mxnet_training_version,
    mxnet_training_py_version,
):
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        framework_version=mxnet_training_version,
        py_version=mxnet_training_py_version,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        enable_sagemaker_metrics=False,
    )
    inputs = "s3://mybucket/train"

    mx.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ["train", "logs_for_job"]
    boto_call_names = [
        c[0] for c in sagemaker_session.boto_session.method_calls
    ]
    assert boto_call_names == ["resource"]

    actual_train_args = sagemaker_session.method_calls[0][2]
    job_name = actual_train_args["job_name"]
    expected_train_args = _get_train_args(job_name)
    expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
        "S3Uri"] = inputs
    expected_train_args["experiment_config"] = EXPERIMENT_CONFIG
    expected_train_args["enable_sagemaker_metrics"] = False

    assert actual_train_args == expected_train_args

    model = mx.create_model()

    actual_environment = model.prepare_container_def(GPU)
    submit_directory = actual_environment["Environment"][
        "SAGEMAKER_SUBMIT_DIRECTORY"]
    model_url = actual_environment["ModelDataUrl"]
    expected_environment = _get_environment(submit_directory, model_url, IMAGE)
    assert actual_environment == expected_environment

    assert "cpu" in model.prepare_container_def(CPU)["Image"]
    predictor = mx.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
    assert _is_mms_version(mxnet_training_version) ^ (
        create_tar_file.called and not repack_model.called)
Exemple #11
0
def test_mxnet(strftime, sagemaker_session, mxnet_version,
               skip_if_mms_version):
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        framework_version=mxnet_version,
    )

    inputs = "s3://mybucket/train"

    mx.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ["train", "logs_for_job"]
    boto_call_names = [
        c[0] for c in sagemaker_session.boto_session.method_calls
    ]
    assert boto_call_names == ["resource"]

    expected_train_args = _create_train_job(mxnet_version)
    expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
        "S3Uri"] = inputs
    expected_train_args["experiment_config"] = EXPERIMENT_CONFIG

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = mx.create_model()

    expected_image_base = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2"
    environment = {
        "Environment": {
            "SAGEMAKER_SUBMIT_DIRECTORY":
            "s3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz".format(
                TIMESTAMP),
            "SAGEMAKER_PROGRAM":
            "dummy_script.py",
            "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS":
            "false",
            "SAGEMAKER_REGION":
            "us-west-2",
            "SAGEMAKER_CONTAINER_LOG_LEVEL":
            "20",
        },
        "Image": expected_image_base.format(mxnet_version),
        "ModelDataUrl": "s3://m/m.tar.gz",
    }
    assert environment == model.prepare_container_def(GPU)

    assert "cpu" in model.prepare_container_def(CPU)["Image"]
    predictor = mx.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
def test_custom_image_estimator_deploy(sagemaker_session):
    custom_image = "mycustomimage:latest"
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
    )
    mx.fit(inputs="s3://mybucket/train", job_name="new_name")
    model = mx.create_model(image=custom_image)
    assert model.image == custom_image
Exemple #13
0
def test_mxnet_mms_version(strftime, repack_model, sagemaker_session,
                           mxnet_version, skip_if_not_mms_version):
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        framework_version=mxnet_version,
    )

    inputs = "s3://mybucket/train"

    mx.fit(inputs=inputs)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ["train", "logs_for_job"]
    boto_call_names = [
        c[0] for c in sagemaker_session.boto_session.method_calls
    ]
    assert boto_call_names == ["resource"]

    expected_train_args = _create_train_job(mxnet_version)
    expected_train_args["input_config"][0]["DataSource"]["S3DataSource"][
        "S3Uri"] = inputs

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = mx.create_model()

    expected_image_base = _get_full_image_uri(mxnet_version,
                                              IMAGE_REPO_SERVING_NAME, "gpu")

    environment = {
        "Environment": {
            "SAGEMAKER_SUBMIT_DIRECTORY":
            "s3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz",
            "SAGEMAKER_PROGRAM": "dummy_script.py",
            "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false",
            "SAGEMAKER_REGION": "us-west-2",
            "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
        },
        "Image":
        expected_image_base.format(mxnet_version),
        "ModelDataUrl":
        "s3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz",
    }
    assert environment == model.prepare_container_def(GPU)

    assert "cpu" in model.prepare_container_def(CPU)["Image"]
    predictor = mx.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
def test_mxnet(strftime, sagemaker_session, mxnet_version,
               skip_if_mms_version):
    mx = MXNet(entry_point=SCRIPT_PATH,
               role=ROLE,
               sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT,
               train_instance_type=INSTANCE_TYPE,
               framework_version=mxnet_version)

    inputs = 's3://mybucket/train'

    mx.fit(inputs=inputs)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ['train', 'logs_for_job']
    boto_call_names = [
        c[0] for c in sagemaker_session.boto_session.method_calls
    ]
    assert boto_call_names == ['resource']

    expected_train_args = _create_train_job(mxnet_version)
    expected_train_args['input_config'][0]['DataSource']['S3DataSource'][
        'S3Uri'] = inputs

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = mx.create_model()

    expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2'
    environment = {
        'Environment': {
            'SAGEMAKER_SUBMIT_DIRECTORY':
            's3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz'.format(
                TIMESTAMP),
            'SAGEMAKER_PROGRAM':
            'dummy_script.py',
            'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS':
            'false',
            'SAGEMAKER_REGION':
            'us-west-2',
            'SAGEMAKER_CONTAINER_LOG_LEVEL':
            '20'
        },
        'Image': expected_image_base.format(mxnet_version),
        'ModelDataUrl': 's3://m/m.tar.gz'
    }
    assert environment == model.prepare_container_def(GPU)

    assert 'cpu' in model.prepare_container_def(CPU)['Image']
    predictor = mx.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
def mxnet_model(sagemaker_local_session):
    script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
    data_path = os.path.join(DATA_DIR, 'mxnet_mnist')

    mx = MXNet(entry_point=script_path, role='SageMakerRole',
               train_instance_count=1, train_instance_type='local',
               sagemaker_session=sagemaker_local_session)

    train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
                                                   key_prefix='integ-test-data/mxnet_mnist/train')
    test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
                                                  key_prefix='integ-test-data/mxnet_mnist/test')

    mx.fit({'train': train_input, 'test': test_input})
    model = mx.create_model(1)
    return model
def test_mxnet_mms_version(strftime, repack_model, sagemaker_session,
                           mxnet_version, skip_if_not_mms_version):
    mx = MXNet(entry_point=SCRIPT_PATH,
               role=ROLE,
               sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT,
               train_instance_type=INSTANCE_TYPE,
               framework_version=mxnet_version)

    inputs = 's3://mybucket/train'

    mx.fit(inputs=inputs)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ['train', 'logs_for_job']
    boto_call_names = [
        c[0] for c in sagemaker_session.boto_session.method_calls
    ]
    assert boto_call_names == ['resource']

    expected_train_args = _create_train_job(mxnet_version)
    expected_train_args['input_config'][0]['DataSource']['S3DataSource'][
        'S3Uri'] = inputs

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = mx.create_model()

    expected_image_base = _get_full_image_uri(mxnet_version,
                                              IMAGE_REPO_SERVING_NAME, 'gpu')
    environment = {
        'Environment': {
            'SAGEMAKER_SUBMIT_DIRECTORY': REPACKED_MODEL_DATA,
            'SAGEMAKER_PROGRAM': 'dummy_script.py',
            'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
            'SAGEMAKER_REGION': 'us-west-2',
            'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
        },
        'Image': expected_image_base.format(mxnet_version),
        'ModelDataUrl': REPACKED_MODEL_DATA
    }
    assert environment == model.prepare_container_def(GPU)

    assert 'cpu' in model.prepare_container_def(CPU)['Image']
    predictor = mx.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)
def test_create_model_with_optional_params(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    enable_cloudwatch_metrics = 'true'
    mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
               container_log_level=container_log_level, base_job_name='job', source_dir=source_dir,
               enable_cloudwatch_metrics=enable_cloudwatch_metrics)

    mx.fit(inputs='s3://mybucket/train', job_name='new_name')

    new_role = 'role'
    model_server_workers = 2
    model = mx.create_model(role=new_role, model_server_workers=model_server_workers)

    assert model.role == new_role
    assert model.model_server_workers == model_server_workers
def mxnet_model(sagemaker_local_session):
    script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
    data_path = os.path.join(DATA_DIR, 'mxnet_mnist')

    mx = MXNet(entry_point=script_path,
               role='SageMakerRole',
               train_instance_count=1,
               train_instance_type='local',
               sagemaker_session=sagemaker_local_session)

    train_input = mx.sagemaker_session.upload_data(
        path=os.path.join(data_path, 'train'),
        key_prefix='integ-test-data/mxnet_mnist/train')
    test_input = mx.sagemaker_session.upload_data(
        path=os.path.join(data_path, 'test'),
        key_prefix='integ-test-data/mxnet_mnist/test')

    mx.fit({'train': train_input, 'test': test_input})
    model = mx.create_model(1)
    return model
Exemple #19
0
def test_create_model_with_custom_hosting_image(sagemaker_session):
    container_log_level = '"logging.INFO"'
    custom_image = "mxnet:2.0"
    custom_hosting_image = "mxnet_hosting:2.0"
    mx = MXNet(
        entry_point=SCRIPT_PATH,
        framework_version="2.0",
        py_version="py3",
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        image_uri=custom_image,
        container_log_level=container_log_level,
        base_job_name="job",
    )

    mx.fit(inputs="s3://mybucket/train", job_name="new_name")
    model = mx.create_model(image_uri=custom_hosting_image)

    assert model.image_uri == custom_hosting_image
Exemple #20
0
def test_create_model_with_optional_params(
    sagemaker_session, mxnet_inference_version, mxnet_inference_py_version
):
    container_log_level = '"logging.INFO"'
    source_dir = "s3://mybucket/source"
    mx = MXNet(
        entry_point=SCRIPT_NAME,
        source_dir=source_dir,
        framework_version=mxnet_inference_version,
        py_version=mxnet_inference_py_version,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        container_log_level=container_log_level,
        base_job_name="job",
    )

    mx.fit(inputs="s3://mybucket/train", job_name="new_name")

    new_role = "role"
    model_server_workers = 2
    vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
    model_name = "model-name"
    model = mx.create_model(
        role=new_role,
        model_server_workers=model_server_workers,
        vpc_config_override=vpc_config,
        entry_point=SERVING_SCRIPT_FILE,
        env=ENV,
        name=model_name,
    )

    assert model.role == new_role
    assert model.model_server_workers == model_server_workers
    assert model.vpc_config == vpc_config
    assert model.entry_point == SERVING_SCRIPT_FILE
    assert model.env == ENV
    assert model.name == model_name
def test_create_model_with_custom_image(sagemaker_session):
    container_log_level = '"logging.INFO"'
    source_dir = 's3://mybucket/source'
    enable_cloudwatch_metrics = 'true'
    custom_image = 'mxnet:2.0'
    mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
               image_name=custom_image, container_log_level=container_log_level,
               base_job_name='job', source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics)

    job_name = 'new_name'
    mx.fit(inputs='s3://mybucket/train', job_name='new_name')
    model = mx.create_model()

    assert model.sagemaker_session == sagemaker_session
    assert model.image == custom_image
    assert model.entry_point == SCRIPT_PATH
    assert model.role == ROLE
    assert model.name == job_name
    assert model.container_log_level == container_log_level
    assert model.source_dir == source_dir
    assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
def test_mxnet(strftime, sagemaker_session, mxnet_version):
    mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
               train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
               framework_version=mxnet_version)

    inputs = 's3://mybucket/train'

    mx.fit(inputs=inputs)

    sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
    assert sagemaker_call_names == ['train', 'logs_for_job']
    boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
    assert boto_call_names == ['resource']

    expected_train_args = _create_train_job(mxnet_version)
    expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs

    actual_train_args = sagemaker_session.method_calls[0][2]
    assert actual_train_args == expected_train_args

    model = mx.create_model()

    expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2'
    environment = {
        'Environment': {
            'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz'.format(TIMESTAMP),
            'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
            'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
        },
        'Image': expected_image_base.format(mxnet_version), 'ModelDataUrl': 's3://m/m.tar.gz'
    }
    assert environment == model.prepare_container_def(GPU)

    assert 'cpu' in model.prepare_container_def(CPU)['Image']
    predictor = mx.deploy(1, GPU)
    assert isinstance(predictor, MXNetPredictor)