def test_deploy_config_from_framework_estimator(sagemaker_session): mxnet_estimator = mxnet.MXNet( entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", py_version='py3', framework_version='1.3.0', role="{{ role }}", train_instance_count=1, train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker_session, base_job_name="{{ base_job_name }}", hyperparameters={'batch_size': 100}) train_data = "{{ train_data }}" # simulate training airflow.training_config(mxnet_estimator, train_data) config = airflow.deploy_config_from_estimator(estimator=mxnet_estimator, task_id='task_id', task_type='training', initial_instance_count="{{ instance_count}}", instance_type="ml.c4.large", endpoint_name="mxnet-endpoint") expected_config = { 'Model': { 'ModelName': "sagemaker-mxnet-%s" % TIME_STAMP, 'PrimaryContainer': { 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3', 'Environment': { 'SAGEMAKER_PROGRAM': '{{ entry_point }}', 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']" "['TrainingJobName'] }}/source/sourcedir.tar.gz", 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', 'SAGEMAKER_REGION': 'us-west-2'}, 'ModelDataUrl': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" "/output/model.tar.gz" }, 'ExecutionRoleArn': '{{ role }}' }, 'EndpointConfig': { 'EndpointConfigName': "sagemaker-mxnet-%s" % TIME_STAMP, 'ProductionVariants': [{ 'InstanceType': 'ml.c4.large', 'InitialInstanceCount': '{{ instance_count}}', 'ModelName': "sagemaker-mxnet-%s" % TIME_STAMP, 'VariantName': 'AllTraffic', 'InitialVariantWeight': 1 }] }, 'Endpoint': { 'EndpointName': 'mxnet-endpoint', 'EndpointConfigName': "sagemaker-mxnet-%s" % TIME_STAMP } } assert config == expected_config
def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN( role="{{ role }}", train_instance_count="{{ instance_count }}", train_instance_type="ml.m4.xlarge", k=16, sample_size=128, predictor_type="regressor", sagemaker_session=sagemaker_session, ) record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") # simulate training airflow.training_config(knn_estimator, record, mini_batch_size=256) config = airflow.deploy_config_from_estimator( estimator=knn_estimator, task_id="task_id", task_type="tuning", initial_instance_count="{{ instance_count }}", instance_type="ml.p2.xlarge", ) expected_config = { "Model": { "ModelName": "knn-%s" % TIME_STAMP, "PrimaryContainer": { "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1", "Environment": {}, "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Tuning']['BestTrainingJob']" "['TrainingJobName'] }}/output/model.tar.gz", }, "ExecutionRoleArn": "{{ role }}", }, "EndpointConfig": { "EndpointConfigName": "knn-%s" % TIME_STAMP, "ProductionVariants": [{ "InstanceType": "ml.p2.xlarge", "InitialInstanceCount": "{{ instance_count }}", "ModelName": "knn-%s" % TIME_STAMP, "VariantName": "AllTraffic", "InitialVariantWeight": 1, }], }, "Endpoint": { "EndpointName": "knn-%s" % TIME_STAMP, "EndpointConfigName": "knn-%s" % TIME_STAMP, }, } assert config == expected_config
def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN(role="{{ role }}", train_instance_count="{{ instance_count }}", train_instance_type='ml.m4.xlarge', k=16, sample_size=128, predictor_type='regressor', sagemaker_session=sagemaker_session) record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') # simulate training airflow.training_config(knn_estimator, record, mini_batch_size=256) config = airflow.deploy_config_from_estimator( estimator=knn_estimator, initial_instance_count="{{ instance_count }}", instance_type="ml.p2.xlarge") expected_config = { 'Model': { 'ModelName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", 'PrimaryContainer': { 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1', 'Environment': {}, 'ModelDataUrl': "s3://output/knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" "/output/model.tar.gz" }, 'ExecutionRoleArn': '{{ role }}' }, 'EndpointConfig': { 'EndpointConfigName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", 'ProductionVariants': [{ 'InstanceType': 'ml.p2.xlarge', 'InitialInstanceCount': '{{ instance_count }}', 'ModelName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", 'VariantName': 'AllTraffic', 'InitialVariantWeight': 1 }] }, 'Endpoint': { 'EndpointName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}", 'EndpointConfigName': "knn-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}" } } assert config == expected_config
def test_deploy_config_from_framework_estimator(sagemaker_session): mxnet_estimator = mxnet.MXNet( entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", py_version="py3", framework_version="1.3.0", role="{{ role }}", train_instance_count=1, train_instance_type="ml.m4.xlarge", sagemaker_session=sagemaker_session, base_job_name="{{ base_job_name }}", hyperparameters={"batch_size": 100}, ) train_data = "{{ train_data }}" # simulate training airflow.training_config(mxnet_estimator, train_data) config = airflow.deploy_config_from_estimator( estimator=mxnet_estimator, task_id="task_id", task_type="training", initial_instance_count="{{ instance_count}}", instance_type="ml.c4.large", endpoint_name="mxnet-endpoint", ) expected_config = { "Model": { "ModelName": "sagemaker-mxnet-%s" % TIME_STAMP, "PrimaryContainer": { "Image": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3", "Environment": { "SAGEMAKER_PROGRAM": "{{ entry_point }}", "SAGEMAKER_SUBMIT_DIRECTORY": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']" "['TrainingJobName'] }}/source/sourcedir.tar.gz", "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", "SAGEMAKER_REGION": "us-west-2", }, "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" "/output/model.tar.gz", }, "ExecutionRoleArn": "{{ role }}", }, "EndpointConfig": { "EndpointConfigName": "sagemaker-mxnet-%s" % TIME_STAMP, "ProductionVariants": [{ "InstanceType": "ml.c4.large", "InitialInstanceCount": "{{ instance_count}}", "ModelName": "sagemaker-mxnet-%s" % TIME_STAMP, "VariantName": "AllTraffic", "InitialVariantWeight": 1, }], }, "Endpoint": { "EndpointName": "mxnet-endpoint", "EndpointConfigName": "sagemaker-mxnet-%s" % TIME_STAMP, }, } assert config == expected_config
# create tuning config tuner_config = tuning_config( tuner=fm_tuner, inputs=config["tune_model"]["inputs"]) # create transform config transform_config = transform_config_from_estimator( estimator=fm_estimator, task_id="model_tuning" if hpo_enabled else "model_training", task_type="tuning" if hpo_enabled else "training", **config["batch_transform"]["transform_config"] ) deploy_endpoint_config = deploy_config_from_estimator( estimator=fm_estimator, task_id="model_tuning" if hpo_enabled else "model_training", task_type="tuning" if hpo_enabled else "training", **config["deploy_endpoint"] ) # ============================================================================= # define airflow DAG and tasks # ============================================================================= # define airflow DAG args = { 'owner': 'airflow', 'start_date': airflow.utils.dates.days_ago(2) } dag = DAG(
} # train_config specifies SageMaker training configuration training_config = training_config(estimator=xgb_estimator, inputs=sagemaker_training_inputs, job_name=sagemaker_taining_job_name) sagemaker_model_name = config.SAGEMAKER_MODEL_NAME_PREFIX + '-{}'.format(guid) sagemaker_endpoint_name = config.SAGEMAKER_ENDPOINT_NAME_PREFIX + '-{}'.format( guid) # endpoint_config specifies SageMaker endpoint configuration endpoint_config = deploy_config_from_estimator( estimator=xgb_estimator, task_id="train", task_type="training", initial_instance_count=1, instance_type="ml.m4.xlarge", model_name=sagemaker_model_name, endpoint_name=sagemaker_endpoint_name) # ============================================================================= # define airflow DAG and tasks # ============================================================================= # define airflow DAG args = { "owner": "airflow", "start_date": airflow.utils.dates.days_ago(2), 'depends_on_past': False }