コード例 #1
0
 def setUp(self):
     self.sagemaker = SageMakerTuningOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_conn',
         config=create_tuning_params,
         wait_for_completion=False,
         check_interval=5)
コード例 #2
0
class TestSageMakerTuningOperator(unittest.TestCase):
    def setUp(self):
        self.sagemaker = SageMakerTuningOperator(
            task_id='test_sagemaker_operator',
            aws_conn_id='sagemaker_test_conn',
            config=create_tuning_params,
            wait_for_completion=False,
            check_interval=5)

    def test_parse_config_integers(self):
        self.sagemaker.parse_config_integers()
        self.assertEqual(
            self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
            ['InstanceCount'],
            int(self.sagemaker.config['TrainingJobDefinition']
                ['ResourceConfig']['InstanceCount']))
        self.assertEqual(
            self.sagemaker.config['TrainingJobDefinition']['ResourceConfig']
            ['VolumeSizeInGB'],
            int(self.sagemaker.config['TrainingJobDefinition']
                ['ResourceConfig']['VolumeSizeInGB']))
        self.assertEqual(
            self.sagemaker.config['HyperParameterTuningJobConfig']
            ['ResourceLimits']['MaxNumberOfTrainingJobs'],
            int(self.sagemaker.config['HyperParameterTuningJobConfig']
                ['ResourceLimits']['MaxNumberOfTrainingJobs']))
        self.assertEqual(
            self.sagemaker.config['HyperParameterTuningJobConfig']
            ['ResourceLimits']['MaxParallelTrainingJobs'],
            int(self.sagemaker.config['HyperParameterTuningJobConfig']
                ['ResourceLimits']['MaxParallelTrainingJobs']))

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_tuning_job')
    def test_execute(self, mock_tuning, mock_client):
        mock_tuning.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 200
            }
        }
        self.sagemaker.execute(None)
        mock_tuning.assert_called_once_with(create_tuning_params,
                                            wait_for_completion=False,
                                            check_interval=5,
                                            max_ingestion_time=None)

    @mock.patch.object(SageMakerHook, 'get_conn')
    @mock.patch.object(SageMakerHook, 'create_tuning_job')
    def test_execute_with_failure(self, mock_tuning, mock_client):
        mock_tuning.return_value = {
            'TrainingJobArn': 'testarn',
            'ResponseMetadata': {
                'HTTPStatusCode': 404
            }
        }
        self.assertRaises(AirflowException, self.sagemaker.execute, None)
コード例 #3
0
                                 if hpo_enabled else "model_training")

# launch sagemaker training job and wait until it completes
train_model_task = SageMakerTrainingOperator(
    task_id='model_training',
    dag=dag,
    config=train_config,
    # aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30)

# launch sagemaker hyperparameter job and wait until it completes
tune_model_task = SageMakerTuningOperator(
    task_id='model_tuning',
    dag=dag,
    config=tuner_config,
    # aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30)

# launch sagemaker batch transform job and wait until it completes
batch_transform_task = SageMakerTransformOperator(
    task_id='predicting',
    dag=dag,
    config=transform_config,
    # aws_conn_id='airflow-sagemaker',
    wait_for_completion=True,
    check_interval=30,
    trigger_rule=TriggerRule.ONE_SUCCESS)

cleanup_task = DummyOperator(task_id='cleaning_up', dag=dag)