def test_nested_top_down_rewrite(self): child_predicate = MagicMock() root_predicate = MagicMock() child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) child_project_opr = LogicalProject([MagicMock()], [child_filter_opr]) root_derived_get_opr = LogicalQueryDerivedGet( MagicMock(), children=[child_project_opr]) root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) root_project_opr = LogicalProject([MagicMock()], [root_filter_opr]) opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] self.assertEqual(type(grp_expr.opr), LogicalProject) self.assertEqual(len(grp_expr.children), 1) child_grp_id = grp_expr.children[0] child_expr = opt_cxt.memo.groups[child_grp_id].logical_exprs[0] self.assertEqual(type(child_expr.opr), LogicalQueryDerivedGet) self.assertEqual(child_expr.opr.predicate, root_predicate) self.assertEqual(len(child_expr.children), 1) child_grp_id = child_expr.children[0] child_expr = opt_cxt.memo.groups[child_grp_id].logical_exprs[0] self.assertEqual(type(child_expr.opr), LogicalGet) self.assertEqual(child_expr.opr.predicate, child_predicate)
def test_nested_implementation(self): child_predicate = MagicMock() root_predicate = MagicMock() child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) child_filter_opr = LogicalFilter(child_predicate, children=[child_get_opr]) child_project_opr = LogicalProject([MagicMock()], children=[child_filter_opr]) root_derived_get_opr = LogicalQueryDerivedGet( MagicMock(), children=[child_project_opr]) root_filter_opr = LogicalFilter(root_predicate, children=[root_derived_get_opr]) root_project_opr = LogicalProject([MagicMock()], children=[root_filter_opr]) opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) opt_cxt, root_grp_id = self.bottom_up_rewrite(root_grp_id, opt_cxt) opt_cxt, root_grp_id = self.implement_group(root_grp_id, opt_cxt) root_grp = opt_cxt.memo.groups[root_grp_id] best_root_grp_expr = root_grp.get_best_expr(PropertyType.DEFAULT) root_opr = best_root_grp_expr.opr self.assertEqual(type(root_opr), SeqScanPlan) self.assertEqual(root_opr.predicate, root_predicate) child_grp_id = best_root_grp_expr.children[0] child_grp = opt_cxt.memo.groups[child_grp_id] best_child_grp_expr = child_grp.get_best_expr(PropertyType.DEFAULT) child_opr = best_child_grp_expr.opr self.assertEqual(type(child_opr), SeqScanPlan) self.assertEqual(child_opr.predicate, child_predicate)
def apply(self, before: LogicalProject, context: OptimizerContext): sample = before.children[0] logical_get = sample.children[0] new_project = LogicalProject(before.target_list) new_project.append_child(logical_get) sample.clear_children() sample.append_child(new_project) return sample
def test_simple_project_into_derived_get(self): rule = EmbedProjectIntoDerivedGet() target_list = MagicMock() logi_derived_get = LogicalQueryDerivedGet(MagicMock()) logi_project = LogicalProject(target_list, [logi_derived_get]) rewrite_opr = rule.apply(logi_project, MagicMock()) self.assertFalse(rewrite_opr is logi_derived_get) self.assertEqual(rewrite_opr.target_list, target_list)
def PushdownProjectThroughJoin(self): rule = EmbedProjectIntoGet() expr1 = MagicMock() expr2 = MagicMock() expr3 = MagicMock() logi_join = LogicalJoin(MagicMock()) logi_project = LogicalProject([expr1, expr2, expr3], [logi_join]) rewrite_opr = rule.apply(logi_project, MagicMock()) self.assertEqual(rewrite_opr, logi_join) self.assertEqual(rewrite_opr.target_list, [expr1, expr2, expr3])
def test_simple_project_into_get(self): rule = EmbedProjectIntoGet() expr1 = MagicMock() expr2 = MagicMock() expr3 = MagicMock() logi_get = LogicalGet(MagicMock(), MagicMock(), MagicMock()) logi_project = LogicalProject([expr1, expr2, expr3], [logi_get]) rewrite_opr = rule.apply(logi_project, MagicMock()) self.assertFalse(rewrite_opr is logi_get) self.assertEqual(rewrite_opr.target_list, [expr1, expr2, expr3])
def test_pushdown_project_thru_sample(self): rule = PushdownProjectThroughSample() target_list = MagicMock() constexpr = MagicMock() logi_get = LogicalGet(MagicMock(), MagicMock(), MagicMock()) sample = LogicalSample(constexpr, [logi_get]) logi_project = LogicalProject(target_list, [sample]) rewrite_opr = rule.apply(logi_project, MagicMock()) self.assertTrue(rewrite_opr is sample) self.assertFalse(rewrite_opr.children[0] is logi_project) self.assertTrue(logi_get is rewrite_opr.children[0].children[0]) self.assertEqual(rewrite_opr.children[0].target_list, target_list)