def test_get_all_table_metadata_from_information_schema(
            self, mock_settings) -> None:
        self.engine.init(self.conf)
        self.engine.execute = MagicMock(
            side_effect=presto_engine_execute_side_effect
        )

        expected = TableMetadata(
                database=MOCK_DATABASE_NAME,
                cluster=MOCK_CLUSTER_NAME,
                schema=MOCK_SCHEMA_NAME,
                name=MOCK_TABLE_NAME,
                columns=[
                    ColumnMetadata(
                        name=MOCK_INFORMATION_SCHEMA_RESULT_1['col_name'],
                        description=MOCK_INFORMATION_SCHEMA_RESULT_1['col_description'],
                        col_type=MOCK_INFORMATION_SCHEMA_RESULT_1['col_type'],
                        sort_order=MOCK_INFORMATION_SCHEMA_RESULT_1['col_sort_order'],
                        is_partition_column=None
                    ),
                    ColumnMetadata(
                        name=MOCK_INFORMATION_SCHEMA_RESULT_2['col_name'],
                        description=MOCK_INFORMATION_SCHEMA_RESULT_2['col_description'],
                        col_type=MOCK_INFORMATION_SCHEMA_RESULT_2['col_type'],
                        sort_order=MOCK_INFORMATION_SCHEMA_RESULT_2['col_sort_order'],
                        is_partition_column=None
                        )],
                is_view=bool(MOCK_INFORMATION_SCHEMA_RESULT_1['is_view']),
        )
        results = self.engine.get_all_table_metadata_from_information_schema(
            cluster=MOCK_CLUSTER_NAME)
        result = next(results)
        self.maxDiff = None
        self.assertEqual(result.__repr__(), expected.__repr__())
    def test_table_metadata_extraction_with_single_result(
            self, mock_settings) -> None:
        extractor = PrestoLoopExtractor()
        conf = self.conf.copy()
        conf.put('is_table_metadata_enabled', True)
        extractor.init(conf)
        extractor.execute = MagicMock(
            side_effect=presto_engine_execute_side_effect)

        results = extractor.extract()
        is_partition_column = True \
            if MOCK_COLUMN_RESULT[2] == 'partition key' \
            else False
        expected = TableMetadata(
            database=extractor._database,
            cluster=None,
            schema=MOCK_SCHEMA_NAME,
            name=MOCK_TABLE_NAME,
            columns=[
                ColumnMetadata(name=MOCK_COLUMN_RESULT[0],
                               description=MOCK_COLUMN_RESULT[3],
                               col_type=MOCK_COLUMN_RESULT[1],
                               sort_order=0,
                               is_partition_column=is_partition_column)
            ])
        self.assertEqual(results.__repr__(), expected.__repr__())
Example #3
0
    def get_all_table_metadata_from_information_schema(
        self,
        cluster: Optional[str] = None,
        where_clause_suffix='',
    ):

        unformatted_query = """
        SELECT
          a.table_catalog AS catalog
          , a.table_schema AS schema
          , a.table_name AS name
          , NULL AS description
          , a.column_name AS col_name
          , a.ordinal_position as col_sort_order
          , IF(a.extra_info = 'partition key', 1, 0) AS is_partition_col
          , a.comment AS col_description
          , a.data_type AS col_type
          , IF(b.table_name is not null, 1, 0) AS is_view
        FROM {cluster_prefix}information_schema.columns a
        LEFT JOIN {cluster_prefix}information_schema.views b
            ON a.table_catalog = b.table_catalog
            and a.table_schema = b.table_schema
            and a.table_name = b.table_name
        {where_clause_suffix}
        """

        LOGGER.info('Pulling all table metadata in bulk from' +
                    'information_schema in cluster name: {}'.format(cluster))

        if cluster is not None:
            cluster_prefix = cluster + '.'
        else:
            cluster_prefix = ''

        formatted_query = unformatted_query.format(
            cluster_prefix=cluster_prefix,
            where_clause_suffix=where_clause_suffix)

        LOGGER.info('SQL for presto: {}'.format(formatted_query))

        query_results = self.execute(formatted_query,
                                     is_dict_return_enabled=True)

        for _, group in groupby(query_results, self._get_table_key):
            columns = []
            for row in group:
                last_row = row
                columns.append(
                    ColumnMetadata(row['col_name'], row['col_description'],
                                   row['col_type'], row['col_sort_order']))

            yield TableMetadata(
                self._database,
                cluster or self._default_cluster_name,
                last_row['schema'],
                last_row['name'],
                last_row['description'],
                columns,
                is_view=bool(last_row['is_view']),
            )
Example #4
0
    def test_transformed_record_contains_components(self):
        """
        """
        column = ColumnMetadata(
            name=COLUMN,
            col_type=int,
            sort_order=0,
            description=COLUMN_DESCRIPTION,
        )
        record = TableMetadata(database=DATABASE,
                               cluster=CLUSTER,
                               schema=SCHEMA,
                               name=TABLE,
                               columns=[column])
        components = [
            DATABASE,
            CLUSTER,
            SCHEMA,
            TABLE,
            COLUMN,
            COLUMN_DESCRIPTION,
        ]
        transformer = MarkdownTransformer()
        transformer.init(self._conf)
        transformed_record = transformer.transform(record)
        markdown_blob = transformed_record.markdown_blob
        transformer.close()

        has_components = \
            all(x in markdown_blob for x in components)

        self.assertEqual(has_components, True)
Example #5
0
    def _retrieve_tables(self, dataset) -> Any:
        for page in self._page_table_list_results(dataset):
            if 'tables' not in page:
                continue

            for table in page['tables']:
                tableRef = table['tableReference']
                table_id = tableRef['tableId']

                # BigQuery tables that have 8 digits as last characters are
                # considered date range tables and are grouped together in the
                # UI. (e.g. ga_sessions_20190101, ga_sessions_20190102, etc.)

                if self._is_sharded_table(table_id):
                    # If the last eight characters are digits, we assume the
                    # table is of a table date range type and then we only need
                    # one schema definition
                    table_prefix = \
                        table_id[:-BigQueryMetadataExtractor.DATE_LENGTH]
                    if table_prefix in self.grouped_tables:
                        # If one table in the date range is processed, then
                        # ignore other ones (it adds too much metadata)
                        continue

                    table_id = table_prefix
                    self.grouped_tables.add(table_prefix)

                table = self.bigquery_service.tables().get(
                    projectId=tableRef['projectId'],
                    datasetId=tableRef['datasetId'],
                    tableId=tableRef['tableId']).execute(
                        num_retries=BigQueryMetadataExtractor.NUM_RETRIES)

                # BigQuery tables also have interesting metadata about
                # partitioning data location (EU/US), mod/create time, etc...
                # Extract that some other time?
                cols = []
                # Not all tables have schemas
                if 'schema' in table:
                    schema = table['schema']
                    if 'fields' in schema:
                        total_cols = 0
                        for column in schema['fields']:
                            total_cols = \
                                self._iterate_over_cols(
                                    '', column, cols, total_cols + 1)

                table_meta = TableMetadata(database='bigquery',
                                           cluster=tableRef['projectId'],
                                           schema=tableRef['datasetId'],
                                           name=table_id,
                                           description=table.get(
                                               'description', ''),
                                           columns=cols,
                                           is_view=table['type'] == 'VIEW')

                yield (table_meta)
Example #6
0
    def get_table_metadata(self,
                           schema: str,
                           table: str,
                           cluster: Optional[str] = None,
                           is_view_query_enabled: Optional[bool] = False):
        # Format table and schema addresses for queries.
        full_schema_address = self._get_full_schema_address(cluster, schema)
        full_table_address = '{}.{}'.format(full_schema_address, table)

        # Execute query that gets column type + partition information.
        columns_query = 'show columns in {}'.format(full_table_address)
        column_query_results = self.execute(columns_query, has_header=True)
        column_query_field_names = next(column_query_results)
        columns = []
        for i, column_query_result in enumerate(column_query_results):
            column_dict = \
                    dict(zip(column_query_field_names, column_query_result))
            columns.append(
                ColumnMetadata(
                    name=column_dict['Column'],
                    description=column_dict['Comment'],
                    col_type=column_dict['Type'],
                    sort_order=i,
                    is_partition_column=column_dict['Extra'] ==
                    'partition key',
                ))

        if is_view_query_enabled:
            # Execute query that returns if table is a view.
            view_query = """
                select table_type
                from information_schema.tables
                where table_schema='{table_schema}'
                  and table_name='{table_name}'
                """.format(table_schema=schema, table_name=table)
            view_query_results = self.execute(view_query, has_header=False)
            is_view = next(view_query_results)[0] == 'VIEW'
        else:
            is_view = False

        return TableMetadata(
            database=self._database,
            cluster=cluster,
            schema=schema,
            name=table,
            description=None,
            columns=columns,
            is_view=is_view,
        )
Example #7
0
    def _get_extract_iter(self):
        with self.driver.session() as session:
            if not hasattr(self, 'results'):
                self.results = session.read_transaction(self._execute_query)

            for result in self.results:
                # Parse watermark information.
                partition_columns = []
                for watermark in result['watermarks']:
                    partition_columns.append(watermark['partition_key'])

                # Parse column information.
                column_names = result['column_names']
                column_descriptions = result['column_descriptions']
                column_types = result['column_types']
                column_sort_orders = result['column_sort_orders']
                zipped_columns = zip_longest(column_names, column_descriptions,
                                             column_types, column_sort_orders)

                column_metadatas = []
                for column_name, \
                        column_description, \
                        column_type, \
                        column_sort_order \
                        in zipped_columns:
                    if column_name in partition_columns:
                        is_partition_column = True
                    else:
                        is_partition_column = False
                    column_metadatas.append(
                        ColumnMetadata(
                            name=column_name,
                            description=column_description,
                            col_type=column_type,
                            sort_order=column_sort_order,
                            is_partition_column=is_partition_column,
                        ))

                yield TableMetadata(
                    database=result['database'],
                    cluster=result['cluster'],
                    schema=result['schema'],
                    name=result['name'],
                    description=result['description'],
                    columns=column_metadatas,
                    is_view=result['is_view'],
                    tags=result['tags'],
                )
    def test_load_no_cluster(self):
        record = TableMetadata(
            database='mock_database',
            cluster=None,
            schema='mock_schema',
            name='mock_table',
            markdown_blob='Test',
        )
        loader = MetaframeLoader()
        loader.init(self._conf)
        loader.load(record)

        loader.close()
        file_path = './.test_artifacts/mock_database/mock_schema.mock_table.md'
        with open(file_path, 'r') as f:
            written_record = f.read()
        print(written_record)

        self.assertEqual(written_record, record.markdown_blob)