def test_deploy_framework_model_config(sagemaker_session):
    chainer_model = chainer.ChainerModel(
        model_data="{{ model_data }}",
        role="{{ role }}",
        entry_point="{{ entry_point }}",
        source_dir="{{ source_dir }}",
        image=None,
        py_version='py3',
        framework_version='5.0.0',
        model_server_workers="{{ model_server_worker }}",
        sagemaker_session=sagemaker_session)

    config = airflow.deploy_config(chainer_model,
                                   initial_instance_count="{{ instance_count }}",
                                   instance_type="ml.m4.xlarge")
    expected_config = {
        'Model': {
            'ModelName': "sagemaker-chainer-%s" % TIME_STAMP,
            'PrimaryContainer': {
                'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3',
                'Environment': {
                    'SAGEMAKER_PROGRAM': '{{ entry_point }}',
                    'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/sagemaker-chainer-%s/source/sourcedir.tar.gz"
                                                  % TIME_STAMP,
                    'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
                    'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
                    'SAGEMAKER_REGION': 'us-west-2',
                    'SAGEMAKER_MODEL_SERVER_WORKERS': '{{ model_server_worker }}'
                },
                'ModelDataUrl': '{{ model_data }}'},
            'ExecutionRoleArn': '{{ role }}'
        },
        'EndpointConfig': {
            'EndpointConfigName': "sagemaker-chainer-%s" % TIME_STAMP,
            'ProductionVariants': [{
                'InstanceType': 'ml.m4.xlarge',
                'InitialInstanceCount': '{{ instance_count }}',
                'ModelName': "sagemaker-chainer-%s" % TIME_STAMP,
                'VariantName': 'AllTraffic',
                'InitialVariantWeight': 1
            }]
        },
        'Endpoint': {
            'EndpointName': "sagemaker-chainer-%s" % TIME_STAMP,
            'EndpointConfigName': "sagemaker-chainer-%s" % TIME_STAMP
        },
        'S3Operations': {
            'S3Upload': [{
                'Path': '{{ source_dir }}',
                'Bucket': 'output',
                'Key': "sagemaker-chainer-%s/source/sourcedir.tar.gz" % TIME_STAMP,
                'Tar': True
            }]
        }
    }

    assert config == expected_config
def test_framework_model_config(sagemaker_session):
    chainer_model = chainer.ChainerModel(
        model_data="{{ model_data }}",
        role="{{ role }}",
        entry_point="{{ entry_point }}",
        source_dir="{{ source_dir }}",
        image=None,
        py_version="py3",
        framework_version="5.0.0",
        model_server_workers="{{ model_server_worker }}",
        sagemaker_session=sagemaker_session,
    )

    config = airflow.model_config(instance_type="ml.c4.xlarge",
                                  model=chainer_model)
    expected_config = {
        "ModelName": "sagemaker-chainer-%s" % TIME_STAMP,
        "PrimaryContainer": {
            "Image":
            "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3",
            "Environment": {
                "SAGEMAKER_PROGRAM":
                "{{ entry_point }}",
                "SAGEMAKER_SUBMIT_DIRECTORY":
                "s3://output/sagemaker-chainer-%s/source/sourcedir.tar.gz" %
                TIME_STAMP,
                "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS":
                "false",
                "SAGEMAKER_CONTAINER_LOG_LEVEL":
                "20",
                "SAGEMAKER_REGION":
                "us-west-2",
                "SAGEMAKER_MODEL_SERVER_WORKERS":
                "{{ model_server_worker }}",
            },
            "ModelDataUrl": "{{ model_data }}",
        },
        "ExecutionRoleArn": "{{ role }}",
        "S3Operations": {
            "S3Upload": [{
                "Path":
                "{{ source_dir }}",
                "Bucket":
                "output",
                "Key":
                "sagemaker-chainer-%s/source/sourcedir.tar.gz" % TIME_STAMP,
                "Tar":
                True,
            }]
        },
    }

    assert config == expected_config
def test_framework_model_config(sagemaker_session):
    chainer_model = chainer.ChainerModel(
        model_data="{{ model_data }}",
        role="{{ role }}",
        entry_point="{{ entry_point }}",
        source_dir="{{ source_dir }}",
        image=None,
        py_version='py3',
        framework_version='5.0.0',
        model_server_workers="{{ model_server_worker }}",
        sagemaker_session=sagemaker_session)

    config = airflow.model_config(instance_type='ml.c4.xlarge',
                                  model=chainer_model)
    expected_config = {
        'ModelName':
        "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
        'PrimaryContainer': {
            'Image':
            '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3',
            'Environment': {
                'SAGEMAKER_PROGRAM':
                '{{ entry_point }}',
                'SAGEMAKER_SUBMIT_DIRECTORY':
                "s3://output/sagemaker-chainer-"
                "{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}"
                "/source/sourcedir.tar.gz",
                'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS':
                'false',
                'SAGEMAKER_CONTAINER_LOG_LEVEL':
                '20',
                'SAGEMAKER_REGION':
                'us-west-2',
                'SAGEMAKER_MODEL_SERVER_WORKERS':
                '{{ model_server_worker }}'
            },
            'ModelDataUrl': '{{ model_data }}'
        },
        'ExecutionRoleArn': '{{ role }}',
        'S3Operations': {
            'S3Upload': [{
                'Path': '{{ source_dir }}',
                'Bucket': 'output',
                'Key':
                "sagemaker-chainer-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}/source/sourcedir.tar.gz",
                'Tar': True
            }]
        }
    }

    assert config == expected_config