def test_get_binded_column_check_table_alias_map(self):
     ctx = StatementBinderContext()
     mock_table_map = ctx._check_table_alias_map = MagicMock()
     mock_table_map.return_value = 'col_obj'
     result = ctx.get_binded_column('col_name', 'alias')
     mock_table_map.assert_called_once_with('alias', 'col_name')
     self.assertEqual(result, ('alias', 'col_obj'))
    def test_get_binded_column_should_search_all(self):
        ctx = StatementBinderContext()
        mock_search_all = ctx._search_all_alias_maps = MagicMock()
        mock_search_all.return_value = ('alias', 'col_obj')

        result = ctx.get_binded_column('col_name')
        mock_search_all.assert_called_once_with('col_name')
        self.assertEqual(result, ('alias', 'col_obj'))
    def test_add_table_alias(self, mock_catalog):
        mock_get = mock_catalog().get_dataset_metadata = MagicMock()
        mock_get.return_value = 'table_obj'
        ctx = StatementBinderContext()

        mock_check = ctx._check_duplicate_alias = MagicMock()
        ctx.add_table_alias('alias', 'table_name')
        mock_check.assert_called_with('alias')
        mock_get.assert_called_with(None, 'table_name')
        self.assertEqual(ctx._table_alias_map['alias'], 'table_obj')
    def test_add_derived_table_alias(self):
        exprs = [MagicMock(spec=TupleValueExpression, col_object='A'),
                 MagicMock(spec=FunctionExpression, output_objs=['B', 'C'])]
        ctx = StatementBinderContext()

        mock_check = ctx._check_duplicate_alias = MagicMock()
        ctx.add_derived_table_alias('alias', exprs)

        mock_check.assert_called_with('alias')
        self.assertEqual(
            ctx._derived_table_alias_map['alias'], ['A', 'B', 'C'])
示例#5
0
    def test_bind_select_statement_union_starts_new_context(self, mock_ctx):
        with patch.object(StatementBinder, 'bind'):
            binder = StatementBinder(StatementBinderContext())
            select_statement = MagicMock()
            select_statement.union_link = None
            binder._bind_select_statement(select_statement)
            self.assertEqual(mock_ctx.call_count, 0)

            binder = StatementBinder(StatementBinderContext())
            select_statement = MagicMock()
            binder._bind_select_statement(select_statement)
            self.assertEqual(mock_ctx.call_count, 1)
示例#6
0
    def test_bind_unknown_object(self):
        class UnknownType:
            pass

        with self.assertRaises(NotImplementedError):
            binder = StatementBinder(StatementBinderContext())
            binder.bind(UnknownType())
 def test_search_all_alias_raise_duplicate_error(self):
     with self.assertRaises(RuntimeError):
         ctx = StatementBinderContext()
         ctx._check_table_alias_map = MagicMock()
         ctx._check_derived_table_alias_map = MagicMock()
         # duplicate
         ctx._table_alias_map['alias'] = 'col_name'
         ctx._derived_table_alias_map['alias'] = 'col_name'
         ctx._search_all_alias_maps('col_name')
示例#8
0
def extend_star(binder_context: StatementBinderContext) \
        -> List[TupleValueExpression]:
    col_objs = binder_context._get_all_alias_and_col_name()

    target_list = list(
        [TupleValueExpression(col_name=col_name, table_alias=alias) 
            for alias, col_name in col_objs]
    )
    return target_list
示例#9
0
 def test_bind_tableref_with_func_expr(self):
     with patch.object(StatementBinder, 'bind') as mock_binder:
         binder = StatementBinder(StatementBinderContext())
         tableref = MagicMock()
         tableref.is_table_atom.return_value = False
         tableref.is_select.return_value = False
         tableref.is_join.return_value = False
         binder._bind_tableref(tableref)
         mock_binder.assert_called_with(tableref.func_expr)
示例#10
0
 def test_bind_load_data_raises(self, mock_tve, mock_create):
     load_statement = MagicMock()
     column = MagicMock()
     load_statement.column_list = [column]
     load_statement.table_ref.table.table_obj = None
     with self.assertRaises(RuntimeError):
         with patch.object(StatementBinder, 'bind'):
             binder = StatementBinder(StatementBinderContext())
             binder._bind_load_data_statement(load_statement)
示例#11
0
 def test_bind_tableref_starts_new_context(self, mock_ctx):
     with patch.object(StatementBinder, 'bind'):
         binder = StatementBinder(StatementBinderContext())
         tableref = MagicMock()
         tableref.is_table_atom.return_value = False
         tableref.is_join.return_value = False
         tableref.is_select.return_value = True
         binder._bind_tableref(tableref)
         self.assertEqual(mock_ctx.call_count, 1)
示例#12
0
 def test_bind_tableref_with_join(self):
     with patch.object(StatementBinder, 'bind') as mock_binder:
         binder = StatementBinder(StatementBinderContext())
         tableref = MagicMock()
         tableref.is_table_atom.return_value = False
         tableref.is_select.return_value = False
         tableref.is_join.return_value = True
         binder._bind_tableref(tableref)
         mock_binder.assert_any_call(tableref.join_node.left)
         mock_binder.assert_any_call(tableref.join_node.right)
示例#13
0
    def test_bind_func_expr(self, mock_path_to_class, mock_catalog):
        # setup
        func_expr = MagicMock(alias='func_expr', output_col_aliases=[])
        obj1 = MagicMock()
        obj1.name.lower.return_value = 'out1'
        obj2 = MagicMock()
        obj2.name.lower.return_value = 'out2'
        func_ouput_objs = [obj1, obj2]
        udf_obj = MagicMock()
        mock_get_name = mock_catalog().get_udf_by_name = MagicMock()
        mock_get_name.return_value = udf_obj

        mock_get_udf_outputs = mock_catalog().get_udf_outputs = MagicMock()
        mock_get_udf_outputs.return_value = func_ouput_objs
        mock_path_to_class.return_value.return_value = 'path_to_class'

        # Case 1 set output
        func_expr.output = 'out1'
        binder = StatementBinder(StatementBinderContext())
        binder._bind_func_expr(func_expr)

        mock_get_name.assert_called_with(func_expr.name)
        mock_get_udf_outputs.assert_called_with(udf_obj)
        mock_path_to_class.assert_called_with(udf_obj.impl_file_path,
                                              udf_obj.name)
        self.assertEqual(func_expr.output_objs, [obj1])
        self.assertEqual(func_expr.output_col_aliases,
                         ['{}.{}'.format(func_expr.alias, obj1.name.lower())])
        self.assertEqual(func_expr.function, 'path_to_class')

        # Case 2 output not set
        func_expr.output = None
        binder = StatementBinder(StatementBinderContext())
        binder._bind_func_expr(func_expr)

        mock_get_name.assert_called_with(func_expr.name)
        mock_get_udf_outputs.assert_called_with(udf_obj)
        mock_path_to_class.assert_called_with(udf_obj.impl_file_path,
                                              udf_obj.name)
        self.assertEqual(func_expr.output_objs, func_ouput_objs)
        self.assertEqual(func_expr.output_col_aliases,
                         ['func_expr.out1', 'func_expr.out2'])
        self.assertEqual(func_expr.function, 'path_to_class')
示例#14
0
 def test_bind_tableref_should_raise(self):
     with patch.object(StatementBinder, 'bind'):
         with self.assertRaises(ValueError):
             binder = StatementBinder(StatementBinderContext())
             tableref = MagicMock()
             tableref.is_select.return_value = False
             tableref.is_func_expr.return_value = False
             tableref.is_join.return_value = False
             tableref.is_table_atom.return_value = False
             binder._bind_tableref(tableref)
示例#15
0
    def test_bind_tableref(self, mock_bind_tabe_info):
        with patch.object(StatementBinderContext, 'add_table_alias') as mock:
            binder = StatementBinder(StatementBinderContext())
            tableref = MagicMock()
            tableref.is_table_atom.return_value = True
            binder._bind_tableref(tableref)
            mock.assert_called_with(tableref.alias, tableref.table.table_name)
            mock_bind_tabe_info.assert_called_once_with(tableref.table)

        with patch.object(StatementBinder, 'bind') as mock_binder:
            with patch.object(StatementBinderContext,
                              'add_derived_table_alias') as mock_context:
                binder = StatementBinder(StatementBinderContext())
                tableref = MagicMock()
                tableref.is_table_atom.return_value = False
                tableref.is_select.return_value = True
                binder._bind_tableref(tableref)
                mock_context.assert_called_with(
                    tableref.alias, tableref.select_statement.target_list)
                mock_binder.assert_called_with(tableref.select_statement)
示例#16
0
 def test_bind_tuple_value_expression(self):
     with patch.object(StatementBinderContext, 'get_binded_column') as mock:
         mock.return_value = ['table_alias', 'col_obj']
         binder = StatementBinder(StatementBinderContext())
         tve = MagicMock()
         tve.col_name = 'col_name'
         binder._bind_tuple_expr(tve)
         col_alias = '{}.{}'.format('table_alias', 'col_name')
         mock.assert_called_with(tve.col_name, tve.table_alias)
         self.assertEqual(tve.col_object, 'col_obj')
         self.assertEqual(tve.col_alias, col_alias)
示例#17
0
 def _bind_select_statement(self, node: SelectStatement):
     self.bind(node.from_table)
     if node.where_clause:
         self.bind(node.where_clause)
     if node.target_list:
         # SELECT * support
         if len(node.target_list) == 1 and \
                 isinstance(node.target_list[0], TupleValueExpression) and \
                 node.target_list[0].col_name == '*':
             node.target_list = extend_star(self._binder_context)
         for expr in node.target_list:
             self.bind(expr)
     if node.orderby_list:
         for expr in node.orderby_list:
             self.bind(expr[0])
     if node.union_link:
         current_context = self._binder_context
         self._binder_context = StatementBinderContext()
         self.bind(node.union_link)
         self._binder_context = current_context
示例#18
0
def execute_query(query) -> Iterator[Batch]:
    """
    Execute the query and return a result generator.
    """
    stmt = Parser().parse(query)[0]
    try:
        StatementBinder(StatementBinderContext()).bind(stmt)
    except Exception as error:
        raise RuntimeError(f'Binder failed: {error}')
    l_plan = StatementToPlanConvertor().visit(stmt)
    p_plan = PlanGenerator().build(l_plan)
    return PlanExecutor(p_plan).execute_plan()
示例#19
0
 def test_bind_select_statement(self):
     with patch.object(StatementBinder, 'bind') as mock_binder:
         binder = StatementBinder(StatementBinderContext())
         select_statement = MagicMock()
         mocks = [MagicMock(), MagicMock(), MagicMock(), MagicMock()]
         select_statement.target_list = mocks[:2]
         select_statement.orderby_list = [(mocks[2], 0), (mocks[3], 0)]
         binder._bind_select_statement(select_statement)
         mock_binder.assert_any_call(select_statement.from_table)
         mock_binder.assert_any_call(select_statement.where_clause)
         mock_binder.assert_any_call(select_statement.union_link)
         for mock in mocks:
             mock_binder.assert_any_call(mock)
 def test_get_binded_column_raise_error(self):
     # no alias
     with self.assertRaises(RuntimeError):
         ctx = StatementBinderContext()
         mock_search_all = ctx._search_all_alias_maps = MagicMock()
         mock_search_all.return_value = (None, None)
         ctx.get_binded_column('col_name')
     # with alias
     with self.assertRaises(RuntimeError):
         ctx = StatementBinderContext()
         mock_table_map = ctx._check_table_alias_map = MagicMock()
         mock_table_map.return_value = None
         mock_derived_map = ctx._check_derived_table_alias_map = MagicMock()
         mock_derived_map.return_value = None
         ctx.get_binded_column('col_name', 'alias')
示例#21
0
    def test_bind_load_data(self, mock_tve, mock_create):
        load_statement = MagicMock()
        column = MagicMock()
        load_statement.column_list = [column]

        table_ref_obj = MagicMock()
        table_ref_obj.columns = [column]

        with patch.object(StatementBinder, 'bind') as mock_binder:
            binder = StatementBinder(StatementBinderContext())
            binder._bind_load_data_statement(load_statement)
            mock_binder.assert_any_call(load_statement.table_ref)
            mock_create.assert_not_called()
            mock_tve.assert_not_called()
            mock_binder.assert_any_call(column)
示例#22
0
 def _bind_tableref(self, node: TableRef):
     if node.is_table_atom():
         # Table
         self._binder_context.add_table_alias(node.alias,
                                              node.table.table_name)
         bind_table_info(node.table)
     elif node.is_select():
         current_context = self._binder_context
         self._binder_context = StatementBinderContext()
         self.bind(node.select_statement)
         self._binder_context = current_context
         self._binder_context.add_derived_table_alias(
             node.alias, node.select_statement.target_list)
     elif node.is_join():
         self.bind(node.join_node.left)
         self.bind(node.join_node.right)
         if node.join_node.predicate:
             self.bind(node.join_node.predicate)
     elif node.is_func_expr():
         self.bind(node.func_expr)
         self._binder_context.add_derived_table_alias(
             node.func_expr.alias, [node.func_expr])
     else:
         raise ValueError(f'Unsupported node {type(node)}')
    def test_check_table_alias_map(self, mock_catalog):
        mock_get_column_object = mock_catalog().get_column_object = MagicMock()
        mock_get_column_object.return_value = 'catalog_value'
        # key exists
        ctx = StatementBinderContext()
        ctx._table_alias_map['alias'] = 'table_obj'
        result = ctx._check_table_alias_map('alias', 'col_name')
        mock_get_column_object.assert_called_once_with('table_obj', 'col_name')
        self.assertEqual(result, 'catalog_value')

        # key does not exixt
        mock_get_column_object.reset_mock()
        ctx = StatementBinderContext()
        result = ctx._check_table_alias_map('alias', 'col_name')
        mock_get_column_object.assert_not_called()
        self.assertEqual(result, None)
    def test_check_derived_table_alias_map(self):
        # key exists
        ctx = StatementBinderContext()
        obj1 = MagicMock()
        obj1.name.lower.return_value = 'col_name1'
        obj2 = MagicMock()
        obj2.name.lower.return_value = 'col_name2'
        objs = [obj1, obj2]
        ctx._derived_table_alias_map['alias'] = objs
        result = ctx._check_derived_table_alias_map('alias', 'col_name1')
        self.assertEqual(result, obj1)

        # key does not exixt
        ctx = StatementBinderContext()
        result = ctx._check_derived_table_alias_map('alias', 'col_name')
        self.assertEqual(result, None)
    def test_search_all_alias_maps(self):
        ctx = StatementBinderContext()
        check_table_map = ctx._check_table_alias_map = MagicMock()
        check_derived_map = ctx._check_derived_table_alias_map = MagicMock()

        # only _table_alias_map has entry
        check_table_map.return_value = 'col_obj'
        ctx._table_alias_map['alias'] = 'col_name'
        ctx._derived_table_alias_map = {}
        result = ctx._search_all_alias_maps('col_name')
        check_table_map.assert_called_once_with('alias', 'col_name')
        check_derived_map.assert_not_called()
        self.assertEqual(result, ('alias', 'col_obj'))

        # only _derived_table_alias_map
        check_derived_map.return_value = 'derived_col_obj'
        ctx._table_alias_map = {}
        ctx._derived_table_alias_map['alias'] = 'col_name'
        result = ctx._search_all_alias_maps('col_name')
        check_table_map.assert_called_once_with('alias', 'col_name')
        check_table_map.assert_called_once_with('alias', 'col_name')
        self.assertEqual(result, ('alias', 'derived_col_obj'))
示例#26
0
    def test_bind_load_video_statement(self, mock_tve, mock_create):
        load_statement = MagicMock()
        load_statement.file_options = {'file_format': FileFormatType.VIDEO}
        load_statement.column_list = None
        column = MagicMock()
        table_ref_obj = MagicMock()
        table_ref_obj.columns = [column]
        table_ref_obj.name = 'table_alias'
        load_statement.table_ref.table.table_obj = table_ref_obj
        load_statement.table_ref.table.table_name = 'table_name'
        mock_tve.return_value = tve_return_value = MagicMock()

        with patch.object(StatementBinder, 'bind') as mock_binder:
            binder = StatementBinder(StatementBinderContext())
            binder._bind_load_data_statement(load_statement)
            mock_binder.assert_any_call(load_statement.table_ref)
            mock_create.assert_any_call('table_name')
            mock_tve.assert_called_with(col_name=column.name,
                                        table_alias='table_alias',
                                        col_object=column)
            mock_binder.assert_any_call(tve_return_value)
            self.assertEqual(load_statement.column_list, [tve_return_value])
    def test_check_duplicate_alias(self):
        with self.assertRaises(RuntimeError):
            ctx = StatementBinderContext()
            ctx._derived_table_alias_map['alias'] = MagicMock()
            ctx._check_duplicate_alias('alias')

        with self.assertRaises(RuntimeError):
            ctx = StatementBinderContext()
            ctx._table_alias_map['alias'] = MagicMock()
            ctx._check_duplicate_alias('alias')

        # no duplicate
        ctx = StatementBinderContext()
        ctx._check_duplicate_alias('alias')
示例#28
0
class StatementBinder:
    def __init__(self, binder_context: StatementBinderContext):
        self._binder_context = binder_context
        self._catalog = CatalogManager()

    @singledispatchmethod
    def bind(self, node):
        raise NotImplementedError(f'Cannot bind {type(node)}')

    @bind.register(AbstractStatement)
    def _bind_abstract_statement(self, node: AbstractStatement):
        pass

    @bind.register(AbstractExpression)
    def _bind_abstract_expr(self, node: AbstractExpression):
        for child in node.children:
            self.bind(child)

    @bind.register(SelectStatement)
    def _bind_select_statement(self, node: SelectStatement):
        self.bind(node.from_table)
        if node.where_clause:
            self.bind(node.where_clause)
        if node.target_list:
            # SELECT * support
            if len(node.target_list) == 1 and \
                    isinstance(node.target_list[0], TupleValueExpression) and \
                    node.target_list[0].col_name == '*':
                node.target_list = extend_star(self._binder_context)
            for expr in node.target_list:
                self.bind(expr)
        if node.orderby_list:
            for expr in node.orderby_list:
                self.bind(expr[0])
        if node.union_link:
            current_context = self._binder_context
            self._binder_context = StatementBinderContext()
            self.bind(node.union_link)
            self._binder_context = current_context

    @bind.register(CreateMaterializedViewStatement)
    def _bind_create_mat_statement(self,
                                   node: CreateMaterializedViewStatement):
        self.bind(node.query)
        # Todo Verify if the number projected columns matches table

    @bind.register(LoadDataStatement)
    def _bind_load_data_statement(self, node: LoadDataStatement):
        table_ref = node.table_ref
        if node.file_options['file_format'] == FileFormatType.VIDEO:
            # Create a new metadata object
            create_video_metadata(table_ref.table.table_name)

        self.bind(table_ref)

        table_ref_obj = table_ref.table.table_obj
        if table_ref_obj is None:
            error = '{} does not exists. Create the table using \
                            CREATE TABLE.'.format(table_ref.table.table_name)
            logger.error(error)
            raise RuntimeError(error)

        # if query had columns specified, we just copy them
        if node.column_list is not None:
            column_list = node.column_list

        # else we curate the column list from the metadata
        else:
            column_list = []
            for column in table_ref_obj.columns:
                column_list.append(
                    TupleValueExpression(
                        col_name=column.name,
                        table_alias=table_ref_obj.name.lower(),
                        col_object=column))

        # bind the columns
        for expr in column_list:
            self.bind(expr)

        node.column_list = column_list

    @bind.register(DropTableStatement)
    def _bind_drop_table_statement(self, node: DropTableStatement):
        for table in node.table_refs:
            self.bind(table)

    @bind.register(TableRef)
    def _bind_tableref(self, node: TableRef):
        if node.is_table_atom():
            # Table
            self._binder_context.add_table_alias(node.alias,
                                                 node.table.table_name)
            bind_table_info(node.table)
        elif node.is_select():
            current_context = self._binder_context
            self._binder_context = StatementBinderContext()
            self.bind(node.select_statement)
            self._binder_context = current_context
            self._binder_context.add_derived_table_alias(
                node.alias, node.select_statement.target_list)
        elif node.is_join():
            self.bind(node.join_node.left)
            self.bind(node.join_node.right)
            if node.join_node.predicate:
                self.bind(node.join_node.predicate)
        elif node.is_func_expr():
            self.bind(node.func_expr)
            self._binder_context.add_derived_table_alias(
                node.func_expr.alias, [node.func_expr])
        else:
            raise ValueError(f'Unsupported node {type(node)}')

    @bind.register(TupleValueExpression)
    def _bind_tuple_expr(self, node: TupleValueExpression):
        table_alias, col_obj = self._binder_context.get_binded_column(
            node.col_name, node.table_alias)
        node.col_alias = '{}.{}'.format(table_alias, node.col_name.lower())
        node.col_object = col_obj

    @bind.register(FunctionExpression)
    def _bind_func_expr(self, node: FunctionExpression):
        # bind all the children
        for child in node.children:
            self.bind(child)

        node.alias = node.alias or node.name.lower()
        udf_obj = self._catalog.get_udf_by_name(node.name)
        assert udf_obj is not None, (
            'UDF with name {} does not exist in the catalog. Please '
            'create the UDF using CREATE UDF command'.format(node.name))

        output_objs = self._catalog.get_udf_outputs(udf_obj)
        if node.output:
            for obj in output_objs:
                if obj.name.lower() == node.output:
                    node.output_col_aliases.append('{}.{}'.format(
                        node.alias, obj.name.lower()))
                    node.output_objs = [obj]
            assert len(node.output_col_aliases) == 1, (
                'Duplicate columns {} in UDF {}'.format(
                    node.output, udf_obj.name))
        else:
            node.output_col_aliases = [
                '{}.{}'.format(node.alias, obj.name.lower())
                for obj in output_objs
            ]
            node.output_objs = output_objs

        node.function = path_to_class(udf_obj.impl_file_path, udf_obj.name)()
示例#29
0
 def test_bind_create_mat_statement(self):
     with patch.object(StatementBinder, 'bind') as mock_binder:
         binder = StatementBinder(StatementBinderContext())
         mat_statement = MagicMock()
         binder._bind_create_mat_statement(mat_statement)
         mock_binder.assert_called_with(mat_statement.query)