def test_get_binded_column_raise_error(self):
     # no alias
     with self.assertRaises(RuntimeError):
         ctx = StatementBinderContext()
         mock_search_all = ctx._search_all_alias_maps = MagicMock()
         mock_search_all.return_value = (None, None)
         ctx.get_binded_column('col_name')
     # with alias
     with self.assertRaises(RuntimeError):
         ctx = StatementBinderContext()
         mock_table_map = ctx._check_table_alias_map = MagicMock()
         mock_table_map.return_value = None
         mock_derived_map = ctx._check_derived_table_alias_map = MagicMock()
         mock_derived_map.return_value = None
         ctx.get_binded_column('col_name', 'alias')
 def test_get_binded_column_check_table_alias_map(self):
     ctx = StatementBinderContext()
     mock_table_map = ctx._check_table_alias_map = MagicMock()
     mock_table_map.return_value = 'col_obj'
     result = ctx.get_binded_column('col_name', 'alias')
     mock_table_map.assert_called_once_with('alias', 'col_name')
     self.assertEqual(result, ('alias', 'col_obj'))
    def test_get_binded_column_should_search_all(self):
        ctx = StatementBinderContext()
        mock_search_all = ctx._search_all_alias_maps = MagicMock()
        mock_search_all.return_value = ('alias', 'col_obj')

        result = ctx.get_binded_column('col_name')
        mock_search_all.assert_called_once_with('col_name')
        self.assertEqual(result, ('alias', 'col_obj'))
示例#4
0
class StatementBinder:
    def __init__(self, binder_context: StatementBinderContext):
        self._binder_context = binder_context
        self._catalog = CatalogManager()

    @singledispatchmethod
    def bind(self, node):
        raise NotImplementedError(f'Cannot bind {type(node)}')

    @bind.register(AbstractStatement)
    def _bind_abstract_statement(self, node: AbstractStatement):
        pass

    @bind.register(AbstractExpression)
    def _bind_abstract_expr(self, node: AbstractExpression):
        for child in node.children:
            self.bind(child)

    @bind.register(SelectStatement)
    def _bind_select_statement(self, node: SelectStatement):
        self.bind(node.from_table)
        if node.where_clause:
            self.bind(node.where_clause)
        if node.target_list:
            # SELECT * support
            if len(node.target_list) == 1 and \
                    isinstance(node.target_list[0], TupleValueExpression) and \
                    node.target_list[0].col_name == '*':
                node.target_list = extend_star(self._binder_context)
            for expr in node.target_list:
                self.bind(expr)
        if node.orderby_list:
            for expr in node.orderby_list:
                self.bind(expr[0])
        if node.union_link:
            current_context = self._binder_context
            self._binder_context = StatementBinderContext()
            self.bind(node.union_link)
            self._binder_context = current_context

    @bind.register(CreateMaterializedViewStatement)
    def _bind_create_mat_statement(self,
                                   node: CreateMaterializedViewStatement):
        self.bind(node.query)
        # Todo Verify if the number projected columns matches table

    @bind.register(LoadDataStatement)
    def _bind_load_data_statement(self, node: LoadDataStatement):
        table_ref = node.table_ref
        if node.file_options['file_format'] == FileFormatType.VIDEO:
            # Create a new metadata object
            create_video_metadata(table_ref.table.table_name)

        self.bind(table_ref)

        table_ref_obj = table_ref.table.table_obj
        if table_ref_obj is None:
            error = '{} does not exists. Create the table using \
                            CREATE TABLE.'.format(table_ref.table.table_name)
            logger.error(error)
            raise RuntimeError(error)

        # if query had columns specified, we just copy them
        if node.column_list is not None:
            column_list = node.column_list

        # else we curate the column list from the metadata
        else:
            column_list = []
            for column in table_ref_obj.columns:
                column_list.append(
                    TupleValueExpression(
                        col_name=column.name,
                        table_alias=table_ref_obj.name.lower(),
                        col_object=column))

        # bind the columns
        for expr in column_list:
            self.bind(expr)

        node.column_list = column_list

    @bind.register(DropTableStatement)
    def _bind_drop_table_statement(self, node: DropTableStatement):
        for table in node.table_refs:
            self.bind(table)

    @bind.register(TableRef)
    def _bind_tableref(self, node: TableRef):
        if node.is_table_atom():
            # Table
            self._binder_context.add_table_alias(node.alias,
                                                 node.table.table_name)
            bind_table_info(node.table)
        elif node.is_select():
            current_context = self._binder_context
            self._binder_context = StatementBinderContext()
            self.bind(node.select_statement)
            self._binder_context = current_context
            self._binder_context.add_derived_table_alias(
                node.alias, node.select_statement.target_list)
        elif node.is_join():
            self.bind(node.join_node.left)
            self.bind(node.join_node.right)
            if node.join_node.predicate:
                self.bind(node.join_node.predicate)
        elif node.is_func_expr():
            self.bind(node.func_expr)
            self._binder_context.add_derived_table_alias(
                node.func_expr.alias, [node.func_expr])
        else:
            raise ValueError(f'Unsupported node {type(node)}')

    @bind.register(TupleValueExpression)
    def _bind_tuple_expr(self, node: TupleValueExpression):
        table_alias, col_obj = self._binder_context.get_binded_column(
            node.col_name, node.table_alias)
        node.col_alias = '{}.{}'.format(table_alias, node.col_name.lower())
        node.col_object = col_obj

    @bind.register(FunctionExpression)
    def _bind_func_expr(self, node: FunctionExpression):
        # bind all the children
        for child in node.children:
            self.bind(child)

        node.alias = node.alias or node.name.lower()
        udf_obj = self._catalog.get_udf_by_name(node.name)
        assert udf_obj is not None, (
            'UDF with name {} does not exist in the catalog. Please '
            'create the UDF using CREATE UDF command'.format(node.name))

        output_objs = self._catalog.get_udf_outputs(udf_obj)
        if node.output:
            for obj in output_objs:
                if obj.name.lower() == node.output:
                    node.output_col_aliases.append('{}.{}'.format(
                        node.alias, obj.name.lower()))
                    node.output_objs = [obj]
            assert len(node.output_col_aliases) == 1, (
                'Duplicate columns {} in UDF {}'.format(
                    node.output, udf_obj.name))
        else:
            node.output_col_aliases = [
                '{}.{}'.format(node.alias, obj.name.lower())
                for obj in output_objs
            ]
            node.output_objs = output_objs

        node.function = path_to_class(udf_obj.impl_file_path, udf_obj.name)()