class QueryRunner: def __init__(self): self.sqs_handler = SQSHandler() self.s3_handler = S3Handler(os.environ["MATRIX_QUERY_BUCKET"]) self.batch_handler = BatchHandler() self.redshift_handler = RedshiftHandler() self.matrix_infra_config = MatrixInfraConfig() @property def query_job_q_url(self): return self.matrix_infra_config.query_job_q_url @property def query_job_deadletter_q_url(self): return self.matrix_infra_config.query_job_deadletter_q_url def run(self, max_loops=None): loops = 0 while max_loops is None or loops < max_loops: loops += 1 messages = self.sqs_handler.receive_messages_from_queue( self.query_job_q_url) if messages: message = messages[0] logger.info(f"Received {message} from {self.query_job_q_url}") payload = json.loads(message['Body']) request_id = payload['request_id'] request_tracker = RequestTracker(request_id) Logging.set_correlation_id(logger, value=request_id) obj_key = payload['s3_obj_key'] receipt_handle = message['ReceiptHandle'] try: logger.info(f"Fetching query from {obj_key}") query = self.s3_handler.load_content_from_obj_key(obj_key) logger.info(f"Running query from {obj_key}") self.redshift_handler.transaction([query], read_only=True) logger.info(f"Finished running query from {obj_key}") logger.info( f"Deleting {message} from {self.query_job_q_url}") self.sqs_handler.delete_message_from_queue( self.query_job_q_url, receipt_handle) logger.info( "Incrementing completed queries in state table") request_tracker.complete_subtask_execution(Subtask.QUERY) if request_tracker.is_request_ready_for_conversion(): logger.info("Scheduling batch conversion job") batch_job_id = self.batch_handler.schedule_matrix_conversion( request_id, request_tracker.format) request_tracker.write_batch_job_id_to_db(batch_job_id) except Exception as e: logger.info( f"QueryRunner failed on {message} with error {e}") request_tracker.log_error(str(e)) logger.info( f"Adding {message} to {self.query_job_deadletter_q_url}" ) self.sqs_handler.add_message_to_queue( self.query_job_deadletter_q_url, payload) logger.info( f"Deleting {message} from {self.query_job_q_url}") self.sqs_handler.delete_message_from_queue( self.query_job_q_url, receipt_handle) else: logger.info(f"No messages to read from {self.query_job_q_url}")
class TestBatchHandler(unittest.TestCase): def setUp(self): self.request_id = str(uuid.uuid4()) self.batch_handler = BatchHandler() self.mock_batch_client = Stubber(self.batch_handler._client) @mock.patch( "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data" ) @mock.patch( "matrix.common.aws.batch_handler.BatchHandler._enqueue_batch_job") def test_schedule_matrix_conversion(self, mock_enqueue_batch_job, mock_cw_put): format = "test_format" job_name = f"conversion-{os.environ['DEPLOYMENT_STAGE']}-{self.request_id}-{format}" self.batch_handler.schedule_matrix_conversion(self.request_id, format) mock_enqueue_batch_job.assert_called_once_with( job_name=job_name, job_queue_arn=os.environ['BATCH_CONVERTER_JOB_QUEUE_ARN'], job_def_arn=os.environ['BATCH_CONVERTER_JOB_DEFINITION_ARN'], command=mock.ANY, environment=mock.ANY) mock_cw_put.assert_called_once_with( metric_name=MetricName.CONVERSION_REQUEST, metric_value=1) def test_enqueue_batch_job(self): expected_params = { 'jobName': "test_job_name", 'jobQueue': "test_job_queue", 'jobDefinition': "test_job_definition", 'containerOverrides': { 'command': [], 'environment': [] } } expected_response = {'jobId': "test_id", 'jobName': "test_job_name"} self.mock_batch_client.add_response('submit_job', expected_response, expected_params) self.mock_batch_client.activate() self.batch_handler._enqueue_batch_job("test_job_name", "test_job_queue", "test_job_definition", [], {}) def test_get_batch_job_status(self): expected_params = {'jobs': ['123']} expected_response = { 'jobs': [{ 'status': "FAILED", 'jobName': "test_job_name", 'jobId': "test_job_id", 'jobQueue': "test_job_queue", 'startedAt': 123, 'jobDefinition': "test_job_definition" }] } self.mock_batch_client.add_response('describe_jobs', expected_response, expected_params) self.mock_batch_client.activate() status = self.batch_handler.get_batch_job_status('123') self.assertEqual(status, 'FAILED')