def finish_up(**context):
    """ Delete the SQS message to mark completion, broadcast to SNS. """
    task_instance = context["task_instance"]
    index = context["index"]

    message = task_instance.xcom_pull(task_ids=f"receive_task_{index}",
                                      key="message")
    sqs = get_sqs()

    _LOG.info("deleting %s", message["ReceiptHandle"])
    sqs.delete_message(QueueUrl=PROCESS_SCENE_QUEUE,
                       ReceiptHandle=message["ReceiptHandle"])

    msg = task_instance.xcom_pull(task_ids=f"dea-s2-wagl-nrt-{index}",
                                  key="return_value")

    if msg == {}:
        _LOG.info("dataset already existed, did not get processed by this DAG")
        return

    dataset_location = msg["dataset"]
    parsed = urlparse(dataset_location)
    _LOG.info("dataset location: %s", dataset_location)

    s3 = get_s3()
    response = s3.get_object(Bucket=parsed.netloc, Key=parsed.path.lstrip("/"))
    body = json.dumps(yaml.safe_load(response["Body"]), indent=2)

    _LOG.info("publishing to SNS: %s", body)
    sns_hook = AwsSnsHook(aws_conn_id=AWS_CONN_ID)
    sns_hook.publish_to_target(PUBLISH_S2_NRT_SNS, body)
def sns_publish(key, bucket_name):
    """
    Send message to an SNS Topic
    """

    hook = AwsSnsHook()
    content_object = hook.get_key(key=key, bucket_name=bucket_name)
    file_content = content_object.get()['Body'].read().decode('utf-8')
    return json.loads(file_content)
Beispiel #3
0
    def test_publish_to_target_plain(self):
        hook = AwsSnsHook(aws_conn_id='aws_default')

        message = "Hello world"
        topic_name = "test-topic"
        target = hook.get_conn().create_topic(Name=topic_name).get('TopicArn')

        response = hook.publish_to_target(target, message)

        assert 'MessageId' in response
def sns_broadcast(**context):
    task_instance = context["task_instance"]
    index = context["index"]
    msg = task_instance.xcom_pull(task_ids=f"dea-s2-wagl-nrt-{index}",
                                  key="return_value")

    msg_str = json.dumps(msg)

    sns_hook = AwsSnsHook(aws_conn_id=AWS_CONN_ID)
    sns_hook.publish_to_target(PUBLISH_S2_NRT_SNS, msg_str)
    def test_publish_to_target(self):
        hook = AwsSnsHook(aws_conn_id='aws_default')

        message = "Hello world"
        topic_name = "test-topic"
        target = hook.get_conn().create_topic(Name=topic_name).get('TopicArn')

        response = hook.publish_to_target(target, message)

        self.assertTrue('MessageId' in response)
    def execute(self, context):
        sns = AwsSnsHook(aws_conn_id=self.aws_conn_id)

        self.log.info(
            'Sending SNS notification to {} using {}:\n{}'.format(
                self.target_arn,
                self.aws_conn_id,
                self.message
            )
        )

        return sns.publish_to_target(
            target_arn=self.target_arn,
            message=self.message
        )
    def execute(self, context):
        sns = AwsSnsHook(aws_conn_id=self.aws_conn_id)

        self.log.info(
            'Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s',
            self.target_arn,
            self.aws_conn_id,
            self.subject,
            self.message_attributes,
            self.message,
        )

        return sns.publish_to_target(
            target_arn=self.target_arn,
            message=self.message,
            subject=self.subject,
            message_attributes=self.message_attributes,
        )
def _load_new_files(*, aws_conn_id: str, state_table_name: str,
                    bucket_name: str, prefix: str, files_type_id: str,
                    data_load_trigger_sns_topic_arn: str,
                    data_load_trigger_client_id: str, task: BaseOperator, **_):
    """Python callable for the `LoadNewFilesOperator`."""

    log = task.log

    pipeline_state = PipelineStateHook(state_table_name,
                                       aws_conn_id=aws_conn_id)
    bucket = S3Hook(aws_conn_id).get_bucket(bucket_name)
    sns = AwsSnsHook(aws_conn_id)

    # list files in the bucket and find new ones
    new_files = []
    for file_obj in bucket.objects.filter(Prefix=prefix):
        state_key = _file_state_key(files_type_id, file_obj.key)
        state_value = pipeline_state.get_state(state_key)
        if state_value is None or state_value != file_obj.e_tag:
            new_files.append(file_obj)
    log.info("new files: %s", new_files)

    # check if found any new files
    if not new_files:
        raise AirflowSkipException("no new files found")

    # process each new file
    for file_obj in new_files:

        # trigger data loader
        file_s3_url = f's3://{bucket.name}/{file_obj.key}'
        log.info("triggering data load for %s", file_s3_url)
        sns.publish_to_target(target_arn=data_load_trigger_sns_topic_arn,
                              message=json.dumps({
                                  'client_id': data_load_trigger_client_id,
                                  'data_url': file_s3_url
                              }))

        # save state
        state_key = _file_state_key(files_type_id, file_obj.key)
        pipeline_state.save_state(state_key, file_obj.e_tag)
Beispiel #9
0
    def test_publish_to_target_with_attributes(self):
        hook = AwsSnsHook(aws_conn_id='aws_default')

        message = "Hello world"
        topic_name = "test-topic"
        target = hook.get_conn().create_topic(Name=topic_name).get('TopicArn')

        response = hook.publish_to_target(target,
                                          message,
                                          message_attributes={
                                              'test-string':
                                              'string-value',
                                              'test-number':
                                              123456,
                                              'test-array':
                                              ['first', 'second', 'third'],
                                              'test-binary':
                                              b'binary-value',
                                          })

        assert 'MessageId' in response
def _save_campaign_id_mappings(
        *,
        aws_conn_id: str,
        mysql_conn_id: str,
        bucket_name: str,
        bucket_data_prefix: str,
        partner: str,
        notifications_topic_arn: str,
        ds_nodash: str,
        task: BaseOperator,
        task_instance: TaskInstance,
        **_
):
    """Python callable for the operator that saves campaign id mappings in the firewall DB."""

    log = task.log

    # read new mappings from S3
    extracted_mappings = _load_extracted_mappings(
        aws_conn_id=aws_conn_id,
        bucket_name=bucket_name,
        bucket_data_prefix=bucket_data_prefix,
        partner=partner,
        ds_nodash=ds_nodash,
        log=log

    )
    if not extracted_mappings:
        return "no mappings have been extracted, bailing out"

    # connect to firewall database and update it
    log.info("connecting to firewall database")
    mysql = MySqlHook(mysql_conn_id=mysql_conn_id)
    with closing(mysql.get_conn()) as conn:
        mysql.set_autocommit(conn, False)
        with closing(conn.cursor()) as cur:
            (invalid_ias_adv_entity_ids, overriding_campaign_id_mappings) = _update_firewall_db(
                cur=cur,
                log=log,
                extracted_mappings=extracted_mappings
            )
        log.info("committing transaction")
        conn.commit()

    # send notification about any exceptions
    if invalid_ias_adv_entity_ids or overriding_campaign_id_mappings:
        log.info("sending mapping exceptions notification")
        invalid_ias_adv_entity_ids_msg = ', '.join(
            str(eid) for eid in invalid_ias_adv_entity_ids
        ) if invalid_ias_adv_entity_ids else 'None'
        overriding_campaign_id_mappings_msg = '\n'.join(
            '{:<19} | {:<24} | {:<19}'.format(*row) for row in overriding_campaign_id_mappings
        ) if overriding_campaign_id_mappings else 'None'
        sns = AwsSnsHook(aws_conn_id=aws_conn_id)
        sns.publish_to_target(
            target_arn=notifications_topic_arn,
            subject=f'Campaign ID mapping exceptions ({partner})',
            message=(
                f"Encountered campaign ID mapping exceptions:\n"
                f"\n"
                f"\nDAG:             {task_instance.dag_id}"
                f"\nTask:            {task_instance.task_id}"
                f"\nExecution Date:  {task_instance.execution_date}"
                f"\nHost:            {task_instance.hostname}"
                f"\n"
                f"\nUnknown IAS adv entity IDs:"
                f"\n"
                f"\n{invalid_ias_adv_entity_ids_msg}"
                f"\n"
                f"\nAttempts to change existing mappings:"
                f"\n"
                f"\npartner campaign ID | existing IAS campaign ID | new IAS campaign ID"
                f"\n{overriding_campaign_id_mappings_msg}"
            )
        )

    # done
    return "campaign id mappings have been updated"
Beispiel #11
0
 def test_get_conn_returns_a_boto3_connection(self):
     hook = AwsSnsHook(aws_conn_id='aws_default')
     self.assertIsNotNone(hook.get_conn())
 def test_get_conn_returns_a_boto3_connection(self):
     hook = AwsSnsHook(aws_conn_id='aws_default')
     self.assertIsNotNone(hook.get_conn())