コード例 #1
0
ファイル: api.py プロジェクト: linhcao1611/Trac-JIRA
class DatabaseManagerTestCase(unittest.TestCase):
    def setUp(self):
        self.env = EnvironmentStub(default_data=True)
        self.dbm = DatabaseManager(self.env)

    def tearDown(self):
        self.env.reset_db()

    def test_destroy_db(self):
        """Database doesn't exist after calling destroy_db."""
        self.env.db_query("SELECT name FROM system")
        self.assertIsNotNone(self.dbm._cnx_pool)
        self.dbm.destroy_db()
        self.assertIsNone(self.dbm._cnx_pool)  # No connection pool
        self.assertFalse(self.dbm.db_exists())

    def test_get_column_names(self):
        """Get column names for the default database."""
        for table in default_schema:
            column_names = [col.name for col in table.columns]
            self.assertEqual(column_names,
                             self.dbm.get_column_names(table.name))

    def test_get_default_database_version(self):
        """Get database version for the default entry named
        `database_version`.
        """
        self.assertEqual(default_db_version, self.dbm.get_database_version())

    def test_get_table_names(self):
        """Get table names for the default database."""
        self.assertEqual(sorted(table.name for table in default_schema),
                         sorted(self.dbm.get_table_names()))

    def test_set_default_database_version(self):
        """Set database version for the default entry named
        `database_version`.
        """
        new_db_version = default_db_version + 1
        self.dbm.set_database_version(new_db_version)
        self.assertEqual(new_db_version, self.dbm.get_database_version())

        # Restore the previous version to avoid destroying the database
        # on teardown
        self.dbm.set_database_version(default_db_version)
        self.assertEqual(default_db_version, self.dbm.get_database_version())

    def test_set_get_plugin_database_version(self):
        """Get and set database version for an entry with an
        arbitrary name.
        """
        name = 'a_trac_plugin_version'
        db_ver = 1

        self.assertFalse(self.dbm.get_database_version(name))
        self.dbm.set_database_version(db_ver, name)
        self.assertEqual(db_ver, self.dbm.get_database_version(name))
コード例 #2
0
class DatabaseManagerTestCase(unittest.TestCase):

    def setUp(self):
        self.env = EnvironmentStub(default_data=True)
        self.dbm = DatabaseManager(self.env)

    def tearDown(self):
        self.env.reset_db()

    def test_get_default_database_version(self):
        """Get database version for the default entry named
        `database_version`.
        """
        self.assertEqual(default_db_version, self.dbm.get_database_version())

    def test_get_table_names(self):
        """Get table names for the default database."""
        self.assertEqual(sorted(table.name for table in default_schema),
                         sorted(self.dbm.get_table_names()))

    def test_set_default_database_version(self):
        """Set database version for the default entry named
        `database_version`.
        """
        new_db_version = default_db_version + 1
        self.dbm.set_database_version(new_db_version)
        self.assertEqual(new_db_version, self.dbm.get_database_version())

        # Restore the previous version to avoid destroying the database
        # on teardown
        self.dbm.set_database_version(default_db_version)
        self.assertEqual(default_db_version, self.dbm.get_database_version())

    def test_set_get_plugin_database_version(self):
        """Get and set database version for an entry with an
        arbitrary name.
        """
        name = 'a_trac_plugin_version'
        db_ver = 1

        self.assertFalse(self.dbm.get_database_version(name))
        self.dbm.set_database_version(db_ver, name)
        self.assertEqual(db_ver, self.dbm.get_database_version(name))
コード例 #3
0
def copy_tables(src_env, dst_env, src_db, dst_db, src_dburi, dst_dburi):
    printfout("Copying tables:")

    if src_dburi.startswith('sqlite:'):
        src_db.cnx._eager = False  # avoid uses of eagar cursor
    src_cursor = src_db.cursor()
    if src_dburi.startswith('sqlite:'):
        if type(src_cursor.cursor) is not sqlite_backend.PyFormatCursor:
            raise AssertionError('src_cursor.cursor is %r' % src_cursor.cursor)
    src_tables = set(DatabaseManager(src_env).get_table_names())
    cursor = dst_db.cursor()
    dst_dbm = DatabaseManager(dst_env)
    tables = set(dst_dbm.get_table_names()) & src_tables
    sequences = set(dst_dbm.get_sequence_names())
    progress = sys.stdout.isatty() and sys.stderr.isatty()
    replace_cast = get_replace_cast(src_db, dst_db, src_dburi, dst_dburi)

    # speed-up copying data with SQLite database
    if dst_dburi.startswith('sqlite:'):
        sqlite_backend.set_synchronous(cursor, 'OFF')
        multirows_insert = sqlite_backend.sqlite_version >= (3, 7, 11)
        max_parameters = 999
    else:
        multirows_insert = True
        max_parameters = None

    def copy_table(db, cursor, table):
        src_cursor.execute('SELECT * FROM ' + src_db.quote(table))
        columns = get_column_names(src_cursor)
        n_rows = 100
        if multirows_insert and max_parameters:
            n_rows = min(n_rows, int(max_parameters // len(columns)))
        quoted_table = db.quote(table)
        holders = '(%s)' % ','.join(['%s'] * len(columns))
        count = 0

        cursor.execute('DELETE FROM ' + quoted_table)
        while True:
            rows = src_cursor.fetchmany(n_rows)
            if not rows:
                break
            count += len(rows)
            if progress:
                printfout("%d records\r  %s table... ",
                          count,
                          table,
                          newline=False)
            if replace_cast is not None and table == 'report':
                rows = replace_report_query(rows, columns, replace_cast)
            query = 'INSERT INTO %s (%s) VALUES ' % \
                    (quoted_table, ','.join(map(db.quote, columns)))
            if multirows_insert:
                cursor.execute(query + ','.join([holders] * len(rows)),
                               sum(rows, ()))
            else:
                cursor.executemany(query + holders, rows)

        return count

    try:
        cursor = dst_db.cursor()
        for table in sorted(tables):
            printfout("  %s table... ", table, newline=False)
            count = copy_table(dst_db, cursor, table)
            printfout("%d records.", count)
        for table in tables & sequences:
            dst_db.update_sequence(cursor, table)
        dst_db.commit()
    except:
        dst_db.rollback()
        raise
コード例 #4
0
class DatabaseManagerTestCase(unittest.TestCase):
    def setUp(self):
        self.env = EnvironmentStub(default_data=True)
        self.dbm = DatabaseManager(self.env)

    def tearDown(self):
        self.env.reset_db()

    def test_destroy_db(self):
        """Database doesn't exist after calling destroy_db."""
        with self.env.db_query as db:
            db("SELECT name FROM " + db.quote('system'))
        self.assertIsNotNone(self.dbm._cnx_pool)
        self.dbm.destroy_db()
        self.assertIsNone(self.dbm._cnx_pool)  # No connection pool
        scheme, params = parse_connection_uri(get_dburi())
        if scheme != 'postgres' or params.get('schema', 'public') != 'public':
            self.assertFalse(self.dbm.db_exists())
        else:
            self.assertEqual([], self.dbm.get_table_names())

    def test_get_column_names(self):
        """Get column names for the default database."""
        for table in default_schema:
            column_names = [col.name for col in table.columns]
            self.assertEqual(column_names,
                             self.dbm.get_column_names(table.name))

    def test_get_default_database_version(self):
        """Get database version for the default entry named
        `database_version`.
        """
        self.assertEqual(default_db_version, self.dbm.get_database_version())

    def test_get_table_names(self):
        """Get table names for the default database."""
        self.assertEqual(sorted(table.name for table in default_schema),
                         sorted(self.dbm.get_table_names()))

    def test_has_table(self):
        self.assertIs(True, self.dbm.has_table('system'))
        self.assertIs(True, self.dbm.has_table('wiki'))
        self.assertIs(False, self.dbm.has_table('trac'))
        self.assertIs(False, self.dbm.has_table('blah.blah'))

    def test_no_database_version(self):
        """False is returned when entry doesn't exist"""
        self.assertFalse(self.dbm.get_database_version('trac_plugin_version'))

    def test_set_default_database_version(self):
        """Set database version for the default entry named
        `database_version`.
        """
        new_db_version = default_db_version + 1
        self.dbm.set_database_version(new_db_version)
        self.assertEqual(new_db_version, self.dbm.get_database_version())
        self.assertEqual([('INFO', 'Upgraded database_version from 45 to 46')],
                         self.env.log_messages)

        # Restore the previous version to avoid destroying the database
        # on teardown
        self.dbm.set_database_version(default_db_version)
        self.assertEqual(default_db_version, self.dbm.get_database_version())

    def test_set_get_plugin_database_version(self):
        """Get and set database version for an entry with an
        arbitrary name.
        """
        name = 'trac_plugin_version'
        db_ver = 1

        self.dbm.set_database_version(db_ver, name)
        self.assertEqual([], self.env.log_messages)
        self.assertEqual(db_ver, self.dbm.get_database_version(name))
        # DB update will be skipped when new value equals database version
        self.dbm.set_database_version(db_ver, name)
        self.assertEqual([], self.env.log_messages)

    def test_get_sequence_names(self):
        sequence_names = []
        if self.dbm.connection_uri.startswith('postgres'):
            for table in default_schema:
                for column in table.columns:
                    if column.name == 'id' and column.auto_increment:
                        sequence_names.append(table.name)
            sequence_names.sort()

        self.assertEqual(sequence_names, self.dbm.get_sequence_names())