示例#1
0
class DatabaseManager(Component):

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option(
        'trac', 'database', 'sqlite:db/trac.db', """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    timeout = IntOption(
        'trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''. ''(Since 0.11)''""")

    def __init__(self):
        self._cnx_pool = None

    def init_db(self):
        connector, args = self._get_connector()
        connector.init_db(**args)

    def get_connection(self):
        if not self._cnx_pool:
            connector, args = self._get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        return self._cnx_pool.get_cnx(self.timeout or None)

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None

    def _get_connector(self):  ### FIXME: Make it public?
        scheme, args = _parse_db_str(self.connection_uri)
        candidates = {}
        connector = None
        for connector in self.connectors:
            for scheme_, priority in connector.get_supported_schemes():
                if scheme_ != scheme:
                    continue
                highest = candidates.get(scheme_, (None, 0))[1]
                if priority > highest:
                    candidates[scheme] = (connector, priority)
            connector = candidates.get(scheme, [None])[0]
        if not connector:
            raise TracError('Unsupported database type "%s"' % scheme)

        if scheme == 'sqlite':
            # Special case for SQLite to support a path relative to the
            # environment directory
            if args['path'] != ':memory:' and \
                   not args['path'].startswith('/'):
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        return connector, args
示例#2
0
class DatabaseManager(Component):

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option('trac', 'database', 'sqlite:db/trac.db',
        """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    timeout = IntOption('trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''. ''(Since 0.11)''""")

    def __init__(self):
        self._cnx_pool = None

    def init_db(self):
        connector, args = self._get_connector()
        connector.init_db(**args)

    def get_connection(self):
        if not self._cnx_pool:
            connector, args = self._get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        return self._cnx_pool.get_cnx(self.timeout or None)

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None

    def _get_connector(self): ### FIXME: Make it public?
        scheme, args = _parse_db_str(self.connection_uri)
        candidates = {}
        connector = None
        for connector in self.connectors:
            for scheme_, priority in connector.get_supported_schemes():
                if scheme_ != scheme:
                    continue
                highest = candidates.get(scheme_, (None, 0))[1]
                if priority > highest:
                    candidates[scheme] = (connector, priority)
            connector = candidates.get(scheme, [None])[0]
        if not connector:
            raise TracError('Unsupported database type "%s"' % scheme)

        if scheme == 'sqlite':
            # Special case for SQLite to support a path relative to the
            # environment directory
            if args['path'] != ':memory:' and \
                   not args['path'].startswith('/'):
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        return connector, args
示例#3
0
    def get_connection(self, readonly=False):
        """Get a database connection from the pool.

        If `readonly` is `True`, the returned connection will purposely
        lack the `rollback` and `commit` methods.
        """
        if not self._cnx_pool:
            connector, args = self.get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        db = self._cnx_pool.get_cnx(self.timeout or None)
        if readonly:
            db = ConnectionWrapper(db, readonly=True)
        return db
示例#4
0
文件: api.py 项目: pkdevbox/trac
    def get_connection(self, readonly=False):
        """Get a database connection from the pool.

        If `readonly` is `True`, the returned connection will purposely
        lack the `rollback` and `commit` methods.
        """
        if not self._cnx_pool:
            connector, args = self.get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        db = self._cnx_pool.get_cnx(self.timeout or None)
        if readonly:
            db = ConnectionWrapper(db, readonly=True)
        return db
示例#5
0
class DatabaseManager(Component):
    """Component used to manage the `IDatabaseConnector` implementations."""

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option(
        'trac', 'database', 'sqlite:db/trac.db', """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    backup_dir = Option('trac', 'backup_dir', 'db',
                        """Database backup location""")

    timeout = IntOption(
        'trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''. (''since 0.11'')""")

    debug_sql = BoolOption(
        'trac', 'debug_sql', False,
        """Show the SQL queries in the Trac log, at DEBUG level.
        (''since 0.11.5'')""")

    def __init__(self):
        self._cnx_pool = None
        self._transaction_local = ThreadLocal(wdb=None, rdb=None)

    def init_db(self):
        connector, args = self.get_connector()
        from trac.db_default import schema
        args['schema'] = schema
        connector.init_db(**args)

    def create_tables(self, schema):
        """Create the specified tables.

        :param schema: an iterable of table objects.

        :since: version 1.0.2
        """
        connector = self.get_connector()[0]
        with self.env.db_transaction as db:
            for table in schema:
                for sql in connector.to_sql(table):
                    db(sql)

    def drop_tables(self, schema):
        """Drop the specified tables.

        :param schema: an iterable of `Table` objects or table names.

        :since: version 1.0.2
        """
        with self.env.db_transaction as db:
            for table in schema:
                table_name = table.name if isinstance(table, Table) else table
                db.drop_table(table_name)

    def get_connection(self, readonly=False):
        """Get a database connection from the pool.

        If `readonly` is `True`, the returned connection will purposely
        lack the `rollback` and `commit` methods.
        """
        if not self._cnx_pool:
            connector, args = self.get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        db = self._cnx_pool.get_cnx(self.timeout or None)
        if readonly:
            db = ConnectionWrapper(db, readonly=True)
        return db

    def get_exceptions(self):
        return self.get_connector()[0].get_exceptions()

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None

    def backup(self, dest=None):
        """Save a backup of the database.

        :param dest: base filename to write to.

        Returns the file actually written.
        """
        connector, args = self.get_connector()
        if not dest:
            backup_dir = self.backup_dir
            if not os.path.isabs(backup_dir):
                backup_dir = os.path.join(self.env.path, backup_dir)
            db_str = self.config.get('trac', 'database')
            db_name, db_path = db_str.split(":", 1)
            dest_name = '%s.%i.%d.bak' % (db_name, self.env.get_version(),
                                          int(time.time()))
            dest = os.path.join(backup_dir, dest_name)
        else:
            backup_dir = os.path.dirname(dest)
        if not os.path.exists(backup_dir):
            os.makedirs(backup_dir)
        return connector.backup(dest)

    def get_connector(self):
        scheme, args = _parse_db_str(self.connection_uri)
        candidates = [
            (priority, connector) for connector in self.connectors
            for scheme_, priority in connector.get_supported_schemes()
            if scheme_ == scheme
        ]
        if not candidates:
            raise TracError(
                _('Unsupported database type "%(scheme)s"', scheme=scheme))
        priority, connector = max(candidates)
        if priority < 0:
            raise TracError(connector.error)

        if scheme == 'sqlite':
            if args['path'] == ':memory:':
                # Special case for SQLite in-memory database, always get
                # the /same/ connection over
                pass
            elif not args['path'].startswith('/'):
                # Special case for SQLite to support a path relative to the
                # environment directory
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        if self.debug_sql:
            args['log'] = self.log
        return connector, args

    _get_connector = get_connector  # For 0.11 compatibility
示例#6
0
文件: api.py 项目: wataash/trac
class DatabaseManager(Component):
    """Component used to manage the `IDatabaseConnector` implementations."""

    implements(IEnvironmentSetupParticipant, ISystemInfoProvider)

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option(
        'trac', 'database', 'sqlite:db/trac.db', """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    backup_dir = Option('trac', 'backup_dir', 'db',
                        """Database backup location""")

    timeout = IntOption(
        'trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''.""")

    debug_sql = BoolOption(
        'trac', 'debug_sql', False,
        """Show the SQL queries in the Trac log, at DEBUG level.
        """)

    def __init__(self):
        self._cnx_pool = None
        self._transaction_local = ThreadLocal(wdb=None, rdb=None)

    def init_db(self):
        connector, args = self.get_connector()
        args['schema'] = db_default.schema
        connector.init_db(**args)
        version = db_default.db_version
        self.set_database_version(version, 'initial_database_version')
        self.set_database_version(version)

    def insert_default_data(self):
        self.insert_into_tables(db_default.get_data)

    def destroy_db(self):
        connector, args = self.get_connector()
        # Connections to on-disk db must be closed before deleting it.
        self.shutdown()
        connector.destroy_db(**args)

    def db_exists(self):
        connector, args = self.get_connector()
        return connector.db_exists(**args)

    def create_tables(self, schema):
        """Create the specified tables.

        :param schema: an iterable of table objects.

        :since: version 1.0.2
        """
        connector = self.get_connector()[0]
        with self.env.db_transaction as db:
            for table in schema:
                for sql in connector.to_sql(table):
                    db(sql)

    def drop_columns(self, table, columns):
        """Drops the specified columns from table.

        :since: version 1.2
        """
        table_name = table.name if isinstance(table, Table) else table
        with self.env.db_transaction as db:
            if not db.has_table(table_name):
                raise self.env.db_exc.OperationalError('Table %s not found' %
                                                       db.quote(table_name))
            for col in columns:
                db.drop_column(table_name, col)

    def drop_tables(self, schema):
        """Drop the specified tables.

        :param schema: an iterable of `Table` objects or table names.

        :since: version 1.0.2
        """
        with self.env.db_transaction as db:
            for table in schema:
                table_name = table.name if isinstance(table, Table) else table
                db.drop_table(table_name)

    def insert_into_tables(self, data_or_callable):
        """Insert data into existing tables.

        :param data_or_callable: Nested tuples of table names, column names
                                 and row data::

                                   (table1,
                                    (column1, column2),
                                    ((row1col1, row1col2),
                                     (row2col1, row2col2)),
                                    table2, ...)

                                 or a callable that takes a single parameter
                                 `db` and returns the aforementioned nested
                                 tuple.
        :since: version 1.1.3
        """
        with self.env.db_transaction as db:
            data = data_or_callable(db) if callable(data_or_callable) \
                                        else data_or_callable
            for table, cols, vals in data:
                db.executemany(
                    "INSERT INTO %s (%s) VALUES (%s)" %
                    (db.quote(table), ','.join(cols), ','.join(
                        ['%s'] * len(cols))), vals)

    def reset_tables(self):
        """Deletes all data from the tables and resets autoincrement indexes.

        :return: list of names of the tables that were reset.

        :since: version 1.1.3
        """
        with self.env.db_transaction as db:
            return db.reset_tables()

    def upgrade_tables(self, new_schema):
        """Upgrade table schema to `new_schema`, preserving data in
        columns that exist in the current schema and `new_schema`.

        :param new_schema: tuple or list of `Table` objects

        :since: version 1.2
        """
        with self.env.db_transaction as db:
            cursor = db.cursor()
            for new_table in new_schema:
                temp_table_name = new_table.name + '_old'
                has_table = self.has_table(new_table)
                if has_table:
                    old_column_names = set(self.get_column_names(new_table))
                    new_column_names = {col.name for col in new_table.columns}
                    column_names = old_column_names & new_column_names
                    if column_names:
                        cols_to_copy = ','.join(
                            db.quote(name) for name in column_names)
                        cursor.execute("""
                            CREATE TEMPORARY TABLE %s AS SELECT * FROM %s
                            """ % (db.quote(temp_table_name),
                                   db.quote(new_table.name)))
                    self.drop_tables((new_table, ))
                self.create_tables((new_table, ))
                if has_table and column_names:
                    cursor.execute("""
                        INSERT INTO %s (%s) SELECT %s FROM %s
                        """ % (db.quote(new_table.name), cols_to_copy,
                               cols_to_copy, db.quote(temp_table_name)))
                    for col in new_table.columns:
                        if col.auto_increment:
                            db.update_sequence(cursor, new_table.name,
                                               col.name)
                    self.drop_tables((temp_table_name, ))

    def get_connection(self, readonly=False):
        """Get a database connection from the pool.

        If `readonly` is `True`, the returned connection will purposely
        lack the `rollback` and `commit` methods.
        """
        if not self._cnx_pool:
            connector, args = self.get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        db = self._cnx_pool.get_cnx(self.timeout or None)
        if readonly:
            db = ConnectionWrapper(db, readonly=True)
        return db

    def get_database_version(self, name='database_version'):
        """Returns the database version from the SYSTEM table as an int,
        or `False` if the entry is not found.

        :param name: The name of the entry that contains the database version
                     in the SYSTEM table. Defaults to `database_version`,
                     which contains the database version for Trac.
        """
        with self.env.db_query as db:
            for value, in db(
                    """
                    SELECT value FROM {0} WHERE name=%s
                    """.format(db.quote('system')), (name, )):
                return int(value)
            else:
                return False

    def get_exceptions(self):
        return self.get_connector()[0].get_exceptions()

    def get_sequence_names(self):
        """Returns a list of the sequence names.

        :since: 1.3.2
        """
        with self.env.db_query as db:
            return db.get_sequence_names()

    def get_table_names(self):
        """Returns a list of the table names.

        :since: 1.1.6
        """
        with self.env.db_query as db:
            return db.get_table_names()

    def get_column_names(self, table):
        """Returns a list of the column names for `table`.

        :param table: a `Table` object or table name.

        :since: 1.2
        """
        table_name = table.name if isinstance(table, Table) else table
        with self.env.db_query as db:
            if not db.has_table(table_name):
                raise self.env.db_exc.OperationalError('Table %s not found' %
                                                       db.quote(table_name))
            return db.get_column_names(table_name)

    def has_table(self, table):
        """Returns whether the table exists."""
        table_name = table.name if isinstance(table, Table) else table
        with self.env.db_query as db:
            return db.has_table(table_name)

    def set_database_version(self, version, name='database_version'):
        """Sets the database version in the SYSTEM table.

        :param version: an integer database version.
        :param name: The name of the entry that contains the database version
                     in the SYSTEM table. Defaults to `database_version`,
                     which contains the database version for Trac.
        """
        current_database_version = self.get_database_version(name)
        if current_database_version is False:
            with self.env.db_transaction as db:
                db(
                    """
                    INSERT INTO {0} (name, value) VALUES (%s, %s)
                    """.format(db.quote('system')), (name, version))
        elif version != self.get_database_version(name):
            with self.env.db_transaction as db:
                db(
                    """
                    UPDATE {0} SET value=%s WHERE name=%s
                    """.format(db.quote('system')), (version, name))
            self.log.info("Upgraded %s from %d to %d", name,
                          current_database_version, version)

    def needs_upgrade(self, version, name='database_version'):
        """Checks the database version to determine if an upgrade is needed.

        :param version: the expected integer database version.
        :param name: the name of the entry in the SYSTEM table that contains
                     the database version. Defaults to `database_version`,
                     which contains the database version for Trac.

        :return: `True` if the stored version is less than the expected
                  version, `False` if it is equal to the expected version.
        :raises TracError: if the stored version is greater than the expected
                           version.
        """
        dbver = self.get_database_version(name)
        if dbver == version:
            return False
        elif dbver > version:
            raise TracError(_("Need to downgrade %(name)s.", name=name))
        self.log.info("Need to upgrade %s from %d to %d", name, dbver, version)
        return True

    def upgrade(self, version, name='database_version', pkg='trac.upgrades'):
        """Invokes `do_upgrade(env, version, cursor)` in module
        `"%s/db%i.py" % (pkg, version)`, for each required version upgrade.

        :param version: the expected integer database version.
        :param name: the name of the entry in the SYSTEM table that contains
                     the database version. Defaults to `database_version`,
                     which contains the database version for Trac.
        :param pkg: the package containing the upgrade modules.

        :raises TracError: if the package or module doesn't exist.
        """
        dbver = self.get_database_version(name)
        for i in xrange(dbver + 1, version + 1):
            module = '%s.db%i' % (pkg, i)
            try:
                upgrader = importlib.import_module(module)
            except ImportError:
                raise TracError(
                    _("No upgrade module %(module)s.py", module=module))
            with self.env.db_transaction as db:
                cursor = db.cursor()
                upgrader.do_upgrade(self.env, i, cursor)
                self.set_database_version(i, name)

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None

    def backup(self, dest=None):
        """Save a backup of the database.

        :param dest: base filename to write to.

        Returns the file actually written.
        """
        connector, args = self.get_connector()
        if not dest:
            backup_dir = self.backup_dir
            if not os.path.isabs(backup_dir):
                backup_dir = os.path.join(self.env.path, backup_dir)
            db_str = self.config.get('trac', 'database')
            db_name, db_path = db_str.split(":", 1)
            dest_name = '%s.%i.%d.bak' % (db_name, self.env.database_version,
                                          int(time.time()))
            dest = os.path.join(backup_dir, dest_name)
        else:
            backup_dir = os.path.dirname(dest)
        if not os.path.exists(backup_dir):
            os.makedirs(backup_dir)
        return connector.backup(dest)

    def get_connector(self):
        scheme, args = parse_connection_uri(self.connection_uri)
        candidates = [
            (priority, connector) for connector in self.connectors
            for scheme_, priority in connector.get_supported_schemes()
            if scheme_ == scheme
        ]
        if not candidates:
            raise TracError(
                _('Unsupported database type "%(scheme)s"', scheme=scheme))
        priority, connector = max(candidates)
        if priority < 0:
            raise TracError(connector.error)

        if scheme == 'sqlite':
            if args['path'] == ':memory:':
                # Special case for SQLite in-memory database, always get
                # the /same/ connection over
                pass
            elif not os.path.isabs(args['path']):
                # Special case for SQLite to support a path relative to the
                # environment directory
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        if self.debug_sql:
            args['log'] = self.log
        return connector, args

    # IEnvironmentSetupParticipant methods

    def environment_created(self):
        pass

    def environment_needs_upgrade(self):
        return self.needs_upgrade(db_default.db_version)

    def upgrade_environment(self):
        self.upgrade(db_default.db_version)

    # ISystemInfoProvider methods

    def get_system_info(self):
        connector = self.get_connector()[0]
        for info in connector.get_system_info():
            yield info
示例#7
0
class DatabaseManager(Component):
    """Component used to manage the `IDatabaseConnector` implementations."""

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option(
        'trac', 'database', 'sqlite:db/trac.db', """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    backup_dir = Option('trac', 'backup_dir', 'db',
                        """Database backup location""")

    timeout = IntOption(
        'trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''.""")

    debug_sql = BoolOption(
        'trac', 'debug_sql', False,
        """Show the SQL queries in the Trac log, at DEBUG level.
        """)

    def __init__(self):
        self._cnx_pool = None
        self._transaction_local = ThreadLocal(wdb=None, rdb=None)

    def init_db(self):
        connector, args = self.get_connector()
        from trac.db_default import schema
        args['schema'] = schema
        connector.init_db(**args)

    def destroy_db(self):
        connector, args = self.get_connector()
        connector.destroy_db(**args)
        self.shutdown()

    def create_tables(self, schema):
        """Create the specified tables.

        :param schema: an iterable of table objects.

        :since: version 1.0.2
        """
        connector = self.get_connector()[0]
        with self.env.db_transaction as db:
            for table in schema:
                for sql in connector.to_sql(table):
                    db(sql)

    def drop_tables(self, schema):
        """Drop the specified tables.

        :param schema: an iterable of `Table` objects or table names.

        :since: version 1.0.2
        """
        with self.env.db_transaction as db:
            for table in schema:
                table_name = table.name if isinstance(table, Table) else table
                db.drop_table(table_name)

    def insert_into_tables(self, data_or_callable):
        """Insert data into existing tables.

        :param data_or_callable: Nested tuples of table names, column names
                                 and row data:
                                 (table1,
                                  (column1, column2),
                                  ((row1col1, row1col2), (row2col1, row2col2)),
                                  table2, ...)
                                or a callable that takes a single parameter
                                `db` and returns the aforementioned nested
                                tuple.
        :since: version 1.1.3
        """
        with self.env.db_transaction as db:
            data = data_or_callable(db) if callable(data_or_callable) \
                                        else data_or_callable
            for table, cols, vals in data:
                db.executemany(
                    "INSERT INTO %s (%s) VALUES (%s)" %
                    (table, ','.join(cols), ','.join(['%s'] * len(cols))),
                    vals)

    def reset_tables(self):
        """Deletes all data from the tables and resets autoincrement indexes.

        :return: list of names of the tables that were reset.

        :since: version 1.1.3
        """
        with self.env.db_transaction as db:
            return db.reset_tables()

    def get_connection(self, readonly=False):
        """Get a database connection from the pool.

        If `readonly` is `True`, the returned connection will purposely
        lack the `rollback` and `commit` methods.
        """
        if not self._cnx_pool:
            connector, args = self.get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        db = self._cnx_pool.get_cnx(self.timeout or None)
        if readonly:
            db = ConnectionWrapper(db, readonly=True)
        return db

    def get_database_version(self, name='database_version'):
        """Returns the database version from the SYSTEM table as an int,
        or `False` if the entry is not found.

        :param name: The name of the entry that contains the database version
                     in the SYSTEM table. Defaults to `database_version`,
                     which contains the database version for Trac.
        """
        rows = self.env.db_query(
            """
                SELECT value FROM system WHERE name=%s
                """, (name, ))
        return int(rows[0][0]) if rows else False

    def get_exceptions(self):
        return self.get_connector()[0].get_exceptions()

    def set_database_version(self, version, name='database_version'):
        """Sets the database version in the SYSTEM table.

        :param version: an integer database version.
        :param name: The name of the entry that contains the database version
                     in the SYSTEM table. Defaults to `database_version`,
                     which contains the database version for Trac.
        """
        current_database_version = self.get_database_version(name)
        if current_database_version is False:
            self.env.db_transaction(
                """
                    INSERT INTO system (name, value) VALUES (%s, %s)
                    """, (name, version))
        else:
            self.env.db_transaction(
                """
                    UPDATE system SET value=%s WHERE name=%s
                    """, (version, name))
            self.log.info("Upgraded %s from %d to %d", name,
                          current_database_version, version)

    def needs_upgrade(self, version, name='database_version'):
        """Checks the database version to determine if an upgrade is needed.

        :param version: the expected integer database version.
        :param name: the name of the entry in the SYSTEM table that contains
                     the database version. Defaults to `database_version`,
                     which contains the database version for Trac.

        :return: `True` if the stored version is less than the expected
                  version, `False` if it is equal to the expected version.
        :raises TracError: if the stored version is greater than the expected
                           version.
        """
        dbver = self.get_database_version(name)
        if dbver == version:
            return False
        elif dbver > version:
            raise TracError(_("Need to downgrade %(name)s.", name=name))
        self.log.info("Need to upgrade %s from %d to %d", name, dbver, version)
        return True

    def upgrade(self, version, name='database_version', pkg=None):
        """Invokes `do_upgrade(env, version, cursor)` in module
        `"%s/db%i.py" % (pkg, version)`, for each required version upgrade.

        :param version: the expected integer database version.
        :param name: the name of the entry in the SYSTEM table that contains
                     the database version. Defaults to `database_version`,
                     which contains the database version for Trac.
        :param pkg: the package containing the upgrade modules.

        :raises TracError: if the package or module doesn't exist.
        """
        dbver = self.get_database_version(name)
        for i in range(dbver + 1, version + 1):
            module = 'db%i' % i
            try:
                upgrades = __import__(pkg, globals(), locals(), [module])
            except ImportError:
                raise TracError(_("No upgrade package %(pkg)s", pkg=pkg))
            try:
                script = getattr(upgrades, module)
            except AttributeError:
                raise TracError(
                    _("No upgrade module %(module)s.py", module=module))
            with self.env.db_transaction as db:
                cursor = db.cursor()
                script.do_upgrade(self.env, i, cursor)
                self.set_database_version(i, name)

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None

    def backup(self, dest=None):
        """Save a backup of the database.

        :param dest: base filename to write to.

        Returns the file actually written.
        """
        connector, args = self.get_connector()
        if not dest:
            backup_dir = self.backup_dir
            if not os.path.isabs(backup_dir):
                backup_dir = os.path.join(self.env.path, backup_dir)
            db_str = self.config.get('trac', 'database')
            db_name, db_path = db_str.split(":", 1)
            dest_name = '%s.%i.%d.bak' % (db_name, self.env.database_version,
                                          int(time.time()))
            dest = os.path.join(backup_dir, dest_name)
        else:
            backup_dir = os.path.dirname(dest)
        if not os.path.exists(backup_dir):
            os.makedirs(backup_dir)
        return connector.backup(dest)

    def get_connector(self):
        scheme, args = parse_connection_uri(self.connection_uri)
        candidates = [
            (priority, connector) for connector in self.connectors
            for scheme_, priority in connector.get_supported_schemes()
            if scheme_ == scheme
        ]
        if not candidates:
            raise TracError(
                _('Unsupported database type "%(scheme)s"', scheme=scheme))
        priority, connector = max(candidates)
        if priority < 0:
            raise TracError(connector.error)

        if scheme == 'sqlite':
            if args['path'] == ':memory:':
                # Special case for SQLite in-memory database, always get
                # the /same/ connection over
                pass
            elif not os.path.isabs(args['path']):
                # Special case for SQLite to support a path relative to the
                # environment directory
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        if self.debug_sql:
            args['log'] = self.log
        return connector, args

    _get_connector = get_connector  # For 0.11 compatibility
示例#8
0
 def get_connection(self):
     if not self._cnx_pool:
         connector, args = self.get_connector()
         self._cnx_pool = ConnectionPool(5, connector, **args)
     return self._cnx_pool.get_cnx(self.timeout or None)
示例#9
0
class DatabaseManager(Component):

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option(
        'trac', 'database', 'sqlite:db/trac.db', """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    backup_dir = Option('trac', 'backup_dir', 'db',
                        """Database backup location""")

    timeout = IntOption(
        'trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''. ''(Since 0.11)''""")

    debug_sql = BoolOption(
        'trac', 'debug_sql', False,
        """Show the SQL queries in the Trac log, at DEBUG level.
        ''(Since 0.11.5)''""")

    def __init__(self):
        self._cnx_pool = None

    def init_db(self):
        connector, args = self.get_connector()
        connector.init_db(**args)

    def get_connection(self):
        if not self._cnx_pool:
            connector, args = self.get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        return self._cnx_pool.get_cnx(self.timeout or None)

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None

    def backup(self, dest=None):
        """Save a backup of the database.

        @param dest: base filename to write to.
        Returns the file actually written.
        """
        connector, args = self.get_connector()
        if not dest:
            backup_dir = self.backup_dir
            if not os.path.isabs(backup_dir):
                backup_dir = os.path.join(self.env.path, backup_dir)
            db_str = self.config.get('trac', 'database')
            db_name, db_path = db_str.split(":", 1)
            dest_name = '%s.%i.%d.bak' % (db_name, self.env.get_version(),
                                          int(time.time()))
            dest = os.path.join(backup_dir, dest_name)
        else:
            backup_dir = os.path.dirname(dest)
        if not os.path.exists(backup_dir):
            os.makedirs(backup_dir)
        return connector.backup(dest)

    def get_connector(self):
        scheme, args = _parse_db_str(self.connection_uri)
        candidates = [
            (priority, connector) for connector in self.connectors
            for scheme_, priority in connector.get_supported_schemes()
            if scheme_ == scheme
        ]
        if not candidates:
            raise TracError(
                _('Unsupported database type "%(scheme)s"', scheme=scheme))
        priority, connector = max(candidates)
        if priority < 0:
            raise TracError(connector.error)

        if scheme == 'sqlite':
            # Special case for SQLite to support a path relative to the
            # environment directory
            if args['path'] != ':memory:' and \
                   not args['path'].startswith('/'):
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        if self.debug_sql:
            args['log'] = self.log
        return connector, args

    _get_connector = get_connector  # For 0.11 compatibility
示例#10
0
文件: api.py 项目: pkdevbox/trac
class DatabaseManager(Component):
    """Component used to manage the `IDatabaseConnector` implementations."""

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option('trac', 'database', 'sqlite:db/trac.db',
        """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    backup_dir = Option('trac', 'backup_dir', 'db',
        """Database backup location""")

    timeout = IntOption('trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''.""")

    debug_sql = BoolOption('trac', 'debug_sql', False,
        """Show the SQL queries in the Trac log, at DEBUG level.
        """)

    def __init__(self):
        self._cnx_pool = None
        self._transaction_local = ThreadLocal(wdb=None, rdb=None)

    def init_db(self):
        connector, args = self.get_connector()
        from trac.db_default import schema
        args['schema'] = schema
        connector.init_db(**args)

    def destroy_db(self):
        connector, args = self.get_connector()
        connector.destroy_db(**args)
        self.shutdown()

    def create_tables(self, schema):
        """Create the specified tables.

        :param schema: an iterable of table objects.

        :since: version 1.0.2
        """
        connector = self.get_connector()[0]
        with self.env.db_transaction as db:
            for table in schema:
                for sql in connector.to_sql(table):
                    db(sql)

    def drop_tables(self, schema):
        """Drop the specified tables.

        :param schema: an iterable of `Table` objects or table names.

        :since: version 1.0.2
        """
        with self.env.db_transaction as db:
            for table in schema:
                table_name = table.name if isinstance(table, Table) else table
                db.drop_table(table_name)

    def insert_into_tables(self, data_or_callable):
        """Insert data into existing tables.

        :param data_or_callable: Nested tuples of table names, column names
                                 and row data:
                                 (table1,
                                  (column1, column2),
                                  ((row1col1, row1col2), (row2col1, row2col2)),
                                  table2, ...)
                                or a callable that takes a single parameter
                                `db` and returns the aforementioned nested
                                tuple.
        :since: version 1.1.3
        """
        with self.env.db_transaction as db:
            data = data_or_callable(db) if callable(data_or_callable) \
                                        else data_or_callable
            for table, cols, vals in data:
                db.executemany("INSERT INTO %s (%s) VALUES (%s)"
                               % (table, ','.join(cols),
                                  ','.join(['%s'] * len(cols))), vals)

    def reset_tables(self):
        """Deletes all data from the tables and resets autoincrement indexes.

        :return: list of names of the tables that were reset.

        :since: version 1.1.3
        """
        with self.env.db_transaction as db:
            return db.reset_tables()

    def get_connection(self, readonly=False):
        """Get a database connection from the pool.

        If `readonly` is `True`, the returned connection will purposely
        lack the `rollback` and `commit` methods.
        """
        if not self._cnx_pool:
            connector, args = self.get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        db = self._cnx_pool.get_cnx(self.timeout or None)
        if readonly:
            db = ConnectionWrapper(db, readonly=True)
        return db

    def get_database_version(self, name='database_version'):
        """Returns the database version from the SYSTEM table as an int,
        or `False` if the entry is not found.

        :param name: The name of the entry that contains the database version
                     in the SYSTEM table. Defaults to `database_version`,
                     which contains the database version for Trac.
        """
        rows = self.env.db_query("""
                SELECT value FROM system WHERE name=%s
                """, (name,))
        return int(rows[0][0]) if rows else False

    def get_exceptions(self):
        return self.get_connector()[0].get_exceptions()

    def get_table_names(self):
        """Returns a list of the table names.

        :since: 1.1.6
        """
        with self.env.db_query as db:
            return db.get_table_names()

    def set_database_version(self, version, name='database_version'):
        """Sets the database version in the SYSTEM table.

        :param version: an integer database version.
        :param name: The name of the entry that contains the database version
                     in the SYSTEM table. Defaults to `database_version`,
                     which contains the database version for Trac.
        """
        current_database_version = self.get_database_version(name)
        if current_database_version is False:
            self.env.db_transaction("""
                    INSERT INTO system (name, value) VALUES (%s, %s)
                    """, (name, version))
        else:
            self.env.db_transaction("""
                    UPDATE system SET value=%s WHERE name=%s
                    """, (version, name))
            self.log.info("Upgraded %s from %d to %d",
                          name, current_database_version, version)

    def needs_upgrade(self, version, name='database_version'):
        """Checks the database version to determine if an upgrade is needed.

        :param version: the expected integer database version.
        :param name: the name of the entry in the SYSTEM table that contains
                     the database version. Defaults to `database_version`,
                     which contains the database version for Trac.

        :return: `True` if the stored version is less than the expected
                  version, `False` if it is equal to the expected version.
        :raises TracError: if the stored version is greater than the expected
                           version.
        """
        dbver = self.get_database_version(name)
        if dbver == version:
            return False
        elif dbver > version:
            raise TracError(_("Need to downgrade %(name)s.", name=name))
        self.log.info("Need to upgrade %s from %d to %d",
                      name, dbver, version)
        return True

    def upgrade(self, version, name='database_version', pkg=None):
        """Invokes `do_upgrade(env, version, cursor)` in module
        `"%s/db%i.py" % (pkg, version)`, for each required version upgrade.

        :param version: the expected integer database version.
        :param name: the name of the entry in the SYSTEM table that contains
                     the database version. Defaults to `database_version`,
                     which contains the database version for Trac.
        :param pkg: the package containing the upgrade modules.

        :raises TracError: if the package or module doesn't exist.
        """
        dbver = self.get_database_version(name)
        for i in range(dbver + 1, version + 1):
            module = 'db%i' % i
            try:
                upgrades = __import__(pkg, globals(), locals(), [module])
            except ImportError:
                raise TracError(_("No upgrade package %(pkg)s", pkg=pkg))
            try:
                script = getattr(upgrades, module)
            except AttributeError:
                raise TracError(_("No upgrade module %(module)s.py",
                                  module=module))
            with self.env.db_transaction as db:
                cursor = db.cursor()
                script.do_upgrade(self.env, i, cursor)
                self.set_database_version(i, name)

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None

    def backup(self, dest=None):
        """Save a backup of the database.

        :param dest: base filename to write to.

        Returns the file actually written.
        """
        connector, args = self.get_connector()
        if not dest:
            backup_dir = self.backup_dir
            if not os.path.isabs(backup_dir):
                backup_dir = os.path.join(self.env.path, backup_dir)
            db_str = self.config.get('trac', 'database')
            db_name, db_path = db_str.split(":", 1)
            dest_name = '%s.%i.%d.bak' % (db_name, self.env.database_version,
                                          int(time.time()))
            dest = os.path.join(backup_dir, dest_name)
        else:
            backup_dir = os.path.dirname(dest)
        if not os.path.exists(backup_dir):
            os.makedirs(backup_dir)
        return connector.backup(dest)

    def get_connector(self):
        scheme, args = parse_connection_uri(self.connection_uri)
        candidates = [
            (priority, connector)
            for connector in self.connectors
            for scheme_, priority in connector.get_supported_schemes()
            if scheme_ == scheme
        ]
        if not candidates:
            raise TracError(_('Unsupported database type "%(scheme)s"',
                              scheme=scheme))
        priority, connector = max(candidates)
        if priority < 0:
            raise TracError(connector.error)

        if scheme == 'sqlite':
            if args['path'] == ':memory:':
                # Special case for SQLite in-memory database, always get
                # the /same/ connection over
                pass
            elif not os.path.isabs(args['path']):
                # Special case for SQLite to support a path relative to the
                # environment directory
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        if self.debug_sql:
            args['log'] = self.log
        return connector, args

    _get_connector = get_connector  # For 0.11 compatibility
示例#11
0
 def get_connection(self):
     if not self._cnx_pool:
         connector, args = self._get_connector()
         self._cnx_pool = ConnectionPool(5, connector, **args)
     return self._cnx_pool.get_cnx(self.timeout or None)
示例#12
0
class DatabaseManager(Component):

    connectors = ExtensionPoint(IDatabaseConnector)

    connection_uri = Option('trac', 'database', 'sqlite:db/trac.db',
        """Database connection
        [wiki:TracEnvironment#DatabaseConnectionStrings string] for this
        project""")

    backup_dir = Option('trac', 'backup_dir', 'db',
        """Database backup location""")

    timeout = IntOption('trac', 'timeout', '20',
        """Timeout value for database connection, in seconds.
        Use '0' to specify ''no timeout''. ''(Since 0.11)''""")

    debug_sql = BoolOption('trac', 'debug_sql', False,
        """Show the SQL queries in the Trac log, at DEBUG level.
        ''(Since 0.11.5)''""")

    def __init__(self):
        self._cnx_pool = None

    def init_db(self):
        connector, args = self._get_connector()
        connector.init_db(**args)

    def get_connection(self):
        if not self._cnx_pool:
            connector, args = self._get_connector()
            self._cnx_pool = ConnectionPool(5, connector, **args)
        return self._cnx_pool.get_cnx(self.timeout or None)

    def shutdown(self, tid=None):
        if self._cnx_pool:
            self._cnx_pool.shutdown(tid)
            if not tid:
                self._cnx_pool = None
                
    def backup(self, dest=None):
        """Save a backup of the database.

        @param dest: base filename to write to.
        Returns the file actually written.
        """
        connector, args = self._get_connector()
        if not dest:
            backup_dir = self.backup_dir
            if backup_dir[0] != "/":
                backup_dir = os.path.join(self.env.path, backup_dir)
            db_str = self.config.get('trac', 'database')
            db_name, db_path = db_str.split(":",1)
            dest_name = '%s.%i.%d.bak' % (db_name, self.env.get_version(),
                                          int(time.time()))
            dest = os.path.join(backup_dir, dest_name)
        else:
            backup_dir = os.path.dirname(dest)
        if not os.path.exists(backup_dir):
            os.makedirs(backup_dir)
        return connector.backup(dest)

    def _get_connector(self): ### FIXME: Make it public?
        scheme, args = _parse_db_str(self.connection_uri)
        candidates = [
            (priority, connector)
            for connector in self.connectors
            for scheme_, priority in connector.get_supported_schemes()
            if scheme_ == scheme
        ]
        if not candidates:
            raise TracError('Unsupported database type "%s"' % scheme)
        priority, connector = max(candidates)
        if priority < 0:
            raise TracError(connector.error)

        if scheme == 'sqlite':
            # Special case for SQLite to support a path relative to the
            # environment directory
            if args['path'] != ':memory:' and \
                   not args['path'].startswith('/'):
                args['path'] = os.path.join(self.env.path,
                                            args['path'].lstrip('/'))

        if self.debug_sql:
            args['log'] = self.log
        return connector, args