Example #1
0
class TestSQSPublishOperator(unittest.TestCase):
    def setUp(self):
        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

        self.dag = DAG('test_dag_id', default_args=args)
        self.operator = SQSPublishOperator(task_id='test_task',
                                           dag=self.dag,
                                           sqs_queue='test',
                                           message_content='hello',
                                           aws_conn_id='aws_default')

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()

    @mock_sqs
    def test_execute_success(self):
        self.sqs_hook.create_queue('test')

        result = self.operator.execute(self.mock_context)
        self.assertTrue('MD5OfMessageBody' in result)
        self.assertTrue('MessageId' in result)

        message = self.sqs_hook.get_conn().receive_message(QueueUrl='test')

        self.assertEqual(len(message['Messages']), 1)
        self.assertEqual(message['Messages'][0]['MessageId'],
                         result['MessageId'])
        self.assertEqual(message['Messages'][0]['Body'], 'hello')

        context_calls = []

        self.assertTrue(self.mock_context['ti'].method_calls == context_calls,
                        "context call  should be same")
    def setUp(self):
        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(task_id='test_task',
                                dag=self.dag,
                                sqs_queue='test',
                                aws_conn_id='aws_default')

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()
def dag_result(**context):
    try:
        message = fetch_sqs_message(context)
    except KeyError:
        # no messages, success
        return

    sqs_hook = SQSHook(aws_conn_id=AWS_CONN_ID)
    message_body = json.dumps(decode(message))
    sqs_hook.send_message(DEADLETTER_SCENE_QUEUE, message_body)
    raise ValueError(f"processing failed for {message_body}")
    def execute(self, context, session=None):
        dag_id = context['ti'].dag_id
        tasks = self.xcom_pull(context, self.task_id_collector, dag_id,
                               self.xcom_tasks_key)
        queue_url = self.xcom_pull(context, self.task_id_collector, dag_id,
                                   self.xcom_sqs_queue_url_key)
        sqs_client = SQSHook(aws_conn_id=self.aws_conn_id).get_conn()
        self.log.info('Trying to push %d messages on queue: %s', len(tasks),
                      queue_url)
        entries = [{
            'Id': str(task.id),
            'MessageBody': task.request_data,
            'MessageGroupId': task.task_id,
            'MessageDeduplicationId': str(task.id)
        } for task in tasks]
        try:
            response = sqs_client.send_message_batch(QueueUrl=queue_url,
                                                     Entries=entries)
        except Exception as e:
            self.log.exception(
                'SQS Send message API failed for "%s" queue!\nRequest Entries: %',
                queue_url,
                str(entries),
                exc_info=e)

            self.log.info("Setting the tasks up for reschedule!")
            self._set_task_states([task.id for task in tasks],
                                  State.UP_FOR_RESCHEDULE,
                                  session=session)
            session.commit()

            raise

        success_resps = response.get('Successful', list())
        failed_resps = response.get('Failed', list())
        if success_resps:
            self.log.info('Successfully pushed %d messages!',
                          len(success_resps))
            self._set_task_states([int(resp['Id']) for resp in success_resps],
                                  State.QUEUED,
                                  session=session)
            jobs = [
                ErgoJob(resp['MessageId'], int(resp['Id']))
                for resp in success_resps
            ]
            session.add_all(jobs)
        if failed_resps:
            self.log.error('Failed to push %d messages!', len(failed_resps))
            self._set_task_states([int(resp['Id']) for resp in failed_resps],
                                  State.UP_FOR_RESCHEDULE,
                                  session=session)
        session.commit()
Example #5
0
    def setUp(self):
        configuration.load_test_config()

        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

        self.dag = DAG('test_dag_id', default_args=args)
        self.operator = SQSPublishOperator(task_id='test_task',
                                           dag=self.dag,
                                           sqs_queue='test',
                                           message_content='hello',
                                           aws_conn_id='aws_default')

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()
def filter_scenes(**context):
    task_instance = context["task_instance"]
    index = context["index"]
    all_messages = task_instance.xcom_pull(
        task_ids=f"filter_scene_queue_sensor_{index}",
        key="messages")["Messages"]

    messages = [
        message for message in all_messages
        if region_code(message) in australian_region_codes()
    ]

    sqs_hook = SQSHook(aws_conn_id=AWS_CONN_ID)

    for message in messages:
        message_body = json.dumps(decode(message))
        sqs_hook.send_message(PROCESS_SCENE_QUEUE, message_body)
    def poke(self, context):
        """
        Check for message on subscribed queue and write to xcom the message with key ``messages``

        :param context: the context object
        :type context: dict
        :return: ``True`` if message is available or ``False``
        """

        sqs_hook = SQSHook(aws_conn_id=self.aws_conn_id)
        sqs_conn = sqs_hook.get_conn()

        self.log.info('SQSSensor checking for message on queue: %s',
                      self.sqs_queue)

        messages = sqs_conn.receive_message(
            QueueUrl=self.sqs_queue,
            MaxNumberOfMessages=self.max_messages,
            WaitTimeSeconds=self.wait_time_seconds)

        self.log.info("reveived message %s", str(messages))

        if 'Messages' in messages and len(messages['Messages']) > 0:

            entries = [{
                'Id': message['MessageId'],
                'ReceiptHandle': message['ReceiptHandle']
            } for message in messages['Messages']]

            result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue,
                                                   Entries=entries)

            if 'Successful' in result:
                context['ti'].xcom_push(key='messages', value=messages)
                return True
            else:
                raise AirflowException('Delete SQS Messages failed ' +
                                       str(result) + ' for messages ' +
                                       str(messages))

        return False
    def execute(self, context):
        """
        Publish the message to SQS queue

        :param context: the context object
        :type context: dict
        :return: dict with information about the message sent
            For details of the returned dict see :py:meth:`botocore.client.SQS.send_message`
        :rtype: dict
        """

        hook = SQSHook(aws_conn_id=self.aws_conn_id)

        result = hook.send_message(queue_url=self.sqs_queue,
                                   message_body=self.message_content,
                                   delay_seconds=self.delay_seconds,
                                   message_attributes=self.message_attributes)

        self.log.info('result is send_message is %s', result)

        return result
    def execute(self, context):
        """
        Publish the message to SQS queue

        :param context: the context object
        :type context: dict
        :return: dict with information about the message sent
            For details of the returned dict see :py:meth:`botocore.client.SQS.send_message`
        :rtype: dict
        """

        hook = SQSHook(aws_conn_id=self.aws_conn_id)

        result = hook.send_message(queue_url=self.sqs_queue,
                                   message_body=self.message_content,
                                   delay_seconds=self.delay_seconds,
                                   message_attributes=self.message_attributes)

        self.log.info('result is send_message is %s', result)

        return result
Example #10
0
 def execute(self, context, session=None):
     dag_id = context['ti'].dag_id
     tasks = self.xcom_pull(context, self.task_id_collector, dag_id,
                            self.xcom_tasks_key)
     sqs_client = SQSHook(aws_conn_id=self.aws_conn_id).get_conn()
     self.log.info(
         'SqsTaskPusherOperator trying to push %d messages on queue: %s',
         len(tasks), self.sqs_queue_url)
     entries = [{
         'Id': str(task.id),
         'MessageBody': task.request_data,
         'MessageGroupId': task.task_id,
         'MessageDeduplicationId': str(task.id)
     } for task in tasks]
     response = sqs_client.send_message_batch(QueueUrl=self.sqs_queue_url,
                                              Entries=entries)
     success_resps = response.get('Successful', list())
     failed_resps = response.get('Failed', list())
     if success_resps:
         self.log.info(
             'SqsTaskPusherOperator successfully pushed %d messages!',
             len(success_resps))
         success_tasks = session.query(ErgoTask).filter(
             ErgoTask.id.in_([int(resp['Id']) for resp in success_resps]))
         for task in success_tasks:
             task.state = State.QUEUED
         jobs = [
             ErgoJob(resp['MessageId'], int(resp['Id']))
             for resp in success_resps
         ]
         session.add_all(jobs)
     if failed_resps:
         self.log.error('SqsTaskPusherOperator failed to push %d messages!',
                        len(failed_resps))
         failed_tasks = session.query(ErgoTask).filter(
             ErgoTask.id.in_([int(resp['Id']) for resp in failed_resps]))
         for task in failed_tasks:
             task.state = State.UP_FOR_RETRY
     session.commit()
    def poke(self, context):
        """
        Check for message on subscribed queue and write to xcom the message with key ``messages``

        :param context: the context object
        :type context: dict
        :return: ``True`` if message is available or ``False``
        """

        sqs_hook = SQSHook(aws_conn_id=self.aws_conn_id)
        sqs_conn = sqs_hook.get_conn()

        self.log.info('SQSSensor checking for message on queue: %s', self.sqs_queue)

        messages = sqs_conn.receive_message(QueueUrl=self.sqs_queue,
                                            MaxNumberOfMessages=self.max_messages,
                                            WaitTimeSeconds=self.wait_time_seconds)

        self.log.info("reveived message %s", str(messages))

        if 'Messages' in messages and len(messages['Messages']) > 0:

            entries = [{'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']}
                       for message in messages['Messages']]

            result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue,
                                                   Entries=entries)

            if 'Successful' in result:
                context['ti'].xcom_push(key='messages', value=messages)
                return True
            else:
                raise AirflowException(
                    'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages))

        return False
    def setUp(self):
        configuration.load_test_config()

        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE
        }

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(
            task_id='test_task',
            dag=self.dag,
            sqs_queue='test',
            aws_conn_id='aws_default'
        )

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()
Example #13
0
class TestSQSSensor(unittest.TestCase):
    def setUp(self):
        configuration.load_test_config()

        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(task_id='test_task',
                                dag=self.dag,
                                sqs_queue='test',
                                aws_conn_id='aws_default')

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()

    @mock_sqs
    def test_poke_success(self):
        self.sqs_hook.create_queue('test')
        self.sqs_hook.send_message(queue_url='test', message_body='hello')

        result = self.sensor.poke(self.mock_context)
        self.assertTrue(result)

        self.assertTrue(
            "'Body': 'hello'" in str(self.mock_context['ti'].method_calls),
            "context call should contain message hello")

    @mock_sqs
    def test_poke_no_messsage_failed(self):

        self.sqs_hook.create_queue('test')
        result = self.sensor.poke(self.mock_context)
        self.assertFalse(result)

        context_calls = []

        self.assertTrue(self.mock_context['ti'].method_calls == context_calls,
                        "context call  should be same")

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_delete_raise_airflow_exception(self, mock_sqs_hook):
        message = {
            'Messages': [{
                'MessageId': 'c585e508-2ea0-44c7-bf3e-d1ba0cb87834',
                'ReceiptHandle': 'mockHandle',
                'MD5OfBody': 'e5a9d8684a8edfed460b8d42fd28842f',
                'Body': 'h21'
            }],
            'ResponseMetadata': {
                'RequestId': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                'HTTPStatusCode': 200,
                'HTTPHeaders': {
                    'x-amzn-requestid': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                    'date': 'Mon, 18 Feb 2019 18:41:52 GMT',
                    'content-type': 'text/xml',
                    'mock_sqs_hook-length': '830'
                },
                'RetryAttempts': 0
            }
        }
        mock_sqs_hook().get_conn().receive_message.return_value = message
        mock_sqs_hook().get_conn().delete_message_batch.return_value = \
            {'Failed': [{'Id': '22f67273-4dbc-4c19-83b5-aee71bfeb832'}]}

        with self.assertRaises(AirflowException) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue(
            'Delete SQS Messages failed' in context.exception.args[0])

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_receive_raise_exception(self, mock_sqs_hook):
        mock_sqs_hook().get_conn().receive_message.side_effect = Exception(
            'test exception')

        with self.assertRaises(Exception) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue('test exception' in context.exception.args[0])
 def test_get_conn(self):
     hook = SQSHook(aws_conn_id='aws_default')
     self.assertIsNotNone(hook.get_conn())
 def test_get_conn(self):
     hook = SQSHook(aws_conn_id='aws_default')
     self.assertIsNotNone(hook.get_conn())
class TestSQSSensor(unittest.TestCase):

    def setUp(self):
        configuration.load_test_config()

        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE
        }

        self.dag = DAG('test_dag_id', default_args=args)
        self.sensor = SQSSensor(
            task_id='test_task',
            dag=self.dag,
            sqs_queue='test',
            aws_conn_id='aws_default'
        )

        self.mock_context = MagicMock()
        self.sqs_hook = SQSHook()

    @mock_sqs
    def test_poke_success(self):
        self.sqs_hook.create_queue('test')
        self.sqs_hook.send_message(queue_url='test', message_body='hello')

        result = self.sensor.poke(self.mock_context)
        self.assertTrue(result)

        self.assertTrue("'Body': 'hello'" in str(self.mock_context['ti'].method_calls),
                        "context call should contain message hello")

    @mock_sqs
    def test_poke_no_messsage_failed(self):

        self.sqs_hook.create_queue('test')
        result = self.sensor.poke(self.mock_context)
        self.assertFalse(result)

        context_calls = []

        self.assertTrue(self.mock_context['ti'].method_calls == context_calls, "context call  should be same")

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_delete_raise_airflow_exception(self, mock_sqs_hook):
        message = {'Messages': [{'MessageId': 'c585e508-2ea0-44c7-bf3e-d1ba0cb87834',
                                 'ReceiptHandle': 'mockHandle',
                                 'MD5OfBody': 'e5a9d8684a8edfed460b8d42fd28842f',
                                 'Body': 'h21'}],
                   'ResponseMetadata': {'RequestId': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                                        'HTTPStatusCode': 200,
                                        'HTTPHeaders': {
                                            'x-amzn-requestid': '56cbf4aa-f4ef-5518-9574-a04e0a5f1411',
                                            'date': 'Mon, 18 Feb 2019 18:41:52 GMT',
                                            'content-type': 'text/xml', 'mock_sqs_hook-length': '830'},
                                        'RetryAttempts': 0}}
        mock_sqs_hook().get_conn().receive_message.return_value = message
        mock_sqs_hook().get_conn().delete_message_batch.return_value = \
            {'Failed': [{'Id': '22f67273-4dbc-4c19-83b5-aee71bfeb832'}]}

        with self.assertRaises(AirflowException) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue('Delete SQS Messages failed' in context.exception.args[0])

    @patch('airflow.contrib.sensors.aws_sqs_sensor.SQSHook')
    def test_poke_receive_raise_exception(self, mock_sqs_hook):
        mock_sqs_hook().get_conn().receive_message.side_effect = Exception('test exception')

        with self.assertRaises(Exception) as context:
            self.sensor.poke(self.mock_context)

        self.assertTrue('test exception' in context.exception.args[0])