def visit_select(self, statement: AbstractStatement): """convertor for select statement Arguments: statement {AbstractStatement} -- [input select statement] """ # Create a logical get node video = statement.from_table if video is not None: self.visit_table_ref(video) # Filter Operator predicate = statement.where_clause if predicate is not None: # Binding the expression bind_predicate_expr(predicate) filter_opr = LogicalFilter(predicate) filter_opr.append_child(self._plan) self._plan = filter_opr # Projection operator select_columns = statement.target_list # ToDO # add support for SELECT STAR if select_columns is not None: # Bind the columns using catalog bind_columns_expr(select_columns) projection_opr = LogicalProject(select_columns) projection_opr.append_child(self._plan) self._plan = projection_opr
def test_nested_implementation(self): child_predicate = MagicMock() root_predicate = MagicMock() child_get_opr = LogicalGet(MagicMock(), MagicMock()) child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) child_project_opr = LogicalProject(MagicMock(), [child_filter_opr]) root_derived_get_opr = LogicalQueryDerivedGet([child_project_opr]) root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) root_project_opr = LogicalProject(MagicMock(), [root_filter_opr]) opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) opt_cxt, root_grp_id = self.bottom_up_rewrite(root_grp_id, opt_cxt) opt_cxt, root_grp_id = self.implement_group(root_grp_id, opt_cxt) root_grp = opt_cxt.memo.groups[root_grp_id] best_root_grp_expr = root_grp.get_best_expr(PropertyType.DEFAULT) root_opr = best_root_grp_expr.opr self.assertEqual(type(root_opr), SeqScanPlan) self.assertEqual(root_opr.predicate, root_predicate) child_grp_id = best_root_grp_expr.children[0] child_grp = opt_cxt.memo.groups[child_grp_id] best_child_grp_expr = child_grp.get_best_expr(PropertyType.DEFAULT) child_opr = best_child_grp_expr.opr self.assertEqual(type(child_opr), SeqScanPlan) self.assertEqual(child_opr.predicate, child_predicate)
def test_should_visit_select_if_nested_query(self, mock_p, mock_c, mock_d): m = MagicMock() mock_p.return_value = mock_c.return_value = mock_d.return_value = m stmt = Parser().parse(""" SELECT id FROM (SELECT data, id FROM video \ WHERE data > 2) WHERE id>3;""")[0] converter = StatementToPlanConvertor() actual_plan = converter.visit(stmt) plans = [LogicalProject([TupleValueExpression('id')])] plans.append( LogicalFilter( ComparisonExpression(ExpressionType.COMPARE_GREATER, TupleValueExpression('id'), ConstantValueExpression(3)))) plans.append(LogicalQueryDerivedGet()) plans.append( LogicalProject( [TupleValueExpression('data'), TupleValueExpression('id')])) plans.append( LogicalFilter( ComparisonExpression(ExpressionType.COMPARE_GREATER, TupleValueExpression('data'), ConstantValueExpression(2)))) plans.append(LogicalGet(TableRef(TableInfo('video')), m)) expected_plan = None for plan in reversed(plans): if expected_plan: plan.append_child(expected_plan) expected_plan = plan self.assertEqual(expected_plan, actual_plan) wrong_plan = plans[0] for plan in plans[1:]: wrong_plan.append_child(plan) self.assertNotEqual(wrong_plan, actual_plan)
def test_should_not_call_insert_generator_for_other_types( self, mock_class): PlanGenerator().build(Operator(None)) PlanGenerator().build(Operator(LogicalFilter(None))) PlanGenerator().build(Operator(LogicalGet(None, 1))) PlanGenerator().build(Operator(LogicalProject([]))) mock_class.assert_not_called()
def test_should_return_correct_plan_tree_for_input_logical_tree(self): logical_plan = LogicalProject( [1, 2], [LogicalFilter("a", [LogicalGet("video", 1)])]) plan = ScanGenerator().build(logical_plan) self.assertTrue(isinstance(plan, SeqScanPlan)) self.assertEqual("a", plan.predicate) self.assertEqual([1, 2], plan.columns) self.assertEqual(PlanNodeType.STORAGE_PLAN, plan.children[0].node_type) self.assertEqual(1, plan.children[0].video)
def test_simple_project_into_derived_get(self): rule = EmbedProjectIntoDerivedGet() target_list = MagicMock() logi_derived_get = LogicalQueryDerivedGet([Dummy()]) logi_project = LogicalProject(target_list, [logi_derived_get]) rewrite_opr = rule.apply(logi_project, MagicMock()) self.assertEqual(rewrite_opr, logi_derived_get) self.assertEqual(rewrite_opr.target_list, target_list)
def test_pushdown_project_thru_sample(self): rule = PushdownProjectThroughSample() target_list = MagicMock() constexpr = MagicMock() logi_get = LogicalGet(MagicMock(), MagicMock(), [Dummy()]) sample = LogicalSample(constexpr, [logi_get]) logi_project = LogicalProject(target_list, [sample]) rewrite_opr = rule.apply(logi_project, MagicMock()) self.assertEqual(rewrite_opr, sample) self.assertEqual(rewrite_opr.children[0].target_list, target_list)
def test_simple_project_into_get(self): rule = EmbedProjectIntoGet() expr1 = MagicMock() expr2 = MagicMock() expr3 = MagicMock() logi_get = LogicalGet(MagicMock(), MagicMock()) logi_project = LogicalProject([expr1, expr2, expr3], [logi_get]) rewrite_opr = rule.apply(logi_project, MagicMock()) self.assertEqual(rewrite_opr, logi_get) self.assertEqual(rewrite_opr.target_list, [expr1, expr2, expr3])
def test_should_visit_select_union_if_union_query(self, mock_p, mock_c, mock_d): m = MagicMock() mock_p.return_value = mock_c.return_value = mock_d.return_value = m stmt = Parser().parse(""" SELECT id FROM video WHERE id>3 UNION ALL SELECT id FROM video WHERE id<=3;""")[0] converter = StatementToPlanConvertor() actual_plan = converter.visit(stmt) left_plans = [LogicalProject([TupleValueExpression('id')])] left_plans.append( LogicalFilter( ComparisonExpression(ExpressionType.COMPARE_GREATER, TupleValueExpression('id'), ConstantValueExpression(3)))) left_plans.append(LogicalGet(TableRef(TableInfo('video')), m)) def reverse_plan(plans): return_plan = None for plan in reversed(plans): if return_plan: plan.append_child(return_plan) return_plan = plan return return_plan expect_left_plan = reverse_plan(left_plans) right_plans = [LogicalProject([TupleValueExpression('id')])] right_plans.append( LogicalFilter( ComparisonExpression(ExpressionType.COMPARE_LEQ, TupleValueExpression('id'), ConstantValueExpression(3)))) right_plans.append(LogicalGet(TableRef(TableInfo('video')), m)) expect_right_plan = reverse_plan(right_plans) expected_plan = LogicalUnion(True) expected_plan.append_child(expect_right_plan) expected_plan.append_child(expect_left_plan) self.assertEqual(expected_plan, actual_plan)
def test_nested_bottom_up_rewrite(self): child_predicate = MagicMock() root_predicate = MagicMock() child_get_opr = LogicalGet(MagicMock(), MagicMock()) child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) child_project_opr = LogicalProject(MagicMock(), [child_filter_opr]) root_derived_get_opr = LogicalQueryDerivedGet([child_project_opr]) root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) root_project_opr = LogicalProject(MagicMock(), [root_filter_opr]) opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) opt_cxt, root_grp_id = self.bottom_up_rewrite(root_grp_id, opt_cxt) grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] self.assertEqual(type(grp_expr.opr), LogicalQueryDerivedGet) self.assertEqual(len(grp_expr.opr.children), 1) self.assertEqual(grp_expr.opr.predicate, root_predicate) test_child_opr = grp_expr.opr.children[0] self.assertEqual(type(test_child_opr), LogicalGet) self.assertEqual(test_child_opr.predicate, child_predicate)
def test_nested_top_down_rewrite(self): child_predicate = MagicMock() root_predicate = MagicMock() child_get_opr = LogicalGet(MagicMock(), MagicMock()) child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) child_project_opr = LogicalProject(MagicMock(), [child_filter_opr]) root_derived_get_opr = LogicalQueryDerivedGet([child_project_opr]) root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) root_project_opr = LogicalProject(MagicMock(), [root_filter_opr]) opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] # rewrite happens in a way that new expression is # inserted in a new group self.assertEqual(type(grp_expr.opr), LogicalProject) self.assertEqual(len(grp_expr.opr.children), 1) test_child_opr = grp_expr.opr.children[0] self.assertEqual(type(test_child_opr), LogicalFilter) self.assertEqual(len(test_child_opr.children), 1) self.assertEqual(test_child_opr.predicate, root_predicate) test_child_opr = test_child_opr.children[0] self.assertEqual(type(test_child_opr), LogicalQueryDerivedGet) self.assertEqual(len(test_child_opr.children), 1) test_child_opr = test_child_opr.children[0] self.assertEqual(type(test_child_opr), LogicalProject) self.assertEqual(len(test_child_opr.children), 1) test_child_opr = test_child_opr.children[0] self.assertEqual(type(test_child_opr), LogicalGet) self.assertEqual(test_child_opr.predicate, child_predicate)
def test_visit_select_orderby(self, mock_p, mock_c, mock_d): m = MagicMock() mock_p.return_value = mock_c.return_value = mock_d.return_value = m stmt = Parser().parse(""" SELECT data, id FROM video \ WHERE data > 2 ORDER BY data, id DESC;""")[0] converter = StatementToPlanConvertor() actual_plan = converter.visit(stmt) plans = [] plans.append( LogicalOrderBy([ (TupleValueExpression('data'), ParserOrderBySortType.ASC), (TupleValueExpression('id'), ParserOrderBySortType.DESC) ])) plans.append( LogicalProject( [TupleValueExpression('data'), TupleValueExpression('id')])) plans.append( LogicalFilter( ComparisonExpression(ExpressionType.COMPARE_GREATER, TupleValueExpression('data'), ConstantValueExpression(2)))) plans.append(LogicalGet(TableRef(TableInfo('video')), m)) expected_plan = None for plan in reversed(plans): if expected_plan: plan.append_child(expected_plan) expected_plan = plan self.assertEqual(expected_plan, actual_plan) wrong_plan = plans[0] for plan in plans[1:]: wrong_plan.append_child(plan) self.assertNotEqual(wrong_plan, actual_plan)
def test_visit_select_sample(self, mock_p, mock_c, mock_d): m = MagicMock() mock_p.return_value = mock_c.return_value = mock_d.return_value = m stmt = Parser().parse(""" SELECT data, id FROM video SAMPLE 2 \ WHERE id > 2 LIMIT 3;""")[0] converter = StatementToPlanConvertor() actual_plan = converter.visit(stmt) plans = [] plans.append(LogicalLimit(ConstantValueExpression(3))) plans.append( LogicalProject( [TupleValueExpression('data'), TupleValueExpression('id')])) plans.append( LogicalFilter( ComparisonExpression(ExpressionType.COMPARE_GREATER, TupleValueExpression('id'), ConstantValueExpression(2)))) plans.append(LogicalSample(ConstantValueExpression(2))) plans.append( LogicalGet( TableRef(TableInfo('video'), ConstantValueExpression(2)), m)) expected_plan = None for plan in reversed(plans): if expected_plan: plan.append_child(expected_plan) expected_plan = plan self.assertEqual(expected_plan, actual_plan)
def test_should_return_use_scan_generator_for_logical_project(self, mock_class): mock_instance = mock_class.return_value l_project = LogicalProject([]) PlanGenerator().build(l_project) mock_instance.build.assert_called_with(l_project)
def _visit_projection(self, select_columns): # Bind the columns using catalog bind_columns_expr(select_columns, self._column_map) projection_opr = LogicalProject(select_columns) projection_opr.append_child(self._plan) self._plan = projection_opr