Ejemplo n.º 1
0
 def setUp(self):
     self.sagemaker = SageMakerTrainingOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=create_training_params,
         wait_for_completion=False,
         check_interval=5)
Ejemplo n.º 2
0
class TestSageMakerTrainingOperator(unittest.TestCase):
    def setUp(self):
        configuration.load_test_config()
        self.sagemaker = SageMakerTrainingOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=create_training_params,
            wait_for_completion=False,
            check_interval=5)

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        self.assertEqual(
            self.sagemaker.config['ResourceConfig']['InstanceCount'],
            int(self.sagemaker.config['ResourceConfig']['InstanceCount']))
        self.assertEqual(
            self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'],
            int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']))
        self.assertEqual(
            self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'],
            int(self.sagemaker.config['StoppingCondition']
                ['MaxRuntimeInSeconds']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute(self, mock_training, mock_client):
        mock_training.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            }
        }
        self.sagemaker.execute(None)
        mock_training.assert_called_once_with(create_training_params,
                                              wait_for_completion=False,
                                              print_log=True,
                                              check_interval=5,
                                              max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute_with_failure(self, mock_training, mock_client):
        mock_training.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            }
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
 def setUp(self):
     configuration.load_test_config()
     self.sagemaker = SageMakerTrainingOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=create_training_params,
         wait_for_completion=False,
         check_interval=5
     )
class TestSageMakerTrainingOperator(unittest.TestCase):

    def setUp(self):
        configuration.load_test_config()
        self.sagemaker = SageMakerTrainingOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=create_training_params,
            wait_for_completion=False,
            check_interval=5
        )

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        self.assertEqual(self.sagemaker.config['ResourceConfig']['InstanceCount'],
                         int(self.sagemaker.config['ResourceConfig']['InstanceCount']))
        self.assertEqual(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'],
                         int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']))
        self.assertEqual(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'],
                         int(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute(self, mock_training, mock_client):
        mock_training.return_value = {'TrainingJobArn': 'testarn',
                                      'ResponseMetadata':
                                          {'HTTPStatusCode': 200}}
        self.sagemaker.execute(None)
        mock_training.assert_called_once_with(create_training_params,
                                              wait_for_completion=False,
                                              print_log=True,
                                              check_interval=5,
                                              max_ingestion_time=None
                                              )

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute_with_failure(self, mock_training, mock_client):
        mock_training.return_value = {'TrainingJobArn': 'testarn',
                                      'ResponseMetadata':
                                          {'HTTPStatusCode': 404}}
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
Ejemplo n.º 5
0
def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provided(
        sagemaker_session, cpu_instance_type):
    with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
        estimator = PyTorch(
            entry_point=PYTORCH_MNIST_SCRIPT,
            role=ROLE,
            framework_version="1.1.0",
            train_instance_count=2,
            train_instance_type=cpu_instance_type,
            hyperparameters={
                "epochs": 6,
                "backend": "gloo"
            },
        )

        train_config = sm_airflow.training_config(estimator=estimator)

        uploaded_s3_data = train_config["HyperParameters"][
            "sagemaker_submit_directory"].strip('"')

        transform_config = sm_airflow.transform_config_from_estimator(
            estimator=estimator,
            task_id="transform_config",
            task_type="training",
            instance_count=SINGLE_INSTANCE_COUNT,
            instance_type=cpu_instance_type,
            data=uploaded_s3_data,
            content_type="text/csv",
        )

        default_args = {
            "owner": "airflow",
            "start_date": airflow.utils.dates.days_ago(2),
            "provide_context": True,
        }

        dag = DAG("tensorflow_example",
                  default_args=default_args,
                  schedule_interval="@once")

        train_op = SageMakerTrainingOperator(task_id="tf_training",
                                             config=train_config,
                                             wait_for_completion=True,
                                             dag=dag)

        transform_op = SageMakerTransformOperator(task_id="transform_operator",
                                                  config=transform_config,
                                                  wait_for_completion=True,
                                                  dag=dag)

        transform_op.set_upstream(train_op)

        _assert_that_s3_url_contains_data(sagemaker_session, uploaded_s3_data)
def _build_airflow_workflow(estimator,
                            instance_type,
                            inputs=None,
                            mini_batch_size=None):
    training_config = sm_airflow.training_config(
        estimator=estimator, inputs=inputs, mini_batch_size=mini_batch_size)

    model = estimator.create_model()
    assert model is not None

    model_config = sm_airflow.model_config(instance_type, model)
    assert model_config is not None

    transform_config = sm_airflow.transform_config_from_estimator(
        estimator=estimator,
        task_id="transform_config",
        task_type="training",
        instance_count=SINGLE_INSTANCE_COUNT,
        instance_type=estimator.train_instance_type,
        data=inputs,
        content_type="text/csv",
        input_filter="$",
        output_filter="$",
    )

    default_args = {
        "owner": "airflow",
        "start_date": airflow.utils.dates.days_ago(2),
        "provide_context": True,
    }

    dag = DAG("tensorflow_example",
              default_args=default_args,
              schedule_interval="@once")

    train_op = SageMakerTrainingOperator(task_id="tf_training",
                                         config=training_config,
                                         wait_for_completion=True,
                                         dag=dag)

    transform_op = SageMakerTransformOperator(task_id="transform_operator",
                                              config=transform_config,
                                              wait_for_completion=True,
                                              dag=dag)

    transform_op.set_upstream(train_op)

    return training_config
# prepare the data for training
prepare_task = PythonOperator(task_id='preparing',
                              dag=dag,
                              provide_context=False,
                              python_callable=prepare.prepare,
                              op_kwargs=config["prepare_data"])

branching = BranchPythonOperator(task_id='branching',
                                 dag=dag,
                                 python_callable=lambda: "model_tuning"
                                 if hpo_enabled else "model_training")

# launch sagemaker training job and wait until it completes
train_model_task = SageMakerTrainingOperator(task_id='model_training',
                                             dag=dag,
                                             config=train_config,
                                             aws_conn_id='airflow-sagemaker',
                                             wait_for_completion=True,
                                             check_interval=30)

# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(task_id='model_tuning',
                                          dag=dag,
                                          config=tuner_config,
                                          aws_conn_id='airflow-sagemaker',
                                          wait_for_completion=True,
                                          check_interval=30)

# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
    task_id='predicting',
    dag=dag,
Ejemplo n.º 8
0
    task_id='sm_proc_job',
    dag=dag,
    provide_context=True,
    python_callable=sm_proc_job.sm_proc_job,
    op_kwargs={
        'role': role,
        'sess': sess,
        'bucket': bucket,
        'spark_repo_uri': config["processing_job"]['spark_repo_uri'],
        'base_job_name': config["processing_job"]["base_job_name"]
    })

# Train xgboost model task
train_model_task = SageMakerTrainingOperator(task_id='xgboost_model_training',
                                             dag=dag,
                                             config=train_config,
                                             aws_conn_id='airflow-sagemaker',
                                             wait_for_completion=True,
                                             check_interval=30)

# Inference pipeline endpoint task
inference_pipeline_task = PythonOperator(
    task_id='inference_pipeline',
    dag=dag,
    python_callable=inference_pipeline_ep.inference_pipeline_ep,
    op_kwargs={
        'role': role,
        'sess': sess,
        'spark_model_uri':
        config['inference_pipeline']['inputs']['spark_model'],
        'pipeline_model_name':
        config['inference_pipeline']['pipeline_model_name'],
Ejemplo n.º 9
0
# define airflow DAG and tasks
# =============================================================================

# define airflow DAG
default_args = {
    'owner': 'airflow',
    'start_date': airflow.utils.dates.days_ago(2),
    'provide_context': True
}

dag = DAG(dag_id='train_dkn',
          default_args=default_args,
          schedule_interval='@once')

train_op = SageMakerTrainingOperator(task_id='tf_training',
                                     config=train_dkn_config,
                                     wait_for_completion=True,
                                     dag=dag)

task_def_op = PythonOperator(task_id='task_definition',
                             python_callable=task_def,
                             op_args=['gw1'],
                             provide_context=True,
                             dag=dag)

deploy_ecs_op = PythonOperator(task_id='run_task',
                               python_callable=deploy_model_service,
                               op_args=['gw1'],
                               provide_context=True,
                               dag=dag)

deploy_ecs_op.set_upstream(task_def_op)
# prepare the data for training
prepare_task = PythonOperator(task_id='preparing',
                              dag=dag,
                              provide_context=False,
                              python_callable=prepare.prepare,
                              op_kwargs=config["prepare_data"])

branching = BranchPythonOperator(task_id='branching',
                                 dag=dag,
                                 python_callable=lambda: "model_tuning"
                                 if hpo_enabled else "model_training")

# launch sagemaker training job and wait until it completes
train_model_task = SageMakerTrainingOperator(task_id='model_training',
                                             dag=dag,
                                             config=train_config,
                                             aws_conn_id='airflow-sagemaker',
                                             wait_for_completion=True,
                                             check_interval=30)

# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(task_id='model_tuning',
                                          dag=dag,
                                          config=tuner_config,
                                          aws_conn_id='airflow-sagemaker',
                                          wait_for_completion=True,
                                          check_interval=30)

# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
    task_id='predicting',
    dag=dag,
Ejemplo n.º 11
0
sagemaker_train_model = SageMakerTrainingOperator(
    task_id="sagemaker_train_model",
    config={
        "TrainingJobName":
        "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
        "AlgorithmSpecification": {
            "TrainingImage":
            "438346466558.dkr.ecr.eu-west-1.amazonaws.com/kmeans:1",
            "TrainingInputMode": "File",
        },
        "HyperParameters": {
            "k": "10",
            "feature_dim": "784"
        },
        "InputDataConfig": [{
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://your-bucket/mnist_data",
                    "S3DataDistributionType": "FullyReplicated",
                }
            },
        }],
        "OutputDataConfig": {
            "S3OutputPath": "s3://your-bucket/mnistclassifier-output"
        },
        "ResourceConfig": {
            "InstanceType": "ml.c4.xlarge",
            "InstanceCount": 1,
            "VolumeSizeInGB": 10,
        },
        "RoleArn": ("arn:aws:iam::297623009465:role/service-role/"
                    "AmazonSageMaker-ExecutionRole-20180905T153196"),
        "StoppingCondition": {
            "MaxRuntimeInSeconds": 24 * 60 * 60
        },
    },
    wait_for_completion=True,
    print_log=True,
    check_interval=10,
    dag=dag,
)
        schedule_interval=None,
        concurrency=1,
        max_active_runs=1,
) as dag:
    process_task = PythonOperator(
        task_id="process",
        dag=dag,
        #provide_context=False,
        python_callable=preprocess_glue,
    )

    train_task = SageMakerTrainingOperator(
        task_id="train",
        config=training_config,
        aws_conn_id="airflow-sagemaker",
        wait_for_completion=True,
        check_interval=60,  #check status of the job every minute
        max_ingestion_time=
        None,  #allow training job to run as long as it needs, change for early stop
    )

    endpoint_deploy_task = SageMakerEndpointOperator(
        task_id="endpoint-deploy",
        config=endpoint_config,
        aws_conn_id="sagemaker-airflow",
        wait_for_completion=True,
        check_interval=60,  #check status of endpoint deployment every minute
        max_ingestion_time=None,
        operation=
        'create',  #change to update if you are updating rather than creating an endpoint
    )