def setUp(self):
     super(TestQueryRunner, self).setUp()
     self.query_runner = QueryRunner()
     self.matrix_infra_config.set(self.__class__.TEST_CONFIG)
     self.query_runner.matrix_infra_config = self.matrix_infra_config
     self.sqs_handler = SQSHandler()
     self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name")
     self.sqs.meta.client.purge_queue(
         QueueUrl="test_deadletter_query_job_q_name")
Beispiel #2
0
    def __init__(self, request_id: str):
        Logging.set_correlation_id(logger, value=request_id)

        self.request_id = request_id
        self.request_tracker = RequestTracker(request_id)
        self.dynamo_handler = DynamoHandler()
        self.sqs_handler = SQSHandler()
        self.infra_config = MatrixInfraConfig()
        self.redshift_config = MatrixRedshiftConfig()
        self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET']
        self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET'])
Beispiel #3
0
    def dss_notification():
        body = app.current_request.json_body
        bundle_uuid = body['match']['bundle_uuid']
        bundle_version = body['match']['bundle_version']
        subscription_id = body['subscription_id']
        event_type = body['event_type']

        config = MatrixInfraConfig()
        hmac_secret_key = config.dss_subscription_hmac_secret_key.encode()
        HTTPSignatureAuth.verify(
            requests.Request(url="http://host/dss/notification",
                             method=app.current_request.method,
                             headers=app.current_request.headers),
            key_resolver=lambda key_id, algorithm: hmac_secret_key)

        payload = {
            'bundle_uuid': bundle_uuid,
            'bundle_version': bundle_version,
            'event_type': event_type,
        }
        queue_url = config.notification_q_url
        SQSHandler().add_message_to_queue(queue_url, payload)

        return chalice.Response(
            status_code=requests.codes.ok,
            body=f"Received notification from subscription {subscription_id}: "
            f"{event_type} {bundle_uuid}.{bundle_version}")
class TestQueryRunner(MatrixTestCaseUsingMockAWS):
    def setUp(self):
        super(TestQueryRunner, self).setUp()
        self.query_runner = QueryRunner()
        self.matrix_infra_config.set(self.__class__.TEST_CONFIG)
        self.query_runner.matrix_infra_config = self.matrix_infra_config
        self.sqs_handler = SQSHandler()
        self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name")
        self.sqs.meta.client.purge_queue(
            QueueUrl="test_deadletter_query_job_q_name")

    @mock.patch(
        "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key")
    @mock.patch(
        "matrix.common.aws.sqs_handler.SQSHandler.receive_messages_from_queue")
    def test_run__with_no_messages_in_queue(self, mock_receive_messages,
                                            mock_load_obj):
        mock_receive_messages.return_value = None
        self.query_runner.run(max_loops=1)
        mock_receive_messages.assert_called_once_with(
            self.query_runner.query_job_q_url)
        mock_load_obj.assert_not_called()

    @mock.patch(
        "matrix.common.aws.batch_handler.BatchHandler.schedule_matrix_conversion"
    )
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.is_request_ready_for_conversion"
    )
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.complete_subtask_execution"
    )
    @mock.patch(
        "matrix.common.aws.redshift_handler.RedshiftHandler.transaction")
    @mock.patch(
        "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key")
    def test_run__with_one_message_in_queue_and_not_ready_for_conversion(
            self, mock_load_obj, mock_transaction, mock_complete_subtask,
            mock_is_ready_for_conversion, mock_schedule_conversion):
        request_id = str(uuid.uuid4())
        payload = {'request_id': request_id, 's3_obj_key': "test_s3_obj_key"}
        self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload)
        mock_is_ready_for_conversion.return_value = False

        self.query_runner.run(max_loops=1)

        mock_load_obj.assert_called_once_with("test_s3_obj_key")
        mock_transaction.assert_called()
        mock_complete_subtask.assert_called_once_with(Subtask.QUERY)
        mock_schedule_conversion.assert_not_called()

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.write_batch_job_id_to_db"
    )
    @mock.patch("matrix.common.request.request_tracker.RequestTracker.format")
    @mock.patch(
        "matrix.common.aws.batch_handler.BatchHandler.schedule_matrix_conversion"
    )
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.is_request_ready_for_conversion"
    )
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.complete_subtask_execution"
    )
    @mock.patch(
        "matrix.common.aws.redshift_handler.RedshiftHandler.transaction")
    @mock.patch(
        "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key")
    def test_run__with_one_message_in_queue_and_ready_for_conversion(
            self, mock_load_obj, mock_transaction, mock_complete_subtask,
            mock_is_ready_for_conversion, mock_schedule_conversion,
            mock_request_format, mock_write_batch_job_id_to_db):
        request_id = str(uuid.uuid4())
        payload = {'request_id': request_id, 's3_obj_key': "test_s3_obj_key"}
        self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload)
        mock_is_ready_for_conversion.return_value = True
        mock_schedule_conversion.return_value = "123-123"

        self.query_runner.run(max_loops=1)

        mock_schedule_conversion.assert_called_once_with(request_id, mock.ANY)
        mock_write_batch_job_id_to_db.assert_called_once_with("123-123")

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.log_error")
    @mock.patch("matrix.common.request.request_tracker.RequestTracker.format")
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.complete_subtask_execution"
    )
    @mock.patch(
        "matrix.common.aws.redshift_handler.RedshiftHandler.transaction")
    @mock.patch(
        "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key")
    def test_run__with_one_message_in_queue_and_fails(self, mock_load_obj,
                                                      mock_transaction,
                                                      mock_complete_subtask,
                                                      mock_request_format,
                                                      mock_log_error):
        request_id = str(uuid.uuid4())
        payload = {'request_id': request_id, 's3_obj_key': "test_s3_obj_key"}
        self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload)
        mock_complete_subtask.side_effect = MatrixException(
            status=requests.codes.not_found, title=f"Unable to find")

        self.query_runner.run(max_loops=1)

        mock_log_error.assert_called_once()
        query_queue_messages = self.sqs_handler.receive_messages_from_queue(
            "test_query_job_q_name", 1)
        self.assertEqual(query_queue_messages, None)
        deadletter_queue_messages = self.sqs_handler.receive_messages_from_queue(
            "test_deadletter_query_job_q_name", 1)
        self.assertEqual(len(deadletter_queue_messages), 1)
        message_body = json.loads(deadletter_queue_messages[0]['Body'])
        self.assertEqual(message_body['request_id'], request_id)
        self.assertEqual(message_body['s3_obj_key'], "test_s3_obj_key")
Beispiel #5
0
import uuid

from connexion.lifecycle import ConnexionResponse

from matrix.common import constants
from matrix.common import query_constructor
from matrix.common.exceptions import MatrixException
from matrix.common.constants import MatrixFormat, MatrixRequestStatus
from matrix.common.config import MatrixInfraConfig
from matrix.common.aws.lambda_handler import LambdaHandler, LambdaName
from matrix.common.aws.redshift_handler import RedshiftHandler, TableName
from matrix.common.request.request_tracker import RequestTracker
from matrix.common.aws.sqs_handler import SQSHandler

lambda_handler = LambdaHandler()
sqs_handler = SQSHandler()
matrix_infra_config = MatrixInfraConfig()


def post_matrix(body: dict):

    feature = body.get("feature", constants.DEFAULT_FEATURE)
    fields = body.get("fields", constants.DEFAULT_FIELDS)
    format_ = body['format'] if 'format' in body else MatrixFormat.LOOM.value
    expected_formats = [mf.value for mf in MatrixFormat]

    # Validate input parameters
    if format_ not in expected_formats:
        return ({
            'message':
            "Invalid parameters supplied. "
Beispiel #6
0
class Driver:
    """
    Formats and stores redshift queries in s3 and sqs for execution.
    """
    def __init__(self, request_id: str):
        Logging.set_correlation_id(logger, value=request_id)

        self.request_id = request_id
        self.request_tracker = RequestTracker(request_id)
        self.dynamo_handler = DynamoHandler()
        self.sqs_handler = SQSHandler()
        self.infra_config = MatrixInfraConfig()
        self.redshift_config = MatrixRedshiftConfig()
        self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET']
        self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET'])

    @property
    def query_job_q_url(self):
        return self.infra_config.query_job_q_url

    @property
    def redshift_role_arn(self):
        return self.redshift_config.redshift_role_arn

    def run(self, filter_: typing.Dict[str, typing.Any], fields: typing.List[str], feature: str):
        """
        Initialize a matrix service request and spawn redshift queries.

        :param filter_: Filter dict describing which cells to get expression data for
        :param fields: Which metadata fields to return
        :param format: MatrixFormat file format of output expression matrix
        :param feature: Which feature (gene vs transcript) to include in output
        """
        logger.debug(f"Driver running with parameters: filter={filter_}, "
                     f"fields={fields}, feature={feature}")

        try:
            matrix_request_queries = query_constructor.create_matrix_request_queries(
                filter_, fields, feature)
        except (query_constructor.MalformedMatrixFilter, query_constructor.MalformedMatrixFeature) as exc:
            self.request_tracker.log_error(f"Query construction failed with error: {str(exc)}")
            raise

        s3_obj_keys = self._format_and_store_queries_in_s3(matrix_request_queries)
        for key in s3_obj_keys:
            self._add_request_query_to_sqs(key, s3_obj_keys[key])
        self.request_tracker.complete_subtask_execution(Subtask.DRIVER)

    def _format_and_store_queries_in_s3(self, queries: dict):
        feature_query = queries[QueryType.FEATURE].format(results_bucket=self.query_results_bucket,
                                                          request_id=self.request_id,
                                                          iam_role=self.redshift_role_arn)
        feature_query_obj_key = self.s3_handler.store_content_in_s3(f"{self.request_id}/{QueryType.FEATURE.value}",
                                                                    feature_query)

        exp_query = queries[QueryType.EXPRESSION].format(results_bucket=self.query_results_bucket,
                                                         request_id=self.request_id,
                                                         iam_role=self.redshift_role_arn)
        exp_query_obj_key = self.s3_handler.store_content_in_s3(f"{self.request_id}/{QueryType.EXPRESSION.value}",
                                                                exp_query)

        cell_query = queries[QueryType.CELL].format(results_bucket=self.query_results_bucket,
                                                    request_id=self.request_id,
                                                    iam_role=self.redshift_role_arn)
        cell_query_obj_key = self.s3_handler.store_content_in_s3(f"{self.request_id}/{QueryType.CELL.value}",
                                                                 cell_query)

        return {
            QueryType.CELL: cell_query_obj_key,
            QueryType.EXPRESSION: exp_query_obj_key,
            QueryType.FEATURE: feature_query_obj_key
        }

    def _add_request_query_to_sqs(self, query_type: QueryType, s3_obj_key: str):
        queue_url = self.query_job_q_url
        payload = {
            'request_id': self.request_id,
            's3_obj_key': s3_obj_key,
            'type': query_type.value
        }
        logger.debug(f"Adding {payload} to sqs {queue_url}")
        self.sqs_handler.add_message_to_queue(queue_url, payload)
Beispiel #7
0
class Driver:
    """
    Formats and stores redshift queries in s3 and sqs for execution.
    """
    def __init__(self, request_id: str, bundles_per_worker: int = 100):
        Logging.set_correlation_id(logger, value=request_id)

        self.request_id = request_id
        self.bundles_per_worker = bundles_per_worker
        self.request_tracker = RequestTracker(request_id)
        self.dynamo_handler = DynamoHandler()
        self.sqs_handler = SQSHandler()
        self.infra_config = MatrixInfraConfig()
        self.redshift_config = MatrixRedshiftConfig()
        self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET']
        self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET'])
        self.redshift_handler = RedshiftHandler()

    @property
    def query_job_q_url(self):
        return self.infra_config.query_job_q_url

    @property
    def redshift_role_arn(self):
        return self.redshift_config.redshift_role_arn

    def run(self, bundle_fqids: typing.List[str], bundle_fqids_url: str,
            format: str):
        """
        Initialize a matrix service request and spawn redshift queries.

        :param bundle_fqids: List of bundle fqids to be queried on
        :param bundle_fqids_url: URL from which bundle_fqids can be retrieved
        :param format: MatrixFormat file format of output expression matrix
        """
        logger.debug(
            f"Driver running with parameters: bundle_fqids={bundle_fqids}, "
            f"bundle_fqids_url={bundle_fqids_url}, format={format}, "
            f"bundles_per_worker={self.bundles_per_worker}")

        if bundle_fqids_url:
            response = self._get_bundle_manifest(bundle_fqids_url)
            resolved_bundle_fqids = self._parse_download_manifest(
                response.text)
            if len(resolved_bundle_fqids) == 0:
                error_msg = "no bundles found in the supplied bundle manifest"
                logger.info(error_msg)
                self.request_tracker.log_error(error_msg)
                return
        else:
            resolved_bundle_fqids = bundle_fqids
        logger.debug(f"resolved bundles: {resolved_bundle_fqids}")

        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.NUM_BUNDLES, len(resolved_bundle_fqids))
        s3_obj_keys = self._format_and_store_queries_in_s3(
            resolved_bundle_fqids)

        analysis_table_bundle_count = self._fetch_bundle_count_from_analysis_table(
            resolved_bundle_fqids)
        if analysis_table_bundle_count != len(resolved_bundle_fqids):
            error_msg = "resolved bundles in request do not match bundles available in matrix service"
            logger.info(error_msg)
            self.request_tracker.log_error(error_msg)
            return

        for key in s3_obj_keys:
            self._add_request_query_to_sqs(key, s3_obj_keys[key])
        self.request_tracker.complete_subtask_execution(Subtask.DRIVER)

    @retry(reraise=True, wait=wait_fixed(5), stop=stop_after_attempt(60))
    def _get_bundle_manifest(self, bundle_fqids_url):
        response = requests.get(bundle_fqids_url)
        return response

    @staticmethod
    def _parse_download_manifest(data: str) -> typing.List[str]:
        def _parse_line(line: str) -> str:
            tokens = line.split("\t")
            return f"{tokens[0]}.{tokens[1]}"

        lines = data.splitlines()[1:]
        return list(map(_parse_line, lines))

    def _format_and_store_queries_in_s3(self, resolved_bundle_fqids: list):
        feature_query = feature_query_template.format(
            self.query_results_bucket, self.request_id, self.redshift_role_arn)
        feature_query_obj_key = self.s3_handler.store_content_in_s3(
            f"{self.request_id}/feature", feature_query)

        exp_query = expression_query_template.format(
            self.query_results_bucket, self.request_id, self.redshift_role_arn,
            format_str_list(resolved_bundle_fqids))
        exp_query_obj_key = self.s3_handler.store_content_in_s3(
            f"{self.request_id}/expression", exp_query)

        cell_query = cell_query_template.format(
            self.query_results_bucket, self.request_id, self.redshift_role_arn,
            format_str_list(resolved_bundle_fqids))
        cell_query_obj_key = self.s3_handler.store_content_in_s3(
            f"{self.request_id}/cell", cell_query)

        return {
            QueryType.CELL: cell_query_obj_key,
            QueryType.EXPRESSION: exp_query_obj_key,
            QueryType.FEATURE: feature_query_obj_key
        }

    def _add_request_query_to_sqs(self, query_type: QueryType,
                                  s3_obj_key: str):
        queue_url = self.query_job_q_url
        payload = {
            'request_id': self.request_id,
            's3_obj_key': s3_obj_key,
            'type': query_type.value
        }
        logger.debug(f"Adding {payload} to sqs {queue_url}")
        self.sqs_handler.add_message_to_queue(queue_url, payload)

    def _fetch_bundle_count_from_analysis_table(self,
                                                resolved_bundle_fqids: list):
        analysis_table_bundle_count_query = analysis_bundle_count_query_template.format(
            format_str_list(resolved_bundle_fqids))
        analysis_table_bundle_count_query = analysis_table_bundle_count_query.strip(
        ).replace('\n', '')
        results = self.redshift_handler.transaction(
            [analysis_table_bundle_count_query],
            read_only=True,
            return_results=True)
        analysis_table_bundle_count = results[0][0]
        return analysis_table_bundle_count
 def __init__(self):
     self.sqs_handler = SQSHandler()
     self.s3_handler = S3Handler(os.environ["MATRIX_QUERY_BUCKET"])
     self.batch_handler = BatchHandler()
     self.redshift_handler = RedshiftHandler()
     self.matrix_infra_config = MatrixInfraConfig()
class QueryRunner:
    def __init__(self):
        self.sqs_handler = SQSHandler()
        self.s3_handler = S3Handler(os.environ["MATRIX_QUERY_BUCKET"])
        self.batch_handler = BatchHandler()
        self.redshift_handler = RedshiftHandler()
        self.matrix_infra_config = MatrixInfraConfig()

    @property
    def query_job_q_url(self):
        return self.matrix_infra_config.query_job_q_url

    @property
    def query_job_deadletter_q_url(self):
        return self.matrix_infra_config.query_job_deadletter_q_url

    def run(self, max_loops=None):
        loops = 0
        while max_loops is None or loops < max_loops:
            loops += 1
            messages = self.sqs_handler.receive_messages_from_queue(
                self.query_job_q_url)
            if messages:
                message = messages[0]
                logger.info(f"Received {message} from {self.query_job_q_url}")
                payload = json.loads(message['Body'])
                request_id = payload['request_id']
                request_tracker = RequestTracker(request_id)
                Logging.set_correlation_id(logger, value=request_id)
                obj_key = payload['s3_obj_key']
                receipt_handle = message['ReceiptHandle']
                try:
                    logger.info(f"Fetching query from {obj_key}")
                    query = self.s3_handler.load_content_from_obj_key(obj_key)

                    logger.info(f"Running query from {obj_key}")
                    self.redshift_handler.transaction([query], read_only=True)
                    logger.info(f"Finished running query from {obj_key}")

                    logger.info(
                        f"Deleting {message} from {self.query_job_q_url}")
                    self.sqs_handler.delete_message_from_queue(
                        self.query_job_q_url, receipt_handle)

                    logger.info(
                        "Incrementing completed queries in state table")
                    request_tracker.complete_subtask_execution(Subtask.QUERY)

                    if request_tracker.is_request_ready_for_conversion():
                        logger.info("Scheduling batch conversion job")
                        batch_job_id = self.batch_handler.schedule_matrix_conversion(
                            request_id, request_tracker.format)
                        request_tracker.write_batch_job_id_to_db(batch_job_id)
                except Exception as e:
                    logger.info(
                        f"QueryRunner failed on {message} with error {e}")
                    request_tracker.log_error(str(e))
                    logger.info(
                        f"Adding {message} to {self.query_job_deadletter_q_url}"
                    )
                    self.sqs_handler.add_message_to_queue(
                        self.query_job_deadletter_q_url, payload)
                    logger.info(
                        f"Deleting {message} from {self.query_job_q_url}")
                    self.sqs_handler.delete_message_from_queue(
                        self.query_job_q_url, receipt_handle)
            else:
                logger.info(f"No messages to read from {self.query_job_q_url}")
Beispiel #10
0
 def setUp(self):
     super(TestSQSHandler, self).setUp()
     self.sqs_handler = SQSHandler()
     self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name")
Beispiel #11
0
class TestSQSHandler(MatrixTestCaseUsingMockAWS):
    def setUp(self):
        super(TestSQSHandler, self).setUp()
        self.sqs_handler = SQSHandler()
        self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name")

    def test_add_message_to_queue(self):
        payload = {'test_key': "test_value"}
        self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload)

        messages = self.sqs.meta.client.receive_message(
            QueueUrl="test_query_job_q_name")
        message_body = json.loads(messages['Messages'][0]['Body'])
        self.assertEqual(message_body['test_key'], "test_value")

    def test_receive_messages_from_queue__returns_None_when_no_messages_found(
            self):
        message = self.sqs_handler.receive_messages_from_queue(
            "test_query_job_q_name", 1)
        self.assertEqual(message, None)

    def test_retrieve_messages_from_queue__returns_message_when_message_is_found(
            self):
        payload = {'test_key': "test_value"}
        self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload)

        messages = self.sqs_handler.receive_messages_from_queue(
            queue_url="test_query_job_q_name")

        message_body = json.loads(messages[0]['Body'])
        self.assertEqual(len(messages), 1)
        self.assertEqual(message_body['test_key'], "test_value")

    def test_delete_message_from_queue(self):
        payload = {'test_key': "test_value"}
        self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload)
        messages = self.sqs_handler.receive_messages_from_queue(
            queue_url="test_query_job_q_name")
        receipt_handle = messages[0]['ReceiptHandle']

        self.sqs_handler.delete_message_from_queue("test_query_job_q_name",
                                                   receipt_handle)

        message = self.sqs_handler.receive_messages_from_queue(
            "test_query_job_q_name", 1)
        self.assertEqual(message, None)