示例#1
0
    def test_runs_for_hive_stats(self, mock_hive_metastore_hook):
        mock_mysql_hook = MockMySqlHook()
        mock_presto_hook = MockPrestoHook()
        with patch(
                'airflow.providers.apache.hive.operators.hive_stats.PrestoHook',
                return_value=mock_presto_hook):
            with patch(
                    'airflow.providers.apache.hive.operators.hive_stats.MySqlHook',
                    return_value=mock_mysql_hook):
                op = HiveStatsCollectionOperator(
                    task_id='hive_stats_check',
                    table="airflow.static_babynames_partitioned",
                    partition={'ds': DEFAULT_DATE_DS},
                    dag=self.dag,
                )
                op.run(start_date=DEFAULT_DATE,
                       end_date=DEFAULT_DATE,
                       ignore_ti_state=True)

        select_count_query = (
            "SELECT COUNT(*) AS __count FROM airflow." +
            "static_babynames_partitioned WHERE ds = '2015-01-01';")
        mock_presto_hook.get_first.assert_called_with(hql=select_count_query)

        expected_stats_select_query = (
            "SELECT 1 FROM hive_stats WHERE table_name='airflow." +
            "static_babynames_partitioned' AND " +
            "partition_repr='{\"ds\": \"2015-01-01\"}' AND " +
            "dttm='2015-01-01T00:00:00+00:00' " + "LIMIT 1;")

        raw_stats_select_query = mock_mysql_hook.get_records.call_args_list[0][
            0][0]
        actual_stats_select_query = re.sub(r'\s{2,}', ' ',
                                           raw_stats_select_query).strip()

        assert expected_stats_select_query == actual_stats_select_query

        insert_rows_val = [(
            '2015-01-01',
            '2015-01-01T00:00:00+00:00',
            'airflow.static_babynames_partitioned',
            '{"ds": "2015-01-01"}',
            '',
            'count',
            ['val_0', 'val_1'],
        )]

        mock_mysql_hook.insert_rows.assert_called_with(
            table='hive_stats',
            rows=insert_rows_val,
            target_fields=[
                'ds',
                'dttm',
                'table_name',
                'partition_repr',
                'col',
                'metric',
                'value',
            ],
        )
示例#2
0
    def test_execute_delete_previous_runs_rows(self, mock_hive_metastore_hook,
                                               mock_presto_hook,
                                               mock_mysql_hook,
                                               mock_json_dumps):
        mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [
            fake_col
        ]
        mock_mysql_hook.return_value.get_records.return_value = True

        hive_stats_collection_operator = HiveStatsCollectionOperator(
            **self.kwargs)
        hive_stats_collection_operator.execute(context={})

        sql = """
            DELETE FROM hive_stats
            WHERE
                table_name='{}' AND
                partition_repr='{}' AND
                dttm='{}';
            """.format(
            hive_stats_collection_operator.table,
            mock_json_dumps.return_value,
            hive_stats_collection_operator.dttm,
        )
        mock_mysql_hook.return_value.run.assert_called_once_with(sql)
示例#3
0
 def test_hive_stats(self):
     op = HiveStatsCollectionOperator(
         task_id='hive_stats_check',
         table="airflow.static_babynames_partitioned",
         partition={'ds': DEFAULT_DATE_DS},
         dag=self.dag)
     op.run(start_date=DEFAULT_DATE,
            end_date=DEFAULT_DATE,
            ignore_ti_state=True)
示例#4
0
    def test_execute(self, mock_hive_metastore_hook, mock_presto_hook,
                     mock_mysql_hook, mock_json_dumps):
        mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [
            fake_col
        ]
        mock_mysql_hook.return_value.get_records.return_value = False

        hive_stats_collection_operator = HiveStatsCollectionOperator(
            **self.kwargs)
        hive_stats_collection_operator.execute(context={})

        mock_hive_metastore_hook.assert_called_once_with(
            metastore_conn_id=hive_stats_collection_operator.metastore_conn_id)
        mock_hive_metastore_hook.return_value.get_table.assert_called_once_with(
            table_name=hive_stats_collection_operator.table)
        mock_presto_hook.assert_called_once_with(
            presto_conn_id=hive_stats_collection_operator.presto_conn_id)
        mock_mysql_hook.assert_called_once_with(
            hive_stats_collection_operator.mysql_conn_id)
        mock_json_dumps.assert_called_once_with(
            hive_stats_collection_operator.partition, sort_keys=True)
        field_types = {
            col.name: col.type
            for col in mock_hive_metastore_hook.return_value.get_table.
            return_value.sd.cols
        }
        exprs = {('', 'count'): 'COUNT(*)'}
        for col, col_type in list(field_types.items()):
            exprs.update(
                hive_stats_collection_operator.get_default_exprs(
                    col, col_type))
        exprs = OrderedDict(exprs)
        rows = [
            (hive_stats_collection_operator.ds,
             hive_stats_collection_operator.dttm,
             hive_stats_collection_operator.table,
             mock_json_dumps.return_value) + (r[0][0], r[0][1], r[1])
            for r in zip(exprs,
                         mock_presto_hook.return_value.get_first.return_value)
        ]
        mock_mysql_hook.return_value.insert_rows.assert_called_once_with(
            table='hive_stats',
            rows=rows,
            target_fields=[
                'ds',
                'dttm',
                'table_name',
                'partition_repr',
                'col',
                'metric',
                'value',
            ])
示例#5
0
    def test_get_default_exprs(self):
        col = 'col'

        default_exprs = HiveStatsCollectionOperator(
            **self.kwargs).get_default_exprs(col, None)

        assert default_exprs == {(col, 'non_null'): f'COUNT({col})'}
示例#6
0
    def test_get_default_exprs_excluded_cols(self):
        col = 'excluded_col'
        self.kwargs.update(dict(excluded_columns=[col]))

        default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, None)

        self.assertEqual(default_exprs, {})
示例#7
0
    def test_get_default_exprs_blacklist(self):
        col = 'blacklisted_col'
        self.kwargs.update(dict(col_blacklist=[col]))

        default_exprs = HiveStatsCollectionOperator(
            **self.kwargs).get_default_exprs(col, None)

        self.assertEqual(default_exprs, {})
示例#8
0
    def test_get_default_exprs(self):
        col = 'col'

        default_exprs = HiveStatsCollectionOperator(
            **self.kwargs).get_default_exprs(col, None)

        self.assertEqual(default_exprs,
                         {(col, 'non_null'): 'COUNT({})'.format(col)})
示例#9
0
    def test_execute_with_assignment_func(self, mock_hive_metastore_hook,
                                          mock_presto_hook, mock_mysql_hook,
                                          mock_json_dumps):
        def assignment_func(col, _):
            return {(col, 'test'): 'TEST({})'.format(col)}

        self.kwargs.update(dict(assignment_func=assignment_func))
        mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [
            fake_col
        ]
        mock_mysql_hook.return_value.get_records.return_value = False

        hive_stats_collection_operator = HiveStatsCollectionOperator(
            **self.kwargs)
        hive_stats_collection_operator.execute(context={})

        field_types = {
            col.name: col.type
            for col in mock_hive_metastore_hook.return_value.get_table.
            return_value.sd.cols
        }
        exprs = {('', 'count'): 'COUNT(*)'}
        for col, col_type in list(field_types.items()):
            exprs.update(
                hive_stats_collection_operator.assignment_func(col, col_type))
        exprs = OrderedDict(exprs)
        rows = [
            (hive_stats_collection_operator.ds,
             hive_stats_collection_operator.dttm,
             hive_stats_collection_operator.table,
             mock_json_dumps.return_value) + (r[0][0], r[0][1], r[1])
            for r in zip(exprs,
                         mock_presto_hook.return_value.get_first.return_value)
        ]
        mock_mysql_hook.return_value.insert_rows.assert_called_once_with(
            table='hive_stats',
            rows=rows,
            target_fields=[
                'ds',
                'dttm',
                'table_name',
                'partition_repr',
                'col',
                'metric',
                'value',
            ])
示例#10
0
    def test_execute_no_query_results(self, mock_hive_metastore_hook,
                                      mock_presto_hook, mock_mysql_hook):
        mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [
            fake_col
        ]
        mock_mysql_hook.return_value.get_records.return_value = False
        mock_presto_hook.return_value.get_first.return_value = None

        with pytest.raises(AirflowException):
            HiveStatsCollectionOperator(**self.kwargs).execute(context={})
示例#11
0
    def test_get_default_exprs_string(self):
        col = 'col'
        col_type = 'string'

        default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type)

        self.assertEqual(default_exprs, {
            (col, 'approx_distinct'): 'APPROX_DISTINCT({})'.format(col),
            (col, 'len'): 'SUM(CAST(LENGTH({}) AS BIGINT))'.format(col),
            (col, 'non_null'): 'COUNT({})'.format(col)
        })
示例#12
0
    def test_get_default_exprs_boolean(self):
        col = 'col'
        col_type = 'boolean'

        default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type)

        self.assertEqual(default_exprs, {
            (col, 'false'): 'SUM(CASE WHEN NOT {} THEN 1 ELSE 0 END)'.format(col),
            (col, 'non_null'): 'COUNT({})'.format(col),
            (col, 'true'): 'SUM(CASE WHEN {} THEN 1 ELSE 0 END)'.format(col)
        })
示例#13
0
    def test_get_default_exprs_string(self):
        col = 'col'
        col_type = 'string'

        default_exprs = HiveStatsCollectionOperator(
            **self.kwargs).get_default_exprs(col, col_type)

        assert default_exprs == {
            (col, 'approx_distinct'): f'APPROX_DISTINCT({col})',
            (col, 'len'): f'SUM(CAST(LENGTH({col}) AS BIGINT))',
            (col, 'non_null'): f'COUNT({col})',
        }
示例#14
0
    def test_get_default_exprs_boolean(self):
        col = 'col'
        col_type = 'boolean'

        default_exprs = HiveStatsCollectionOperator(
            **self.kwargs).get_default_exprs(col, col_type)

        assert default_exprs == {
            (col, 'false'): f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)',
            (col, 'non_null'): f'COUNT({col})',
            (col, 'true'): f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)',
        }
示例#15
0
    def test_get_default_exprs_number(self):
        col = 'col'
        for col_type in ['double', 'int', 'bigint', 'float']:
            default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type)

            self.assertEqual(default_exprs, {
                (col, 'avg'): 'AVG({})'.format(col),
                (col, 'max'): 'MAX({})'.format(col),
                (col, 'min'): 'MIN({})'.format(col),
                (col, 'non_null'): 'COUNT({})'.format(col),
                (col, 'sum'): 'SUM({})'.format(col)
            })
示例#16
0
    def test_get_default_exprs_number(self):
        col = 'col'
        for col_type in ['double', 'int', 'bigint', 'float']:
            default_exprs = HiveStatsCollectionOperator(
                **self.kwargs).get_default_exprs(col, col_type)

            assert default_exprs == {
                (col, 'avg'): f'AVG({col})',
                (col, 'max'): f'MAX({col})',
                (col, 'min'): f'MIN({col})',
                (col, 'non_null'): f'COUNT({col})',
                (col, 'sum'): f'SUM({col})',
            }