Exemple #1
0
class AchillesTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.hpo_bucket = gcs_utils.get_hpo_bucket(test_util.FAKE_HPO_ID)
        self.project_id = app_identity.get_application_id()
        self.storage_client = StorageClient(self.project_id)
        self.storage_client.empty_bucket(self.hpo_bucket)
        test_util.delete_all_tables(bq_utils.get_dataset_id())

    def tearDown(self):
        test_util.delete_all_tables(bq_utils.get_dataset_id())
        self.storage_client.empty_bucket(self.hpo_bucket)

    def _load_dataset(self):
        for cdm_table in resources.CDM_TABLES:
            cdm_filename: str = f'{cdm_table}.csv'
            cdm_filepath: str = os.path.join(test_util.FIVE_PERSONS_PATH,
                                             cdm_filename)

            bucket = self.storage_client.get_bucket(self.hpo_bucket)
            cdm_blob = bucket.blob(cdm_filename)
            if os.path.exists(cdm_filepath):
                cdm_blob.upload_from_filename(cdm_filepath)
            else:
                cdm_blob.upload_from_string('dummy\n')

            bq_utils.load_cdm_csv(test_util.FAKE_HPO_ID, cdm_table)

    def test_load_analyses(self):
        achilles.create_tables(test_util.FAKE_HPO_ID, True)
        achilles.load_analyses(test_util.FAKE_HPO_ID)
        cmd = sql_wrangle.qualify_tables(
            'SELECT DISTINCT(analysis_id) FROM %sachilles_analysis' %
            sql_wrangle.PREFIX_PLACEHOLDER, test_util.FAKE_HPO_ID)
        result = bq_utils.query(cmd)
        self.assertEqual(ACHILLES_LOOKUP_COUNT, int(result['totalRows']))

    def test_run_analyses(self):
        # Long-running test
        self._load_dataset()
        achilles.create_tables(test_util.FAKE_HPO_ID, True)
        achilles.load_analyses(test_util.FAKE_HPO_ID)
        achilles.run_analyses(hpo_id=test_util.FAKE_HPO_ID)
        cmd = sql_wrangle.qualify_tables(
            'SELECT COUNT(1) FROM %sachilles_results' %
            sql_wrangle.PREFIX_PLACEHOLDER, test_util.FAKE_HPO_ID)
        result = bq_utils.query(cmd)
        self.assertEqual(int(result['rows'][0]['f'][0]['v']),
                         ACHILLES_RESULTS_COUNT)
Exemple #2
0
class GcsUtilsTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.hpo_bucket = gcs_utils.get_hpo_bucket(FAKE_HPO_ID)
        self.gcs_path = '/'.join([self.hpo_bucket, 'dummy'])
        self.project_id = app_identity.get_application_id()
        self.storage_client = StorageClient(self.project_id)
        self.storage_client.empty_bucket(self.hpo_bucket)

    def test_upload_object(self):
        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        self.assertEqual(len(bucket_items), 0)
        with open(FIVE_PERSONS_PERSON_CSV, 'rb') as fp:
            gcs_utils.upload_object(self.hpo_bucket, 'person.csv', fp)
        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        self.assertEqual(len(bucket_items), 1)
        bucket_item = bucket_items[0]
        self.assertEqual(bucket_item['name'], 'person.csv')

    def test_get_object(self):
        with open(FIVE_PERSONS_PERSON_CSV, 'r') as fp:
            expected = fp.read()
        with open(FIVE_PERSONS_PERSON_CSV, 'rb') as fp:
            gcs_utils.upload_object(self.hpo_bucket, 'person.csv', fp)
        result = gcs_utils.get_object(self.hpo_bucket, 'person.csv')
        self.assertEqual(expected, result)

    def test_get_metadata_on_existing_file(self):
        expected_file_name = 'person.csv'
        with open(FIVE_PERSONS_PERSON_CSV, 'rb') as fp:
            gcs_utils.upload_object(self.hpo_bucket, expected_file_name, fp)
        metadata = gcs_utils.get_metadata(self.hpo_bucket, expected_file_name)
        self.assertIsNotNone(metadata)
        self.assertEqual(metadata['name'], expected_file_name)

    def test_get_metadata_on_not_existing_file(self):
        expected = 100
        actual = gcs_utils.get_metadata(self.hpo_bucket,
                                        'this_file_does_not_exist', expected)
        self.assertEqual(expected, actual)

    def test_list_bucket_404_when_bucket_does_not_exist(self):
        with self.assertRaises(HttpError) as cm:
            gcs_utils.list_bucket('some-bucket-which-does-not-exist-123')
        self.assertEqual(cm.exception.resp.status, 404)

    def tearDown(self):
        self.storage_client.empty_bucket(self.hpo_bucket)
Exemple #3
0
class BqUtilsTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.hpo_bucket = gcs_utils.get_hpo_bucket(FAKE_HPO_ID)
        self.person_table_id = bq_utils.get_table_id(FAKE_HPO_ID,
                                                     common.PERSON)
        self.dataset_id = bq_utils.get_dataset_id()
        test_util.delete_all_tables(self.dataset_id)
        self.project_id = app_identity.get_application_id()
        self.TEST_FIELDS = [
            {
                "type": "integer",
                "name": "integer_field",
                "mode": "required",
                "description": "An integer field"
            },
            # DC-586 Import RDR rules should support null fields
            {
                "type": "integer",
                "name": "nullable_integer_field",
                "mode": "nullable",
                "description": "A nullable integer field"
            },
            {
                "type": "string",
                "name": "string_field",
                "mode": "required",
                "description": "A string field"
            },
            {
                "type": "date",
                "name": "date_field",
                "mode": "required",
                "description": "A date field"
            },
            {
                "type": "timestamp",
                "name": "timestamp_field",
                "mode": "required",
                "description": "A timestamp field"
            },
            {
                "type": "boolean",
                "name": "boolean_field",
                "mode": "required",
                "description": "A boolean field"
            },
            {
                "type": "float",
                "name": "float_field",
                "mode": "required",
                "description": "A float field"
            }
        ]
        self.DT_FORMAT = '%Y-%m-%d %H:%M:%S'
        self.client = StorageClient(self.project_id)
        self.client.empty_bucket(self.hpo_bucket)

    def _drop_tables(self):
        tables = bq_utils.list_tables()
        for table in tables:
            table_id = table['tableReference']['tableId']
            if table_id not in common.VOCABULARY_TABLES:
                bq_utils.delete_table(table_id)

    def _table_has_clustering(self, table_info):
        clustering = table_info.get('clustering')
        self.assertIsNotNone(clustering)
        fields = clustering.get('fields')
        self.assertSetEqual(set(fields), {'person_id'})
        time_partitioning = table_info.get('timePartitioning')
        self.assertIsNotNone(time_partitioning)
        tpe = time_partitioning.get('type')
        self.assertEqual(tpe, 'DAY')

    def test_load_csv(self):
        app_id = app_identity.get_application_id()
        table_name = 'achilles_analysis'
        csv_file_name = table_name + '.csv'
        local_csv_path = os.path.join(test_util.TEST_DATA_EXPORT_PATH,
                                      csv_file_name)
        sc_bucket = self.client.get_bucket(self.hpo_bucket)
        bucket_blob = sc_bucket.blob(csv_file_name)
        with open(local_csv_path, 'rb') as fp:
            bucket_blob.upload_from_file(fp)
        hpo_bucket = self.hpo_bucket
        gcs_object_path = 'gs://%(hpo_bucket)s/%(csv_file_name)s' % locals()
        dataset_id = self.dataset_id
        load_results = bq_utils.load_csv(table_name, gcs_object_path, app_id,
                                         dataset_id, table_name)

        load_job_id = load_results['jobReference']['jobId']
        incomplete_jobs = bq_utils.wait_on_jobs([load_job_id])
        self.assertEqual(len(incomplete_jobs), 0,
                         'loading table {} timed out'.format(table_name))
        query_response = bq_utils.query('SELECT COUNT(1) FROM %(table_name)s' %
                                        locals())
        self.assertEqual(query_response['kind'], 'bigquery#queryResponse')

    def test_load_cdm_csv(self):
        sc_bucket = self.client.get_bucket(self.hpo_bucket)
        bucket_blob = sc_bucket.blob('person.csv')
        with open(FIVE_PERSONS_PERSON_CSV, 'rb') as fp:
            bucket_blob.upload_from_file(fp)
        result = bq_utils.load_cdm_csv(FAKE_HPO_ID, common.PERSON)
        self.assertEqual(result['status']['state'], 'RUNNING')
        load_job_id = result['jobReference']['jobId']
        table_id = result['configuration']['load']['destinationTable'][
            'tableId']
        incomplete_jobs = bq_utils.wait_on_jobs([load_job_id])
        self.assertEqual(len(incomplete_jobs), 0,
                         'loading table {} timed out'.format(table_id))
        table_info = bq_utils.get_table_info(table_id)
        num_rows = table_info.get('numRows')
        self.assertEqual(num_rows, '5')

    def test_query_result(self):
        sc_bucket = self.client.get_bucket(self.hpo_bucket)
        bucket_blob = sc_bucket.blob('person.csv')
        with open(FIVE_PERSONS_PERSON_CSV, 'rb') as fp:
            bucket_blob.upload_from_file(fp)
        result = bq_utils.load_cdm_csv(FAKE_HPO_ID, common.PERSON)
        load_job_id = result['jobReference']['jobId']
        incomplete_jobs = bq_utils.wait_on_jobs([load_job_id])
        self.assertEqual(len(incomplete_jobs), 0,
                         'loading table {} timed out'.format(common.PERSON))
        table_id = bq_utils.get_table_id(FAKE_HPO_ID, common.PERSON)
        q = 'SELECT person_id FROM %s' % table_id
        result = bq_utils.query(q)
        self.assertEqual(5, int(result['totalRows']))

    def test_create_table(self):
        table_id = 'some_random_table_id'
        fields = [
            dict(name='person_id', type='integer', mode='required'),
            dict(name='name', type='string', mode='nullable')
        ]
        result = bq_utils.create_table(table_id, fields)
        self.assertTrue('kind' in result)
        self.assertEqual(result['kind'], 'bigquery#table')
        table_info = bq_utils.get_table_info(table_id)
        self._table_has_clustering(table_info)

    def test_create_existing_table_without_drop_raises_error(self):
        table_id = 'some_random_table_id'
        fields = [
            dict(name='id', type='integer', mode='required'),
            dict(name='name', type='string', mode='nullable')
        ]
        bq_utils.create_table(table_id, fields)
        with self.assertRaises(bq_utils.InvalidOperationError):
            bq_utils.create_table(table_id, fields, drop_existing=False)

    def test_create_table_drop_existing_success(self):
        table_id = 'some_random_table_id'
        fields = [
            dict(name='id', type='integer', mode='required'),
            dict(name='name', type='string', mode='nullable')
        ]
        result_1 = bq_utils.create_table(table_id, fields)
        # sanity check
        table_id = result_1['tableReference']['tableId']
        self.assertTrue(bq_utils.table_exists(table_id))
        result_2 = bq_utils.create_table(table_id, fields, drop_existing=True)
        # same id and second one created after first one
        self.assertEqual(result_1['id'], result_2['id'])
        self.assertTrue(result_2['creationTime'] > result_1['creationTime'])

    def test_create_standard_table(self):
        standard_tables = list(resources.CDM_TABLES) + ACHILLES_TABLES
        for standard_table in standard_tables:
            table_id = f'prefix_for_test_{standard_table}'
            result = bq_utils.create_standard_table(standard_table, table_id)
            self.assertTrue('kind' in result)
            self.assertEqual(result['kind'], 'bigquery#table')
            # sanity check
            self.assertTrue(bq_utils.table_exists(table_id))

    def test_load_ehr_observation(self):
        hpo_id = 'pitt'
        dataset_id = self.dataset_id
        table_id = bq_utils.get_table_id(hpo_id, table_name='observation')
        q = 'SELECT observation_id FROM {dataset_id}.{table_id} ORDER BY observation_id'.format(
            dataset_id=dataset_id, table_id=table_id)
        expected_observation_ids = [
            int(row['observation_id'])
            for row in resources.csv_to_list(PITT_FIVE_PERSONS_OBSERVATION_CSV)
        ]
        sc_bucket = self.client.get_bucket(gcs_utils.get_hpo_bucket(hpo_id))
        bucket_blob = sc_bucket.blob('observation.csv')
        with open(PITT_FIVE_PERSONS_OBSERVATION_CSV, 'rb') as fp:
            bucket_blob.upload_from_file(fp)
        result = bq_utils.load_cdm_csv(hpo_id, 'observation')
        job_id = result['jobReference']['jobId']
        incomplete_jobs = bq_utils.wait_on_jobs([job_id])
        self.assertEqual(len(incomplete_jobs), 0,
                         'pitt_observation load job did not complete')
        load_job_result = bq_utils.get_job_details(job_id)
        load_job_result_status = load_job_result['status']
        load_job_errors = load_job_result_status.get('errors')
        self.assertIsNone(load_job_errors,
                          msg='pitt_observation load job failed: ' +
                          str(load_job_errors))
        query_results_response = bq_utils.query(q)
        query_job_errors = query_results_response.get('errors')
        self.assertIsNone(query_job_errors)
        actual_result = [
            int(row['f'][0]['v']) for row in query_results_response['rows']
        ]
        self.assertCountEqual(actual_result, expected_observation_ids)

    def test_load_table_from_csv(self):
        table_id = 'test_csv_table'
        csv_file = 'load_csv_test_data.csv'
        csv_path = os.path.join(test_util.TEST_DATA_PATH, csv_file)
        with open(csv_path, 'r') as f:
            expected = list(csv.DictReader(f))
        bq_utils.load_table_from_csv(self.project_id, self.dataset_id,
                                     table_id, csv_path, self.TEST_FIELDS)
        q = """ SELECT *
                FROM `{project_id}.{dataset_id}.{table_id}`""".format(
            project_id=self.project_id,
            dataset_id=self.dataset_id,
            table_id=table_id)
        r = bq_utils.query(q)
        actual = bq_utils.response2rows(r)

        # Convert the epoch times to datetime with time zone
        for row in actual:
            row['timestamp_field'] = time.strftime(
                self.DT_FORMAT + ' UTC', time.gmtime(row['timestamp_field']))
        expected.sort(key=lambda row: row['integer_field'])
        actual.sort(key=lambda row: row['integer_field'])
        for i, _ in enumerate(expected):
            self.assertCountEqual(expected[i], actual[i])

    def test_get_hpo_info(self):
        hpo_info = bq_utils.get_hpo_info()
        self.assertGreater(len(hpo_info), 0)

    def test_csv_line_to_sql_row_expr(self):
        fields = [{
            'name': 'nullable_date_col',
            'type': 'date',
            'mode': 'nullable',
            'description': ''
        }, {
            'name': 'nullable_float_col',
            'type': 'float',
            'mode': 'nullable',
            'description': ''
        }, {
            'name': 'nullable_integer_col',
            'type': 'integer',
            'mode': 'nullable',
            'description': ''
        }, {
            'name': 'nullable_string_col',
            'type': 'string',
            'mode': 'nullable',
            'description': ''
        }, {
            'name': 'nullable_timestamp_col',
            'type': 'timestamp',
            'mode': 'nullable',
            'description': ''
        }, {
            'name': 'required_date_col',
            'type': 'date',
            'mode': 'required',
            'description': ''
        }, {
            'name': 'required_float_col',
            'type': 'float',
            'mode': 'required',
            'description': ''
        }, {
            'name': 'required_integer_col',
            'type': 'integer',
            'mode': 'required',
            'description': ''
        }, {
            'name': 'required_string_col',
            'type': 'string',
            'mode': 'required',
            'description': ''
        }, {
            'name': 'required_timestamp_col',
            'type': 'timestamp',
            'mode': 'required',
            'description': ''
        }]

        # dummy values for each type
        flt_str = "3.14"
        int_str = "1234"
        str_str = "abc"
        dt_str = "2019-01-01"
        ts_str = "2019-01-01 14:00:00.0"
        row = {
            'nullable_date_col': dt_str,
            'nullable_float_col': flt_str,
            'nullable_integer_col': int_str,
            'nullable_string_col': str_str,
            'nullable_timestamp_col': ts_str,
            'required_date_col': dt_str,
            'required_float_col': flt_str,
            'required_integer_col': int_str,
            'required_string_col': str_str,
            'required_timestamp_col': ts_str
        }
        # all fields populated
        expected_expr = f"('{dt_str}',{flt_str},{int_str},'{str_str}','{ts_str}','{dt_str}',{flt_str},{int_str},'{str_str}','{ts_str}')"
        actual_expr = bq_utils.csv_line_to_sql_row_expr(row, fields)
        self.assertEqual(expected_expr, actual_expr)

        # nullable int zero is converted
        row['nullable_integer_col'] = '0'
        expected_expr = f"('{dt_str}',{flt_str},0,'{str_str}','{ts_str}','{dt_str}',{flt_str},{int_str},'{str_str}','{ts_str}')"
        actual_expr = bq_utils.csv_line_to_sql_row_expr(row, fields)
        self.assertEqual(expected_expr, actual_expr)

        # empty nullable is converted null
        row['nullable_date_col'] = ''
        row['nullable_float_col'] = ''
        row['nullable_integer_col'] = ''
        row['nullable_string_col'] = ''
        row['nullable_timestamp_col'] = ''
        expected_expr = f"(NULL,NULL,NULL,NULL,NULL,'{dt_str}',{flt_str},{int_str},'{str_str}','{ts_str}')"
        actual_expr = bq_utils.csv_line_to_sql_row_expr(row, fields)
        self.assertEqual(expected_expr, actual_expr)

        # empty required string converted to empty string value
        row['required_string_col'] = ''
        actual_expr = bq_utils.csv_line_to_sql_row_expr(row, fields)
        expected_expr = f"(NULL,NULL,NULL,NULL,NULL,'{dt_str}',{flt_str},{int_str},'','{ts_str}')"
        self.assertEqual(expected_expr, actual_expr)

        # empty required field raises error
        row['required_integer_col'] = ''
        with self.assertRaises(bq_utils.InvalidOperationError) as c:
            bq_utils.csv_line_to_sql_row_expr(row, fields)
        self.assertEqual(
            c.exception.msg,
            f'Value not provided for required field required_integer_col')

    def tearDown(self):
        test_util.delete_all_tables(self.dataset_id)
        self.client.empty_bucket(self.hpo_bucket)
Exemple #4
0
class AchillesHeelTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.hpo_bucket = gcs_utils.get_hpo_bucket(FAKE_HPO_ID)
        self.dataset = bq_utils.get_dataset_id()
        self.project_id = app_identity.get_application_id()
        self.storage_client = StorageClient(self.project_id)
        self.storage_client.empty_bucket(self.hpo_bucket)
        test_util.delete_all_tables(self.dataset)

    def tearDown(self):
        test_util.delete_all_tables(bq_utils.get_dataset_id())
        self.storage_client.empty_bucket(self.hpo_bucket)

    def _load_dataset(self, hpo_id):
        for cdm_table in resources.CDM_TABLES:

            cdm_filename: str = f'{cdm_table}.csv'
            cdm_filepath: str = os.path.join(test_util.FIVE_PERSONS_PATH,
                                             cdm_filename)

            bucket = self.storage_client.get_bucket(self.hpo_bucket)
            cdm_blob = bucket.blob(cdm_filename)
            if os.path.exists(cdm_filepath):
                cdm_blob.upload_from_filename(cdm_filepath)
            else:
                cdm_blob.upload_from_string('dummy\n')

            bq_utils.load_cdm_csv(hpo_id, cdm_table)

        # ensure concept table exists
        if not bq_utils.table_exists(common.CONCEPT):
            bq_utils.create_standard_table(common.CONCEPT, common.CONCEPT)
            q = """INSERT INTO {dataset}.concept
            SELECT * FROM {vocab}.concept""".format(
                dataset=self.dataset, vocab=common.VOCABULARY_DATASET)
            bq_utils.query(q)

    @staticmethod
    def get_mock_hpo_bucket():
        bucket_env = 'BUCKET_NAME_' + FAKE_HPO_ID.upper()
        hpo_bucket_name = os.getenv(bucket_env)
        if hpo_bucket_name is None:
            raise EnvironmentError()
        return hpo_bucket_name

    @mock.patch('gcs_utils.get_hpo_bucket')
    def test_heel_analyses(self, mock_hpo_bucket):
        # Long-running test
        mock_hpo_bucket.return_value = self.get_mock_hpo_bucket()

        # create randomized tables to bypass BQ rate limits
        random_string = str(randint(10000, 99999))
        randomized_hpo_id = FAKE_HPO_ID + '_' + random_string

        # prepare
        self._load_dataset(randomized_hpo_id)
        test_util.populate_achilles(hpo_id=randomized_hpo_id,
                                    include_heel=False)

        # define tables
        achilles_heel_results = randomized_hpo_id + '_' + achilles_heel.ACHILLES_HEEL_RESULTS
        achilles_results_derived = randomized_hpo_id + '_' + achilles_heel.ACHILLES_RESULTS_DERIVED

        # run achilles heel
        achilles_heel.create_tables(randomized_hpo_id, True)
        achilles_heel.run_heel(hpo_id=randomized_hpo_id)

        # validate
        query = sql_wrangle.qualify_tables(
            'SELECT COUNT(1) as num_rows FROM %s' % achilles_heel_results)
        response = bq_utils.query(query)
        rows = bq_utils.response2rows(response)
        self.assertEqual(ACHILLES_HEEL_RESULTS_COUNT, rows[0]['num_rows'])
        query = sql_wrangle.qualify_tables(
            'SELECT COUNT(1) as num_rows FROM %s' % achilles_results_derived)
        response = bq_utils.query(query)
        rows = bq_utils.response2rows(response)
        self.assertEqual(ACHILLES_RESULTS_DERIVED_COUNT, rows[0]['num_rows'])

        # test new heel re-categorization
        errors = [
            2, 4, 5, 101, 200, 206, 207, 209, 400, 405, 406, 409, 411, 413,
            500, 505, 506, 509, 600, 605, 606, 609, 613, 700, 705, 706, 709,
            711, 713, 715, 716, 717, 800, 805, 806, 809, 813, 814, 906, 1006,
            1609, 1805
        ]
        query = sql_wrangle.qualify_tables(
            """SELECT analysis_id FROM {table_id}
            WHERE achilles_heel_warning LIKE 'ERROR:%'
            GROUP BY analysis_id""".format(table_id=achilles_heel_results))
        response = bq_utils.query(query)
        rows = bq_utils.response2rows(response)
        actual_result = [row["analysis_id"] for row in rows]
        for analysis_id in actual_result:
            self.assertIn(analysis_id, errors)

        warnings = [
            4, 5, 7, 8, 9, 200, 210, 302, 400, 402, 412, 420, 500, 511, 512,
            513, 514, 515, 602, 612, 620, 702, 712, 720, 802, 812, 820
        ]
        query = sql_wrangle.qualify_tables(
            """SELECT analysis_id FROM {table_id}
            WHERE achilles_heel_warning LIKE 'WARNING:%'
            GROUP BY analysis_id""".format(table_id=achilles_heel_results))
        response = bq_utils.query(query)
        rows = bq_utils.response2rows(response)
        actual_result = [row["analysis_id"] for row in rows]
        for analysis_id in actual_result:
            self.assertIn(analysis_id, warnings)

        notifications = [
            101, 103, 105, 114, 115, 118, 208, 301, 410, 610, 710, 810, 900,
            907, 1000, 1800, 1807
        ]
        query = sql_wrangle.qualify_tables(
            """SELECT analysis_id FROM {table_id}
            WHERE achilles_heel_warning LIKE 'NOTIFICATION:%' and analysis_id is not null
            GROUP BY analysis_id""".format(table_id=achilles_heel_results))
        response = bq_utils.query(query)
        rows = bq_utils.response2rows(response)
        actual_result = [row["analysis_id"] for row in rows]
        for analysis_id in actual_result:
            self.assertIn(analysis_id, notifications)
class RetractDataGcsTest(TestCase):

    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.project_id = app_identity.get_application_id()
        self.hpo_id = test_util.FAKE_HPO_ID
        self.bucket = os.environ.get(f'BUCKET_NAME_FAKE')
        self.site_bucket = 'test_bucket'
        self.folder_1 = '2019-01-01-v1/'
        self.folder_2 = '2019-02-02-v2/'
        self.client = StorageClient(self.project_id)
        self.folder_prefix_1 = f'{self.hpo_id}/{self.site_bucket}/{self.folder_1}'
        self.folder_prefix_2 = f'{self.hpo_id}/{self.site_bucket}/{self.folder_2}'
        self.pids = [17, 20]
        self.skip_pids = [10, 25]
        self.project_id = 'project_id'
        self.sandbox_dataset_id = os.environ.get('UNIONED_DATASET_ID')
        self.pid_table_id = 'pid_table'
        self.gcs_bucket = self.client.bucket(self.bucket)
        self.client.empty_bucket(self.gcs_bucket)

    @patch('retraction.retract_data_gcs.extract_pids_from_table')
    @patch('gcs_utils.get_drc_bucket')
    @patch('gcs_utils.get_hpo_bucket')
    def test_integration_five_person_data_retraction_skip(
        self, mock_hpo_bucket, mock_bucket, mock_extract_pids):
        mock_hpo_bucket.return_value = self.site_bucket
        mock_bucket.return_value = self.bucket
        mock_extract_pids.return_value = self.skip_pids
        lines_to_remove = {}
        expected_lines_post = {}
        for file_path in test_util.FIVE_PERSONS_FILES:
            # generate results files
            file_name = file_path.split('/')[-1]
            lines_to_remove[file_name] = 0
            with open(file_path, 'rb') as f:
                # skip header
                next(f)
                expected_lines_post[file_name] = []
                for line in f:
                    line = line.strip()
                    if line != b'':
                        expected_lines_post[file_name].append(line)

                # write file to cloud for testing
                blob = self.gcs_bucket.blob(self.folder_prefix_1 + file_name)
                blob.upload_from_file(f, rewind=True, content_type='text/csv')
                blob = self.gcs_bucket.blob(self.folder_prefix_2 + file_name)
                blob.upload_from_file(f, rewind=True, content_type='text/csv')

        rd.run_gcs_retraction(self.project_id,
                              self.sandbox_dataset_id,
                              self.pid_table_id,
                              self.hpo_id,
                              folder='all_folders',
                              force_flag=True,
                              bucket=self.bucket,
                              site_bucket=self.site_bucket)

        total_lines_post = {}
        for file_path in test_util.FIVE_PERSONS_FILES:
            file_name = file_path.split('/')[-1]
            blob = self.gcs_bucket.blob(self.folder_prefix_1 + file_name)
            actual_result_contents = blob.download_as_string().split(b'\n')
            # convert to list and remove header and last list item since it is a newline
            total_lines_post[file_name] = actual_result_contents[1:-1]

        for key in expected_lines_post:
            self.assertEqual(lines_to_remove[key], 0)
            self.assertListEqual(expected_lines_post[key],
                                 total_lines_post[key])

    @patch('retraction.retract_data_gcs.extract_pids_from_table')
    @patch('gcs_utils.get_drc_bucket')
    @patch('gcs_utils.get_hpo_bucket')
    def test_integration_five_person_data_retraction(self, mock_hpo_bucket,
                                                     mock_bucket,
                                                     mock_extract_pids):
        mock_hpo_bucket.return_value = self.site_bucket
        mock_bucket.return_value = self.bucket
        mock_extract_pids.return_value = self.pids
        expected_lines_post = {}
        for file_path in test_util.FIVE_PERSONS_FILES:
            # generate results files
            file_name = file_path.split('/')[-1]
            table_name = file_name.split('.')[0]
            expected_lines_post[file_name] = []
            with open(file_path, 'rb') as f:
                # skip header
                next(f)
                expected_lines_post[file_name] = []
                for line in f:
                    line = line.strip()
                    if line != b'':
                        if not ((table_name in rd.PID_IN_COL1 and
                                 int(line.split(b",")[0]) in self.pids) or
                                (table_name in rd.PID_IN_COL2 and
                                 int(line.split(b",")[1]) in self.pids)):
                            expected_lines_post[file_name].append(line)

                # write file to cloud for testing
                blob = self.gcs_bucket.blob(self.folder_prefix_1 + file_name)
                blob.upload_from_file(f, rewind=True, content_type='text/csv')
                blob = self.gcs_bucket.blob(self.folder_prefix_2 + file_name)
                blob.upload_from_file(f, rewind=True, content_type='text/csv')

        rd.run_gcs_retraction(self.project_id,
                              self.sandbox_dataset_id,
                              self.pid_table_id,
                              self.hpo_id,
                              folder='all_folders',
                              force_flag=True,
                              bucket=self.bucket,
                              site_bucket=self.site_bucket)

        total_lines_post = {}
        for file_path in test_util.FIVE_PERSONS_FILES:
            file_name = file_path.split('/')[-1]
            blob = self.gcs_bucket.blob(self.folder_prefix_1 + file_name)
            actual_result_contents = blob.download_as_string().split(b'\n')
            # convert to list and remove header and last list item since it is a newline
            total_lines_post[file_name] = actual_result_contents[1:-1]

        for key in expected_lines_post:
            self.assertListEqual(expected_lines_post[key],
                                 total_lines_post[key])

    def tearDown(self):
        self.client.empty_bucket(self.gcs_bucket)
Exemple #6
0
class ExportTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        print(
            '\n**************************************************************')
        print(cls.__name__)
        print('**************************************************************')
        dataset_id = bq_utils.get_dataset_id()
        test_util.delete_all_tables(dataset_id)
        test_util.populate_achilles()

    def setUp(self):
        self.hpo_bucket = gcs_utils.get_hpo_bucket(FAKE_HPO_ID)
        self.project_id = app_identity.get_application_id()
        self.storage_client = StorageClient(self.project_id)

    def _test_report_export(self, report):
        data_density_path = os.path.join(export.EXPORT_PATH, report)
        result = export.export_from_path(data_density_path, FAKE_HPO_ID)
        return result
        # TODO more strict testing of result payload. The following doesn't work because field order is random.
        # actual_payload = json.dumps(result, sort_keys=True, indent=4, separators=(',', ': '))
        # expected_path = os.path.join(test_util.TEST_DATA_EXPORT_SYNPUF_PATH, report + '.json')
        # with open(expected_path, 'r') as f:
        #     expected_payload = f.read()
        #     self.assertEqual(actual_payload, expected_payload)
        # return result

    @mock.patch('validation.export.is_hpo_id')
    def test_export_data_density(self, mock_is_hpo_id):
        # INTEGRATION TEST
        mock_is_hpo_id.return_value = True
        export_result = self._test_report_export('datadensity')
        expected_keys = [
            'CONCEPTS_PER_PERSON', 'RECORDS_PER_PERSON', 'TOTAL_RECORDS'
        ]
        for expected_key in expected_keys:
            self.assertTrue(expected_key in export_result)
        self.assertEqual(
            len(export_result['TOTAL_RECORDS']['X_CALENDAR_MONTH']), 283)

    @mock.patch('validation.export.is_hpo_id')
    def test_export_person(self, mock_is_hpo_id):
        # INTEGRATION TEST
        mock_is_hpo_id.return_value = True
        export_result = self._test_report_export('person')
        expected_keys = [
            'BIRTH_YEAR_HISTOGRAM', 'ETHNICITY_DATA', 'GENDER_DATA',
            'RACE_DATA', 'SUMMARY'
        ]
        for expected_key in expected_keys:
            self.assertTrue(expected_key in export_result)
        self.assertEqual(
            len(export_result['BIRTH_YEAR_HISTOGRAM']['DATA']['COUNT_VALUE']),
            72)

    @mock.patch('validation.export.is_hpo_id')
    def test_export_achillesheel(self, mock_is_hpo_id):
        # INTEGRATION TEST
        mock_is_hpo_id.return_value = True
        export_result = self._test_report_export('achillesheel')
        self.assertTrue('MESSAGES' in export_result)
        self.assertEqual(len(export_result['MESSAGES']['ATTRIBUTENAME']), 14)

    @mock.patch('validation.export.is_hpo_id')
    def test_run_export(self, mock_is_hpo_id):
        # validation/main.py INTEGRATION TEST
        mock_is_hpo_id.return_value = True
        folder_prefix: str = 'dummy-prefix-2018-03-24/'

        main._upload_achilles_files(FAKE_HPO_ID, folder_prefix)
        main.run_export(datasource_id=FAKE_HPO_ID, folder_prefix=folder_prefix)

        storage_bucket = self.storage_client.get_bucket(self.hpo_bucket)
        bucket_objects = storage_bucket.list_blobs()
        actual_object_names: list = [obj.name for obj in bucket_objects]
        for report in common.ALL_REPORT_FILES:
            prefix: str = f'{folder_prefix}{common.ACHILLES_EXPORT_PREFIX_STRING}{FAKE_HPO_ID}/'
            expected_object_name: str = f'{prefix}{report}'
            self.assertIn(expected_object_name, actual_object_names)

        datasources_json_path: str = folder_prefix + common.ACHILLES_EXPORT_DATASOURCES_JSON
        self.assertIn(datasources_json_path, actual_object_names)

        datasources_blob = storage_bucket.blob(datasources_json_path)
        datasources_json: str = datasources_blob.download_as_bytes().decode()
        datasources_actual: dict = json.loads(datasources_json)
        datasources_expected: dict = {
            'datasources': [{
                'name': FAKE_HPO_ID,
                'folder': FAKE_HPO_ID,
                'cdmVersion': 5
            }]
        }
        self.assertDictEqual(datasources_expected, datasources_actual)

    def test_run_export_without_datasource_id(self):
        # validation/main.py INTEGRATION TEST
        with self.assertRaises(RuntimeError):
            main.run_export(datasource_id=None, target_bucket=None)

    @mock.patch('validation.export.is_hpo_id')
    def test_run_export_with_target_bucket_and_datasource_id(
        self, mock_is_hpo_id):
        # validation/main.py INTEGRATION TEST
        mock_is_hpo_id.return_value = True
        folder_prefix: str = 'dummy-prefix-2018-03-24/'
        bucket_nyc: str = gcs_utils.get_hpo_bucket('nyc')
        main.run_export(datasource_id=FAKE_HPO_ID,
                        folder_prefix=folder_prefix,
                        target_bucket=bucket_nyc)
        storage_bucket = self.storage_client.get_bucket(bucket_nyc)
        bucket_objects = storage_bucket.list_blobs()
        actual_object_names: list = [obj.name for obj in bucket_objects]
        for report in common.ALL_REPORT_FILES:
            prefix: str = f'{folder_prefix}{common.ACHILLES_EXPORT_PREFIX_STRING}{FAKE_HPO_ID}/'
            expected_object_name: str = f'{prefix}{report}'
            self.assertIn(expected_object_name, actual_object_names)
        datasources_json_path: str = f'{folder_prefix}{common.ACHILLES_EXPORT_DATASOURCES_JSON}'
        self.assertIn(datasources_json_path, actual_object_names)

        datasources_blob = storage_bucket.blob(datasources_json_path)
        datasources_json: str = datasources_blob.download_as_bytes().decode()
        datasources_actual: dict = json.loads(datasources_json)
        datasources_expected: dict = {
            'datasources': [{
                'name': FAKE_HPO_ID,
                'folder': FAKE_HPO_ID,
                'cdmVersion': 5
            }]
        }
        self.assertDictEqual(datasources_expected, datasources_actual)

    def tearDown(self):
        self.storage_client.empty_bucket(self.hpo_bucket)
        bucket_nyc: str = gcs_utils.get_hpo_bucket('nyc')
        self.storage_client.empty_bucket(bucket_nyc)

    @classmethod
    def tearDownClass(cls):
        dataset_id = bq_utils.get_dataset_id()
        test_util.delete_all_tables(dataset_id)
Exemple #7
0
class ValidationMainTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.hpo_id = test_util.FAKE_HPO_ID
        self.hpo_bucket = gcs_utils.get_hpo_bucket(self.hpo_id)
        self.project_id = app_identity.get_application_id()
        self.rdr_dataset_id = bq_utils.get_rdr_dataset_id()
        mock_get_hpo_name = mock.patch('validation.main.get_hpo_name')

        self.mock_get_hpo_name = mock_get_hpo_name.start()
        self.mock_get_hpo_name.return_value = 'Fake HPO'
        self.addCleanup(mock_get_hpo_name.stop)

        self.bigquery_dataset_id = bq_utils.get_dataset_id()
        self.folder_prefix = '2019-01-01-v1/'

        self.storage_client = StorageClient(self.project_id)
        self.storage_bucket = self.storage_client.get_bucket(self.hpo_bucket)
        self.storage_client.empty_bucket(self.hpo_bucket)

        test_util.delete_all_tables(self.bigquery_dataset_id)
        self._create_drug_class_table(self.bigquery_dataset_id)

    @staticmethod
    def _create_drug_class_table(bigquery_dataset_id):

        table_name = 'drug_class'
        fields = [{
            "type": "integer",
            "name": "concept_id",
            "mode": "required"
        }, {
            "type": "string",
            "name": "concept_name",
            "mode": "required"
        }, {
            "type": "string",
            "name": "drug_class_name",
            "mode": "required"
        }]
        bq_utils.create_table(table_id=table_name,
                              fields=fields,
                              drop_existing=True,
                              dataset_id=bigquery_dataset_id)

        bq_utils.query(q=main_consts.DRUG_CLASS_QUERY.format(
            dataset_id=bigquery_dataset_id),
                       use_legacy_sql=False,
                       destination_table_id='drug_class',
                       retry_count=bq_consts.BQ_DEFAULT_RETRY_COUNT,
                       write_disposition='WRITE_TRUNCATE',
                       destination_dataset_id=bigquery_dataset_id)

        # ensure concept ancestor table exists
        if not bq_utils.table_exists(common.CONCEPT_ANCESTOR):
            bq_utils.create_standard_table(common.CONCEPT_ANCESTOR,
                                           common.CONCEPT_ANCESTOR)
            q = """INSERT INTO {dataset}.concept_ancestor
            SELECT * FROM {vocab}.concept_ancestor""".format(
                dataset=bigquery_dataset_id, vocab=common.VOCABULARY_DATASET)
            bq_utils.query(q)

    def table_has_clustering(self, table_info):
        clustering = table_info.get('clustering')
        self.assertIsNotNone(clustering)
        fields = clustering.get('fields')
        self.assertSetEqual(set(fields), {'person_id'})
        time_partitioning = table_info.get('timePartitioning')
        self.assertIsNotNone(time_partitioning)
        tpe = time_partitioning.get('type')
        self.assertEqual(tpe, 'DAY')

    def test_all_files_unparseable_output(self):
        # TODO possible bug: if no pre-existing table, results in bq table not found error
        for cdm_table in common.SUBMISSION_FILES:
            cdm_blob = self.storage_bucket.blob(
                f'{self.folder_prefix}{cdm_table}')
            cdm_blob.upload_from_string('.\n .')

        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        folder_items = main.get_folder_items(bucket_items, self.folder_prefix)
        expected_results = [(f, 1, 0, 0) for f in common.SUBMISSION_FILES]
        r = main.validate_submission(self.hpo_id, self.hpo_bucket, folder_items,
                                     self.folder_prefix)
        self.assertSetEqual(set(expected_results), set(r['results']))

    def test_bad_file_names(self):
        bad_file_names: list = [
            "avisit_occurrence.csv",
            "condition_occurence.csv",  # misspelled
            "person_final.csv",
            "procedure_occurrence.tsv"
        ]  # unsupported file extension
        expected_warnings: list = []
        for file_name in bad_file_names:
            bad_blob = self.storage_bucket.blob(
                f'{self.folder_prefix}{file_name}')
            bad_blob.upload_from_string('.')

            expected_item: tuple = (file_name, common.UNKNOWN_FILE)
            expected_warnings.append(expected_item)
        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        folder_items = main.get_folder_items(bucket_items, self.folder_prefix)
        r = main.validate_submission(self.hpo_id, self.hpo_bucket, folder_items,
                                     self.folder_prefix)
        self.assertCountEqual(expected_warnings, r['warnings'])

    @mock.patch('api_util.check_cron')
    def test_validate_five_persons_success(self, mock_check_cron):
        expected_results: list = []
        test_file_names: list = [
            os.path.basename(f) for f in test_util.FIVE_PERSONS_FILES
        ]

        for cdm_filename in common.SUBMISSION_FILES:
            if cdm_filename in test_file_names:
                expected_result: tuple = (cdm_filename, 1, 1, 1)
                test_filepath: str = os.path.join(test_util.FIVE_PERSONS_PATH,
                                                  cdm_filename)
                test_blob = self.storage_bucket.blob(
                    f'{self.folder_prefix}{cdm_filename}')
                test_blob.upload_from_filename(test_filepath)

            else:
                expected_result: tuple = (cdm_filename, 0, 0, 0)
            expected_results.append(expected_result)
        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        folder_items = main.get_folder_items(bucket_items, self.folder_prefix)
        r = main.validate_submission(self.hpo_id, self.hpo_bucket, folder_items,
                                     self.folder_prefix)
        self.assertSetEqual(set(r['results']), set(expected_results))

        # check tables exist and are clustered as expected
        for table in resources.CDM_TABLES + common.PII_TABLES:
            table_id = bq_utils.get_table_id(test_util.FAKE_HPO_ID, table)
            table_info = bq_utils.get_table_info(table_id)
            fields = resources.fields_for(table)
            field_names = [field['name'] for field in fields]
            if 'person_id' in field_names:
                self.table_has_clustering(table_info)

    @mock.patch('validation.main.updated_datetime_object')
    def test_check_processed(self, mock_updated_datetime_object):

        mock_updated_datetime_object.return_value = datetime.datetime.today(
        ) - datetime.timedelta(minutes=7)

        for fname in common.AOU_REQUIRED_FILES:
            blob_name: str = f'{self.folder_prefix}{fname}'
            test_blob = self.storage_bucket.blob(blob_name)
            test_blob.upload_from_string('\n')

            sleep(1)

        blob_name: str = f'{self.folder_prefix}{common.PROCESSED_TXT}'
        test_blob = self.storage_bucket.blob(blob_name)
        test_blob.upload_from_string('\n')

        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        result = main._get_submission_folder(self.hpo_bucket,
                                             bucket_items,
                                             force_process=False)
        self.assertIsNone(result)
        result = main._get_submission_folder(self.hpo_bucket,
                                             bucket_items,
                                             force_process=True)
        self.assertEqual(result, self.folder_prefix)

    @mock.patch('api_util.check_cron')
    def test_copy_five_persons(self, mock_check_cron):
        # upload all five_persons files
        for cdm_pathfile in test_util.FIVE_PERSONS_FILES:
            test_filename: str = os.path.basename(cdm_pathfile)

            blob_name: str = f'{self.folder_prefix}{test_filename}'
            test_blob = self.storage_bucket.blob(blob_name)
            test_blob.upload_from_filename(cdm_pathfile)

            blob_name: str = f'{self.folder_prefix}{self.folder_prefix}{test_filename}'
            test_blob = self.storage_bucket.blob(blob_name)
            test_blob.upload_from_filename(cdm_pathfile)

        main.app.testing = True
        with main.app.test_client() as c:
            c.get(test_util.COPY_HPO_FILES_URL)
            prefix = test_util.FAKE_HPO_ID + '/' + self.hpo_bucket + '/' + self.folder_prefix
            expected_bucket_items = [
                prefix + item.split(os.sep)[-1]
                for item in test_util.FIVE_PERSONS_FILES
            ]
            expected_bucket_items.extend([
                prefix + self.folder_prefix + item.split(os.sep)[-1]
                for item in test_util.FIVE_PERSONS_FILES
            ])

            list_bucket_result = gcs_utils.list_bucket(
                gcs_utils.get_drc_bucket())
            actual_bucket_items = [item['name'] for item in list_bucket_result]
            self.assertSetEqual(set(expected_bucket_items),
                                set(actual_bucket_items))

    def test_target_bucket_upload(self):
        bucket_nyc = gcs_utils.get_hpo_bucket('nyc')
        folder_prefix = 'test-folder-fake/'
        self.storage_client.empty_bucket(bucket_nyc)

        main._upload_achilles_files(hpo_id=None,
                                    folder_prefix=folder_prefix,
                                    target_bucket=bucket_nyc)
        actual_bucket_files = set(
            [item['name'] for item in gcs_utils.list_bucket(bucket_nyc)])
        expected_bucket_files = set([
            'test-folder-fake/' + item
            for item in resources.ALL_ACHILLES_INDEX_FILES
        ])
        self.assertSetEqual(expected_bucket_files, actual_bucket_files)

    @mock.patch('api_util.check_cron')
    def test_pii_files_loaded(self, mock_check_cron):
        # tests if pii files are loaded
        test_file_paths: list = [
            test_util.PII_NAME_FILE, test_util.PII_MRN_BAD_PERSON_ID_FILE
        ]
        test_file_names: list = [os.path.basename(f) for f in test_file_paths]

        blob_name: str = f'{self.folder_prefix}{os.path.basename(test_util.PII_NAME_FILE)}'
        test_blob = self.storage_bucket.blob(blob_name)
        test_blob.upload_from_filename(test_util.PII_NAME_FILE)

        blob_name: str = f'{self.folder_prefix}{os.path.basename(test_util.PII_MRN_BAD_PERSON_ID_FILE)}'
        test_blob = self.storage_bucket.blob(blob_name)
        test_blob.upload_from_filename(test_util.PII_MRN_BAD_PERSON_ID_FILE)

        rs = resources.csv_to_list(test_util.PII_FILE_LOAD_RESULT_CSV)
        expected_results = [(r['file_name'], int(r['found']), int(r['parsed']),
                             int(r['loaded'])) for r in rs]
        for f in common.SUBMISSION_FILES:
            if f not in test_file_names:
                expected_result = (f, 0, 0, 0)
                expected_results.append(expected_result)

        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        folder_items = main.get_folder_items(bucket_items, self.folder_prefix)
        r = main.validate_submission(self.hpo_id, self.hpo_bucket, folder_items,
                                     self.folder_prefix)
        self.assertSetEqual(set(expected_results), set(r['results']))

    @mock.patch('validation.main.updated_datetime_object')
    @mock.patch('validation.main._has_all_required_files')
    @mock.patch('validation.main.all_required_files_loaded')
    @mock.patch('validation.main.is_first_validation_run')
    @mock.patch('api_util.check_cron')
    def test_html_report_five_person(self, mock_check_cron, mock_first_run,
                                     mock_required_files_loaded,
                                     mock_has_all_required_files,
                                     mock_updated_datetime_object):
        mock_required_files_loaded.return_value = False
        mock_first_run.return_value = False
        mock_has_all_required_files.return_value = True
        mock_updated_datetime_object.return_value = datetime.datetime.today(
        ) - datetime.timedelta(minutes=7)

        for cdm_file in test_util.FIVE_PERSONS_FILES:
            blob_name = f'{self.folder_prefix}{os.path.basename(cdm_file)}'
            test_blob = self.storage_bucket.blob(blob_name)
            test_blob.upload_from_filename(cdm_file)

        # load person table in RDR
        bq_utils.load_table_from_csv(self.project_id, self.rdr_dataset_id,
                                     common.PERSON,
                                     test_util.FIVE_PERSONS_PERSON_CSV)

        # Load measurement_concept_sets
        required_labs.load_measurement_concept_sets_table(
            project_id=self.project_id, dataset_id=self.bigquery_dataset_id)
        # Load measurement_concept_sets_descendants
        required_labs.load_measurement_concept_sets_descendants_table(
            project_id=self.project_id, dataset_id=self.bigquery_dataset_id)

        main.app.testing = True
        with main.app.test_client() as c:
            c.get(test_util.VALIDATE_HPO_FILES_URL)
            actual_result = test_util.read_cloud_file(
                self.hpo_bucket, self.folder_prefix + common.RESULTS_HTML)

        # ensure emails are not sent
        bucket_items = gcs_utils.list_bucket(self.hpo_bucket)
        folder_items = main.get_folder_items(bucket_items, self.folder_prefix)
        self.assertFalse(main.is_first_validation_run(folder_items))

        # parse html
        soup = bs(actual_result, parser="lxml", features="lxml")
        missing_pii_html_table = soup.find('table', id='missing_pii')
        table_headers = missing_pii_html_table.find_all('th')
        self.assertEqual('Missing Participant Record Type',
                         table_headers[0].get_text())
        self.assertEqual('Count', table_headers[1].get_text())

        table_rows = missing_pii_html_table.find_next('tbody').find_all('tr')
        missing_record_types = [
            table_row.find('td').text for table_row in table_rows
        ]
        self.assertIn(main_consts.EHR_NO_PII, missing_record_types)
        self.assertIn(main_consts.PII_NO_EHR, missing_record_types)

        # the missing from RDR component is obsolete (see DC-1932)
        # this is to confirm it was removed successfully from the report
        rdr_date = '2020-01-01'
        self.assertNotIn(main_consts.EHR_NO_RDR.format(date=rdr_date),
                         missing_record_types)
        self.assertIn(main_consts.EHR_NO_PARTICIPANT_MATCH,
                      missing_record_types)

        required_lab_html_table = soup.find('table', id='required-lab')
        table_headers = required_lab_html_table.find_all('th')
        self.assertEqual(3, len(table_headers))
        self.assertEqual('Ancestor Concept ID', table_headers[0].get_text())
        self.assertEqual('Ancestor Concept Name', table_headers[1].get_text())
        self.assertEqual('Found', table_headers[2].get_text())

        table_rows = required_lab_html_table.find_next('tbody').find_all('tr')
        table_rows_last_column = [
            table_row.find_all('td')[-1] for table_row in table_rows
        ]
        submitted_labs = [
            row for row in table_rows_last_column
            if 'result-1' in row.attrs['class']
        ]
        missing_labs = [
            row for row in table_rows_last_column
            if 'result-0' in row.attrs['class']
        ]
        self.assertTrue(len(table_rows) > 0)
        self.assertTrue(len(submitted_labs) > 0)
        self.assertTrue(len(missing_labs) > 0)

    def tearDown(self):
        self.storage_client.empty_bucket(self.hpo_bucket)
        bucket_nyc = gcs_utils.get_hpo_bucket('nyc')
        self.storage_client.empty_bucket(bucket_nyc)
        self.storage_client.empty_bucket(gcs_utils.get_drc_bucket())
        test_util.delete_all_tables(self.bigquery_dataset_id)
class TopHeelErrorsTest(TestCase):
    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.project_id = app_identity.get_application_id()
        self.dataset_id = bq_utils.get_dataset_id()
        self.bucket: str = gcs_utils.get_drc_bucket()
        self.storage_client = StorageClient(self.project_id)

        self.storage_client.empty_bucket(self.bucket)
        test_util.delete_all_tables(self.dataset_id)
        self.load_test_data(hpo_id=HPO_NYC)

    def load_test_data(self, hpo_id: str = None):
        """
        Load to bq test achilles heel results data from csv file

        :param hpo_id: if specified, prefix to use on csv test file and bq table, otherwise no prefix is used
        :return: contents of the file as list of objects
        """

        table_name: str = common.ACHILLES_HEEL_RESULTS
        if hpo_id is not None:
            table_id: str = bq_utils.get_table_id(hpo_id, table_name)
        else:
            table_id: str = table_name
        test_file_name: str = f'{table_id}.csv'
        test_file_path: str = os.path.join(test_util.TEST_DATA_PATH,
                                           test_file_name)

        target_bucket = self.storage_client.get_bucket(self.bucket)
        test_blob = target_bucket.blob(test_file_name)
        test_blob.upload_from_filename(test_file_path)

        gcs_path: str = f'gs://{self.bucket}/{test_file_name}'
        load_results = bq_utils.load_csv(table_name, gcs_path, self.project_id,
                                         self.dataset_id, table_id)
        job_id = load_results['jobReference']['jobId']
        bq_utils.wait_on_jobs([job_id])
        return resources.csv_to_list(test_file_path)

    def test_top_heel_errors_no_hpo_prefix(self):
        rows = self.load_test_data()
        for row in rows:
            row[FIELD_DATASET_NAME] = self.dataset_id
        errors = top_n_errors(rows)
        expected_results = comparison_view(errors)
        dataset_errors = top_heel_errors(self.project_id, self.dataset_id)
        actual_results = comparison_view(dataset_errors)
        self.assertCountEqual(actual_results, expected_results)

    @mock.patch('tools.top_heel_errors.get_hpo_ids')
    def test_top_heel_errors_all_hpo(self, mock_hpo_ids):
        hpo_ids = [HPO_NYC, HPO_PITT]
        mock_hpo_ids.return_value = hpo_ids
        expected_results = []
        for hpo_id in [HPO_NYC, HPO_PITT]:
            rows = self.load_test_data(hpo_id)
            for row in rows:
                row[FIELD_DATASET_NAME] = hpo_id
            errors = top_n_errors(rows)
            expected_results += comparison_view(errors)
        dataset_errors = top_heel_errors(self.project_id,
                                         self.dataset_id,
                                         all_hpo=True)
        actual_results = comparison_view(dataset_errors)
        self.assertCountEqual(actual_results, expected_results)

    def tearDown(self):
        self.storage_client.empty_bucket(self.bucket)
        test_util.delete_all_tables(self.dataset_id)
class GcsClientTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.project_id = app_identity.get_application_id()
        self.client = StorageClient(self.project_id)
        self.bucket_name: str = os.environ.get('BUCKET_NAME_FAKE')
        self.prefix: str = 'prefix'
        self.data: bytes = b'bytes'

        # NOTE: this needs to be in sorted order
        self.sub_prefixes: tuple = (f'{self.prefix}/a', f'{self.prefix}/b',
                                    f'{self.prefix}/c', f'{self.prefix}/d')
        self.client.empty_bucket(self.bucket_name)
        self._stage_bucket()

    def test_get_bucket_items_metadata(self):

        items_metadata: list = self.client.get_bucket_items_metadata(
            self.bucket_name)

        actual_metadata: list = [item['name'] for item in items_metadata]
        expected_metadata: list = [
            f'{prefix}/obj.txt' for prefix in self.sub_prefixes
        ]

        self.assertCountEqual(actual_metadata, expected_metadata)
        self.assertIsNotNone(items_metadata[0]['id'])

    def test_get_blob_metadata(self):

        bucket = self.client.get_bucket(self.bucket_name)
        blob_name: str = f'{self.sub_prefixes[0]}/obj.txt'

        blob = bucket.blob(blob_name)
        metadata: dict = self.client.get_blob_metadata(blob)

        self.assertIsNotNone(metadata['id'])
        self.assertIsNotNone(metadata['name'])
        self.assertIsNotNone(metadata['bucket'])
        self.assertIsNotNone(metadata['generation'])
        self.assertIsNotNone(metadata['metageneration'])
        self.assertIsNotNone(metadata['contentType'])
        self.assertIsNotNone(metadata['storageClass'])
        self.assertIsNotNone(metadata['size'])
        self.assertIsNotNone(metadata['md5Hash'])
        self.assertIsNotNone(metadata['crc32c'])
        self.assertIsNotNone(metadata['etag'])
        self.assertIsNotNone(metadata['updated'])
        self.assertIsNotNone(metadata['timeCreated'])

        self.assertEqual(metadata['name'], blob_name)
        self.assertEqual(metadata['size'], len(self.data))

    def test_empty_bucket(self):

        self.client.empty_bucket(self.bucket_name)
        items: list = self.client.list_blobs(self.bucket_name)

        # check that bucket is empty
        self.assertCountEqual(items, [])

    def test_list_sub_prefixes(self):

        items: list = self.client.list_sub_prefixes(self.bucket_name,
                                                    self.prefix)

        # Check same number of elements
        self.assertEqual(len(self.sub_prefixes), len(items))

        # Check same prefix
        for index, item in enumerate(items):
            self.assertEqual(item[:-1], self.sub_prefixes[index])

    def _stage_bucket(self):

        bucket = self.client.bucket(self.bucket_name)
        for sub_prefix in self.sub_prefixes:
            blob = bucket.blob(f'{sub_prefix}/obj.txt')
            blob.upload_from_string(self.data)

    def tearDown(self):
        self.client.empty_bucket(self.bucket_name)
Exemple #10
0
class EhrUnionTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.project_id = bq_utils.app_identity.get_application_id()
        self.hpo_ids = [PITT_HPO_ID, NYC_HPO_ID, EXCLUDED_HPO_ID]
        self.input_dataset_id = bq_utils.get_dataset_id()
        self.output_dataset_id = bq_utils.get_unioned_dataset_id()
        self.storage_client = StorageClient(self.project_id)
        self.tearDown()

        # TODO Generalize to work for all foreign key references
        # Collect all primary key fields in CDM tables
        mapped_fields = []
        for table in cdm.tables_to_map():
            field = table + '_id'
            mapped_fields.append(field)
        self.mapped_fields = mapped_fields
        self.implemented_foreign_keys = [
            eu_constants.VISIT_OCCURRENCE_ID, eu_constants.VISIT_DETAIL_ID,
            eu_constants.CARE_SITE_ID, eu_constants.LOCATION_ID
        ]

    def _empty_hpo_buckets(self):
        for hpo_id in self.hpo_ids:
            bucket = gcs_utils.get_hpo_bucket(hpo_id)
            self.storage_client.empty_bucket(bucket)

    def _create_hpo_table(self, hpo_id, table, dataset_id):
        table_id = bq_utils.get_table_id(hpo_id, table)
        bq_utils.create_table(table_id,
                              resources.fields_for(table),
                              dataset_id=dataset_id)
        return table_id

    def _load_datasets(self):
        """
        Load five persons data for nyc and pitt test hpo and rdr data for the excluded_hpo
        # expected_tables is for testing output
        # it maps table name to list of expected records ex: "unioned_ehr_visit_occurrence" -> [{}, {}, ...]
        """
        expected_tables: dict = {}
        running_jobs: list = []
        for cdm_table in resources.CDM_TABLES:
            output_table: str = ehr_union.output_table_for(cdm_table)
            expected_tables[output_table] = []
            for hpo_id in self.hpo_ids:
                # upload csv into hpo bucket
                cdm_filename: str = f'{cdm_table}.csv'
                if hpo_id == NYC_HPO_ID:
                    cdm_filepath: str = os.path.join(
                        test_util.FIVE_PERSONS_PATH, cdm_filename)
                elif hpo_id == PITT_HPO_ID:
                    cdm_filepath: str = os.path.join(
                        test_util.PITT_FIVE_PERSONS_PATH, cdm_filename)
                elif hpo_id == EXCLUDED_HPO_ID:
                    if cdm_table in [
                            'observation', 'person', 'visit_occurrence'
                    ]:
                        cdm_filepath: str = os.path.join(
                            test_util.RDR_PATH, cdm_filename)
                bucket: str = gcs_utils.get_hpo_bucket(hpo_id)
                gcs_bucket = self.storage_client.get_bucket(bucket)
                if os.path.exists(cdm_filepath):

                    csv_rows = resources.csv_to_list(cdm_filepath)
                    cdm_blob = gcs_bucket.blob(cdm_filename)
                    cdm_blob.upload_from_filename(cdm_filepath)

                else:
                    # results in empty table
                    cdm_blob = gcs_bucket.blob(cdm_filename)
                    cdm_blob.upload_from_string('dummy\n')
                    csv_rows: list = []
                # load table from csv
                result = bq_utils.load_cdm_csv(hpo_id, cdm_table)
                running_jobs.append(result['jobReference']['jobId'])
                if hpo_id != EXCLUDED_HPO_ID:
                    expected_tables[output_table] += list(csv_rows)
        # ensure person to observation output is as expected
        output_table_person: str = ehr_union.output_table_for(common.PERSON)
        output_table_observation: str = ehr_union.output_table_for(
            common.OBSERVATION)
        expected_tables[output_table_observation] += 4 * expected_tables[
            output_table_person]

        incomplete_jobs: list = bq_utils.wait_on_jobs(running_jobs)
        if len(incomplete_jobs) > 0:
            message: str = "Job id(s) %s failed to complete" % incomplete_jobs
            raise RuntimeError(message)
        self.expected_tables = expected_tables

    def _table_has_clustering(self, table_info):
        clustering = table_info.get('clustering')
        self.assertIsNotNone(clustering)
        fields = clustering.get('fields')
        self.assertSetEqual(set(fields), {'person_id'})
        time_partitioning = table_info.get('timePartitioning')
        self.assertIsNotNone(time_partitioning)
        tpe = time_partitioning.get('type')
        self.assertEqual(tpe, 'DAY')

    def _dataset_tables(self, dataset_id):
        """
        Get names of existing tables in specified dataset

        :param dataset_id: identifies the dataset
        :return: list of table_ids
        """
        tables = bq_utils.list_tables(dataset_id)
        return [table['tableReference']['tableId'] for table in tables]

    @mock.patch('bq_utils.get_hpo_info')
    def test_union_ehr(self, mock_hpo_info):
        self._load_datasets()
        input_tables_before = set(self._dataset_tables(self.input_dataset_id))

        # output should be mapping tables and cdm tables
        output_tables_before = self._dataset_tables(self.output_dataset_id)
        mapping_tables = [
            ehr_union.mapping_table_for(table)
            for table in cdm.tables_to_map() + [combine_ehr_rdr.PERSON_TABLE]
        ]
        output_cdm_tables = [
            ehr_union.output_table_for(table) for table in resources.CDM_TABLES
        ]
        expected_output = set(output_tables_before + mapping_tables +
                              output_cdm_tables)

        mock_hpo_info.return_value = [{
            'hpo_id': hpo_id
        } for hpo_id in self.hpo_ids]

        # perform ehr union
        ehr_union.main(self.input_dataset_id, self.output_dataset_id,
                       self.project_id, [EXCLUDED_HPO_ID])

        # input dataset should be unchanged
        input_tables_after = set(self._dataset_tables(self.input_dataset_id))
        self.assertSetEqual(input_tables_before, input_tables_after)

        # fact_relationship from pitt
        hpo_unique_identifiers = ehr_union.get_hpo_offsets(self.hpo_ids)
        pitt_offset = hpo_unique_identifiers[PITT_HPO_ID]
        q = '''SELECT fact_id_1, fact_id_2
               FROM `{input_dataset}.{hpo_id}_fact_relationship`
               where domain_concept_id_1 = 21 and domain_concept_id_2 = 21'''.format(
            input_dataset=self.input_dataset_id, hpo_id=PITT_HPO_ID)
        response = bq_utils.query(q)
        result = bq_utils.response2rows(response)

        expected_fact_id_1 = result[0]["fact_id_1"] + pitt_offset
        expected_fact_id_2 = result[0]["fact_id_2"] + pitt_offset

        q = '''SELECT fr.fact_id_1, fr.fact_id_2 FROM `{dataset_id}.unioned_ehr_fact_relationship` fr
            join `{dataset_id}._mapping_measurement` mm on fr.fact_id_1 = mm.measurement_id
            and mm.src_hpo_id = "{hpo_id}"'''.format(
            dataset_id=self.output_dataset_id, hpo_id=PITT_HPO_ID)
        response = bq_utils.query(q)
        result = bq_utils.response2rows(response)
        actual_fact_id_1, actual_fact_id_2 = result[0]["fact_id_1"], result[0][
            "fact_id_2"]
        self.assertEqual(expected_fact_id_1, actual_fact_id_1)
        self.assertEqual(expected_fact_id_2, actual_fact_id_2)

        # mapping tables
        tables_to_map = cdm.tables_to_map()
        for table_to_map in tables_to_map:
            mapping_table = ehr_union.mapping_table_for(table_to_map)
            expected_fields = {
                'src_table_id',
                'src_%s_id' % table_to_map,
                '%s_id' % table_to_map, 'src_hpo_id', 'src_dataset_id'
            }
            mapping_table_info = bq_utils.get_table_info(
                mapping_table, dataset_id=self.output_dataset_id)
            mapping_table_fields = mapping_table_info.get('schema', dict()).get(
                'fields', [])
            actual_fields = set([f['name'] for f in mapping_table_fields])
            message = 'Table %s has fields %s when %s expected' % (
                mapping_table, actual_fields, expected_fields)
            self.assertSetEqual(expected_fields, actual_fields, message)
            result_table = ehr_union.output_table_for(table_to_map)
            expected_num_rows = len(self.expected_tables[result_table])
            actual_num_rows = int(mapping_table_info.get('numRows', -1))
            message = 'Table %s has %s rows when %s expected' % (
                mapping_table, actual_num_rows, expected_num_rows)
            self.assertEqual(expected_num_rows, actual_num_rows, message)

        # check for each output table
        for table_name in resources.CDM_TABLES:
            # output table exists and row count is sum of those submitted by hpos
            result_table = ehr_union.output_table_for(table_name)
            expected_rows = self.expected_tables[result_table]
            expected_count = len(expected_rows)
            table_info = bq_utils.get_table_info(
                result_table, dataset_id=self.output_dataset_id)
            actual_count = int(table_info.get('numRows'))
            msg = 'Unexpected row count in table {result_table} after ehr union'.format(
                result_table=result_table)
            self.assertEqual(expected_count, actual_count, msg)
            # TODO Compare table rows to expected accounting for the new ids and ignoring field types
            # q = 'SELECT * FROM {dataset}.{table}'.format(dataset=self.output_dataset_id, table=result_table)
            # query_response = bq_utils.query(q)
            # actual_rows = bq_utils.response2rows(query_response)

            # output table has clustering on person_id where applicable
            fields = resources.fields_for(table_name)
            field_names = [field['name'] for field in fields]
            if 'person_id' in field_names:
                self._table_has_clustering(table_info)

        actual_output = set(self._dataset_tables(self.output_dataset_id))
        self.assertSetEqual(expected_output, actual_output)

        # explicit check that output person_ids are same as input
        nyc_person_table_id = bq_utils.get_table_id(NYC_HPO_ID, 'person')
        pitt_person_table_id = bq_utils.get_table_id(PITT_HPO_ID, 'person')
        q = '''SELECT DISTINCT person_id FROM (
           SELECT person_id FROM {dataset_id}.{nyc_person_table_id}
           UNION ALL
           SELECT person_id FROM {dataset_id}.{pitt_person_table_id}
        ) ORDER BY person_id ASC'''.format(
            dataset_id=self.input_dataset_id,
            nyc_person_table_id=nyc_person_table_id,
            pitt_person_table_id=pitt_person_table_id)
        response = bq_utils.query(q)
        expected_rows = bq_utils.response2rows(response)
        person_table_id = ehr_union.output_table_for('person')
        q = '''SELECT DISTINCT person_id
               FROM {dataset_id}.{table_id}
               ORDER BY person_id ASC'''.format(
            dataset_id=self.output_dataset_id, table_id=person_table_id)
        response = bq_utils.query(q)
        actual_rows = bq_utils.response2rows(response)
        self.assertCountEqual(expected_rows, actual_rows)

    # TODO Figure out a good way to test query structure
    # One option may be for each query under test to generate an abstract syntax tree
    # (using e.g. https://github.com/andialbrecht/sqlparse) and compare it to an expected tree fragment.
    # Functions below are for reference

    def convert_ehr_person_to_observation(self, person_row):
        obs_rows = []
        dob_row = {
            'observation_concept_id': eu_constants.DOB_CONCEPT_ID,
            'observation_source_value': None,
            'value_as_string': person_row['birth_datetime'],
            'person_id': person_row['person_id'],
            'observation_date': person_row['birth_date'],
            'value_as_concept_id': None
        }
        gender_row = {
            'observation_concept_id': eu_constants.GENDER_CONCEPT_ID,
            'observation_source_value': person_row['gender_source_value'],
            'value_as_string': None,
            'person_id': person_row['person_id'],
            'observation_date': person_row['birth_date'],
            'value_as_concept_id': person_row['gender_concept_id']
        }
        race_row = {
            'observation_concept_id': eu_constants.RACE_CONCEPT_ID,
            'observation_source_value': person_row['race_source_value'],
            'value_as_string': None,
            'person_id': person_row['person_id'],
            'observation_date': person_row['birth_date'],
            'value_as_concept_id': person_row['race_concept_id']
        }
        ethnicity_row = {
            'observation_concept_id': eu_constants.ETHNICITY_CONCEPT_ID,
            'observation_source_value': person_row['ethnicity_source_value'],
            'value_as_string': None,
            'person_id': person_row['person_id'],
            'observation_date': person_row['birth_date'],
            'value_as_concept_id': person_row['ethnicity_concept_id']
        }
        obs_rows.extend([dob_row, gender_row, race_row, ethnicity_row])
        return obs_rows

    @mock.patch('bq_utils.get_hpo_info')
    @mock.patch('resources.CDM_TABLES', [
        common.PERSON, common.OBSERVATION, common.LOCATION, common.CARE_SITE,
        common.VISIT_OCCURRENCE, common.VISIT_DETAIL
    ])
    @mock.patch('cdm.tables_to_map')
    def test_ehr_person_to_observation(self, mock_tables_map, mock_hpo_info):
        # ehr person table converts to observation records
        self._load_datasets()
        mock_tables_map.return_value = [
            common.OBSERVATION, common.LOCATION, common.CARE_SITE,
            common.VISIT_OCCURRENCE, common.VISIT_DETAIL
        ]

        mock_hpo_info.return_value = [{
            'hpo_id': hpo_id
        } for hpo_id in self.hpo_ids]

        # perform ehr union
        ehr_union.main(self.input_dataset_id, self.output_dataset_id,
                       self.project_id)

        person_query = '''
            SELECT
                p.person_id,
                gender_concept_id,
                gender_source_value,
                race_concept_id,
                race_source_value,
                CAST(birth_datetime AS STRING) AS birth_datetime,
                ethnicity_concept_id,
                ethnicity_source_value,
                EXTRACT(DATE FROM birth_datetime) AS birth_date
            FROM {output_dataset_id}.unioned_ehr_person p
            JOIN {output_dataset_id}._mapping_person AS mp
                ON mp.person_id = p.person_id
            '''.format(output_dataset_id=self.output_dataset_id)
        person_response = bq_utils.query(person_query)
        person_rows = bq_utils.response2rows(person_response)

        # construct dicts of expected values
        expected = []
        for person_row in person_rows:
            expected.extend(self.convert_ehr_person_to_observation(person_row))

        # query for observation table records
        query = '''
            SELECT person_id,
                    observation_concept_id,
                    value_as_concept_id,
                    value_as_string,
                    observation_source_value,
                    observation_date
            FROM {output_dataset_id}.unioned_ehr_observation AS obs
            WHERE obs.observation_concept_id IN ({gender_concept_id},{race_concept_id},{dob_concept_id},
            {ethnicity_concept_id})
            '''

        obs_query = query.format(
            output_dataset_id=self.output_dataset_id,
            gender_concept_id=eu_constants.GENDER_CONCEPT_ID,
            race_concept_id=eu_constants.RACE_CONCEPT_ID,
            dob_concept_id=eu_constants.DOB_CONCEPT_ID,
            ethnicity_concept_id=eu_constants.ETHNICITY_CONCEPT_ID)
        obs_response = bq_utils.query(obs_query)
        obs_rows = bq_utils.response2rows(obs_response)
        actual = obs_rows

        self.assertCountEqual(expected, actual)

    @mock.patch('bq_utils.get_hpo_info')
    @mock.patch('resources.CDM_TABLES', [
        common.PERSON, common.OBSERVATION, common.LOCATION, common.CARE_SITE,
        common.VISIT_OCCURRENCE, common.VISIT_DETAIL
    ])
    @mock.patch('cdm.tables_to_map')
    def test_ehr_person_to_observation_counts(self, mock_tables_map,
                                              mock_hpo_info):
        self._load_datasets()
        mock_tables_map.return_value = [
            common.OBSERVATION, common.LOCATION, common.CARE_SITE,
            common.VISIT_OCCURRENCE, common.VISIT_DETAIL
        ]

        mock_hpo_info.return_value = [{
            'hpo_id': hpo_id
        } for hpo_id in self.hpo_ids]

        # perform ehr union
        ehr_union.main(self.input_dataset_id, self.output_dataset_id,
                       self.project_id)

        q_person = '''
                    SELECT p.*
                    FROM {output_dataset_id}.unioned_ehr_person AS p
                    JOIN {output_dataset_id}._mapping_person AS mp
                        ON mp.person_id = p.person_id
                    '''.format(output_dataset_id=self.output_dataset_id)
        person_response = bq_utils.query(q_person)
        person_rows = bq_utils.response2rows(person_response)
        q_observation = '''
                    SELECT *
                    FROM {output_dataset_id}.unioned_ehr_observation
                    WHERE observation_type_concept_id = 38000280
                    '''.format(output_dataset_id=self.output_dataset_id)
        # observation should contain 4 records of type EHR per person per hpo
        expected = len(person_rows) * 4
        observation_response = bq_utils.query(q_observation)
        observation_rows = bq_utils.response2rows(observation_response)
        actual = len(observation_rows)
        self.assertEqual(
            actual, expected,
            'Expected %s EHR person records in observation but found %s' %
            (expected, actual))

    def _test_table_hpo_subquery(self):
        # person is a simple select, no ids should be mapped
        person = ehr_union.table_hpo_subquery('person',
                                              hpo_id=NYC_HPO_ID,
                                              input_dataset_id='input',
                                              output_dataset_id='output')

        # _mapping_visit_occurrence(src_table_id, src_visit_occurrence_id, visit_occurrence_id)
        # visit_occurrence_id should be mapped
        visit_occurrence = ehr_union.table_hpo_subquery(
            'visit_occurrence',
            hpo_id=NYC_HPO_ID,
            input_dataset_id='input',
            output_dataset_id='output')

        # visit_occurrence_id and condition_occurrence_id should be mapped
        condition_occurrence = ehr_union.table_hpo_subquery(
            'condition_occurrence',
            hpo_id=NYC_HPO_ID,
            input_dataset_id='input',
            output_dataset_id='output')

    def get_table_hpo_subquery_error(self, table, dataset_in, dataset_out):
        subquery = ehr_union.table_hpo_subquery(table, NYC_HPO_ID, dataset_in,
                                                dataset_out)

        # moz-sql-parser doesn't support the ROW_NUMBER() OVER() a analytical function of sql we are removing
        # that statement from the returned query for the parser be able to parse out the query without erroring out.

        subquery = re.sub(
            r",\s+ROW_NUMBER\(\) OVER \(PARTITION BY nm\..+?_id\) AS row_num",
            " ", subquery)
        # offset is being used as a column-name in note_nlp table.
        # Although, BigQuery does not throw any errors for this, moz_sql_parser indentifies as a SQL Keyword.
        # So, change required only in Test Script as a workaround.
        if 'offset,' in subquery:
            subquery = subquery.replace('offset,', '"offset",')
        stmt = moz_sql_parser.parse(subquery)

        # Sanity check it is a select statement
        if 'select' not in stmt:
            return SUBQUERY_FAIL_MSG.format(expr='query type',
                                            table=table,
                                            expected='select',
                                            actual=str(stmt),
                                            subquery=subquery)

        # Input table should be first in FROM expression
        actual_from = first_or_none(
            dpath.util.values(stmt, 'from/0/value/from/value') or
            dpath.util.values(stmt, 'from'))
        expected_from = dataset_in + '.' + bq_utils.get_table_id(
            NYC_HPO_ID, table)
        if expected_from != actual_from:
            return SUBQUERY_FAIL_MSG.format(expr='first object in FROM',
                                            table=table,
                                            expected=expected_from,
                                            actual=actual_from,
                                            subquery=subquery)

        # Ensure all key fields (primary or foreign) yield joins with their associated mapping tables
        # Note: ordering of joins in the subquery is assumed to be consistent with field order in the json file
        fields = resources.fields_for(table)
        id_field = table + '_id'
        key_ind = 0
        expected_join = None
        actual_join = None
        for field in fields:
            if field['name'] in self.mapped_fields:
                # key_ind += 1  # TODO use this increment when we generalize solution for all foreign keys
                if field['name'] == id_field:
                    # Primary key, mapping table associated with this one should be INNER joined
                    key_ind += 1
                    expr = 'inner join on primary key'
                    actual_join = first_or_none(
                        dpath.util.values(stmt, 'from/%s/join/value' % key_ind))
                    expected_join = dataset_out + '.' + ehr_union.mapping_table_for(
                        table)
                elif field['name'] in self.implemented_foreign_keys:
                    # Foreign key, mapping table associated with the referenced table should be LEFT joined
                    key_ind += 1
                    expr = 'left join on foreign key'
                    # Visit_detail table has 'visit_occurrence' column after 'care_site', which is different from
                    # other cdm tables, where 'visit_occurrence' comes before other foreign_keys.
                    # The test expects the same order as other cmd tables, so the expected-query has
                    # 'visit_occurrence' before 'care_site'. The following reorder is required to match the sequence
                    # to the actual-query.
                    if table == 'visit_detail' and key_ind == 2:
                        stmt['from'][2], stmt['from'][3] = stmt['from'][
                            3], stmt['from'][2]
                    actual_join = first_or_none(
                        dpath.util.values(stmt,
                                          'from/%s/left join/value' % key_ind))
                    joined_table = field['name'].replace('_id', '')
                    expected_join = dataset_out + '.' + ehr_union.mapping_table_for(
                        joined_table)
                if expected_join != actual_join:
                    return SUBQUERY_FAIL_MSG.format(expr=expr,
                                                    table=table,
                                                    expected=expected_join,
                                                    actual=actual_join,
                                                    subquery=subquery)

    def test_hpo_subquery(self):
        input_dataset_id = 'input'
        output_dataset_id = 'output'
        subquery_fails = []

        # Key fields should be populated using associated mapping tables
        for table in resources.CDM_TABLES:
            # This condition is to exempt person table from table hpo sub query
            if table != common.PERSON:
                subquery_fail = self.get_table_hpo_subquery_error(
                    table, input_dataset_id, output_dataset_id)
                if subquery_fail is not None:
                    subquery_fails.append(subquery_fail)

        if len(subquery_fails) > 0:
            self.fail('\n\n'.join(subquery_fails))

    def tearDown(self):
        self._empty_hpo_buckets()
        test_util.delete_all_tables(self.input_dataset_id)
        test_util.delete_all_tables(self.output_dataset_id)
class RequiredLabsTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        print('**************************************************************')
        print(cls.__name__)
        print('**************************************************************')

    def setUp(self):
        self.hpo_bucket = gcs_utils.get_hpo_bucket(FAKE_HPO_ID)
        self.project_id = app_identity.get_application_id()
        self.dataset_id = bq_utils.get_dataset_id()
        self.rdr_dataset_id = bq_utils.get_rdr_dataset_id()
        self.folder_prefix = '2019-01-01/'
        test_util.delete_all_tables(self.dataset_id)

        self.storage_client = StorageClient(self.project_id)
        self.storage_client.empty_bucket(self.hpo_bucket)

        self.client = bq.get_client(self.project_id)

        mock_get_hpo_name = mock.patch('validation.main.get_hpo_name')

        self.mock_get_hpo_name = mock_get_hpo_name.start()
        self.mock_get_hpo_name.return_value = 'Fake HPO'
        self.addCleanup(mock_get_hpo_name.stop)

        self._load_data()

    def tearDown(self):
        test_util.delete_all_tables(bq_utils.get_dataset_id())
        self.storage_client.empty_bucket(self.hpo_bucket)

    def _load_data(self):

        # Load measurement_concept_sets
        required_labs.load_measurement_concept_sets_table(
            project_id=self.project_id, dataset_id=self.dataset_id)
        # Load measurement_concept_sets_descendants
        required_labs.load_measurement_concept_sets_descendants_table(
            project_id=self.project_id, dataset_id=self.dataset_id)

        # we need to load measurement.csv into bigquery_dataset_id in advance for the other integration tests
        ehr_measurement_result = bq_utils.load_table_from_csv(
            project_id=self.project_id,
            dataset_id=self.dataset_id,
            table_name=bq_utils.get_table_id(FAKE_HPO_ID, common.MEASUREMENT),
            csv_path=test_util.FIVE_PERSONS_MEASUREMENT_CSV,
            fields=resources.fields_for(common.MEASUREMENT))
        bq_utils.wait_on_jobs(
            [ehr_measurement_result['jobReference']['jobId']])

    def test_check_and_copy_tables(self):
        """
        Test to ensure all the necessary tables for required_labs.py are copied and or created
        """
        # Preconditions
        descendants_table_name = f'{self.project_id}.{self.dataset_id}.{MEASUREMENT_CONCEPT_SETS_DESCENDANTS_TABLE}'
        concept_sets_table_name = f'{self.project_id}.{self.dataset_id}.{MEASUREMENT_CONCEPT_SETS_TABLE}'
        concept_table_name = f'{self.project_id}.{self.dataset_id}.{common.CONCEPT}'
        concept_ancestor_table_name = f'{self.project_id}.{self.dataset_id}.{common.CONCEPT_ANCESTOR}'

        actual_descendants_table = self.client.get_table(
            descendants_table_name)
        actual_concept_sets_table = self.client.get_table(
            concept_sets_table_name)
        actual_concept_table = self.client.get_table(concept_table_name)
        actual_concept_ancestor_table = self.client.get_table(
            concept_ancestor_table_name)

        # Test
        required_labs.check_and_copy_tables(self.project_id, self.dataset_id)

        # Post conditions
        self.assertIsNotNone(actual_descendants_table.created)
        self.assertIsNotNone(actual_concept_sets_table.created)
        self.assertIsNotNone(actual_concept_table.created)
        self.assertIsNotNone(actual_concept_ancestor_table.created)

    def test_measurement_concept_sets_table(self):

        query = sql_wrangle.qualify_tables(
            '''SELECT * FROM {dataset_id}.{table_id}'''.format(
                dataset_id=self.dataset_id,
                table_id=MEASUREMENT_CONCEPT_SETS_TABLE))
        response = bq_utils.query(query)

        actual_fields = [{
            'name': field['name'].lower(),
            'type': field['type'].lower()
        } for field in response['schema']['fields']]

        expected_fields = [{
            'name': field['name'].lower(),
            'type': field['type'].lower()
        } for field in resources.fields_for(MEASUREMENT_CONCEPT_SETS_TABLE)]

        self.assertListEqual(expected_fields, actual_fields)

        measurement_concept_sets_table_path = os.path.join(
            resources.resource_files_path,
            MEASUREMENT_CONCEPT_SETS_TABLE + '.csv')
        expected_total_rows = len(
            resources.csv_to_list(measurement_concept_sets_table_path))
        self.assertEqual(expected_total_rows, int(response['totalRows']))

    def test_load_measurement_concept_sets_descendants_table(self):

        query = sql_wrangle.qualify_tables(
            """SELECT * FROM {dataset_id}.{table_id}""".format(
                dataset_id=self.dataset_id,
                table_id=MEASUREMENT_CONCEPT_SETS_DESCENDANTS_TABLE))
        response = bq_utils.query(query)

        actual_fields = [{
            'name': field['name'].lower(),
            'type': field['type'].lower()
        } for field in response['schema']['fields']]

        expected_fields = [{
            'name': field['name'].lower(),
            'type': field['type'].lower()
        } for field in resources.fields_for(
            MEASUREMENT_CONCEPT_SETS_DESCENDANTS_TABLE)]

        self.assertListEqual(expected_fields, actual_fields)

    def test_get_lab_concept_summary_query(self):
        summary_query = required_labs.get_lab_concept_summary_query(
            FAKE_HPO_ID)
        summary_response = bq_utils.query(summary_query)
        summary_rows = bq_utils.response2rows(summary_response)
        submitted_labs = [
            row for row in summary_rows
            if row['measurement_concept_id_exists'] == 1
        ]
        actual_total_labs = summary_response['totalRows']

        # Count the total number of labs required, this number should be equal to the total number of rows in the
        # results generated by get_lab_concept_summary_query including the submitted and missing labs.
        unique_ancestor_concept_query = sql_wrangle.qualify_tables(
            """SELECT DISTINCT ancestor_concept_id FROM `{project_id}.{dataset_id}.{table_id}`"""
            .format(project_id=self.project_id,
                    dataset_id=self.dataset_id,
                    table_id=MEASUREMENT_CONCEPT_SETS_DESCENDANTS_TABLE))
        unique_ancestor_cocnept_response = bq_utils.query(
            unique_ancestor_concept_query)
        expected_total_labs = unique_ancestor_cocnept_response['totalRows']

        # Count the number of labs in the measurement table, this number should be equal to the number of labs
        # submitted by the fake site
        unique_measurement_concept_id_query = '''
                SELECT
                  DISTINCT c.ancestor_concept_id
                FROM
                  `{project_id}.{dataset_id}.{measurement_concept_sets_descendants}` AS c
                JOIN
                  `{project_id}.{dataset_id}.{measurement}` AS m
                ON
                  c.descendant_concept_id = m.measurement_concept_id
                '''.format(project_id=self.project_id,
                           dataset_id=self.dataset_id,
                           measurement_concept_sets_descendants=
                           MEASUREMENT_CONCEPT_SETS_DESCENDANTS_TABLE,
                           measurement=bq_utils.get_table_id(
                               FAKE_HPO_ID, common.MEASUREMENT))

        unique_measurement_concept_id_response = bq_utils.query(
            unique_measurement_concept_id_query)
        unique_measurement_concept_id_total_labs = unique_measurement_concept_id_response[
            'totalRows']

        self.assertEqual(int(expected_total_labs),
                         int(actual_total_labs),
                         msg='Compare the total number of labs')
        self.assertEqual(int(unique_measurement_concept_id_total_labs),
                         len(submitted_labs),
                         msg='Compare the number '
                         'of labs submitted '
                         'in the measurement')