コード例 #1
0
ファイル: test_join.py プロジェクト: sirily11/SECYAN-GEN
    def setUp(self):
        self.a_table = Table(table_name="a",
                             columns=[
                                 Column(name="name",
                                        column_type=TypeEnum.string),
                                 Column(name="id", column_type=TypeEnum.int)
                             ],
                             data_sizes=[100],
                             data_paths=[""],
                             annotations=[])
        self.b_table = Table(table_name="b",
                             columns=[
                                 Column(name="name",
                                        column_type=TypeEnum.string),
                                 Column(name="id", column_type=TypeEnum.int)
                             ],
                             data_sizes=[100],
                             data_paths=[""],
                             annotations=[])

        self.c_table = Table(table_name="c",
                             columns=[
                                 Column(name="name",
                                        column_type=TypeEnum.string),
                                 Column(name="id", column_type=TypeEnum.int),
                                 Column(name="address",
                                        column_type=TypeEnum.string)
                             ],
                             data_sizes=[100],
                             data_paths=[""],
                             annotations=[])
コード例 #2
0
    def test_simple_join2(self):
        data = [JoinData(left_key="aa", right_key="ab"), JoinData(left_key="ec", right_key="eb")]
        table_a = Table(table_name="A",
                        columns=[Column(name="aa", column_type=TypeEnum.int),
                                 Column(name="b", column_type=TypeEnum.int),
                                 Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[Column(name="ab", column_type=TypeEnum.int),
                                 Column(name="eb", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[Column(name="ec", column_type=TypeEnum.int),
                                 Column(name="f", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        tables = [table_a, table_b, table_c]

        root = SelectNode(tables=tables, annotation_name="demo")
        root.set_identifier_list([Identifier(tokens=[Token(None, "ec")]), Identifier(tokens=[Token(None, "f")])])

        root.next = JoinNode(join_list=data, tables=tables)
        root.next.prev = root

        root.next.merge()
        result = root.next.to_code(table_a.get_root())
        self.assertTrue(len(result) > 0)
コード例 #3
0
 def setUp(self) -> None:
     self.table_a = Table(table_name="A",
                          columns=[
                              Column(name="a", column_type=TypeEnum.int),
                              Column(name="b", column_type=TypeEnum.string)
                          ],
                          owner=CharacterEnum.client,
                          data_sizes=[100],
                          data_paths=[""],
                          annotations=[])
コード例 #4
0
ファイル: test_codegendb.py プロジェクト: sirily11/SECYAN-GEN
    def test_join(self):
        tables = [Table.load_from_json(t) for t in TEST_CONFIG]
        plan = PostgresDBPlan.from_json(TEST_DB_PLAN[0]["Plan"], tables=tables)
        plan.perform_join(is_free_connex_table=self.is_free_connex_table)

        order_table = list(filter(lambda t: t.variable_table_name == "orders", tables))[0]
        lineitem_table = list(filter(lambda t: t.variable_table_name == "lineitem", tables))[0]
        customer_table = list(filter(lambda t: t.variable_table_name == "customer", tables))[0]
コード例 #5
0
ファイル: test_join.py プロジェクト: sirily11/SECYAN-GEN
    def test_get_aggregate_columns2(self):
        table_a = Table(table_name="A",
                        columns=[
                            Column(name="aa", column_type=TypeEnum.int),
                            Column(name="b", column_type=TypeEnum.int),
                            Column(name="c", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[
                            Column(name="ba", column_type=TypeEnum.int),
                            Column(name="e", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_a.join(to_table=table_b, from_table_key="aa", to_table_key="ba")

        column_names = table_a.column_names
        self.assertEqual(len(column_names), 4)

        agg = table_b.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, "ba")
コード例 #6
0
    def test_simple_join1(self):
        data = [JoinData(left_key="aa", right_key="ba")]
        table_a = Table(table_name="A",
                        columns=[Column(name="aa", column_type=TypeEnum.int),
                                 Column(name="b", column_type=TypeEnum.int),
                                 Column(name="c", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[Column(name="ba", column_type=TypeEnum.int),
                                 Column(name="e", column_type=TypeEnum.int)], data_sizes=[100], data_paths=[""],
                        annotations=[])

        root = SelectNode(tables=[table_a, table_b], annotation_name="demo")
        root.set_identifier_list([Identifier(tokens=[Token(None, "b")]), Identifier(tokens=[Token(None, "c")])])

        root.next = JoinNode(join_list=data, tables=[table_a, table_b])
        root.next.prev = root

        root.next.merge()
        result = root.next.to_code(table_a.get_root())
        self.assertTrue('a.Aggregate({ "aa" });' in result[0])
コード例 #7
0
    def test_equal(self):
        column1 = Column(name="a", column_type=TypeEnum.int)
        column2 = Column(name="b", column_type=TypeEnum.int)

        table_1 = Table(columns=[column1],
                        table_name="1",
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])
        table_2 = Table(columns=[column2],
                        table_name="2",
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        column1 = table_1.original_column_names[0]
        column2 = table_2.original_column_names[0]

        column1.related_columns.append(column2)
        column2.related_columns.append(column1)

        self.assertTrue(column1 == column2)
コード例 #8
0
ファイル: test_join.py プロジェクト: sirily11/SECYAN-GEN
    def test_get_aggregate_columns3(self):
        table_a = Table(table_name="A",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="b", column_type=TypeEnum.int),
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="c", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[
                            Column(name="b", column_type=TypeEnum.int),
                            Column(name="d", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_a.join(table_b, 'a', 'a')
        table_a.join(table_c, 'b', 'b')

        agg = table_b.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, 'a')

        agg = table_c.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, "b")
コード例 #9
0
class TestSelect(QueryTestCase):
    def setUp(self) -> None:
        self.table_a = Table(table_name="A",
                             columns=[
                                 Column(name="a", column_type=TypeEnum.int),
                                 Column(name="b", column_type=TypeEnum.string)
                             ],
                             owner=CharacterEnum.client,
                             data_sizes=[100],
                             data_paths=[""],
                             annotations=[])

    def test_simple_select(self):
        select_node = SelectNode(tables=[self.table_a], annotation_name="demo")
        select_node.from_tables = [
            Identifier(tokens=[Token("int", "A")]),
        ]

        code = select_node.to_code(self.table_a.get_root())
        print(code)
        self.assert_content_in_arr(code, "Relation a(a_ri, a_ai);")
        self.assert_content_in_arr(code, "CLIENT")
        self.assert_content_in_arr(code, "{ Relation::INT,Relation::STRING }")

    def test_select_with_aggregation(self):
        sql = """select sum(a) from A"""
        tokens = sqlparse.parse(sql)[0].tokens
        select_node = SelectNode(tables=[self.table_a], annotation_name="demo")

        select_node.from_tables = tokens[6]
        select_node.set_identifier_list([tokens[2]])
        select_node.merge()
        self.assertFalse(select_node.tables[0].is_bool)

    def test_select_with_aggregation2(self):
        sql = """select sum(a) as re from A"""
        tokens = sqlparse.parse(sql)[0].tokens
        select_node = SelectNode(tables=[self.table_a], annotation_name="demo")

        select_node.from_tables = tokens[6]
        select_node.set_identifier_list([tokens[2]])
        select_node.merge()
        self.assertFalse(select_node.tables[0].is_bool)
コード例 #10
0
ファイル: test_load_json.py プロジェクト: sirily11/SECYAN-GEN
    def test_simple_load_table1(self):
        json_content = {
            "table_name":
            "A",
            "owner":
            "client",
            "data_sizes": [100],
            "data_paths": [""],
            "annotations": [],
            "columns": [{
                "column_type": "int",
                "name": "a"
            }, {
                "column_type": "int",
                "name": "b"
            }, {
                "column_type": "string",
                "name": "c"
            }]
        }

        table = Table.load_from_json(json_content=json_content)
        self.assertEqual("a", table.variable_table_name)
        self.assertEqual(CharacterEnum.client, table.owner)

        self.assertEqual(table.original_column_names[0].column_type,
                         TypeEnum.int)
        self.assertEqual(table.original_column_names[0].name, "a")

        self.assertEqual(table.original_column_names[1].column_type,
                         TypeEnum.int)
        self.assertEqual(table.original_column_names[1].name, "b")

        self.assertEqual(table.original_column_names[2].column_type,
                         TypeEnum.string)
        self.assertEqual(table.original_column_names[2].name, "c")
コード例 #11
0
   LINEITEM
where
   c_custkey = o_custkey
   and l_orderkey = o_orderkey
   and o_orderdate >= date '1993-08-01'
   and o_orderdate < date '1993-08-01' + interval '3' month
   and l_returnflag = 'R'
 group by
   c_custkey,
   c_name,
   c_nationkey
order by
   revenue desc
limit
   20;
"""

if __name__ == '__main__':
    password = getenv('password')
    user = getenv('user')
    database = "tpch"
    host = "localhost"
    port = "5432"
    with open("./table_config.json", 'r') as f:
        tables = [Table.load_from_json(t) for t in json.load(f)]
        driver = PostgresDBDriver(password=password, user=user, database_name=database, host=host, port=port,
                                  tables=tables)

        codegen = CodeGenDB(sql=sql, db_driver=driver, tables=tables)
        codegen.parse().to_file("db-test.cpp")
コード例 #12
0
ファイル: test_table.py プロジェクト: sirily11/SECYAN-GEN
    def test_table_height(self):
        """
        A - B -C
        """
        table_a = Table(table_name="A",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="b", column_type=TypeEnum.int),
                            Column(name="c", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="e", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[
                            Column(name="e", column_type=TypeEnum.int),
                            Column(name="f", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_a.join(table_b, "a", "a")
        table_b.join(table_c, "e", "e")

        self.assertEqual(table_a.get_height(), 2)
        self.assertEqual(table_b.get_height(), 1)
        self.assertEqual(table_c.get_height(), 0)
コード例 #13
0
ファイル: test_join.py プロジェクト: sirily11/SECYAN-GEN
class JoinTest(unittest.TestCase):
    """
    Test join on tables
    """
    def setUp(self):
        self.a_table = Table(table_name="a",
                             columns=[
                                 Column(name="name",
                                        column_type=TypeEnum.string),
                                 Column(name="id", column_type=TypeEnum.int)
                             ],
                             data_sizes=[100],
                             data_paths=[""],
                             annotations=[])
        self.b_table = Table(table_name="b",
                             columns=[
                                 Column(name="name",
                                        column_type=TypeEnum.string),
                                 Column(name="id", column_type=TypeEnum.int)
                             ],
                             data_sizes=[100],
                             data_paths=[""],
                             annotations=[])

        self.c_table = Table(table_name="c",
                             columns=[
                                 Column(name="name",
                                        column_type=TypeEnum.string),
                                 Column(name="id", column_type=TypeEnum.int),
                                 Column(name="address",
                                        column_type=TypeEnum.string)
                             ],
                             data_sizes=[100],
                             data_paths=[""],
                             annotations=[])

    def test_simple_join(self):
        self.a_table.join(self.b_table, "id", "id")
        self.assertEqual(len(self.a_table.children), 1)
        column_names = self.a_table.column_names
        self.assertEqual(len(column_names), 3)

    def test_simple_join_2(self):
        self.b_table.join(self.c_table, "id", "id")
        column_names = self.b_table.column_names
        self.assertEqual(len(column_names), 4)
        expected_names = ["c.name", "b.id", "c.address", "b.name"]
        for c in column_names:
            self.assertTrue(c.name_with_table in expected_names)

    def test_simple_join_3(self):
        self.a_table.join(self.b_table, "id", "id")
        self.b_table.join(self.c_table, "id", "id")
        column_names = self.a_table.column_names
        self.assertEqual(len(column_names), 5)
        expected_names = ["a.name", "a.id", "b.name", "c.name", "c.address"]
        for c in column_names:
            self.assertTrue(c.name_with_table in expected_names)

    def test_simple_join_with_error(self):
        self.assertRaises(RuntimeError, self.a_table.join, self.b_table, "id",
                          "abc")

    def test_get_aggregate_columns(self):
        table_a = Table(table_name="A",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="b", column_type=TypeEnum.int),
                            Column(name="c", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="e", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[
                            Column(name="e", column_type=TypeEnum.int),
                            Column(name="f", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_a.join(table_b, "a", "a")
        table_c.join(table_a, "e", "e")

        column_names = table_a.column_names
        self.assertEqual(len(column_names), 4)

        agg = table_b.get_aggregate_columns()
        self.assertEqual(2, len(agg))
        self.assertEqual(agg[0].name, "a")
        self.assertEqual(agg[1].name, "e")

        agg = table_a.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, "e")

        agg = table_c.get_aggregate_columns()
        self.assertEqual(0, len(agg))

    def test_get_aggregate_columns2(self):
        table_a = Table(table_name="A",
                        columns=[
                            Column(name="aa", column_type=TypeEnum.int),
                            Column(name="b", column_type=TypeEnum.int),
                            Column(name="c", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[
                            Column(name="ba", column_type=TypeEnum.int),
                            Column(name="e", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_a.join(to_table=table_b, from_table_key="aa", to_table_key="ba")

        column_names = table_a.column_names
        self.assertEqual(len(column_names), 4)

        agg = table_b.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, "ba")

    def test_get_aggregate_columns3(self):
        table_a = Table(table_name="A",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="b", column_type=TypeEnum.int),
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="c", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[
                            Column(name="b", column_type=TypeEnum.int),
                            Column(name="d", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_a.join(table_b, 'a', 'a')
        table_a.join(table_c, 'b', 'b')

        agg = table_b.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, 'a')

        agg = table_c.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, "b")
コード例 #14
0
ファイル: test_join.py プロジェクト: sirily11/SECYAN-GEN
    def test_get_aggregate_columns(self):
        table_a = Table(table_name="A",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="b", column_type=TypeEnum.int),
                            Column(name="c", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_b = Table(table_name="B",
                        columns=[
                            Column(name="a", column_type=TypeEnum.int),
                            Column(name="e", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_c = Table(table_name="C",
                        columns=[
                            Column(name="e", column_type=TypeEnum.int),
                            Column(name="f", column_type=TypeEnum.int)
                        ],
                        data_sizes=[100],
                        data_paths=[""],
                        annotations=[])

        table_a.join(table_b, "a", "a")
        table_c.join(table_a, "e", "e")

        column_names = table_a.column_names
        self.assertEqual(len(column_names), 4)

        agg = table_b.get_aggregate_columns()
        self.assertEqual(2, len(agg))
        self.assertEqual(agg[0].name, "a")
        self.assertEqual(agg[1].name, "e")

        agg = table_a.get_aggregate_columns()
        self.assertEqual(1, len(agg))
        self.assertEqual(agg[0].name, "e")

        agg = table_c.get_aggregate_columns()
        self.assertEqual(0, len(agg))
コード例 #15
0
ファイル: JoinNode.py プロジェクト: sirily11/SECYAN-GEN
    def __to_code_util__(self,
                         root: Table,
                         from_key=None,
                         to_key=None) -> List[str]:
        """
        Do a post-order tree Traversal to generate code
        :param root: current table
        :param from_key: join key. From table's column name
        :param to_key: join key. To table's column name
        :return: list of generated code
        """
        code = []
        template = Template(self.open_template_file("join.template.j2"))
        for child in root.children:
            code += self.__to_code_util__(child.to_table, child.from_table_key,
                                          child.to_table_key)

        should_aggregate = False
        should_join = False

        if root.parent:
            # If has parent, then do the join.
            # If the number of agg is greater than 0, then do the aggregation

            # if root.parent.owner == root.owner:
            #     # TODO: Remove this error when the original code changed
            #     raise RuntimeError("Cannot semi join by the same owner")
            agg = root.get_aggregate_columns()
            agg = self.remove_duplicates(agg)
            should_join = True
            should_aggregate = len(agg) > 0

            rendered = template.render(left_table=root.parent,
                                       right_table=root,
                                       aggregate=agg,
                                       left=from_key,
                                       right=to_key,
                                       should_aggregate=should_aggregate,
                                       should_join=should_join)

            code += rendered.split("\n")

        else:
            group_by = self.__get_group_by__()
            select = self.__get_select__()
            selections = []
            is_group_by = False
            if group_by:
                selections = [i.normalized for i in group_by.identifier_list]
                is_group_by = True
            elif select:
                selections = [i.normalized for i in select.identifier_list]
            else:
                raise SyntaxError("SQL Statement should have select statement")

            columns = root.get_columns_after_aggregate()
            new_selections = self.__preprocess_selection__(
                selections=selections, columns=columns)
            agg = [
                Column(name=s, column_type=TypeEnum.int)
                for s in new_selections
            ]

            agg = self.remove_duplicates(agg)

            should_join = False
            should_aggregate = len(agg) > 0

            rendered = template.render(left_table=root.parent,
                                       right_table=root,
                                       aggregate=agg,
                                       left=from_key,
                                       right=to_key,
                                       should_aggregate=should_aggregate,
                                       should_join=should_join,
                                       reveal_table=root,
                                       should_reveal=True,
                                       is_group_by=is_group_by)
            code += rendered.split("\n")

        return code