Пример #1
0
def test_transform_config(sagemaker_session):
    tf_transformer = transformer.Transformer(
        model_name="tensorflow-model",
        instance_count="{{ instance_count }}",
        instance_type="ml.p2.xlarge",
        strategy="SingleRecord",
        assemble_with='Line',
        output_path="{{ output_path }}",
        output_kms_key="{{ kms_key }}",
        accept="{{ accept }}",
        max_concurrent_transforms="{{ max_parallel_job }}",
        max_payload="{{ max_payload }}",
        tags=[{"{{ key }}": "{{ value }}"}],
        env={"{{ key }}": "{{ value }}"},
        base_transform_job_name="tensorflow-transform",
        sagemaker_session=sagemaker_session,
        volume_kms_key="{{ kms_key }}")

    data = "{{ transform_data }}"

    config = airflow.transform_config(tf_transformer, data, data_type='S3Prefix', content_type="{{ content_type }}",
                                      compression_type="{{ compression_type }}", split_type="{{ split_type }}")
    expected_config = {
        'TransformJobName': "tensorflow-transform-%s" % TIME_STAMP,
        'ModelName': 'tensorflow-model',
        'TransformInput': {
            'DataSource': {
                'S3DataSource': {
                    'S3DataType': 'S3Prefix',
                    'S3Uri': '{{ transform_data }}'
                }
            },
            'ContentType': '{{ content_type }}',
            'CompressionType': '{{ compression_type }}',
            'SplitType': '{{ split_type }}'},
        'TransformOutput': {
            'S3OutputPath': '{{ output_path }}',
            'KmsKeyId': '{{ kms_key }}',
            'AssembleWith': 'Line',
            'Accept': '{{ accept }}'
        },
        'TransformResources': {
            'InstanceCount': '{{ instance_count }}',
            'InstanceType': 'ml.p2.xlarge',
            'VolumeKmsKeyId': '{{ kms_key }}'
        },
        'BatchStrategy': 'SingleRecord',
        'MaxConcurrentTransforms': '{{ max_parallel_job }}',
        'MaxPayloadInMB': '{{ max_payload }}',
        'Environment': {'{{ key }}': '{{ value }}'},
        'Tags': [{'{{ key }}': '{{ value }}'}]
    }

    assert config == expected_config
def test_transform_config(sagemaker_session):
    tf_transformer = transformer.Transformer(
        model_name="tensorflow-model",
        instance_count="{{ instance_count }}",
        instance_type="ml.p2.xlarge",
        strategy="SingleRecord",
        assemble_with="Line",
        output_path="{{ output_path }}",
        output_kms_key="{{ kms_key }}",
        accept="{{ accept }}",
        max_concurrent_transforms="{{ max_parallel_job }}",
        max_payload="{{ max_payload }}",
        tags=[{
            "{{ key }}": "{{ value }}"
        }],
        env={"{{ key }}": "{{ value }}"},
        base_transform_job_name="tensorflow-transform",
        sagemaker_session=sagemaker_session,
        volume_kms_key="{{ kms_key }}",
    )

    data = "{{ transform_data }}"

    config = airflow.transform_config(
        tf_transformer,
        data,
        data_type="S3Prefix",
        content_type="{{ content_type }}",
        compression_type="{{ compression_type }}",
        split_type="{{ split_type }}",
    )
    expected_config = {
        "TransformJobName": "tensorflow-transform-%s" % TIME_STAMP,
        "ModelName": "tensorflow-model",
        "TransformInput": {
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": "{{ transform_data }}"
                }
            },
            "ContentType": "{{ content_type }}",
            "CompressionType": "{{ compression_type }}",
            "SplitType": "{{ split_type }}",
        },
        "TransformOutput": {
            "S3OutputPath": "{{ output_path }}",
            "KmsKeyId": "{{ kms_key }}",
            "AssembleWith": "Line",
            "Accept": "{{ accept }}",
        },
        "TransformResources": {
            "InstanceCount": "{{ instance_count }}",
            "InstanceType": "ml.p2.xlarge",
            "VolumeKmsKeyId": "{{ kms_key }}",
        },
        "BatchStrategy": "SingleRecord",
        "MaxConcurrentTransforms": "{{ max_parallel_job }}",
        "MaxPayloadInMB": "{{ max_payload }}",
        "Environment": {
            "{{ key }}": "{{ value }}"
        },
        "Tags": [{
            "{{ key }}": "{{ value }}"
        }],
    }

    assert config == expected_config