def test_get_binded_column_check_table_alias_map(self): ctx = StatementBinderContext() mock_table_map = ctx._check_table_alias_map = MagicMock() mock_table_map.return_value = 'col_obj' result = ctx.get_binded_column('col_name', 'alias') mock_table_map.assert_called_once_with('alias', 'col_name') self.assertEqual(result, ('alias', 'col_obj'))
def test_get_binded_column_should_search_all(self): ctx = StatementBinderContext() mock_search_all = ctx._search_all_alias_maps = MagicMock() mock_search_all.return_value = ('alias', 'col_obj') result = ctx.get_binded_column('col_name') mock_search_all.assert_called_once_with('col_name') self.assertEqual(result, ('alias', 'col_obj'))
def test_add_table_alias(self, mock_catalog): mock_get = mock_catalog().get_dataset_metadata = MagicMock() mock_get.return_value = 'table_obj' ctx = StatementBinderContext() mock_check = ctx._check_duplicate_alias = MagicMock() ctx.add_table_alias('alias', 'table_name') mock_check.assert_called_with('alias') mock_get.assert_called_with(None, 'table_name') self.assertEqual(ctx._table_alias_map['alias'], 'table_obj')
def test_add_derived_table_alias(self): exprs = [MagicMock(spec=TupleValueExpression, col_object='A'), MagicMock(spec=FunctionExpression, output_objs=['B', 'C'])] ctx = StatementBinderContext() mock_check = ctx._check_duplicate_alias = MagicMock() ctx.add_derived_table_alias('alias', exprs) mock_check.assert_called_with('alias') self.assertEqual( ctx._derived_table_alias_map['alias'], ['A', 'B', 'C'])
def test_bind_select_statement_union_starts_new_context(self, mock_ctx): with patch.object(StatementBinder, 'bind'): binder = StatementBinder(StatementBinderContext()) select_statement = MagicMock() select_statement.union_link = None binder._bind_select_statement(select_statement) self.assertEqual(mock_ctx.call_count, 0) binder = StatementBinder(StatementBinderContext()) select_statement = MagicMock() binder._bind_select_statement(select_statement) self.assertEqual(mock_ctx.call_count, 1)
def test_bind_unknown_object(self): class UnknownType: pass with self.assertRaises(NotImplementedError): binder = StatementBinder(StatementBinderContext()) binder.bind(UnknownType())
def test_search_all_alias_raise_duplicate_error(self): with self.assertRaises(RuntimeError): ctx = StatementBinderContext() ctx._check_table_alias_map = MagicMock() ctx._check_derived_table_alias_map = MagicMock() # duplicate ctx._table_alias_map['alias'] = 'col_name' ctx._derived_table_alias_map['alias'] = 'col_name' ctx._search_all_alias_maps('col_name')
def extend_star(binder_context: StatementBinderContext) \ -> List[TupleValueExpression]: col_objs = binder_context._get_all_alias_and_col_name() target_list = list( [TupleValueExpression(col_name=col_name, table_alias=alias) for alias, col_name in col_objs] ) return target_list
def test_bind_tableref_with_func_expr(self): with patch.object(StatementBinder, 'bind') as mock_binder: binder = StatementBinder(StatementBinderContext()) tableref = MagicMock() tableref.is_table_atom.return_value = False tableref.is_select.return_value = False tableref.is_join.return_value = False binder._bind_tableref(tableref) mock_binder.assert_called_with(tableref.func_expr)
def test_bind_load_data_raises(self, mock_tve, mock_create): load_statement = MagicMock() column = MagicMock() load_statement.column_list = [column] load_statement.table_ref.table.table_obj = None with self.assertRaises(RuntimeError): with patch.object(StatementBinder, 'bind'): binder = StatementBinder(StatementBinderContext()) binder._bind_load_data_statement(load_statement)
def test_bind_tableref_starts_new_context(self, mock_ctx): with patch.object(StatementBinder, 'bind'): binder = StatementBinder(StatementBinderContext()) tableref = MagicMock() tableref.is_table_atom.return_value = False tableref.is_join.return_value = False tableref.is_select.return_value = True binder._bind_tableref(tableref) self.assertEqual(mock_ctx.call_count, 1)
def test_bind_tableref_with_join(self): with patch.object(StatementBinder, 'bind') as mock_binder: binder = StatementBinder(StatementBinderContext()) tableref = MagicMock() tableref.is_table_atom.return_value = False tableref.is_select.return_value = False tableref.is_join.return_value = True binder._bind_tableref(tableref) mock_binder.assert_any_call(tableref.join_node.left) mock_binder.assert_any_call(tableref.join_node.right)
def test_bind_func_expr(self, mock_path_to_class, mock_catalog): # setup func_expr = MagicMock(alias='func_expr', output_col_aliases=[]) obj1 = MagicMock() obj1.name.lower.return_value = 'out1' obj2 = MagicMock() obj2.name.lower.return_value = 'out2' func_ouput_objs = [obj1, obj2] udf_obj = MagicMock() mock_get_name = mock_catalog().get_udf_by_name = MagicMock() mock_get_name.return_value = udf_obj mock_get_udf_outputs = mock_catalog().get_udf_outputs = MagicMock() mock_get_udf_outputs.return_value = func_ouput_objs mock_path_to_class.return_value.return_value = 'path_to_class' # Case 1 set output func_expr.output = 'out1' binder = StatementBinder(StatementBinderContext()) binder._bind_func_expr(func_expr) mock_get_name.assert_called_with(func_expr.name) mock_get_udf_outputs.assert_called_with(udf_obj) mock_path_to_class.assert_called_with(udf_obj.impl_file_path, udf_obj.name) self.assertEqual(func_expr.output_objs, [obj1]) self.assertEqual(func_expr.output_col_aliases, ['{}.{}'.format(func_expr.alias, obj1.name.lower())]) self.assertEqual(func_expr.function, 'path_to_class') # Case 2 output not set func_expr.output = None binder = StatementBinder(StatementBinderContext()) binder._bind_func_expr(func_expr) mock_get_name.assert_called_with(func_expr.name) mock_get_udf_outputs.assert_called_with(udf_obj) mock_path_to_class.assert_called_with(udf_obj.impl_file_path, udf_obj.name) self.assertEqual(func_expr.output_objs, func_ouput_objs) self.assertEqual(func_expr.output_col_aliases, ['func_expr.out1', 'func_expr.out2']) self.assertEqual(func_expr.function, 'path_to_class')
def test_bind_tableref_should_raise(self): with patch.object(StatementBinder, 'bind'): with self.assertRaises(ValueError): binder = StatementBinder(StatementBinderContext()) tableref = MagicMock() tableref.is_select.return_value = False tableref.is_func_expr.return_value = False tableref.is_join.return_value = False tableref.is_table_atom.return_value = False binder._bind_tableref(tableref)
def test_bind_tableref(self, mock_bind_tabe_info): with patch.object(StatementBinderContext, 'add_table_alias') as mock: binder = StatementBinder(StatementBinderContext()) tableref = MagicMock() tableref.is_table_atom.return_value = True binder._bind_tableref(tableref) mock.assert_called_with(tableref.alias, tableref.table.table_name) mock_bind_tabe_info.assert_called_once_with(tableref.table) with patch.object(StatementBinder, 'bind') as mock_binder: with patch.object(StatementBinderContext, 'add_derived_table_alias') as mock_context: binder = StatementBinder(StatementBinderContext()) tableref = MagicMock() tableref.is_table_atom.return_value = False tableref.is_select.return_value = True binder._bind_tableref(tableref) mock_context.assert_called_with( tableref.alias, tableref.select_statement.target_list) mock_binder.assert_called_with(tableref.select_statement)
def test_bind_tuple_value_expression(self): with patch.object(StatementBinderContext, 'get_binded_column') as mock: mock.return_value = ['table_alias', 'col_obj'] binder = StatementBinder(StatementBinderContext()) tve = MagicMock() tve.col_name = 'col_name' binder._bind_tuple_expr(tve) col_alias = '{}.{}'.format('table_alias', 'col_name') mock.assert_called_with(tve.col_name, tve.table_alias) self.assertEqual(tve.col_object, 'col_obj') self.assertEqual(tve.col_alias, col_alias)
def _bind_select_statement(self, node: SelectStatement): self.bind(node.from_table) if node.where_clause: self.bind(node.where_clause) if node.target_list: # SELECT * support if len(node.target_list) == 1 and \ isinstance(node.target_list[0], TupleValueExpression) and \ node.target_list[0].col_name == '*': node.target_list = extend_star(self._binder_context) for expr in node.target_list: self.bind(expr) if node.orderby_list: for expr in node.orderby_list: self.bind(expr[0]) if node.union_link: current_context = self._binder_context self._binder_context = StatementBinderContext() self.bind(node.union_link) self._binder_context = current_context
def execute_query(query) -> Iterator[Batch]: """ Execute the query and return a result generator. """ stmt = Parser().parse(query)[0] try: StatementBinder(StatementBinderContext()).bind(stmt) except Exception as error: raise RuntimeError(f'Binder failed: {error}') l_plan = StatementToPlanConvertor().visit(stmt) p_plan = PlanGenerator().build(l_plan) return PlanExecutor(p_plan).execute_plan()
def test_bind_select_statement(self): with patch.object(StatementBinder, 'bind') as mock_binder: binder = StatementBinder(StatementBinderContext()) select_statement = MagicMock() mocks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] select_statement.target_list = mocks[:2] select_statement.orderby_list = [(mocks[2], 0), (mocks[3], 0)] binder._bind_select_statement(select_statement) mock_binder.assert_any_call(select_statement.from_table) mock_binder.assert_any_call(select_statement.where_clause) mock_binder.assert_any_call(select_statement.union_link) for mock in mocks: mock_binder.assert_any_call(mock)
def test_get_binded_column_raise_error(self): # no alias with self.assertRaises(RuntimeError): ctx = StatementBinderContext() mock_search_all = ctx._search_all_alias_maps = MagicMock() mock_search_all.return_value = (None, None) ctx.get_binded_column('col_name') # with alias with self.assertRaises(RuntimeError): ctx = StatementBinderContext() mock_table_map = ctx._check_table_alias_map = MagicMock() mock_table_map.return_value = None mock_derived_map = ctx._check_derived_table_alias_map = MagicMock() mock_derived_map.return_value = None ctx.get_binded_column('col_name', 'alias')
def test_bind_load_data(self, mock_tve, mock_create): load_statement = MagicMock() column = MagicMock() load_statement.column_list = [column] table_ref_obj = MagicMock() table_ref_obj.columns = [column] with patch.object(StatementBinder, 'bind') as mock_binder: binder = StatementBinder(StatementBinderContext()) binder._bind_load_data_statement(load_statement) mock_binder.assert_any_call(load_statement.table_ref) mock_create.assert_not_called() mock_tve.assert_not_called() mock_binder.assert_any_call(column)
def _bind_tableref(self, node: TableRef): if node.is_table_atom(): # Table self._binder_context.add_table_alias(node.alias, node.table.table_name) bind_table_info(node.table) elif node.is_select(): current_context = self._binder_context self._binder_context = StatementBinderContext() self.bind(node.select_statement) self._binder_context = current_context self._binder_context.add_derived_table_alias( node.alias, node.select_statement.target_list) elif node.is_join(): self.bind(node.join_node.left) self.bind(node.join_node.right) if node.join_node.predicate: self.bind(node.join_node.predicate) elif node.is_func_expr(): self.bind(node.func_expr) self._binder_context.add_derived_table_alias( node.func_expr.alias, [node.func_expr]) else: raise ValueError(f'Unsupported node {type(node)}')
def test_check_table_alias_map(self, mock_catalog): mock_get_column_object = mock_catalog().get_column_object = MagicMock() mock_get_column_object.return_value = 'catalog_value' # key exists ctx = StatementBinderContext() ctx._table_alias_map['alias'] = 'table_obj' result = ctx._check_table_alias_map('alias', 'col_name') mock_get_column_object.assert_called_once_with('table_obj', 'col_name') self.assertEqual(result, 'catalog_value') # key does not exixt mock_get_column_object.reset_mock() ctx = StatementBinderContext() result = ctx._check_table_alias_map('alias', 'col_name') mock_get_column_object.assert_not_called() self.assertEqual(result, None)
def test_check_derived_table_alias_map(self): # key exists ctx = StatementBinderContext() obj1 = MagicMock() obj1.name.lower.return_value = 'col_name1' obj2 = MagicMock() obj2.name.lower.return_value = 'col_name2' objs = [obj1, obj2] ctx._derived_table_alias_map['alias'] = objs result = ctx._check_derived_table_alias_map('alias', 'col_name1') self.assertEqual(result, obj1) # key does not exixt ctx = StatementBinderContext() result = ctx._check_derived_table_alias_map('alias', 'col_name') self.assertEqual(result, None)
def test_search_all_alias_maps(self): ctx = StatementBinderContext() check_table_map = ctx._check_table_alias_map = MagicMock() check_derived_map = ctx._check_derived_table_alias_map = MagicMock() # only _table_alias_map has entry check_table_map.return_value = 'col_obj' ctx._table_alias_map['alias'] = 'col_name' ctx._derived_table_alias_map = {} result = ctx._search_all_alias_maps('col_name') check_table_map.assert_called_once_with('alias', 'col_name') check_derived_map.assert_not_called() self.assertEqual(result, ('alias', 'col_obj')) # only _derived_table_alias_map check_derived_map.return_value = 'derived_col_obj' ctx._table_alias_map = {} ctx._derived_table_alias_map['alias'] = 'col_name' result = ctx._search_all_alias_maps('col_name') check_table_map.assert_called_once_with('alias', 'col_name') check_table_map.assert_called_once_with('alias', 'col_name') self.assertEqual(result, ('alias', 'derived_col_obj'))
def test_bind_load_video_statement(self, mock_tve, mock_create): load_statement = MagicMock() load_statement.file_options = {'file_format': FileFormatType.VIDEO} load_statement.column_list = None column = MagicMock() table_ref_obj = MagicMock() table_ref_obj.columns = [column] table_ref_obj.name = 'table_alias' load_statement.table_ref.table.table_obj = table_ref_obj load_statement.table_ref.table.table_name = 'table_name' mock_tve.return_value = tve_return_value = MagicMock() with patch.object(StatementBinder, 'bind') as mock_binder: binder = StatementBinder(StatementBinderContext()) binder._bind_load_data_statement(load_statement) mock_binder.assert_any_call(load_statement.table_ref) mock_create.assert_any_call('table_name') mock_tve.assert_called_with(col_name=column.name, table_alias='table_alias', col_object=column) mock_binder.assert_any_call(tve_return_value) self.assertEqual(load_statement.column_list, [tve_return_value])
def test_check_duplicate_alias(self): with self.assertRaises(RuntimeError): ctx = StatementBinderContext() ctx._derived_table_alias_map['alias'] = MagicMock() ctx._check_duplicate_alias('alias') with self.assertRaises(RuntimeError): ctx = StatementBinderContext() ctx._table_alias_map['alias'] = MagicMock() ctx._check_duplicate_alias('alias') # no duplicate ctx = StatementBinderContext() ctx._check_duplicate_alias('alias')
class StatementBinder: def __init__(self, binder_context: StatementBinderContext): self._binder_context = binder_context self._catalog = CatalogManager() @singledispatchmethod def bind(self, node): raise NotImplementedError(f'Cannot bind {type(node)}') @bind.register(AbstractStatement) def _bind_abstract_statement(self, node: AbstractStatement): pass @bind.register(AbstractExpression) def _bind_abstract_expr(self, node: AbstractExpression): for child in node.children: self.bind(child) @bind.register(SelectStatement) def _bind_select_statement(self, node: SelectStatement): self.bind(node.from_table) if node.where_clause: self.bind(node.where_clause) if node.target_list: # SELECT * support if len(node.target_list) == 1 and \ isinstance(node.target_list[0], TupleValueExpression) and \ node.target_list[0].col_name == '*': node.target_list = extend_star(self._binder_context) for expr in node.target_list: self.bind(expr) if node.orderby_list: for expr in node.orderby_list: self.bind(expr[0]) if node.union_link: current_context = self._binder_context self._binder_context = StatementBinderContext() self.bind(node.union_link) self._binder_context = current_context @bind.register(CreateMaterializedViewStatement) def _bind_create_mat_statement(self, node: CreateMaterializedViewStatement): self.bind(node.query) # Todo Verify if the number projected columns matches table @bind.register(LoadDataStatement) def _bind_load_data_statement(self, node: LoadDataStatement): table_ref = node.table_ref if node.file_options['file_format'] == FileFormatType.VIDEO: # Create a new metadata object create_video_metadata(table_ref.table.table_name) self.bind(table_ref) table_ref_obj = table_ref.table.table_obj if table_ref_obj is None: error = '{} does not exists. Create the table using \ CREATE TABLE.'.format(table_ref.table.table_name) logger.error(error) raise RuntimeError(error) # if query had columns specified, we just copy them if node.column_list is not None: column_list = node.column_list # else we curate the column list from the metadata else: column_list = [] for column in table_ref_obj.columns: column_list.append( TupleValueExpression( col_name=column.name, table_alias=table_ref_obj.name.lower(), col_object=column)) # bind the columns for expr in column_list: self.bind(expr) node.column_list = column_list @bind.register(DropTableStatement) def _bind_drop_table_statement(self, node: DropTableStatement): for table in node.table_refs: self.bind(table) @bind.register(TableRef) def _bind_tableref(self, node: TableRef): if node.is_table_atom(): # Table self._binder_context.add_table_alias(node.alias, node.table.table_name) bind_table_info(node.table) elif node.is_select(): current_context = self._binder_context self._binder_context = StatementBinderContext() self.bind(node.select_statement) self._binder_context = current_context self._binder_context.add_derived_table_alias( node.alias, node.select_statement.target_list) elif node.is_join(): self.bind(node.join_node.left) self.bind(node.join_node.right) if node.join_node.predicate: self.bind(node.join_node.predicate) elif node.is_func_expr(): self.bind(node.func_expr) self._binder_context.add_derived_table_alias( node.func_expr.alias, [node.func_expr]) else: raise ValueError(f'Unsupported node {type(node)}') @bind.register(TupleValueExpression) def _bind_tuple_expr(self, node: TupleValueExpression): table_alias, col_obj = self._binder_context.get_binded_column( node.col_name, node.table_alias) node.col_alias = '{}.{}'.format(table_alias, node.col_name.lower()) node.col_object = col_obj @bind.register(FunctionExpression) def _bind_func_expr(self, node: FunctionExpression): # bind all the children for child in node.children: self.bind(child) node.alias = node.alias or node.name.lower() udf_obj = self._catalog.get_udf_by_name(node.name) assert udf_obj is not None, ( 'UDF with name {} does not exist in the catalog. Please ' 'create the UDF using CREATE UDF command'.format(node.name)) output_objs = self._catalog.get_udf_outputs(udf_obj) if node.output: for obj in output_objs: if obj.name.lower() == node.output: node.output_col_aliases.append('{}.{}'.format( node.alias, obj.name.lower())) node.output_objs = [obj] assert len(node.output_col_aliases) == 1, ( 'Duplicate columns {} in UDF {}'.format( node.output, udf_obj.name)) else: node.output_col_aliases = [ '{}.{}'.format(node.alias, obj.name.lower()) for obj in output_objs ] node.output_objs = output_objs node.function = path_to_class(udf_obj.impl_file_path, udf_obj.name)()
def test_bind_create_mat_statement(self): with patch.object(StatementBinder, 'bind') as mock_binder: binder = StatementBinder(StatementBinderContext()) mat_statement = MagicMock() binder._bind_create_mat_statement(mat_statement) mock_binder.assert_called_with(mat_statement.query)