def test_amazon_alg_training_config_required_args(sagemaker_session): ntm_estimator = ntm.NTM( role="{{ role }}", num_topics=10, train_instance_count="{{ instance_count }}", train_instance_type="ml.c4.2xlarge", sagemaker_session=sagemaker_session, ) ntm_estimator.epochs = 32 record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") config = airflow.training_config(ntm_estimator, record, mini_batch_size=256) expected_config = { "AlgorithmSpecification": { "TrainingImage": "174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1", "TrainingInputMode": "File", }, "OutputDataConfig": { "S3OutputPath": "s3://output/" }, "TrainingJobName": "ntm-%s" % TIME_STAMP, "StoppingCondition": { "MaxRuntimeInSeconds": 86400 }, "ResourceConfig": { "InstanceCount": "{{ instance_count }}", "InstanceType": "ml.c4.2xlarge", "VolumeSizeInGB": 30, }, "RoleArn": "{{ role }}", "InputDataConfig": [{ "DataSource": { "S3DataSource": { "S3DataDistributionType": "ShardedByS3Key", "S3DataType": "S3Prefix", "S3Uri": "{{ record }}", } }, "ChannelName": "train", }], "HyperParameters": { "num_topics": "10", "epochs": "32", "mini_batch_size": "256", "feature_dim": "100", }, } assert config == expected_config
def test_amazon_alg_training_config_required_args(sagemaker_session): job_name = get_job_name('ntm') ntm_estimator = ntm.NTM(role="{{ role }}", num_topics=10, train_instance_count="{{ instance_count }}", train_instance_type="ml.c4.2xlarge", sagemaker_session=sagemaker_session) ntm_estimator.epochs = 32 record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') config = airflow.training_config(ntm_estimator, record, mini_batch_size=256) expected_config = { 'AlgorithmSpecification': { 'TrainingImage': '174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1', 'TrainingInputMode': 'File' }, 'OutputDataConfig': { 'S3OutputPath': 's3://output/' }, 'TrainingJobName': job_name, 'StoppingCondition': { 'MaxRuntimeInSeconds': 86400 }, 'ResourceConfig': { 'InstanceCount': '{{ instance_count }}', 'InstanceType': 'ml.c4.2xlarge', 'VolumeSizeInGB': 30 }, 'RoleArn': '{{ role }}', 'InputDataConfig': [{ 'DataSource': { 'S3DataSource': { 'S3DataDistributionType': 'ShardedByS3Key', 'S3DataType': 'S3Prefix', 'S3Uri': '{{ record }}' } }, 'ChannelName': 'train' }], 'HyperParameters': { 'num_topics': '10', 'epochs': '32', 'mini_batch_size': '256', 'feature_dim': '100' } } assert config == expected_config
def test_amazon_alg_training_config_all_args(sagemaker_session): ntm_estimator = ntm.NTM( role="{{ role }}", num_topics=10, train_instance_count="{{ instance_count }}", train_instance_type="ml.c4.2xlarge", train_volume_size="{{ train_volume_size }}", train_volume_kms_key="{{ train_volume_kms_key }}", train_max_run="{{ train_max_run }}", input_mode='Pipe', output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", base_job_name="{{ base_job_name }}", tags=[{"{{ key }}": "{{ value }}"}], subnets=["{{ subnet }}"], security_group_ids=["{{ security_group_ids }}"], sagemaker_session=sagemaker_session) ntm_estimator.epochs = 32 ntm_estimator.mini_batch_size = 256 record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') config = airflow.training_config(ntm_estimator, record) expected_config = { 'AlgorithmSpecification': { 'TrainingImage': '174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1', 'TrainingInputMode': 'Pipe' }, 'OutputDataConfig': { 'S3OutputPath': '{{ output_path }}', 'KmsKeyId': '{{ output_volume_kms_key }}' }, 'TrainingJobName': "{{ base_job_name }}-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", 'StoppingCondition': { 'MaxRuntimeInSeconds': '{{ train_max_run }}' }, 'ResourceConfig': { 'InstanceCount': '{{ instance_count }}', 'InstanceType': 'ml.c4.2xlarge', 'VolumeSizeInGB': '{{ train_volume_size }}', 'VolumeKmsKeyId': '{{ train_volume_kms_key }}' }, 'RoleArn': '{{ role }}', 'InputDataConfig': [{ 'DataSource': { 'S3DataSource': { 'S3DataDistributionType': 'ShardedByS3Key', 'S3DataType': 'S3Prefix', 'S3Uri': '{{ record }}' } }, 'ChannelName': 'train' }], 'VpcConfig': { 'Subnets': ['{{ subnet }}'], 'SecurityGroupIds': ['{{ security_group_ids }}'] }, 'HyperParameters': { 'num_topics': '10', 'epochs': '32', 'mini_batch_size': '256', 'feature_dim': '100' }, 'Tags': [{'{{ key }}': '{{ value }}'}] } assert config == expected_config
def test_amazon_alg_training_config_all_args(sagemaker_session): ntm_estimator = ntm.NTM( role="{{ role }}", num_topics=10, train_instance_count="{{ instance_count }}", train_instance_type="ml.c4.2xlarge", train_volume_size="{{ train_volume_size }}", train_volume_kms_key="{{ train_volume_kms_key }}", train_max_run="{{ train_max_run }}", input_mode="Pipe", output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", base_job_name="{{ base_job_name }}", tags=[{ "{{ key }}": "{{ value }}" }], subnets=["{{ subnet }}"], security_group_ids=["{{ security_group_ids }}"], sagemaker_session=sagemaker_session, ) ntm_estimator.epochs = 32 record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") config = airflow.training_config(ntm_estimator, record, mini_batch_size=256) expected_config = { "AlgorithmSpecification": { "TrainingImage": "174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1", "TrainingInputMode": "Pipe", }, "OutputDataConfig": { "S3OutputPath": "{{ output_path }}", "KmsKeyId": "{{ output_volume_kms_key }}", }, "TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP, "StoppingCondition": { "MaxRuntimeInSeconds": "{{ train_max_run }}" }, "ResourceConfig": { "InstanceCount": "{{ instance_count }}", "InstanceType": "ml.c4.2xlarge", "VolumeSizeInGB": "{{ train_volume_size }}", "VolumeKmsKeyId": "{{ train_volume_kms_key }}", }, "RoleArn": "{{ role }}", "InputDataConfig": [{ "DataSource": { "S3DataSource": { "S3DataDistributionType": "ShardedByS3Key", "S3DataType": "S3Prefix", "S3Uri": "{{ record }}", } }, "ChannelName": "train", }], "VpcConfig": { "Subnets": ["{{ subnet }}"], "SecurityGroupIds": ["{{ security_group_ids }}"], }, "HyperParameters": { "num_topics": "10", "epochs": "32", "mini_batch_size": "256", "feature_dim": "100", }, "Tags": [{ "{{ key }}": "{{ value }}" }], } assert config == expected_config