def setUp(self): super(TestQueryRunner, self).setUp() self.query_runner = QueryRunner() self.matrix_infra_config.set(self.__class__.TEST_CONFIG) self.query_runner.matrix_infra_config = self.matrix_infra_config self.sqs_handler = SQSHandler() self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name") self.sqs.meta.client.purge_queue( QueueUrl="test_deadletter_query_job_q_name")
def __init__(self, request_id: str): Logging.set_correlation_id(logger, value=request_id) self.request_id = request_id self.request_tracker = RequestTracker(request_id) self.dynamo_handler = DynamoHandler() self.sqs_handler = SQSHandler() self.infra_config = MatrixInfraConfig() self.redshift_config = MatrixRedshiftConfig() self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET'] self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET'])
def dss_notification(): body = app.current_request.json_body bundle_uuid = body['match']['bundle_uuid'] bundle_version = body['match']['bundle_version'] subscription_id = body['subscription_id'] event_type = body['event_type'] config = MatrixInfraConfig() hmac_secret_key = config.dss_subscription_hmac_secret_key.encode() HTTPSignatureAuth.verify( requests.Request(url="http://host/dss/notification", method=app.current_request.method, headers=app.current_request.headers), key_resolver=lambda key_id, algorithm: hmac_secret_key) payload = { 'bundle_uuid': bundle_uuid, 'bundle_version': bundle_version, 'event_type': event_type, } queue_url = config.notification_q_url SQSHandler().add_message_to_queue(queue_url, payload) return chalice.Response( status_code=requests.codes.ok, body=f"Received notification from subscription {subscription_id}: " f"{event_type} {bundle_uuid}.{bundle_version}")
class TestQueryRunner(MatrixTestCaseUsingMockAWS): def setUp(self): super(TestQueryRunner, self).setUp() self.query_runner = QueryRunner() self.matrix_infra_config.set(self.__class__.TEST_CONFIG) self.query_runner.matrix_infra_config = self.matrix_infra_config self.sqs_handler = SQSHandler() self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name") self.sqs.meta.client.purge_queue( QueueUrl="test_deadletter_query_job_q_name") @mock.patch( "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key") @mock.patch( "matrix.common.aws.sqs_handler.SQSHandler.receive_messages_from_queue") def test_run__with_no_messages_in_queue(self, mock_receive_messages, mock_load_obj): mock_receive_messages.return_value = None self.query_runner.run(max_loops=1) mock_receive_messages.assert_called_once_with( self.query_runner.query_job_q_url) mock_load_obj.assert_not_called() @mock.patch( "matrix.common.aws.batch_handler.BatchHandler.schedule_matrix_conversion" ) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.is_request_ready_for_conversion" ) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.complete_subtask_execution" ) @mock.patch( "matrix.common.aws.redshift_handler.RedshiftHandler.transaction") @mock.patch( "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key") def test_run__with_one_message_in_queue_and_not_ready_for_conversion( self, mock_load_obj, mock_transaction, mock_complete_subtask, mock_is_ready_for_conversion, mock_schedule_conversion): request_id = str(uuid.uuid4()) payload = {'request_id': request_id, 's3_obj_key': "test_s3_obj_key"} self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload) mock_is_ready_for_conversion.return_value = False self.query_runner.run(max_loops=1) mock_load_obj.assert_called_once_with("test_s3_obj_key") mock_transaction.assert_called() mock_complete_subtask.assert_called_once_with(Subtask.QUERY) mock_schedule_conversion.assert_not_called() @mock.patch( "matrix.common.request.request_tracker.RequestTracker.write_batch_job_id_to_db" ) @mock.patch("matrix.common.request.request_tracker.RequestTracker.format") @mock.patch( "matrix.common.aws.batch_handler.BatchHandler.schedule_matrix_conversion" ) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.is_request_ready_for_conversion" ) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.complete_subtask_execution" ) @mock.patch( "matrix.common.aws.redshift_handler.RedshiftHandler.transaction") @mock.patch( "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key") def test_run__with_one_message_in_queue_and_ready_for_conversion( self, mock_load_obj, mock_transaction, mock_complete_subtask, mock_is_ready_for_conversion, mock_schedule_conversion, mock_request_format, mock_write_batch_job_id_to_db): request_id = str(uuid.uuid4()) payload = {'request_id': request_id, 's3_obj_key': "test_s3_obj_key"} self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload) mock_is_ready_for_conversion.return_value = True mock_schedule_conversion.return_value = "123-123" self.query_runner.run(max_loops=1) mock_schedule_conversion.assert_called_once_with(request_id, mock.ANY) mock_write_batch_job_id_to_db.assert_called_once_with("123-123") @mock.patch( "matrix.common.request.request_tracker.RequestTracker.log_error") @mock.patch("matrix.common.request.request_tracker.RequestTracker.format") @mock.patch( "matrix.common.request.request_tracker.RequestTracker.complete_subtask_execution" ) @mock.patch( "matrix.common.aws.redshift_handler.RedshiftHandler.transaction") @mock.patch( "matrix.common.aws.s3_handler.S3Handler.load_content_from_obj_key") def test_run__with_one_message_in_queue_and_fails(self, mock_load_obj, mock_transaction, mock_complete_subtask, mock_request_format, mock_log_error): request_id = str(uuid.uuid4()) payload = {'request_id': request_id, 's3_obj_key': "test_s3_obj_key"} self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload) mock_complete_subtask.side_effect = MatrixException( status=requests.codes.not_found, title=f"Unable to find") self.query_runner.run(max_loops=1) mock_log_error.assert_called_once() query_queue_messages = self.sqs_handler.receive_messages_from_queue( "test_query_job_q_name", 1) self.assertEqual(query_queue_messages, None) deadletter_queue_messages = self.sqs_handler.receive_messages_from_queue( "test_deadletter_query_job_q_name", 1) self.assertEqual(len(deadletter_queue_messages), 1) message_body = json.loads(deadletter_queue_messages[0]['Body']) self.assertEqual(message_body['request_id'], request_id) self.assertEqual(message_body['s3_obj_key'], "test_s3_obj_key")
import uuid from connexion.lifecycle import ConnexionResponse from matrix.common import constants from matrix.common import query_constructor from matrix.common.exceptions import MatrixException from matrix.common.constants import MatrixFormat, MatrixRequestStatus from matrix.common.config import MatrixInfraConfig from matrix.common.aws.lambda_handler import LambdaHandler, LambdaName from matrix.common.aws.redshift_handler import RedshiftHandler, TableName from matrix.common.request.request_tracker import RequestTracker from matrix.common.aws.sqs_handler import SQSHandler lambda_handler = LambdaHandler() sqs_handler = SQSHandler() matrix_infra_config = MatrixInfraConfig() def post_matrix(body: dict): feature = body.get("feature", constants.DEFAULT_FEATURE) fields = body.get("fields", constants.DEFAULT_FIELDS) format_ = body['format'] if 'format' in body else MatrixFormat.LOOM.value expected_formats = [mf.value for mf in MatrixFormat] # Validate input parameters if format_ not in expected_formats: return ({ 'message': "Invalid parameters supplied. "
class Driver: """ Formats and stores redshift queries in s3 and sqs for execution. """ def __init__(self, request_id: str): Logging.set_correlation_id(logger, value=request_id) self.request_id = request_id self.request_tracker = RequestTracker(request_id) self.dynamo_handler = DynamoHandler() self.sqs_handler = SQSHandler() self.infra_config = MatrixInfraConfig() self.redshift_config = MatrixRedshiftConfig() self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET'] self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET']) @property def query_job_q_url(self): return self.infra_config.query_job_q_url @property def redshift_role_arn(self): return self.redshift_config.redshift_role_arn def run(self, filter_: typing.Dict[str, typing.Any], fields: typing.List[str], feature: str): """ Initialize a matrix service request and spawn redshift queries. :param filter_: Filter dict describing which cells to get expression data for :param fields: Which metadata fields to return :param format: MatrixFormat file format of output expression matrix :param feature: Which feature (gene vs transcript) to include in output """ logger.debug(f"Driver running with parameters: filter={filter_}, " f"fields={fields}, feature={feature}") try: matrix_request_queries = query_constructor.create_matrix_request_queries( filter_, fields, feature) except (query_constructor.MalformedMatrixFilter, query_constructor.MalformedMatrixFeature) as exc: self.request_tracker.log_error(f"Query construction failed with error: {str(exc)}") raise s3_obj_keys = self._format_and_store_queries_in_s3(matrix_request_queries) for key in s3_obj_keys: self._add_request_query_to_sqs(key, s3_obj_keys[key]) self.request_tracker.complete_subtask_execution(Subtask.DRIVER) def _format_and_store_queries_in_s3(self, queries: dict): feature_query = queries[QueryType.FEATURE].format(results_bucket=self.query_results_bucket, request_id=self.request_id, iam_role=self.redshift_role_arn) feature_query_obj_key = self.s3_handler.store_content_in_s3(f"{self.request_id}/{QueryType.FEATURE.value}", feature_query) exp_query = queries[QueryType.EXPRESSION].format(results_bucket=self.query_results_bucket, request_id=self.request_id, iam_role=self.redshift_role_arn) exp_query_obj_key = self.s3_handler.store_content_in_s3(f"{self.request_id}/{QueryType.EXPRESSION.value}", exp_query) cell_query = queries[QueryType.CELL].format(results_bucket=self.query_results_bucket, request_id=self.request_id, iam_role=self.redshift_role_arn) cell_query_obj_key = self.s3_handler.store_content_in_s3(f"{self.request_id}/{QueryType.CELL.value}", cell_query) return { QueryType.CELL: cell_query_obj_key, QueryType.EXPRESSION: exp_query_obj_key, QueryType.FEATURE: feature_query_obj_key } def _add_request_query_to_sqs(self, query_type: QueryType, s3_obj_key: str): queue_url = self.query_job_q_url payload = { 'request_id': self.request_id, 's3_obj_key': s3_obj_key, 'type': query_type.value } logger.debug(f"Adding {payload} to sqs {queue_url}") self.sqs_handler.add_message_to_queue(queue_url, payload)
class Driver: """ Formats and stores redshift queries in s3 and sqs for execution. """ def __init__(self, request_id: str, bundles_per_worker: int = 100): Logging.set_correlation_id(logger, value=request_id) self.request_id = request_id self.bundles_per_worker = bundles_per_worker self.request_tracker = RequestTracker(request_id) self.dynamo_handler = DynamoHandler() self.sqs_handler = SQSHandler() self.infra_config = MatrixInfraConfig() self.redshift_config = MatrixRedshiftConfig() self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET'] self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET']) self.redshift_handler = RedshiftHandler() @property def query_job_q_url(self): return self.infra_config.query_job_q_url @property def redshift_role_arn(self): return self.redshift_config.redshift_role_arn def run(self, bundle_fqids: typing.List[str], bundle_fqids_url: str, format: str): """ Initialize a matrix service request and spawn redshift queries. :param bundle_fqids: List of bundle fqids to be queried on :param bundle_fqids_url: URL from which bundle_fqids can be retrieved :param format: MatrixFormat file format of output expression matrix """ logger.debug( f"Driver running with parameters: bundle_fqids={bundle_fqids}, " f"bundle_fqids_url={bundle_fqids_url}, format={format}, " f"bundles_per_worker={self.bundles_per_worker}") if bundle_fqids_url: response = self._get_bundle_manifest(bundle_fqids_url) resolved_bundle_fqids = self._parse_download_manifest( response.text) if len(resolved_bundle_fqids) == 0: error_msg = "no bundles found in the supplied bundle manifest" logger.info(error_msg) self.request_tracker.log_error(error_msg) return else: resolved_bundle_fqids = bundle_fqids logger.debug(f"resolved bundles: {resolved_bundle_fqids}") self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.NUM_BUNDLES, len(resolved_bundle_fqids)) s3_obj_keys = self._format_and_store_queries_in_s3( resolved_bundle_fqids) analysis_table_bundle_count = self._fetch_bundle_count_from_analysis_table( resolved_bundle_fqids) if analysis_table_bundle_count != len(resolved_bundle_fqids): error_msg = "resolved bundles in request do not match bundles available in matrix service" logger.info(error_msg) self.request_tracker.log_error(error_msg) return for key in s3_obj_keys: self._add_request_query_to_sqs(key, s3_obj_keys[key]) self.request_tracker.complete_subtask_execution(Subtask.DRIVER) @retry(reraise=True, wait=wait_fixed(5), stop=stop_after_attempt(60)) def _get_bundle_manifest(self, bundle_fqids_url): response = requests.get(bundle_fqids_url) return response @staticmethod def _parse_download_manifest(data: str) -> typing.List[str]: def _parse_line(line: str) -> str: tokens = line.split("\t") return f"{tokens[0]}.{tokens[1]}" lines = data.splitlines()[1:] return list(map(_parse_line, lines)) def _format_and_store_queries_in_s3(self, resolved_bundle_fqids: list): feature_query = feature_query_template.format( self.query_results_bucket, self.request_id, self.redshift_role_arn) feature_query_obj_key = self.s3_handler.store_content_in_s3( f"{self.request_id}/feature", feature_query) exp_query = expression_query_template.format( self.query_results_bucket, self.request_id, self.redshift_role_arn, format_str_list(resolved_bundle_fqids)) exp_query_obj_key = self.s3_handler.store_content_in_s3( f"{self.request_id}/expression", exp_query) cell_query = cell_query_template.format( self.query_results_bucket, self.request_id, self.redshift_role_arn, format_str_list(resolved_bundle_fqids)) cell_query_obj_key = self.s3_handler.store_content_in_s3( f"{self.request_id}/cell", cell_query) return { QueryType.CELL: cell_query_obj_key, QueryType.EXPRESSION: exp_query_obj_key, QueryType.FEATURE: feature_query_obj_key } def _add_request_query_to_sqs(self, query_type: QueryType, s3_obj_key: str): queue_url = self.query_job_q_url payload = { 'request_id': self.request_id, 's3_obj_key': s3_obj_key, 'type': query_type.value } logger.debug(f"Adding {payload} to sqs {queue_url}") self.sqs_handler.add_message_to_queue(queue_url, payload) def _fetch_bundle_count_from_analysis_table(self, resolved_bundle_fqids: list): analysis_table_bundle_count_query = analysis_bundle_count_query_template.format( format_str_list(resolved_bundle_fqids)) analysis_table_bundle_count_query = analysis_table_bundle_count_query.strip( ).replace('\n', '') results = self.redshift_handler.transaction( [analysis_table_bundle_count_query], read_only=True, return_results=True) analysis_table_bundle_count = results[0][0] return analysis_table_bundle_count
def __init__(self): self.sqs_handler = SQSHandler() self.s3_handler = S3Handler(os.environ["MATRIX_QUERY_BUCKET"]) self.batch_handler = BatchHandler() self.redshift_handler = RedshiftHandler() self.matrix_infra_config = MatrixInfraConfig()
class QueryRunner: def __init__(self): self.sqs_handler = SQSHandler() self.s3_handler = S3Handler(os.environ["MATRIX_QUERY_BUCKET"]) self.batch_handler = BatchHandler() self.redshift_handler = RedshiftHandler() self.matrix_infra_config = MatrixInfraConfig() @property def query_job_q_url(self): return self.matrix_infra_config.query_job_q_url @property def query_job_deadletter_q_url(self): return self.matrix_infra_config.query_job_deadletter_q_url def run(self, max_loops=None): loops = 0 while max_loops is None or loops < max_loops: loops += 1 messages = self.sqs_handler.receive_messages_from_queue( self.query_job_q_url) if messages: message = messages[0] logger.info(f"Received {message} from {self.query_job_q_url}") payload = json.loads(message['Body']) request_id = payload['request_id'] request_tracker = RequestTracker(request_id) Logging.set_correlation_id(logger, value=request_id) obj_key = payload['s3_obj_key'] receipt_handle = message['ReceiptHandle'] try: logger.info(f"Fetching query from {obj_key}") query = self.s3_handler.load_content_from_obj_key(obj_key) logger.info(f"Running query from {obj_key}") self.redshift_handler.transaction([query], read_only=True) logger.info(f"Finished running query from {obj_key}") logger.info( f"Deleting {message} from {self.query_job_q_url}") self.sqs_handler.delete_message_from_queue( self.query_job_q_url, receipt_handle) logger.info( "Incrementing completed queries in state table") request_tracker.complete_subtask_execution(Subtask.QUERY) if request_tracker.is_request_ready_for_conversion(): logger.info("Scheduling batch conversion job") batch_job_id = self.batch_handler.schedule_matrix_conversion( request_id, request_tracker.format) request_tracker.write_batch_job_id_to_db(batch_job_id) except Exception as e: logger.info( f"QueryRunner failed on {message} with error {e}") request_tracker.log_error(str(e)) logger.info( f"Adding {message} to {self.query_job_deadletter_q_url}" ) self.sqs_handler.add_message_to_queue( self.query_job_deadletter_q_url, payload) logger.info( f"Deleting {message} from {self.query_job_q_url}") self.sqs_handler.delete_message_from_queue( self.query_job_q_url, receipt_handle) else: logger.info(f"No messages to read from {self.query_job_q_url}")
def setUp(self): super(TestSQSHandler, self).setUp() self.sqs_handler = SQSHandler() self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name")
class TestSQSHandler(MatrixTestCaseUsingMockAWS): def setUp(self): super(TestSQSHandler, self).setUp() self.sqs_handler = SQSHandler() self.sqs.meta.client.purge_queue(QueueUrl="test_query_job_q_name") def test_add_message_to_queue(self): payload = {'test_key': "test_value"} self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload) messages = self.sqs.meta.client.receive_message( QueueUrl="test_query_job_q_name") message_body = json.loads(messages['Messages'][0]['Body']) self.assertEqual(message_body['test_key'], "test_value") def test_receive_messages_from_queue__returns_None_when_no_messages_found( self): message = self.sqs_handler.receive_messages_from_queue( "test_query_job_q_name", 1) self.assertEqual(message, None) def test_retrieve_messages_from_queue__returns_message_when_message_is_found( self): payload = {'test_key': "test_value"} self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload) messages = self.sqs_handler.receive_messages_from_queue( queue_url="test_query_job_q_name") message_body = json.loads(messages[0]['Body']) self.assertEqual(len(messages), 1) self.assertEqual(message_body['test_key'], "test_value") def test_delete_message_from_queue(self): payload = {'test_key': "test_value"} self.sqs_handler.add_message_to_queue("test_query_job_q_name", payload) messages = self.sqs_handler.receive_messages_from_queue( queue_url="test_query_job_q_name") receipt_handle = messages[0]['ReceiptHandle'] self.sqs_handler.delete_message_from_queue("test_query_job_q_name", receipt_handle) message = self.sqs_handler.receive_messages_from_queue( "test_query_job_q_name", 1) self.assertEqual(message, None)