def test_amazon_alg_training_config_required_args(sagemaker_session):
    ntm_estimator = ntm.NTM(
        role="{{ role }}",
        num_topics=10,
        train_instance_count="{{ instance_count }}",
        train_instance_type="ml.c4.2xlarge",
        sagemaker_session=sagemaker_session,
    )

    ntm_estimator.epochs = 32

    record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix")

    config = airflow.training_config(ntm_estimator,
                                     record,
                                     mini_batch_size=256)
    expected_config = {
        "AlgorithmSpecification": {
            "TrainingImage":
            "174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1",
            "TrainingInputMode": "File",
        },
        "OutputDataConfig": {
            "S3OutputPath": "s3://output/"
        },
        "TrainingJobName":
        "ntm-%s" % TIME_STAMP,
        "StoppingCondition": {
            "MaxRuntimeInSeconds": 86400
        },
        "ResourceConfig": {
            "InstanceCount": "{{ instance_count }}",
            "InstanceType": "ml.c4.2xlarge",
            "VolumeSizeInGB": 30,
        },
        "RoleArn":
        "{{ role }}",
        "InputDataConfig": [{
            "DataSource": {
                "S3DataSource": {
                    "S3DataDistributionType": "ShardedByS3Key",
                    "S3DataType": "S3Prefix",
                    "S3Uri": "{{ record }}",
                }
            },
            "ChannelName": "train",
        }],
        "HyperParameters": {
            "num_topics": "10",
            "epochs": "32",
            "mini_batch_size": "256",
            "feature_dim": "100",
        },
    }
    assert config == expected_config
Beispiel #2
0
def test_amazon_alg_training_config_required_args(sagemaker_session):
    job_name = get_job_name('ntm')
    ntm_estimator = ntm.NTM(role="{{ role }}",
                            num_topics=10,
                            train_instance_count="{{ instance_count }}",
                            train_instance_type="ml.c4.2xlarge",
                            sagemaker_session=sagemaker_session)

    ntm_estimator.epochs = 32

    record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix')

    config = airflow.training_config(ntm_estimator,
                                     record,
                                     mini_batch_size=256)
    expected_config = {
        'AlgorithmSpecification': {
            'TrainingImage':
            '174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1',
            'TrainingInputMode': 'File'
        },
        'OutputDataConfig': {
            'S3OutputPath': 's3://output/'
        },
        'TrainingJobName':
        job_name,
        'StoppingCondition': {
            'MaxRuntimeInSeconds': 86400
        },
        'ResourceConfig': {
            'InstanceCount': '{{ instance_count }}',
            'InstanceType': 'ml.c4.2xlarge',
            'VolumeSizeInGB': 30
        },
        'RoleArn':
        '{{ role }}',
        'InputDataConfig': [{
            'DataSource': {
                'S3DataSource': {
                    'S3DataDistributionType': 'ShardedByS3Key',
                    'S3DataType': 'S3Prefix',
                    'S3Uri': '{{ record }}'
                }
            },
            'ChannelName': 'train'
        }],
        'HyperParameters': {
            'num_topics': '10',
            'epochs': '32',
            'mini_batch_size': '256',
            'feature_dim': '100'
        }
    }
    assert config == expected_config
Beispiel #3
0
def test_amazon_alg_training_config_all_args(sagemaker_session):
    ntm_estimator = ntm.NTM(
        role="{{ role }}",
        num_topics=10,
        train_instance_count="{{ instance_count }}",
        train_instance_type="ml.c4.2xlarge",
        train_volume_size="{{ train_volume_size }}",
        train_volume_kms_key="{{ train_volume_kms_key }}",
        train_max_run="{{ train_max_run }}",
        input_mode='Pipe',
        output_path="{{ output_path }}",
        output_kms_key="{{ output_volume_kms_key }}",
        base_job_name="{{ base_job_name }}",
        tags=[{"{{ key }}": "{{ value }}"}],
        subnets=["{{ subnet }}"],
        security_group_ids=["{{ security_group_ids }}"],
        sagemaker_session=sagemaker_session)

    ntm_estimator.epochs = 32
    ntm_estimator.mini_batch_size = 256

    record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix')

    config = airflow.training_config(ntm_estimator, record)
    expected_config = {
        'AlgorithmSpecification': {
            'TrainingImage': '174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1',
            'TrainingInputMode': 'Pipe'
        },
        'OutputDataConfig': {
            'S3OutputPath': '{{ output_path }}',
            'KmsKeyId': '{{ output_volume_kms_key }}'
        },
        'TrainingJobName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
        'StoppingCondition': {
            'MaxRuntimeInSeconds': '{{ train_max_run }}'
        },
        'ResourceConfig': {
            'InstanceCount': '{{ instance_count }}',
            'InstanceType': 'ml.c4.2xlarge',
            'VolumeSizeInGB': '{{ train_volume_size }}',
            'VolumeKmsKeyId': '{{ train_volume_kms_key }}'
        },
        'RoleArn': '{{ role }}',
        'InputDataConfig': [{
            'DataSource': {
                'S3DataSource': {
                    'S3DataDistributionType': 'ShardedByS3Key',
                    'S3DataType': 'S3Prefix',
                    'S3Uri': '{{ record }}'
                }
            },
            'ChannelName': 'train'
        }],
        'VpcConfig': {
            'Subnets': ['{{ subnet }}'],
            'SecurityGroupIds': ['{{ security_group_ids }}']
        },
        'HyperParameters': {
            'num_topics': '10',
            'epochs': '32',
            'mini_batch_size': '256',
            'feature_dim': '100'
        },
        'Tags': [{'{{ key }}': '{{ value }}'}]
    }

    assert config == expected_config
def test_amazon_alg_training_config_all_args(sagemaker_session):
    ntm_estimator = ntm.NTM(
        role="{{ role }}",
        num_topics=10,
        train_instance_count="{{ instance_count }}",
        train_instance_type="ml.c4.2xlarge",
        train_volume_size="{{ train_volume_size }}",
        train_volume_kms_key="{{ train_volume_kms_key }}",
        train_max_run="{{ train_max_run }}",
        input_mode="Pipe",
        output_path="{{ output_path }}",
        output_kms_key="{{ output_volume_kms_key }}",
        base_job_name="{{ base_job_name }}",
        tags=[{
            "{{ key }}": "{{ value }}"
        }],
        subnets=["{{ subnet }}"],
        security_group_ids=["{{ security_group_ids }}"],
        sagemaker_session=sagemaker_session,
    )

    ntm_estimator.epochs = 32

    record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix")

    config = airflow.training_config(ntm_estimator,
                                     record,
                                     mini_batch_size=256)
    expected_config = {
        "AlgorithmSpecification": {
            "TrainingImage":
            "174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1",
            "TrainingInputMode": "Pipe",
        },
        "OutputDataConfig": {
            "S3OutputPath": "{{ output_path }}",
            "KmsKeyId": "{{ output_volume_kms_key }}",
        },
        "TrainingJobName":
        "{{ base_job_name }}-%s" % TIME_STAMP,
        "StoppingCondition": {
            "MaxRuntimeInSeconds": "{{ train_max_run }}"
        },
        "ResourceConfig": {
            "InstanceCount": "{{ instance_count }}",
            "InstanceType": "ml.c4.2xlarge",
            "VolumeSizeInGB": "{{ train_volume_size }}",
            "VolumeKmsKeyId": "{{ train_volume_kms_key }}",
        },
        "RoleArn":
        "{{ role }}",
        "InputDataConfig": [{
            "DataSource": {
                "S3DataSource": {
                    "S3DataDistributionType": "ShardedByS3Key",
                    "S3DataType": "S3Prefix",
                    "S3Uri": "{{ record }}",
                }
            },
            "ChannelName": "train",
        }],
        "VpcConfig": {
            "Subnets": ["{{ subnet }}"],
            "SecurityGroupIds": ["{{ security_group_ids }}"],
        },
        "HyperParameters": {
            "num_topics": "10",
            "epochs": "32",
            "mini_batch_size": "256",
            "feature_dim": "100",
        },
        "Tags": [{
            "{{ key }}": "{{ value }}"
        }],
    }

    assert config == expected_config