예제 #1
0
    def setUp(self):
        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(task_id='test_task',
                                dag=self.dag,
                                sqs_queue='test',
                                aws_conn_id='aws_default')

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()
def filter_subdag():
    """ Subdag to contain parallel pipeline of filtering. """
    result = DAG(
        dag_id="k8s_wagl_nrt_filter.filter_subdag",
        default_args=default_args,
        concurrency=NUM_PARALLEL_PIPELINE,
        schedule_interval=None,
    )

    with result:
        for index in range(NUM_PARALLEL_PIPELINE):
            SENSOR = SQSSensor(
                task_id=f"filter_scene_queue_sensor_{index}",
                sqs_queue=FILTER_SCENE_QUEUE,
                aws_conn_id=AWS_CONN_ID,
                max_messages=NUM_MESSAGES_TO_POLL,
            )

            FILTER = PythonOperator(
                task_id=f"filter_scenes_{index}",
                python_callable=filter_scenes,
                op_kwargs={"index": index},
                provide_context=True,
            )

            SENSOR >> FILTER

    return result
예제 #3
0
    def setUp(self):
        configuration.load_test_config()

        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE
        }

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(
            task_id='test_task',
            dag=self.dag,
            sqs_queue='test',
            aws_conn_id='aws_default'
        )

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()
예제 #4
0
class TestSQSSensor(unittest.TestCase):
    def setUp(self):
        configuration.load_test_config()

        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(task_id='test_task',
                                dag=self.dag,
                                sqs_queue='test',
                                aws_conn_id='aws_default')

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()

    @mock_sqs
    def test_poke_success(self):
        self.sqs_hook.create_queue('test')
        self.sqs_hook.send_message(queue_url='test', message_body='hello')

        result = self.sensor.poke(self.mock_context)
        self.assertTrue(result)

        self.assertTrue(
            "'Body': 'hello'" in str(self.mock_context['ti'].method_calls),
            "context call should contain message hello")

    @mock_sqs
    def test_poke_no_messsage_failed(self):

        self.sqs_hook.create_queue('test')
        result = self.sensor.poke(self.mock_context)
        self.assertFalse(result)

        context_calls = []

        self.assertTrue(self.mock_context['ti'].method_calls == context_calls,
                        "context call  should be same")

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_delete_raise_airflow_exception(self, mock_sqs_hook):
        message = {
            'Messages': [{
                'MessageId': 'c585e508-2ea0-44c7-bf3e-d1ba0cb87834',
                'ReceiptHandle': 'mockHandle',
                'MD5OfBody': 'e5a9d8684a8edfed460b8d42fd28842f',
                'Body': 'h21'
            }],
            'ResponseMetadata': {
                'RequestId': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                'HTTPStatusCode': 200,
                'HTTPHeaders': {
                    'x-amzn-requestid': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                    'date': 'Mon, 18 Feb 2019 18:41:52 GMT',
                    'content-type': 'text/xml',
                    'mock_sqs_hook-length': '830'
                },
                'RetryAttempts': 0
            }
        }
        mock_sqs_hook().get_conn().receive_message.return_value = message
        mock_sqs_hook().get_conn().delete_message_batch.return_value = \
            {'Failed': [{'Id': '22f67273-4dbc-4c19-83b5-aee71bfeb832'}]}

        with self.assertRaises(AirflowException) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue(
            'Delete SQS Messages failed' in context.exception.args[0])

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_receive_raise_exception(self, mock_sqs_hook):
        mock_sqs_hook().get_conn().receive_message.side_effect = Exception(
            'test exception')

        with self.assertRaises(Exception) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue('test exception' in context.exception.args[0])
예제 #5
0
    'owner': 'airflow',
    'depends_on_past': False,
    'retries': 10,
    'retry_delay': timedelta(minutes=2),
    'start_date': days_ago(1),
}

sqs_queue_url = Config.sqs_result_queue_url

with DAG(
    'ergo_job_collector',
    default_args=default_args,
    is_paused_upon_creation=False,
    schedule_interval=timedelta(seconds=10),
    catchup=False,
    max_active_runs=1
) as dag:
    sqs_collector = SQSSensor(
        task_id=TASK_ID_SQS_COLLECTOR,
        sqs_queue=sqs_queue_url,
        max_messages=10,
        wait_time_seconds=10
    )

    result_transformer = JobResultFromMessagesOperator(
        task_id='process_job_result',
        sqs_sensor_task_id=TASK_ID_SQS_COLLECTOR
    )

sqs_collector >> result_transformer
예제 #6
0
    default_args=default_args,
    description="DEA Sentinel-2 NRT processing",
    concurrency=MAX_ACTIVE_RUNS * NUM_PARALLEL_PIPELINE,
    max_active_runs=MAX_ACTIVE_RUNS,
    catchup=False,
    params={},
    schedule_interval=timedelta(minutes=5),
    tags=["k8s", "dea", "psc", "wagl", "nrt"],
)

with pipeline:
    for index in range(NUM_PARALLEL_PIPELINE):
        SENSOR = SQSSensor(
            task_id=f"process_scene_queue_sensor_{index}",
            sqs_queue=PROCESS_SCENE_QUEUE,
            aws_conn_id=AWS_CONN_ID,
            max_messages=NUM_MESSAGES_TO_POLL,
            retries=0,
            execution_timeout=timedelta(minutes=1),
        )

        CMD = PythonOperator(
            task_id=f"copy_cmd_{index}",
            python_callable=copy_cmd,
            op_kwargs={"index": index},
            provide_context=True,
        )

        COPY = KubernetesPodOperator(
            namespace="processing",
            name="dea-s2-wagl-nrt-copy-scene",
            task_id=f"dea-s2-wagl-nrt-copy-scene-{index}",
예제 #7
0
class TestSQSSensor(unittest.TestCase):

    def setUp(self):
        configuration.load_test_config()

        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE
        }

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(
            task_id='test_task',
            dag=self.dag,
            sqs_queue='test',
            aws_conn_id='aws_default'
        )

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()

    @mock_sqs
    def test_poke_success(self):
        self.sqs_hook.create_queue('test')
        self.sqs_hook.send_message(queue_url='test', message_body='hello')

        result = self.sensor.poke(self.mock_context)
        self.assertTrue(result)

        self.assertTrue("'Body': 'hello'" in str(self.mock_context['ti'].method_calls),
                        "context call should contain message hello")

    @mock_sqs
    def test_poke_no_messsage_failed(self):

        self.sqs_hook.create_queue('test')
        result = self.sensor.poke(self.mock_context)
        self.assertFalse(result)

        context_calls = []

        self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context call  should be same")

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_delete_raise_airflow_exception(self, mock_sqs_hook):
        message = {'Messages': [{'MessageId': 'c585e508-2ea0-44c7-bf3e-d1ba0cb87834',
                                 'ReceiptHandle': 'mockHandle',
                                 'MD5OfBody': 'e5a9d8684a8edfed460b8d42fd28842f',
                                 'Body': 'h21'}],
                   'ResponseMetadata': {'RequestId': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                                        'HTTPStatusCode': 200,
                                        'HTTPHeaders': {
                                            'x-amzn-requestid': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                                            'date': 'Mon, 18 Feb 2019 18:41:52 GMT',
                                            'content-type': 'text/xml', 'mock_sqs_hook-length': '830'},
                                        'RetryAttempts': 0}}
        mock_sqs_hook().get_conn().receive_message.return_value = message
        mock_sqs_hook().get_conn().delete_message_batch.return_value = \
            {'Failed': [{'Id': '22f67273-4dbc-4c19-83b5-aee71bfeb832'}]}

        with self.assertRaises(AirflowException) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue('Delete SQS Messages failed' in context.exception.args[0])

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_receive_raise_exception(self, mock_sqs_hook):
        mock_sqs_hook().get_conn().receive_message.side_effect = Exception('test exception')

        with self.assertRaises(Exception) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue('test exception' in context.exception.args[0])