def test_get_binded_column_raise_error(self): # no alias with self.assertRaises(RuntimeError): ctx = StatementBinderContext() mock_search_all = ctx._search_all_alias_maps = MagicMock() mock_search_all.return_value = (None, None) ctx.get_binded_column('col_name') # with alias with self.assertRaises(RuntimeError): ctx = StatementBinderContext() mock_table_map = ctx._check_table_alias_map = MagicMock() mock_table_map.return_value = None mock_derived_map = ctx._check_derived_table_alias_map = MagicMock() mock_derived_map.return_value = None ctx.get_binded_column('col_name', 'alias')
def test_get_binded_column_check_table_alias_map(self): ctx = StatementBinderContext() mock_table_map = ctx._check_table_alias_map = MagicMock() mock_table_map.return_value = 'col_obj' result = ctx.get_binded_column('col_name', 'alias') mock_table_map.assert_called_once_with('alias', 'col_name') self.assertEqual(result, ('alias', 'col_obj'))
def test_get_binded_column_should_search_all(self): ctx = StatementBinderContext() mock_search_all = ctx._search_all_alias_maps = MagicMock() mock_search_all.return_value = ('alias', 'col_obj') result = ctx.get_binded_column('col_name') mock_search_all.assert_called_once_with('col_name') self.assertEqual(result, ('alias', 'col_obj'))
class StatementBinder: def __init__(self, binder_context: StatementBinderContext): self._binder_context = binder_context self._catalog = CatalogManager() @singledispatchmethod def bind(self, node): raise NotImplementedError(f'Cannot bind {type(node)}') @bind.register(AbstractStatement) def _bind_abstract_statement(self, node: AbstractStatement): pass @bind.register(AbstractExpression) def _bind_abstract_expr(self, node: AbstractExpression): for child in node.children: self.bind(child) @bind.register(SelectStatement) def _bind_select_statement(self, node: SelectStatement): self.bind(node.from_table) if node.where_clause: self.bind(node.where_clause) if node.target_list: # SELECT * support if len(node.target_list) == 1 and \ isinstance(node.target_list[0], TupleValueExpression) and \ node.target_list[0].col_name == '*': node.target_list = extend_star(self._binder_context) for expr in node.target_list: self.bind(expr) if node.orderby_list: for expr in node.orderby_list: self.bind(expr[0]) if node.union_link: current_context = self._binder_context self._binder_context = StatementBinderContext() self.bind(node.union_link) self._binder_context = current_context @bind.register(CreateMaterializedViewStatement) def _bind_create_mat_statement(self, node: CreateMaterializedViewStatement): self.bind(node.query) # Todo Verify if the number projected columns matches table @bind.register(LoadDataStatement) def _bind_load_data_statement(self, node: LoadDataStatement): table_ref = node.table_ref if node.file_options['file_format'] == FileFormatType.VIDEO: # Create a new metadata object create_video_metadata(table_ref.table.table_name) self.bind(table_ref) table_ref_obj = table_ref.table.table_obj if table_ref_obj is None: error = '{} does not exists. Create the table using \ CREATE TABLE.'.format(table_ref.table.table_name) logger.error(error) raise RuntimeError(error) # if query had columns specified, we just copy them if node.column_list is not None: column_list = node.column_list # else we curate the column list from the metadata else: column_list = [] for column in table_ref_obj.columns: column_list.append( TupleValueExpression( col_name=column.name, table_alias=table_ref_obj.name.lower(), col_object=column)) # bind the columns for expr in column_list: self.bind(expr) node.column_list = column_list @bind.register(DropTableStatement) def _bind_drop_table_statement(self, node: DropTableStatement): for table in node.table_refs: self.bind(table) @bind.register(TableRef) def _bind_tableref(self, node: TableRef): if node.is_table_atom(): # Table self._binder_context.add_table_alias(node.alias, node.table.table_name) bind_table_info(node.table) elif node.is_select(): current_context = self._binder_context self._binder_context = StatementBinderContext() self.bind(node.select_statement) self._binder_context = current_context self._binder_context.add_derived_table_alias( node.alias, node.select_statement.target_list) elif node.is_join(): self.bind(node.join_node.left) self.bind(node.join_node.right) if node.join_node.predicate: self.bind(node.join_node.predicate) elif node.is_func_expr(): self.bind(node.func_expr) self._binder_context.add_derived_table_alias( node.func_expr.alias, [node.func_expr]) else: raise ValueError(f'Unsupported node {type(node)}') @bind.register(TupleValueExpression) def _bind_tuple_expr(self, node: TupleValueExpression): table_alias, col_obj = self._binder_context.get_binded_column( node.col_name, node.table_alias) node.col_alias = '{}.{}'.format(table_alias, node.col_name.lower()) node.col_object = col_obj @bind.register(FunctionExpression) def _bind_func_expr(self, node: FunctionExpression): # bind all the children for child in node.children: self.bind(child) node.alias = node.alias or node.name.lower() udf_obj = self._catalog.get_udf_by_name(node.name) assert udf_obj is not None, ( 'UDF with name {} does not exist in the catalog. Please ' 'create the UDF using CREATE UDF command'.format(node.name)) output_objs = self._catalog.get_udf_outputs(udf_obj) if node.output: for obj in output_objs: if obj.name.lower() == node.output: node.output_col_aliases.append('{}.{}'.format( node.alias, obj.name.lower())) node.output_objs = [obj] assert len(node.output_col_aliases) == 1, ( 'Duplicate columns {} in UDF {}'.format( node.output, udf_obj.name)) else: node.output_col_aliases = [ '{}.{}'.format(node.alias, obj.name.lower()) for obj in output_objs ] node.output_objs = output_objs node.function = path_to_class(udf_obj.impl_file_path, udf_obj.name)()