def _create_tables():
    redshift = RedshiftHandler()
    transaction = []
    for table in TableName:
        transaction.append(CREATE_QUERY_TEMPLATE[table.value].format("", "", table.value))

    redshift.transaction(transaction)
 def setUp(self):
     self.dss_env = MATRIX_ENV_TO_DSS_ENV[os.environ['DEPLOYMENT_STAGE']]
     self.api_url = f"https://{os.environ['API_HOST']}/v0"
     self.res_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "res")
     self.headers = {'Content-type': 'application/json', 'Accept': 'application/json'}
     self.verbose = True
     self.s3_file_system = s3fs.S3FileSystem(anon=False)
     self.redshift_handler = RedshiftHandler()
    def __init__(self, bundle_uuid, bundle_version, event_type):
        logger.info(
            f"Running NotificationHandler with parameters: {bundle_uuid}, {bundle_version}, {event_type}"
        )
        self.bundle_uuid = bundle_uuid
        self.bundle_version = bundle_version
        self.event_type = event_type

        self.redshift = RedshiftHandler()
Beispiel #4
0
    def __init__(self, request_id: str, bundles_per_worker: int = 100):
        Logging.set_correlation_id(logger, value=request_id)

        self.request_id = request_id
        self.bundles_per_worker = bundles_per_worker
        self.request_tracker = RequestTracker(request_id)
        self.dynamo_handler = DynamoHandler()
        self.sqs_handler = SQSHandler()
        self.infra_config = MatrixInfraConfig()
        self.redshift_config = MatrixRedshiftConfig()
        self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET']
        self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET'])
        self.redshift_handler = RedshiftHandler()
def load_tables(job_id: str, is_update: bool = False):
    """
    Creates tables and loads PSVs from S3 into Redshift via SQL COPY.
    :param job_id: UUID of the S3 prefix containinng the data to load in
    :param is_update: True if existing database rows will be updated, else False
    """
    _create_tables()
    lock_query = """LOCK TABLE write_lock"""
    delete_query_template = """DELETE FROM {0} USING {0}_temp WHERE {0}.{1} = {0}_temp.{1};"""
    insert_query_template = """INSERT INTO {0} SELECT * FROM {0}_temp;"""

    redshift = RedshiftHandler()
    transaction = [lock_query]
    for table in TableName:
        if (is_update and table
                == TableName.FEATURE) or table == TableName.WRITE_LOCK:
            continue
        s3_prefix = f"s3://{os.environ['MATRIX_PRELOAD_BUCKET']}/{job_id}/{table.value}"
        if table == TableName.CELL:
            s3_prefix += '/'
        iam = os.environ['MATRIX_REDSHIFT_IAM_ROLE_ARN']

        table_name = table.value if not is_update else f"{table.value}_temp"

        if table == TableName.FEATURE:
            copy_stmt = f"COPY {table_name} FROM '{s3_prefix}' iam_role '{iam}' COMPUPDATE ON;"
        elif table == TableName.CELL:
            copy_stmt = f"COPY {table_name} FROM '{s3_prefix}' iam_role '{iam}' GZIP COMPUPDATE ON;"
        elif table == TableName.EXPRESSION:
            copy_stmt = f"COPY {table_name} FROM '{s3_prefix}' iam_role '{iam}' GZIP COMPUPDATE ON COMPROWS 10000000;"
        else:
            copy_stmt = f"COPY {table_name} FROM '{s3_prefix}' iam_role '{iam}';"

        if is_update:
            logger.info(f"ETL: Building queries to update {table.value} table")
            transaction.extend([
                CREATE_QUERY_TEMPLATE[table.value].format(
                    "TEMP ", "_temp", table_name), copy_stmt,
                delete_query_template.format(
                    table.value, RedshiftHandler.PRIMARY_KEY[table]),
                insert_query_template.format(table.value)
            ])
        else:
            logger.info(f"ETL: Building queries to load {table_name} table")
            transaction.append(copy_stmt)

    logger.info(f"ETL: Populating Redshift tables. Committing transaction.")
    try:
        redshift.transaction(transaction)
    except psycopg2.Error as e:
        logger.error("Failed to populate Redshift tables. Rolling back.", e)
def _verify_load(es_query):
    dss_client = etl.get_dss_client(deployment_stage=os.environ['DEPLOYMENT_STAGE'])
    response = dss_client.post_search.iterate(es_query=es_query,
                                              replica='aws',
                                              per_page=500)
    expected_bundles = list(result['bundle_fqid'] for result in response)

    print(f"Loading {len(expected_bundles)} bundles to {os.environ['DEPLOYMENT_STAGE']} complete.\n"
          f"Verifying row counts in Redshift...")
    redshift = RedshiftHandler()
    count_bundles_query = f"SELECT COUNT(*) FROM analysis WHERE bundle_fqid IN {format_str_list(expected_bundles)}"
    results = redshift.transaction(queries=[count_bundles_query],
                                   return_results=True)
    print(f"Found {results[0][0]} analysis rows for {len(expected_bundles)} expected bundles.")
    assert (results[0][0] == len(expected_bundles))
Beispiel #7
0
def _redshift_detail_lookup(name, description):
    type_ = constants.METADATA_FIELD_TO_TYPE[name]
    column_name = constants.METADATA_FIELD_TO_TABLE_COLUMN[name]
    table_name = constants.TABLE_COLUMN_TO_TABLE[column_name]
    fq_name = table_name + "." + column_name

    table_primary_key = RedshiftHandler.PRIMARY_KEY[TableName(table_name)]

    rs_handler = RedshiftHandler()

    if type_ == "categorical":
        query = query_constructor.create_field_detail_query(
            fq_name, table_name, table_primary_key, type_)
        results = dict(
            rs_handler.transaction([query],
                                   return_results=True,
                                   read_only=True))
        if None in results:
            results[""] = results[None]
            results.pop(None)
        if True in results:
            results["True"] = results[True]
            results.pop(True)
        if False in results:
            results["False"] = results[False]
            results.pop(False)
        return ({
            "field_name": name,
            "field_description": description,
            "field_type": type_,
            "cell_counts": results
        }, requests.codes.ok)

    elif type_ == "numeric":
        query = query_constructor.create_field_detail_query(
            fq_name, table_name, table_primary_key, type_)
        results = rs_handler.transaction([query],
                                         return_results=True,
                                         read_only=True)
        min_ = results[0][0]
        max_ = results[0][1]
        return ({
            "field_name": name,
            "field_description": description,
            "field_type": type_,
            "minimum": min_,
            "maximum": max_
        }, requests.codes.ok)
#!/usr/bin/env python
"""
This script creates a readonly user in the redshift db
"""

from matrix.common.aws.redshift_handler import RedshiftHandler
from matrix.common.config import MatrixRedshiftConfig

matrix_config = MatrixRedshiftConfig()
redshift_handler = RedshiftHandler()


def handler():
    readonly_username = matrix_config.readonly_username
    readonly_password = matrix_config.readonly_password
    drop_user_query = f"DROP USER IF EXISTS {readonly_username};"
    add_user_query = f"CREATE USER {readonly_username} WITH PASSWORD '{readonly_password}';"
    grant_public_query = f"grant select on all tables in schema public to {readonly_username};"
    grant_information_schema_query = f"grant select on all tables in schema information_schema to {readonly_username};"
    try:
        redshift_handler.transaction([drop_user_query, add_user_query])
    except Exception as e:
        print(e)
    redshift_handler.transaction(
        [grant_public_query, grant_information_schema_query])
    print("permissions applied to readonly user")


if __name__ == '__main__':
    handler()
class TestMatrixServiceV0(MatrixServiceTest):
    def setUp(self):
        self.dss_env = MATRIX_ENV_TO_DSS_ENV[os.environ['DEPLOYMENT_STAGE']]
        self.api_url = f"https://{os.environ['API_HOST']}/v0"
        self.res_dir = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), "res")
        self.headers = {
            'Content-type': 'application/json',
            'Accept': 'application/json'
        }
        self.verbose = True
        self.s3_file_system = s3fs.S3FileSystem(anon=False)
        self.redshift_handler = RedshiftHandler()

    def tearDown(self):
        if hasattr(self, 'request_id') and self.request_id:
            self._cleanup_matrix_result(self.request_id)

    def test_single_bundle_request(self):
        self.request_id = self._post_matrix_service_request(
            bundle_fqids=[INPUT_BUNDLE_IDS[self.dss_env][0]], format="loom")
        WaitFor(self._poll_get_matrix_service_request, self.request_id)\
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=1200)
        self._analyze_loom_matrix_results(self.request_id,
                                          [INPUT_BUNDLE_IDS[self.dss_env][0]])

    def test_loom_output_matrix_service(self):
        self.request_id = self._post_matrix_service_request(
            bundle_fqids=INPUT_BUNDLE_IDS[self.dss_env], format="loom")
        # timeout seconds is increased to 1200 as batch may take time to spin up spot instances for conversion.
        WaitFor(self._poll_get_matrix_service_request, self.request_id)\
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=1200)
        self._analyze_loom_matrix_results(self.request_id,
                                          INPUT_BUNDLE_IDS[self.dss_env])

    def test_csv_output_matrix_service(self):
        self.request_id = self._post_matrix_service_request(
            bundle_fqids=INPUT_BUNDLE_IDS[self.dss_env], format="csv")
        # timeout seconds is increased to 1200 as batch may take time to spin up spot instances for conversion.
        WaitFor(self._poll_get_matrix_service_request, self.request_id) \
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=1200)
        self._analyze_csv_matrix_results(self.request_id,
                                         INPUT_BUNDLE_IDS[self.dss_env])

    def test_mtx_output_matrix_service(self):
        self.request_id = self._post_matrix_service_request(
            bundle_fqids=INPUT_BUNDLE_IDS[self.dss_env], format="mtx")
        # timeout seconds is increased to 1200 as batch may take time to spin up spot instances for conversion.
        WaitFor(self._poll_get_matrix_service_request, self.request_id) \
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=1200)
        self._analyze_mtx_matrix_results(self.request_id,
                                         INPUT_BUNDLE_IDS[self.dss_env])

    def test_cache_hit_matrix_service(self):
        request_id_1 = self._post_matrix_service_request(
            bundle_fqids=INPUT_BUNDLE_IDS[self.dss_env], format="loom")
        # timeout seconds is increased to 1200 as batch may take time to spin up spot instances for conversion.
        WaitFor(self._poll_get_matrix_service_request, request_id_1) \
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=1200)
        self._analyze_loom_matrix_results(request_id_1,
                                          INPUT_BUNDLE_IDS[self.dss_env])

        request_id_2 = self._post_matrix_service_request(
            bundle_fqids=INPUT_BUNDLE_IDS[self.dss_env], format="loom")
        # timeout seconds is reduced to 300 as cache hits do not run conversion in batch.
        WaitFor(self._poll_get_matrix_service_request, request_id_2) \
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=300)
        self._analyze_loom_matrix_results(request_id_2,
                                          INPUT_BUNDLE_IDS[self.dss_env])

        matrix_location_1 = self._retrieve_matrix_location(request_id_1)
        matrix_location_2 = self._retrieve_matrix_location(request_id_2)

        loom_metrics_1 = validation.calculate_ss2_metrics_loom(
            matrix_location_1)
        loom_metrics_2 = validation.calculate_ss2_metrics_loom(
            matrix_location_2)

        self._compare_metrics(loom_metrics_1, loom_metrics_2)

        self._cleanup_matrix_result(request_id_1)
        self._cleanup_matrix_result(request_id_2)

    def test_matrix_service_without_specified_output(self):
        self.request_id = self._post_matrix_service_request(
            bundle_fqids=INPUT_BUNDLE_IDS[self.dss_env])
        WaitFor(self._poll_get_matrix_service_request, self.request_id)\
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=300)
        self._analyze_loom_matrix_results(self.request_id,
                                          INPUT_BUNDLE_IDS[self.dss_env])

    def test_matrix_service_with_unexpected_bundles(self):
        input_bundles = ['non-existent-bundle1', 'non-existent-bundle2']
        self.request_id = self._post_matrix_service_request(
            bundle_fqids=input_bundles)
        WaitFor(self._poll_get_matrix_service_request, self.request_id)\
            .to_return_value(MatrixRequestStatus.FAILED.value, timeout_seconds=300)

    @unittest.skipUnless(
        os.getenv('DEPLOYMENT_STAGE') != "prod",
        "Do not want to process fake notifications in production.")
    def test_dss_notification(self):
        bundle_data = NOTIFICATION_TEST_DATA[self.dss_env]
        bundle_fqid = bundle_data['bundle_fqid']
        cell_row_count = bundle_data['cell_count']
        expression_row_count = bundle_data['exp_count']

        cellkeys = format_str_list(self._get_cellkeys_from_fqid(bundle_fqid))
        self.assertTrue(len(cellkeys) > 0)

        try:
            self._post_notification(bundle_fqid=bundle_fqid,
                                    event_type="DELETE")
            WaitFor(self._poll_db_get_row_counts_for_fqid, bundle_fqid, cellkeys)\
                .to_return_value((0, 0, 0), timeout_seconds=60)

            self._post_notification(bundle_fqid=bundle_fqid,
                                    event_type="CREATE")
            WaitFor(self._poll_db_get_row_counts_for_fqid, bundle_fqid, cellkeys)\
                .to_return_value((1, cell_row_count, expression_row_count), timeout_seconds=600)

            self._post_notification(bundle_fqid=bundle_fqid,
                                    event_type="TOMBSTONE")
            WaitFor(self._poll_db_get_row_counts_for_fqid, bundle_fqid, cellkeys)\
                .to_return_value((0, 0, 0), timeout_seconds=60)

            self._post_notification(bundle_fqid=bundle_fqid,
                                    event_type="UPDATE")
            WaitFor(self._poll_db_get_row_counts_for_fqid, bundle_fqid, cellkeys)\
                .to_return_value((1, cell_row_count, expression_row_count), timeout_seconds=600)
        finally:
            self._post_notification(bundle_fqid=bundle_fqid,
                                    event_type="CREATE")

    @unittest.skip
    def test_matrix_service_ss2(self):
        timeout = int(os.getenv("MATRIX_TEST_TIMEOUT", 300))
        num_bundles = int(os.getenv("MATRIX_TEST_NUM_BUNDLES", 200))
        bundle_fqids = json.loads(
            open(f"{self.res_dir}/pancreas_ss2_2544_demo_bundles.json",
                 "r").read())[:num_bundles]

        self.request_id = self._post_matrix_service_request(
            bundle_fqids=bundle_fqids, format="loom")

        # wait for request to complete
        time.sleep(2)
        WaitFor(self._poll_get_matrix_service_request, self.request_id)\
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=timeout)

        self._analyze_loom_matrix_results(self.request_id, bundle_fqids)

    def test_bundle_url(self):
        timeout = int(os.getenv("MATRIX_TEST_TIMEOUT", 300))
        bundle_fqids_url = INPUT_BUNDLE_URL.format(dss_env=self.dss_env)

        self.request_id = self._post_matrix_service_request(
            bundle_fqids_url=bundle_fqids_url, format="loom")

        # wait for request to complete
        WaitFor(self._poll_get_matrix_service_request, self.request_id)\
            .to_return_value(MatrixRequestStatus.COMPLETE.value, timeout_seconds=timeout)
        bundle_fqids = [
            '.'.join(el.split('\t')) for el in requests.get(
                bundle_fqids_url).text.strip().split('\n')[1:]
        ]

        self._analyze_loom_matrix_results(self.request_id, bundle_fqids)

    def _poll_db_get_row_counts_for_fqid(self, bundle_fqid, cellkeys):
        """
        :return: Row counts associated with the given bundle fqid for (analysis_count, cell_count, exp_count)
        """
        analysis_query = f"SELECT COUNT(*) FROM analysis WHERE analysis.bundle_fqid = '{bundle_fqid}'"
        cell_query = f"SELECT COUNT(*) FROM cell WHERE cellkey IN {cellkeys}"
        exp_query = f"SELECT COUNT(*) FROM expression WHERE cellkey IN {cellkeys}"

        analysis_row_count = self.redshift_handler.transaction(
            [analysis_query], return_results=True)[0][0]
        cell_row_count = self.redshift_handler.transaction(
            [cell_query], return_results=True)[0][0]
        exp_row_count = self.redshift_handler.transaction(
            [exp_query], return_results=True)[0][0]

        return analysis_row_count, cell_row_count, exp_row_count

    def _get_cellkeys_from_fqid(self, bundle_fqid):
        """
        Returns a generator of cellkeys associated with a given bundle fqid.
        """
        cellkeys_query = f"""
        SELECT DISTINCT cellkey FROM cell
         JOIN analysis ON cell.analysiskey = analysis.analysiskey
         WHERE analysis.bundle_fqid = '{bundle_fqid}'
        """

        results = self.redshift_handler.transaction([cellkeys_query],
                                                    return_results=True)
        return (row[0] for row in results)

    def _post_matrix_service_request(self,
                                     bundle_fqids=None,
                                     bundle_fqids_url=None,
                                     format=None):
        data = {}
        if bundle_fqids:
            data["bundle_fqids"] = bundle_fqids
        if bundle_fqids_url:
            data["bundle_fqids_url"] = bundle_fqids_url
        if format:
            data["format"] = format
        response = self._make_request(
            description="POST REQUEST TO MATRIX SERVICE",
            verb='POST',
            url=f"{self.api_url}/matrix",
            expected_status=202,
            data=json.dumps(data),
            headers=self.headers)
        data = json.loads(response)
        return data["request_id"]

    def _post_notification(self, bundle_fqid, event_type):
        data = {}
        bundle_uuid = bundle_fqid.split('.', 1)[0]
        bundle_version = bundle_fqid.split('.', 1)[1]

        data["transaction_id"] = "test_transaction_id"
        data["subscription_id"] = "test_subscription_id"
        data["event_type"] = event_type
        data["match"] = {}
        data["match"]["bundle_uuid"] = bundle_uuid
        data["match"]["bundle_version"] = bundle_version

        response = self._make_request(
            description="POST NOTIFICATION TO MATRIX SERVICE",
            verb='POST',
            url=f"{self.api_url}/dss/notification",
            expected_status=200,
            data=json.dumps(data),
            headers=self.headers)
        data = json.loads(response)
        return data
class NotificationHandler:

    DELETE_ANALYSIS_QUERY_TEMPLATE = """DELETE FROM analysis WHERE bundle_fqid='{bundle_uuid}.{bundle_version}'"""
    DELETE_CELL_QUERY_TEMPLATE = """
                                 DELETE FROM cell
                                  WHERE cell.analysiskey IN
                                  (SELECT analysis.analysiskey FROM analysis
                                  WHERE analysis.bundle_fqid='{bundle_uuid}.{bundle_version}')
                                 """
    DELETE_EXPRESSION_QUERY_TEMPLATE = """
                                       DELETE FROM expression
                                        WHERE expression.cellkey IN
                                        (SELECT cell.cellkey FROM cell
                                        JOIN analysis ON cell.analysiskey = analysis.analysiskey
                                        WHERE analysis.bundle_fqid='{bundle_uuid}.{bundle_version}')
                                       """

    def __init__(self, bundle_uuid, bundle_version, event_type):
        logger.info(
            f"Running NotificationHandler with parameters: {bundle_uuid}, {bundle_version}, {event_type}"
        )
        self.bundle_uuid = bundle_uuid
        self.bundle_version = bundle_version
        self.event_type = event_type

        self.redshift = RedshiftHandler()

    def run(self):
        if self.event_type == 'CREATE' or self.event_type == 'UPDATE':
            self.update_bundle()
        elif self.event_type == 'DELETE' or self.event_type == 'TOMBSTONE':
            self.remove_bundle()
        else:
            logger.error(
                f"Failed to process notification. Received invalid event type {self.event_type}."
            )
            return

        logger.info(
            f"Done processing DSS notification for {self.bundle_uuid}.{self.bundle_version}."
        )

    def update_bundle(self):
        staging_dir = "/tmp"
        content_type_patterns = ['application/json; dcp-type="metadata*"']
        filename_patterns = [
            "*zarr*",  # match expression data
            "*.results",  # match SS2 results files
            "*.mtx",
            "genes.tsv",
            "barcodes.tsv",
            "empty_drops_result.csv"
        ]  # match 10X results files

        # clean up working directory in case of Lambda container reuse
        shutil.rmtree(
            f"{staging_dir}/{MetadataToPsvTransformer.OUTPUT_DIRNAME}",
            ignore_errors=True)
        etl.etl_dss_bundle(bundle_uuid=self.bundle_uuid,
                           bundle_version=self.bundle_version,
                           content_type_patterns=content_type_patterns,
                           filename_patterns=filename_patterns,
                           transformer_cb=etl.transform_bundle,
                           finalizer_cb=etl.finalizer_notification,
                           staging_directory=os.path.abspath(staging_dir),
                           deployment_stage=os.environ['DEPLOYMENT_STAGE'])

    def remove_bundle(self):
        self.redshift.transaction([
            self.DELETE_EXPRESSION_QUERY_TEMPLATE.format(
                bundle_uuid=self.bundle_uuid,
                bundle_version=self.bundle_version),
            self.DELETE_CELL_QUERY_TEMPLATE.format(
                bundle_uuid=self.bundle_uuid,
                bundle_version=self.bundle_version),
            self.DELETE_ANALYSIS_QUERY_TEMPLATE.format(
                bundle_uuid=self.bundle_uuid,
                bundle_version=self.bundle_version)
        ])
Beispiel #11
0
class Driver:
    """
    Formats and stores redshift queries in s3 and sqs for execution.
    """
    def __init__(self, request_id: str, bundles_per_worker: int = 100):
        Logging.set_correlation_id(logger, value=request_id)

        self.request_id = request_id
        self.bundles_per_worker = bundles_per_worker
        self.request_tracker = RequestTracker(request_id)
        self.dynamo_handler = DynamoHandler()
        self.sqs_handler = SQSHandler()
        self.infra_config = MatrixInfraConfig()
        self.redshift_config = MatrixRedshiftConfig()
        self.query_results_bucket = os.environ['MATRIX_QUERY_RESULTS_BUCKET']
        self.s3_handler = S3Handler(os.environ['MATRIX_QUERY_BUCKET'])
        self.redshift_handler = RedshiftHandler()

    @property
    def query_job_q_url(self):
        return self.infra_config.query_job_q_url

    @property
    def redshift_role_arn(self):
        return self.redshift_config.redshift_role_arn

    def run(self, bundle_fqids: typing.List[str], bundle_fqids_url: str,
            format: str):
        """
        Initialize a matrix service request and spawn redshift queries.

        :param bundle_fqids: List of bundle fqids to be queried on
        :param bundle_fqids_url: URL from which bundle_fqids can be retrieved
        :param format: MatrixFormat file format of output expression matrix
        """
        logger.debug(
            f"Driver running with parameters: bundle_fqids={bundle_fqids}, "
            f"bundle_fqids_url={bundle_fqids_url}, format={format}, "
            f"bundles_per_worker={self.bundles_per_worker}")

        if bundle_fqids_url:
            response = self._get_bundle_manifest(bundle_fqids_url)
            resolved_bundle_fqids = self._parse_download_manifest(
                response.text)
            if len(resolved_bundle_fqids) == 0:
                error_msg = "no bundles found in the supplied bundle manifest"
                logger.info(error_msg)
                self.request_tracker.log_error(error_msg)
                return
        else:
            resolved_bundle_fqids = bundle_fqids
        logger.debug(f"resolved bundles: {resolved_bundle_fqids}")

        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.NUM_BUNDLES, len(resolved_bundle_fqids))
        s3_obj_keys = self._format_and_store_queries_in_s3(
            resolved_bundle_fqids)

        analysis_table_bundle_count = self._fetch_bundle_count_from_analysis_table(
            resolved_bundle_fqids)
        if analysis_table_bundle_count != len(resolved_bundle_fqids):
            error_msg = "resolved bundles in request do not match bundles available in matrix service"
            logger.info(error_msg)
            self.request_tracker.log_error(error_msg)
            return

        for key in s3_obj_keys:
            self._add_request_query_to_sqs(key, s3_obj_keys[key])
        self.request_tracker.complete_subtask_execution(Subtask.DRIVER)

    @retry(reraise=True, wait=wait_fixed(5), stop=stop_after_attempt(60))
    def _get_bundle_manifest(self, bundle_fqids_url):
        response = requests.get(bundle_fqids_url)
        return response

    @staticmethod
    def _parse_download_manifest(data: str) -> typing.List[str]:
        def _parse_line(line: str) -> str:
            tokens = line.split("\t")
            return f"{tokens[0]}.{tokens[1]}"

        lines = data.splitlines()[1:]
        return list(map(_parse_line, lines))

    def _format_and_store_queries_in_s3(self, resolved_bundle_fqids: list):
        feature_query = feature_query_template.format(
            self.query_results_bucket, self.request_id, self.redshift_role_arn)
        feature_query_obj_key = self.s3_handler.store_content_in_s3(
            f"{self.request_id}/feature", feature_query)

        exp_query = expression_query_template.format(
            self.query_results_bucket, self.request_id, self.redshift_role_arn,
            format_str_list(resolved_bundle_fqids))
        exp_query_obj_key = self.s3_handler.store_content_in_s3(
            f"{self.request_id}/expression", exp_query)

        cell_query = cell_query_template.format(
            self.query_results_bucket, self.request_id, self.redshift_role_arn,
            format_str_list(resolved_bundle_fqids))
        cell_query_obj_key = self.s3_handler.store_content_in_s3(
            f"{self.request_id}/cell", cell_query)

        return {
            QueryType.CELL: cell_query_obj_key,
            QueryType.EXPRESSION: exp_query_obj_key,
            QueryType.FEATURE: feature_query_obj_key
        }

    def _add_request_query_to_sqs(self, query_type: QueryType,
                                  s3_obj_key: str):
        queue_url = self.query_job_q_url
        payload = {
            'request_id': self.request_id,
            's3_obj_key': s3_obj_key,
            'type': query_type.value
        }
        logger.debug(f"Adding {payload} to sqs {queue_url}")
        self.sqs_handler.add_message_to_queue(queue_url, payload)

    def _fetch_bundle_count_from_analysis_table(self,
                                                resolved_bundle_fqids: list):
        analysis_table_bundle_count_query = analysis_bundle_count_query_template.format(
            format_str_list(resolved_bundle_fqids))
        analysis_table_bundle_count_query = analysis_table_bundle_count_query.strip(
        ).replace('\n', '')
        results = self.redshift_handler.transaction(
            [analysis_table_bundle_count_query],
            read_only=True,
            return_results=True)
        analysis_table_bundle_count = results[0][0]
        return analysis_table_bundle_count
 def __init__(self):
     self.sqs_handler = SQSHandler()
     self.s3_handler = S3Handler(os.environ["MATRIX_QUERY_BUCKET"])
     self.batch_handler = BatchHandler()
     self.redshift_handler = RedshiftHandler()
     self.matrix_infra_config = MatrixInfraConfig()
class QueryRunner:
    def __init__(self):
        self.sqs_handler = SQSHandler()
        self.s3_handler = S3Handler(os.environ["MATRIX_QUERY_BUCKET"])
        self.batch_handler = BatchHandler()
        self.redshift_handler = RedshiftHandler()
        self.matrix_infra_config = MatrixInfraConfig()

    @property
    def query_job_q_url(self):
        return self.matrix_infra_config.query_job_q_url

    @property
    def query_job_deadletter_q_url(self):
        return self.matrix_infra_config.query_job_deadletter_q_url

    def run(self, max_loops=None):
        loops = 0
        while max_loops is None or loops < max_loops:
            loops += 1
            messages = self.sqs_handler.receive_messages_from_queue(
                self.query_job_q_url)
            if messages:
                message = messages[0]
                logger.info(f"Received {message} from {self.query_job_q_url}")
                payload = json.loads(message['Body'])
                request_id = payload['request_id']
                request_tracker = RequestTracker(request_id)
                Logging.set_correlation_id(logger, value=request_id)
                obj_key = payload['s3_obj_key']
                receipt_handle = message['ReceiptHandle']
                try:
                    logger.info(f"Fetching query from {obj_key}")
                    query = self.s3_handler.load_content_from_obj_key(obj_key)

                    logger.info(f"Running query from {obj_key}")
                    self.redshift_handler.transaction([query], read_only=True)
                    logger.info(f"Finished running query from {obj_key}")

                    logger.info(
                        f"Deleting {message} from {self.query_job_q_url}")
                    self.sqs_handler.delete_message_from_queue(
                        self.query_job_q_url, receipt_handle)

                    logger.info(
                        "Incrementing completed queries in state table")
                    request_tracker.complete_subtask_execution(Subtask.QUERY)

                    if request_tracker.is_request_ready_for_conversion():
                        logger.info("Scheduling batch conversion job")
                        batch_job_id = self.batch_handler.schedule_matrix_conversion(
                            request_id, request_tracker.format)
                        request_tracker.write_batch_job_id_to_db(batch_job_id)
                except Exception as e:
                    logger.info(
                        f"QueryRunner failed on {message} with error {e}")
                    request_tracker.log_error(str(e))
                    logger.info(
                        f"Adding {message} to {self.query_job_deadletter_q_url}"
                    )
                    self.sqs_handler.add_message_to_queue(
                        self.query_job_deadletter_q_url, payload)
                    logger.info(
                        f"Deleting {message} from {self.query_job_q_url}")
                    self.sqs_handler.delete_message_from_queue(
                        self.query_job_q_url, receipt_handle)
            else:
                logger.info(f"No messages to read from {self.query_job_q_url}")