def test_valid_database(self):
        """Tests a good database to make sure there are no false positives."""
        good_db = Database(database_name="good_db")
        table1 = Table(
            table_name="table1",
            primary_key=["col1", "col2"],
            shard_key=ShardKey(shard_keys="col1", number_shards=128),
        )
        table1.add_column(Column(column_name="col1", column_type="INT"))
        table1.add_column(Column(column_name="col2", column_type="INT"))
        good_db.add_table(table1)

        table2 = Table(
            table_name="table2",
            primary_key=["col3", "col4"],
            shard_key=ShardKey(shard_keys="col3", number_shards=128),
        )
        table2.add_foreign_key(
            from_keys=["col3", "col4"],
            to_table="table1",
            to_keys=["col1", "col2"],
        )
        table2.add_column(Column(column_name="col3", column_type="INT"))
        table2.add_column(Column(column_name="col4", column_type="INT"))
        good_db.add_table(table2)

        dv = DatabaseValidator(good_db)
        results = dv.validate()
        self.assertTrue(results.is_valid)
        self.assertEqual([], results.issues)
Example #2
0
    def get_complex_db():
        """
        Returns a more complex database with two tables and keys for testing.
        :return: Database with two tables and keys.
        :rtype: Database
        """
        database = Database(database_name="database2")
        table1 = Table(table_name="table1",
                       primary_key="col1",
                       shard_key=ShardKey("col1", 128))
        table1.add_column(Column(column_name="col1", column_type="INT"))
        table1.add_column(Column(column_name="Col2", column_type="DOUBLE"))
        table1.add_column(Column(column_name="COL3", column_type="FLOAT"))
        database.add_table(table1)

        table2 = Table(table_name="table2",
                       primary_key=["col4", "Col5"],
                       shard_key=ShardKey(["col4", "Col5"], 96))
        table2.add_column(Column(column_name="col4", column_type="VARCHAR(0)"))
        table2.add_column(Column(column_name="Col5", column_type="DATE"))
        table2.add_column(Column(column_name="COL6", column_type="BOOL"))
        database.add_table(table2)

        table2.add_foreign_key(from_keys="Col5",
                               to_table="table1",
                               to_keys="COL3")
        table1.add_relationship(
            to_table="table2",
            conditions='("table1"."col1" == "table2."COL6")')

        return database
Example #3
0
    def test_create_excel(self):
        """Test writing to Excel.  Only test is existance.  Checks shoudl be made for validity."""
        database = Database(database_name="xdb")

        table = Table(table_name="table1",
                      schema_name="s1",
                      primary_key="column_1",
                      shard_key=ShardKey("column_1", 128))
        table.add_column(Column(column_name="column_1", column_type="INT"))
        table.add_column(Column(column_name="column_2", column_type="DOUBLE"))
        table.add_column(Column(column_name="column_3", column_type="FLOAT"))
        database.add_table(table)

        table = Table(table_name="table2",
                      schema_name="s1",
                      primary_key="column_1")
        table.add_column(Column(column_name="column_1", column_type="INT"))
        table.add_column(Column(column_name="column_2",
                                column_type="DATETIME"))
        table.add_column(Column(column_name="column_3", column_type="BOOL"))
        table.add_column(Column(column_name="column_4", column_type="DOUBLE"))
        table.add_foreign_key(from_keys="column_1",
                              to_table="table_1",
                              to_keys="column_1")
        table.add_relationship(to_table="table1",
                               conditions="table2.column_4 = table1.column_2")
        database.add_table(table)

        writer = XLSWriter()
        writer.write_database(database, "test_excel")
    def test_change_column(self):
        """Tests adding / dropping a column from a table."""
        dc = DDLCompare()

        db1 = Database(database_name="database1")
        db2 = Database(database_name="database2")

        t1 = Table(table_name="table1")
        t1.add_column(column=Column(column_name="column1", column_type="INT"))
        db1.add_table(t1)

        t2 = Table(table_name="table1")
        t2.add_column(column=Column(column_name="column1", column_type="FLOAT"))
        db2.add_table(t2)

        diff1, diff2 = dc.compare_databases(db1=db1, db2=db2)

        self.assertTrue(type(diff1[0] is ColumnModifiedDifference))
        self.assertEqual(diff1[0].table_name, "table1")
        self.assertEqual(diff1[0].column.column_name, "column1")
        self.assertEqual(diff1[0].column.column_type, "FLOAT")

        self.assertTrue(type(diff2[0] is ColumnModifiedDifference))
        self.assertEqual(diff2[0].table_name, "table1")
        self.assertEqual(diff2[0].column.column_name, "column1")
        self.assertEqual(diff2[0].column.column_type, "INT")
Example #5
0
    def _read_tables_from_workbook(self):
        """
        Reads the databases and tables from Excel.  These are used to populate from the remaining sheets.
        """

        # "Tables":        ["Database", "Schema", "Table", "Updated", "Update Type", "# Rows", "# Columns",
        #                   "Primary Key", "Shard Key", "# Shards", "RLS Column"],
        table_sheet = self.workbook.sheet_by_name("Tables")
        indices = self.indices["Tables"]

        for row_count in range(1, table_sheet.nrows):
            row = table_sheet.row_values(rowx=row_count, start_colx=0)

            database_name = row[indices["Database"]]
            database = self.databases.get(database_name, None)
            if database is None:
                database = Database(database_name=database_name)
                self.databases[database_name] = database

            pk = row[indices["Primary Key"]].strip()
            if pk == "":
                pk = None
            else:
                pk = [x.strip() for x in pk.split(",")]

            sk_name = row[indices["Shard Key"]].strip()
            sk_nbr_shards = row[indices["# Shards"]]

            if (sk_name == "" and sk_nbr_shards != "") or (
                sk_name != "" and sk_nbr_shards == ""
            ):
                eprint(
                    "ERROR:  %s need to provide both a shard key name and number of shards."
                    % row[indices["Table"]]
                )

            if sk_name == "":
                sk = None
            else:
                sk = [x.strip() for x in sk_name.split(",")]

            shard_key = None
            if sk_name != "" and sk_nbr_shards != "":
                shard_key = ShardKey(
                    shard_keys=sk, number_shards=sk_nbr_shards
                )

            table = Table(
                table_name=row[indices["Table"]],
                schema_name=row[indices["Schema"]],
                primary_key=pk,
                shard_key=None,
            )
            database.add_table(table)
Example #6
0
    def get_simple_db():
        """
        Returns a simple database with one table for testing.
        :return: Database with one table.
        :rtype: Database
        """
        database = Database(database_name="database1")
        table = Table(table_name="table1")
        table.add_column(Column(column_name="col1", column_type="INT"))
        table.add_column(Column(column_name="Col2", column_type="DOUBLE"))
        table.add_column(Column(column_name="COL3", column_type="FLOAT"))
        database.add_table(table)

        return database
    def test_new_and_drop_table(self):
        """Tests adding / dropping a table.  It's added in one and dropped in the other."""
        dc = DDLCompare()

        db1 = Database(database_name="database1")
        db2 = Database(database_name="database2")
        db1.add_table(Table(table_name="table_from_1"))

        diff1, diff2 = dc.compare_databases(db1=db1, db2=db2)

        self.assertTrue(type(diff1[0]) is TableDroppedDifference)
        self.assertEqual(diff1[0].table_name, "table_from_1")
        self.assertTrue(type(diff2[0]) is TableCreatedDifference)
        self.assertEqual(diff2[0].table_name, "table_from_1")
    def test_add_and_drop_fk(self):
        """Tests adding / dropping a column from a table."""
        dc = DDLCompare()

        db1 = Database(database_name="database1")
        db2 = Database(database_name="database2")

        t1 = Table(table_name="table1", primary_key="column1")
        db1.add_table(t1)

        t2 = Table(table_name="table1")
        db2.add_table(t2)

        diff1, diff2 = dc.compare_databases(db1=db1, db2=db2)

        self.assertTrue(type(diff1[0] is PrimaryKeyDroppedDifference))
        self.assertEqual(diff1[0].table_name, "table1")

        self.assertTrue(type(diff2[0] is PrimaryKeyAddedDifference))
        self.assertEqual(diff2[0].table_name, "table1")
    def test_add_and_drop_rel(self):
        """Tests adding / dropping a column from a table."""
        dc = DDLCompare()

        db1 = Database(database_name="database1")
        db2 = Database(database_name="database2")

        t1 = Table(table_name="table1")
        t1.add_relationship(relationship=GenericRelationship(from_table="table_1", to_table="table_2",
                                                             conditions="table1.col1 = table2.col2"))
        db1.add_table(t1)

        t2 = Table(table_name="table1")
        db2.add_table(t2)

        diff1, diff2 = dc.compare_databases(db1=db1, db2=db2)

        self.assertTrue(type(diff1[0] is GenericRelationshipDroppedDifference))
        self.assertEqual(diff1[0].table_name, "table1")

        self.assertTrue(type(diff2[0] is GenericRelationshipAddedDifference))
        self.assertEqual(diff2[0].table_name, "table1")
    def test_with_csvfile(self):
        """test the tsload writer when the csv exists"""
        # todo Create the csv file.
        database = Database(database_name="xdb")

        table = Table(
            table_name="table1",
            schema_name="s1",
            primary_key="column_1",
            shard_key=ShardKey("column_1", 128),
        )
        table.add_column(Column(column_name="column_1", column_type="INT"))
        table.add_column(Column(column_name="column_2", column_type="DOUBLE"))
        table.add_column(Column(column_name="column_3", column_type="FLOAT"))
        table.add_column(Column(column_name="column_3", column_type="DATE"))
        database.add_table(table)

        table = Table(
            table_name="table2",
            schema_name="s1",
            primary_key="column_1",
            shard_key=ShardKey("column_1", 128),
        )
        table.add_column(Column(column_name="column_1", column_type="INT"))
        table.add_column(Column(column_name="column_2", column_type="FLOAT"))
        table.add_column(Column(column_name="column_3", column_type="DOUBLE"))
        database.add_table(table)

        table = Table(
            table_name="table3",
            schema_name="s1",
            primary_key="column_1",
            shard_key=ShardKey("column_1", 128),
        )
        table.add_column(Column(column_name="column_1", column_type="INT"))
        table.add_column(Column(column_name="column_2", column_type="FLOAT"))
        table.add_column(Column(column_name="column_3", column_type="VARCHAR"))
        database.add_table(table)

        tsload_writer = TsloadWriter()
        tsload_writer.write_tsloadcommand(database, "tsloadwriter_test")
        with open("tsloadwriter_test", "r") as infile:
            line = infile.readline()
            self.assertTrue(line.startswith("tsload "))
            self.assertTrue(line.index('--target_database "xdb"') > 0)
            self.assertTrue(line.index('--target_schema "s1"'))
 def create_test_database():
     """Creates a database for testing."""
     database = Database("database1")
     database.add_table(Table(table_name="table1", schema_name="schema1"))
     database.add_table(Table(table_name="table2"))
     return database
Example #12
0
class DDLParser(object):
    """
    Parses DDL from various formats and creates a DataModel object that can be used for writing data.
    The following assumptions are made about the DDL being read:
    * CREATE TABLE occur together on a single line, not split across lines.
    * CREATE TABLE statements will not occur inside of a comment block.
    * Delimiters, such as commas, will not be part of the table or column name.
    * Comment characters, such as #, --, or /* */ will not be part of a column name.
    * CREATE TABLE will have (....) with no embedded, unbalanced parentheses.
    
    """

    # TODO: capture primary keys, foreign keys, and relationships.

    def __init__(
        self,
        database_name,
        schema_name=DatamodelConstants.DEFAULT_SCHEMA,
        parse_keys=False,
    ):
        """
        Creates a new DDL parser.
        :param database_name: Name of the database to create.
        :type database_name: str
        :param schema_name: Name of the schema if not using the default.
        :param parse_keys: If true, the parser will attempt to parse keys as well. 
        :type parse_keys: bool
        :type schema_name: str
        """
        self.schema_name = schema_name
        self.database = Database(database_name=database_name)
        self.parse_keys = parse_keys

    def parse_ddl(self, filename):
        """
        Parsed DDL from a stream and returns a populated Database.
        :param filename: Name of the file to read from.
        :return: A Database object.
        :rtype: Database
        """

        if filename is None:
            stream = open(sys.stdin, "r")
        else:
            stream = open(filename, "r")

        # First read the entire input into memory.  This will allow multiple passes through the data.
        input_ddl = []
        for line in stream:
            input_ddl.append(line)

        self.parse_tables(input_ddl)
        if self.parse_keys:
            self.parse_primary_keys(input_ddl)

        return self.database

    def parse_tables(self, input_ddl):
        """
        Parses the input DDL to convert to a database model.
        :param input_ddl: The DDL to convert.
        :type input_ddl: list of str
        """
        creating = False
        buff = ""
        for line in input_ddl:
            l = self.clean_line(line)

            if not creating:  # looking for CREATE TABLE statement.
                if l.lower().find("create table") >= 0:
                    creating = True
                    buff = l
                    if self.is_complete_create(buff):
                        self.parse_create_table(buff)
                        buff = ""
                        creating = False
            else:  # looking for the end of a create table.
                buff += l
                if self.is_complete_create(buff):
                    self.parse_create_table(buff)
                    buff = ""
                    creating = False

    @staticmethod
    def is_complete_create(buff):
        """
        Returns true if the number of open and close parentheses match.
        :param buff: The buffer being read.
        :return: str
        """
        nbr_open = buff.count("(")
        nbr_close = buff.count(")")
        return nbr_open > 0 and nbr_open == nbr_close

    def parse_create_table(self, buff):
        """
        Parses a create table statement.
        :param buff: The buffer read in.
        :type buff: str
        :return: 
        """
        buff = buff.replace("[", '"').replace(
            "]", '"'
        )  # for SQL Server quotes
        table_name = self.get_table_name(buff)
        columns = self.get_columns(buff)

        table = Table(table_name=table_name, schema_name=self.schema_name)
        table.add_columns(columns)
        self.database.add_table(table)

    def get_table_name(self, buff):
        """
        Gets the table name from the buffer.
        :param buff: The line with the create details.
        :type buff: str
        :return: The name of the table.
        :rtype: str
        """
        # The table name (and maybe a schema) are before the opening (
        tn = buff[0:buff.find("(")].rstrip()
        # split on spaces and assume last one is table name (and maybe schema)
        tn = tn.split(" ")[-1]
        tn = tn.split(".")[-1]
        tn = self.strip_quotes(tn)
        return tn

    @staticmethod
    def strip_quotes(line):
        """
        Strips off any quotes in the given line.
        :param line: The line to strip quotes from.
        :type line: str
        :return: The line without quotes.
        :rtype: str
        """
        return line.replace("'", "").replace("`", "").replace('"', "")

    def get_columns(self, buff):
        """
        Get the columns from the table statement.
        :param buff: The buffer with the create details.
        :type buff: str
        :return: A list of Columns
        :rtype: list
        """
        # The fields will be between the ( ).
        columns = []
        buff = buff[buff.find("(") + 1:buff.rfind(")")].strip()

        # think all DBs use commas for field separators
        # need to find the commas that are not inside of parents.
        field_buff = ""
        open_paren = False
        raw_fields = []

        for c in buff:

            if open_paren:
                field_buff += c
                if c == ")":
                    open_paren = False
            elif c == "(":
                field_buff += c
                open_paren = True
            else:
                if c == ",":
                    raw_fields.append(field_buff)
                    field_buff = ""
                else:
                    field_buff += c

        if field_buff != "":
            raw_fields.append(field_buff)

        for rf in raw_fields:
            rfl = rf.lower()
            # ignore key declarations.
            if "key " in rfl:
                continue

            had_quote = False
            if rfl[0] in "\"'`":  # should be a quote or letter
                had_quote = True
                name = rf[1:rf.find(rf[0], 1)]
            else:
                name = rf[0:rf.find(" ")]

            # The type comes after the name and goes up to the first of a
            #   space, close paren, or comma.  Assuming no space in type.
            start_idx = len(name) + (
                3 if had_quote else 1
            )  # extra 1 for space
            if rfl.find(")") > 0:  # type with ()
                data_type = rf[start_idx:rf.find(")") + 1]
            else:
                # either next space or comma.
                space_end_idx = rf.find(" ", start_idx)
                comma_end_idx = rf.find(",", start_idx)
                if space_end_idx == -1:  # not found
                    if comma_end_idx == -1:  # neither found
                        end_idx = len(rf)  # end of line
                    else:
                        end_idx = comma_end_idx
                elif comma_end_idx == -1:
                    end_idx = space_end_idx
                else:
                    end_idx = min(space_end_idx, comma_end_idx)
                data_type = rf[start_idx:end_idx]

            # print ("  adding %s as %s" % (name, data_type))
            columns.append(
                Column(
                    column_name=name, column_type=self.convert_type(data_type)
                )
            )

        return columns

    @staticmethod
    def convert_type(data_type):
        """
        Converts data types from other databases to ThoughtSpot types.
        :param data_type:  The datatype to convert.
        :type data_type: str
        :return: A ThoughtSpot data type.
        :rtype: str
        """
        if ")" in data_type:
            t = data_type[0:data_type.find(")") + 1]
        elif " " in data_type:
            t = data_type[0:data_type.find(" ") + 1]
        else:
            t = data_type

        t = t.lower()

        if "int" in t:
            new_t = "BIGINT"
        elif "rowversion" in t:  # MS type
            new_t = "INT"
        elif "uniqueidentifier" in t:  # Oracle type
            new_t = "VARCHAR(0)"
        elif "serial" in t:  # serial index, Oracle and others
            new_t = "INT"
        elif "bit" in t:
            new_t = "BOOL"
        elif "blob" in t or "binary" in t:
            new_t = "UNKNOWN"
        elif "number" in t:  # support for NUMBER(1), NUMBER(1,1)
            if ")" in t:
                numsize = t[t.find("(") + 1:t.find(")")]
                if "," in numsize:
                    first_num, second_num = numsize.split(",")
                    if second_num.strip() == "0":
                        if first_num == "*" or int(
                            first_num
                        ) > 9:  # Support Oracle Number(*,n)
                            new_t = "BIGINT"
                        else:
                            new_t = "INT"
                    else:
                        new_t = "DOUBLE"
                else:
                    new_t = "INT"
            else:
                new_t = "BIGINT"
        elif "decimal" in t or "numeric" in t or "float" in t or "double" in t or "money" in t or "real" in t:
            new_t = "DOUBLE"
        elif "datetime" in t:
            new_t = "DATETIME"
        elif "timestamp" in t:
            new_t = "DATETIME"
        elif "time" in t:
            new_t = "TIME"
        elif "date" in t:
            new_t = "DATE"
        elif "bool" in t:
            new_t = "BOOL"
        elif "text" in t:
            new_t = "VARCHAR(0)"
        elif "long" in t:  # Oracle variable type
            new_t = "VARCHAR(0)"
        elif "enum" in t:
            new_t = "VARCHAR(0)"
        elif "xml" in t:
            new_t = "VARCHAR(0)"
        elif "char" in t:
            new_t = "VARCHAR(0)"
        # nbytes = 0
        # if ')' in t:
        #     nbytes = t[t.find('(') + 1:t.find(')')]
        #     nbytes = re.sub("[^0-9]", "", nbytes)
        #     if nbytes == "":
        #         nbytes = 0
        # new_t = "VARCHAR(%s)" % nbytes
        else:
            new_t = "UNKNOWN"

        return new_t

    def parse_primary_keys(self, input_ddl):
        """
        Parses primary keys (and shard keys for TQL.
        :param input_ddl: The input DDL to parse from.
        :type input_ddl: list of str
        """
        # read through lines until a CREATE TABLE or ALTER TABLE is found.
        # next look for either a new CREATE TABLE, a PRIMARY KEY, or PARTITION BY HASH
        # add the primary key or partition

        # TODO - get this to work.
        create_or_update = False
        buff = ""
        for line in input_ddl:
            l = self.clean_line(line)

            if not create_or_update:  # looking for CREATE TABLE or UPDATE TABLE statement.
                if l.lower().find("create table") >= 0 or l.lower().find(
                    "update table"
                ) >= 0:
                    create_or_update = True
                    buff = l
                    if self.is_complete_create(buff):
                        self.parse_create_table(buff)
                        buff = ""
                        create_or_update = False
            else:  # looking for the end of a create table.
                buff += l
                if self.is_complete_create(buff):
                    self.parse_create_table(buff)
                    buff = ""
                    create_or_update = False

        pass

    @staticmethod
    def clean_line(line):
        """
        Removes unwanted characters from the input line.
        :param line:  The line to clean up.
        :type line: str
        :return: The cleaned up line.
        :rtype: str
        """
        l = line.strip()
        l = re.sub(" +", " ", l)
        l = re.sub("\t+", " ", l)
        return l