def test_alias_double_inner(self):
        # type: () -> None
        query = """\
SELECT cluster AS cluster,
       date_trunc('day', CAST(ds AS TIMESTAMP)) AS __timestamp,
       sum(p90_time) AS sum__p90_time
FROM
  (select ds,
          cluster,
          approx_percentile(latency_mins, .50) as p50_time,
          approx_percentile(latency_mins, .60) as p60_time,
          approx_percentile(latency_mins, .70) as p70_time,
          approx_percentile(latency_mins, .75) as p75_time,
          approx_percentile(latency_mins, .80) as p80_time,
          approx_percentile(latency_mins, .90) as p90_time,
          approx_percentile(latency_mins, .95) as p95_time
   from
     (SELECT ds,
             cluster_name as cluster,
             query_id,
             date_diff('second',query_starttime,query_endtime)/60.0 as latency_mins
      FROM etl.hive_query_logs
      WHERE date(ds) > date_add('day', -60, current_date)
        AND environment = 'production'
        AND operation_name = 'QUERY' )
   group by ds,
            cluster
   order by ds) AS expr_qry
WHERE ds >= '2018-03-30 00:00:00'
  AND ds <= '2018-05-29 23:29:57'
GROUP BY cluster,
         date_trunc('day', CAST(ds AS TIMESTAMP))
ORDER BY sum__p90_time DESC
LIMIT 5000
"""
        actual = ColumnUsageProvider.get_columns(query)
        expected = [
            Column(name='CLUSTER_NAME',
                   table=Table(name='HIVE_QUERY_LOGS',
                               schema='ETL',
                               alias='EXPR_QRY'),
                   col_alias='CLUSTER'),
            Column(name='DS',
                   table=OrTable(tables=[
                       Table(name='HIVE_QUERY_LOGS',
                             schema='ETL',
                             alias='EXPR_QRY'), None
                   ]),
                   col_alias='__TIMESTAMP'),
            Column(name='QUERY_ENDTIME',
                   table=Table(name='HIVE_QUERY_LOGS',
                               schema='ETL',
                               alias='EXPR_QRY'),
                   col_alias='SUM__P90_TIME')
        ]
        self.assertEqual(expected.__repr__(), actual.__repr__())
 def test_inner_sql_col_alias(self):
     # type: () -> None
     query = 'SELECT TMP1.A, F FROM (SELECT A, B AS F, C FROM FOOBAR) AS TMP1'
     actual = ColumnUsageProvider.get_columns(query)
     expected = [
         Column(name='A',
                table=Table(name='FOOBAR', schema=None, alias='TMP1'),
                col_alias=None),
         Column(name='B',
                table=Table(name='FOOBAR', schema=None, alias='TMP1'),
                col_alias='F')
     ]
     self.assertEqual(expected.__repr__(), actual.__repr__())
 def test_join_with_alias(self):
     # type: () -> None
     query = 'SELECT FOO.A, BAR.B FROM FOOTABLE AS FOO JOIN BARTABLE AS BAR ON FOO.A = BAR.A'
     actual = ColumnUsageProvider.get_columns(query)
     expected = [
         Column(name='A',
                table=Table(name='FOOTABLE', schema=None, alias='FOO'),
                col_alias=None),
         Column(name='B',
                table=Table(name='BARTABLE', schema=None, alias='BAR'),
                col_alias=None)
     ]
     self.assertEqual(expected.__repr__(), actual.__repr__())
    def test_with_schema(self):
        # type: () -> None
        query = 'SELECT foo, bar FROM scm.foobar;'

        actual = ColumnUsageProvider.get_columns(query)
        expected = [
            Column(name='FOO',
                   table=Table(name='FOOBAR', schema='SCM', alias=None),
                   col_alias=None),
            Column(name='BAR',
                   table=Table(name='FOOBAR', schema='SCM', alias=None),
                   col_alias=None)
        ]
        self.assertEqual(expected.__repr__(), actual.__repr__())
 def test_inner_sql_table_alias(self):
     # type: () -> None
     query = """
     SELECT col1, temp.col2 FROM (SELECT col1, col2, col3 FROM foobar) as temp
     """
     actual = ColumnUsageProvider.get_columns(query)
     expected = [
         Column(name='COL1',
                table=Table(name='FOOBAR', schema=None, alias='TEMP'),
                col_alias=None),
         Column(name='COL2',
                table=Table(name='FOOBAR', schema=None, alias='TEMP'),
                col_alias=None)
     ]
     self.assertEqual(expected.__repr__(), actual.__repr__())
    def exitQuerySpecification(
            self,
            ctx  # type: SqlBaseParser.QuerySpecificationContext
    ):
        # type: (...) -> None
        """
        Call back method for Query specification. It merges processing
        columns with processed column
        :param ctx:
        :return:
        """

        if LOGGER.isEnabledFor(logging.DEBUG):
            LOGGER.debug('processing_cols: {}'.format(self._processing_cols))
            LOGGER.debug('processed_cols: {}'.format(self.processed_cols))

        result = []

        for col in self._processing_cols:
            for resolved in Column.resolve(col, self.processed_cols):
                result.append(resolved)

        self.processed_cols = result
        self._processing_cols = []
        if self._stack:
            self._processing_cols = self._stack.pop()

        self._current_col = None

        if LOGGER.isEnabledFor(logging.DEBUG):
            LOGGER.debug('done processing_cols: {}'.format(
                self._processing_cols))
            LOGGER.debug('done processed_cols: {}'.format(self.processed_cols))
 def exitColumnReference(
         self,
         ctx  # type: SqlBaseParser.ColumnReferenceContext
 ):
     # type: (...) -> None
     """
     Call back method for column that does not have table indicator
     :param ctx:
     :return:
     """
     self._current_col = Column(ctx.getText())
    def test_join(self):
        # type: () -> None
        query = 'SELECT A, B FROM scm.FOO JOIN BAR ON FOO.A = BAR.B'
        actual = ColumnUsageProvider.get_columns(query)
        expected = [
            Column(name='A',
                   table=OrTable(tables=[
                       Table(name='FOO', schema='SCM', alias=None),
                       Table(name='BAR', schema=None, alias=None)
                   ]),
                   col_alias=None),
            Column(name='B',
                   table=OrTable(tables=[
                       Table(name='FOO', schema='SCM', alias=None),
                       Table(name='BAR', schema=None, alias=None)
                   ]),
                   col_alias=None)
        ]

        self.assertEqual(expected.__repr__(), actual.__repr__())
 def exitDereference(
         self,
         ctx  # type: SqlBaseParser.DereferenceContext
 ):
     # type: (...) -> None
     """
     Call back method for column with table indicator e.g: foo.bar
     :param ctx:
     :return:
     """
     self._current_col = Column(ctx.identifier().getText(),
                                table=Table(ctx.base.getText()))
 def test_table_alias(self):
     # type: () -> None
     query = """
     SELECT A.*  FROM FACT_RIDES A LEFT JOIN DIM_VEHICLES B ON A.VEHICLE_KEY = B.VEHICLE_KEY
     WHERE B.RENTAL_PROVIDER_ID IS NOT NULL  LIMIT 100
     """
     actual = ColumnUsageProvider.get_columns(query)
     expected = [
         Column(name='*',
                table=Table(name='FACT_RIDES', schema=None, alias='A'),
                col_alias=None)
     ]
     self.assertEqual(expected.__repr__(), actual.__repr__())
    def exitSelectAll(
            self,
            ctx  # type: SqlBaseParser.SelectAllContext
    ):
        # type: (...) -> None
        """
        Call back method for select ALL column.
        :param ctx:
        :return:
        """
        self._current_col = Column('*')
        if ctx.qualifiedName():
            self._current_col.table = Table(ctx.qualifiedName().getText())

        self._processing_cols.append(self._current_col)
        self._current_col = None
    def exitTableName(
            self,
            ctx  # type: SqlBaseParser.TableNameContext
    ):
        # type: (...) -> None
        """
        Call back method for table name
        :param ctx:
        :return:
        """
        table_name = ctx.getText()
        table = Table(table_name)
        if '.' in table_name:
            db_tbl = table_name.split('.')
            table = Table(db_tbl[len(db_tbl) - 1],
                          schema=db_tbl[len(db_tbl) - 2])

        self._current_col = Column('*', table=table)