예제 #1
0
 def setUp(self):
     self.sagemaker = SageMakerTrainingOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=create_training_params,
         wait_for_completion=False,
         check_interval=5)
예제 #2
0
class TestSageMakerTrainingOperator(unittest.TestCase):
    def setUp(self):
        self.sagemaker = SageMakerTrainingOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=create_training_params,
            wait_for_completion=False,
            check_interval=5)

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        self.assertEqual(
            self.sagemaker.config['ResourceConfig']['InstanceCount'],
            int(self.sagemaker.config['ResourceConfig']['InstanceCount']))
        self.assertEqual(
            self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'],
            int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']))
        self.assertEqual(
            self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'],
            int(self.sagemaker.config['StoppingCondition']
                ['MaxRuntimeInSeconds']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute(self, mock_training, mock_client):
        mock_training.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            }
        }
        self.sagemaker.execute(None)
        mock_training.assert_called_once_with(create_training_params,
                                              wait_for_completion=False,
                                              print_log=True,
                                              check_interval=5,
                                              max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute_with_failure(self, mock_training, mock_client):
        mock_training.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            }
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
예제 #3
0
class TestSageMakerTrainingOperator(unittest.TestCase):
    def setUp(self):
        self.sagemaker = SageMakerTrainingOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_id',
            config=create_training_params,
            wait_for_completion=False,
            check_interval=5,
        )

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        self.assertEqual(
            self.sagemaker.config['ResourceConfig']['InstanceCount'],
            int(self.sagemaker.config['ResourceConfig']['InstanceCount']),
        )
        self.assertEqual(
            self.sagemaker.config['ResourceConfig']['VolumeSizeInGB'],
            int(self.sagemaker.config['ResourceConfig']['VolumeSizeInGB']),
        )
        self.assertEqual(
            self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'],
            int(self.sagemaker.config['StoppingCondition']
                ['MaxRuntimeInSeconds']),
        )

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute(self, mock_training, mock_client):
        mock_training.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            },
        }
        self.sagemaker.execute(None)
        mock_training.assert_called_once_with(
            create_training_params,
            wait_for_completion=False,
            print_log=True,
            check_interval=5,
            max_ingestion_time=None,
        )

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_training_job')
    def test_execute_with_failure(self, mock_training, mock_client):
        mock_training.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            },
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)

    # pylint: enable=unused-argument

    @mock.patch.object(SageMakerHook, "get_conn")
    @mock.patch.object(SageMakerHook, "list_training_jobs")
    @mock.patch.object(SageMakerHook, "create_training_job")
    def test_execute_with_existing_job_increment(self,
                                                 mock_create_training_job,
                                                 mock_list_training_jobs,
                                                 mock_client):
        self.sagemaker.action_if_job_exists = "increment"
        mock_create_training_job.return_value = {
            "ResponseMetadata": {
                "HTTPStatusCode": 200
            }
        }
        mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
        self.sagemaker.execute(None)

        expected_config = create_training_params.copy()
        # Expect to see TrainingJobName suffixed with "-2" because we return one existing job
        expected_config["TrainingJobName"] = f"{job_name}-2"
        mock_create_training_job.assert_called_once_with(
            expected_config,
            wait_for_completion=False,
            print_log=True,
            check_interval=5,
            max_ingestion_time=None,
        )

    @mock.patch.object(SageMakerHook, "get_conn")
    @mock.patch.object(SageMakerHook, "list_training_jobs")
    @mock.patch.object(SageMakerHook, "create_training_job")
    def test_execute_with_existing_job_fail(self, mock_create_training_job,
                                            mock_list_training_jobs,
                                            mock_client):
        self.sagemaker.action_if_job_exists = "fail"
        mock_create_training_job.return_value = {
            "ResponseMetadata": {
                "HTTPStatusCode": 200
            }
        }
        mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}]
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
예제 #4
0
prepare_task = PythonOperator(task_id='preparing',
                              dag=dag,
                              provide_context=True,
                              python_callable=prepare.prepare,
                              op_kwargs=config["prepare_data"])

branching = BranchPythonOperator(task_id='branching',
                                 dag=dag,
                                 python_callable=lambda: "model_tuning"
                                 if hpo_enabled else "model_training")

# launch sagemaker training job and wait until it completes
train_model_task = SageMakerTrainingOperator(
    task_id='model_training',
    dag=dag,
    config=train_config,
    # aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30)

# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(
    task_id='model_tuning',
    dag=dag,
    config=tuner_config,
    # aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30)

# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
sagemaker_train_model = SageMakerTrainingOperator(
    task_id="sagemaker_train_model",
    config={
        "TrainingJobName":
        "mnistclassifier-{{ execution_date.strftime('%Y-%m-%d-%H-%M-%S') }}",
        "AlgorithmSpecification": {
            "TrainingImage":
            "438346466558.dkr.ecr.eu-west-1.amazonaws.com/kmeans:1",
            "TrainingInputMode": "File",
        },
        "HyperParameters": {
            "k": "10",
            "feature_dim": "784"
        },
        "InputDataConfig": [{
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://your-bucket/mnist_data",
                    "S3DataDistributionType": "FullyReplicated",
                }
            },
        }],
        "OutputDataConfig": {
            "S3OutputPath": "s3://your-bucket/mnistclassifier-output"
        },
        "ResourceConfig": {
            "InstanceType": "ml.c4.xlarge",
            "InstanceCount": 1,
            "VolumeSizeInGB": 10,
        },
        "RoleArn": ("arn:aws:iam::297623009465:role/service-role/"
                    "AmazonSageMaker-ExecutionRole-20180905T153196"),
        "StoppingCondition": {
            "MaxRuntimeInSeconds": 24 * 60 * 60
        },
    },
    wait_for_completion=True,
    print_log=True,
    check_interval=10,
    dag=dag,
)
예제 #6
0
        schedule_interval=None,
        concurrency=1,
        max_active_runs=1,
) as dag:
    process_task = PythonOperator(
        task_id="process",
        dag=dag,
        #provide_context=False,
        python_callable=preprocess_glue,
    )

    train_task = SageMakerTrainingOperator(
        task_id="train",
        config=training_config,
        aws_conn_id="airflow-sagemaker",
        wait_for_completion=True,
        check_interval=60,  #check status of the job every minute
        max_ingestion_time=
        None,  #allow training job to run as long as it needs, change for early stop
    )

    endpoint_deploy_task = SageMakerEndpointOperator(
        task_id="endpoint-deploy",
        config=endpoint_config,
        aws_conn_id="sagemaker-airflow",
        wait_for_completion=True,
        check_interval=60,  #check status of endpoint deployment every minute
        max_ingestion_time=None,
        operation=
        'create',  #change to update if you are updating rather than creating an endpoint
    )