def test_hive_stats(self): op = HiveStatsCollectionOperator( task_id='hive_stats_check', table="airflow.static_babynames_partitioned", partition={'ds': DEFAULT_DATE_DS}, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_execute(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps): mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [ fake_col ] mock_mysql_hook.return_value.get_records.return_value = False hive_stats_collection_operator = HiveStatsCollectionOperator( **self.kwargs) hive_stats_collection_operator.execute(context={}) mock_hive_metastore_hook.assert_called_once_with( metastore_conn_id=hive_stats_collection_operator.metastore_conn_id) mock_hive_metastore_hook.return_value.get_table.assert_called_once_with( table_name=hive_stats_collection_operator.table) mock_presto_hook.assert_called_once_with( presto_conn_id=hive_stats_collection_operator.presto_conn_id) mock_mysql_hook.assert_called_once_with( hive_stats_collection_operator.mysql_conn_id) mock_json_dumps.assert_called_once_with( hive_stats_collection_operator.partition, sort_keys=True) field_types = { col.name: col.type for col in mock_hive_metastore_hook.return_value.get_table. return_value.sd.cols } exprs = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): exprs.update( hive_stats_collection_operator.get_default_exprs( col, col_type)) exprs = OrderedDict(exprs) rows = [ (hive_stats_collection_operator.ds, hive_stats_collection_operator.dttm, hive_stats_collection_operator.table, mock_json_dumps.return_value) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value) ] mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ])
def test_get_default_exprs_excluded_cols(self): col = 'excluded_col' self.kwargs.update(dict(excluded_columns=[col])) default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, None) self.assertEqual(default_exprs, {})
def test_get_default_exprs_blacklist(self): col = 'blacklisted_col' self.kwargs.update(dict(col_blacklist=[col])) default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, None) self.assertEqual(default_exprs, {})
def test_get_default_exprs(self): col = 'col' default_exprs = HiveStatsCollectionOperator( **self.kwargs).get_default_exprs(col, None) self.assertEqual(default_exprs, {(col, 'non_null'): 'COUNT({})'.format(col)})
def test_execute_with_assignment_func(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps): def assignment_func(col, _): return {(col, 'test'): 'TEST({})'.format(col)} self.kwargs.update(dict(assignment_func=assignment_func)) mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [ fake_col ] mock_mysql_hook.return_value.get_records.return_value = False hive_stats_collection_operator = HiveStatsCollectionOperator( **self.kwargs) hive_stats_collection_operator.execute(context={}) field_types = { col.name: col.type for col in mock_hive_metastore_hook.return_value.get_table. return_value.sd.cols } exprs = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): exprs.update( hive_stats_collection_operator.assignment_func(col, col_type)) exprs = OrderedDict(exprs) rows = [ (hive_stats_collection_operator.ds, hive_stats_collection_operator.dttm, hive_stats_collection_operator.table, mock_json_dumps.return_value) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value) ] mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ])
def test_runs_for_hive_stats(self, mock_hive_metastore_hook): mock_mysql_hook = MockMySqlHook() mock_presto_hook = MockPrestoHook() with patch('airflow.operators.hive_stats_operator.PrestoHook', return_value=mock_presto_hook): with patch('airflow.operators.hive_stats_operator.MySqlHook', return_value=mock_mysql_hook): op = HiveStatsCollectionOperator( task_id='hive_stats_check', table="airflow.static_babynames_partitioned", partition={'ds': DEFAULT_DATE_DS}, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) select_count_query = "SELECT COUNT(*) AS __count FROM airflow." \ + "static_babynames_partitioned WHERE ds = '2015-01-01';" mock_presto_hook.get_first.assert_called_with(hql=select_count_query) expected_stats_select_query = "SELECT 1 FROM hive_stats WHERE table_name='airflow." \ + "static_babynames_partitioned' AND " \ + "partition_repr='{\"ds\": \"2015-01-01\"}' AND " \ + "dttm='2015-01-01T00:00:00+00:00' " \ + "LIMIT 1;" raw_stats_select_query = mock_mysql_hook.get_records.call_args_list[0][0][0] actual_stats_select_query = re.sub(r'\s{2,}', ' ', raw_stats_select_query).strip() self.assertEqual(expected_stats_select_query, actual_stats_select_query) insert_rows_val = [('2015-01-01', '2015-01-01T00:00:00+00:00', 'airflow.static_babynames_partitioned', '{"ds": "2015-01-01"}', '', 'count', ['val_0', 'val_1'])] mock_mysql_hook.insert_rows.assert_called_with(table='hive_stats', rows=insert_rows_val, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ])
def test_execute_with_assignment_func(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps): def assignment_func(col, col_type): return { (col, 'test'): 'TEST({})'.format(col) } self.kwargs.update(dict(assignment_func=assignment_func)) mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [fake_col] mock_mysql_hook.return_value.get_records.return_value = False hive_stats_collection_operator = HiveStatsCollectionOperator(**self.kwargs) hive_stats_collection_operator.execute(context={}) field_types = { col.name: col.type for col in mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols } exprs = { ('', 'count'): 'COUNT(*)' } for col, col_type in list(field_types.items()): exprs.update(hive_stats_collection_operator.assignment_func(col, col_type)) exprs = OrderedDict(exprs) rows = [(hive_stats_collection_operator.ds, hive_stats_collection_operator.dttm, hive_stats_collection_operator.table, mock_json_dumps.return_value) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value)] mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ] )
def test_execute(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps): mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [fake_col] mock_mysql_hook.return_value.get_records.return_value = False hive_stats_collection_operator = HiveStatsCollectionOperator(**self.kwargs) hive_stats_collection_operator.execute(context={}) mock_hive_metastore_hook.assert_called_once_with( metastore_conn_id=hive_stats_collection_operator.metastore_conn_id) mock_hive_metastore_hook.return_value.get_table.assert_called_once_with( table_name=hive_stats_collection_operator.table) mock_presto_hook.assert_called_once_with(presto_conn_id=hive_stats_collection_operator.presto_conn_id) mock_mysql_hook.assert_called_once_with(hive_stats_collection_operator.mysql_conn_id) mock_json_dumps.assert_called_once_with(hive_stats_collection_operator.partition, sort_keys=True) field_types = { col.name: col.type for col in mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols } exprs = { ('', 'count'): 'COUNT(*)' } for col, col_type in list(field_types.items()): exprs.update(hive_stats_collection_operator.get_default_exprs(col, col_type)) exprs = OrderedDict(exprs) rows = [(hive_stats_collection_operator.ds, hive_stats_collection_operator.dttm, hive_stats_collection_operator.table, mock_json_dumps.return_value) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, mock_presto_hook.return_value.get_first.return_value)] mock_mysql_hook.return_value.insert_rows.assert_called_once_with( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ] )
def test_execute_no_query_results(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook): mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [ fake_col ] mock_mysql_hook.return_value.get_records.return_value = False mock_presto_hook.return_value.get_first.return_value = None self.assertRaises(AirflowException, HiveStatsCollectionOperator(**self.kwargs).execute, context={})
def test_get_default_exprs_string(self): col = 'col' col_type = 'string' default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type) self.assertEqual(default_exprs, { (col, 'approx_distinct'): 'APPROX_DISTINCT({})'.format(col), (col, 'len'): 'SUM(CAST(LENGTH({}) AS BIGINT))'.format(col), (col, 'non_null'): 'COUNT({})'.format(col) })
def test_get_default_exprs_boolean(self): col = 'col' col_type = 'boolean' default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type) self.assertEqual(default_exprs, { (col, 'false'): 'SUM(CASE WHEN NOT {} THEN 1 ELSE 0 END)'.format(col), (col, 'non_null'): 'COUNT({})'.format(col), (col, 'true'): 'SUM(CASE WHEN {} THEN 1 ELSE 0 END)'.format(col) })
def test_get_default_exprs_number(self): col = 'col' for col_type in ['double', 'int', 'bigint', 'float']: default_exprs = HiveStatsCollectionOperator(**self.kwargs).get_default_exprs(col, col_type) self.assertEqual(default_exprs, { (col, 'avg'): 'AVG({})'.format(col), (col, 'max'): 'MAX({})'.format(col), (col, 'min'): 'MIN({})'.format(col), (col, 'non_null'): 'COUNT({})'.format(col), (col, 'sum'): 'SUM({})'.format(col) })
def test_execute_delete_previous_runs_rows(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps): mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [ fake_col ] mock_mysql_hook.return_value.get_records.return_value = True hive_stats_collection_operator = HiveStatsCollectionOperator( **self.kwargs) hive_stats_collection_operator.execute(context={}) sql = """ DELETE FROM hive_stats WHERE table_name='{}' AND partition_repr='{}' AND dttm='{}'; """.format(hive_stats_collection_operator.table, mock_json_dumps.return_value, hive_stats_collection_operator.dttm) mock_mysql_hook.return_value.run.assert_called_once_with(sql)
def test_execute_delete_previous_runs_rows(self, mock_hive_metastore_hook, mock_presto_hook, mock_mysql_hook, mock_json_dumps): mock_hive_metastore_hook.return_value.get_table.return_value.sd.cols = [fake_col] mock_mysql_hook.return_value.get_records.return_value = True hive_stats_collection_operator = HiveStatsCollectionOperator(**self.kwargs) hive_stats_collection_operator.execute(context={}) sql = """ DELETE FROM hive_stats WHERE table_name='{}' AND partition_repr='{}' AND dttm='{}'; """.format( hive_stats_collection_operator.table, mock_json_dumps.return_value, hive_stats_collection_operator.dttm ) mock_mysql_hook.return_value.run.assert_called_once_with(sql)