Esempio n. 1
0
    def test_select_limit_offset(self):
        query = self.table.select(limit=50, offset=10)
        self.assertEqual(str(query),
            'SELECT * FROM "t" AS "a" LIMIT 50 OFFSET 10')
        self.assertEqual(query.params, ())

        query.limit = None
        self.assertEqual(str(query),
            'SELECT * FROM "t" AS "a" OFFSET 10')
        self.assertEqual(query.params, ())

        query.offset = 0
        self.assertEqual(str(query),
            'SELECT * FROM "t" AS "a"')
        self.assertEqual(query.params, ())

        flavor = Flavor(max_limit=-1)
        Flavor.set(flavor)
        try:
            query.offset = None
            self.assertEqual(str(query),
                'SELECT * FROM "t" AS "a"')
            self.assertEqual(query.params, ())

            query.offset = 0
            self.assertEqual(str(query),
                'SELECT * FROM "t" AS "a"')
            self.assertEqual(query.params, ())

            query.offset = 10
            self.assertEqual(str(query),
                'SELECT * FROM "t" AS "a" LIMIT -1 OFFSET 10')
            self.assertEqual(query.params, ())
        finally:
            Flavor.set(Flavor())
Esempio n. 2
0
 def test_no_as(self):
     query = self.table.select(self.table.c)
     try:
         Flavor.set(Flavor(no_as=True))
         self.assertEqual(str(query), 'SELECT "a"."c" FROM "t" "a"')
         self.assertEqual(query.params, ())
     finally:
         Flavor.set(Flavor())
Esempio n. 3
0
 def test_filter(self):
     flavor = Flavor(filter_=True)
     Flavor.set(flavor)
     try:
         avg = Avg(self.table.a + 1, filter_=self.table.a > 0)
         self.assertEqual(str(avg),
                          'AVG(("a" + %s)) FILTER (WHERE ("a" > %s))')
         self.assertEqual(avg.params, (1, 0))
     finally:
         Flavor.set(Flavor())
Esempio n. 4
0
    def test_at_time_zone_mapping(self):
        class MyAtTimeZone(Function):
            _function = 'MY_TIMEZONE'

        time_zone = AtTimeZone(self.table.c1, 'UTC')
        flavor = Flavor(function_mapping={
            AtTimeZone: MyAtTimeZone,
        })
        Flavor.set(flavor)
        try:
            self.assertEqual(str(time_zone), 'MY_TIMEZONE("c1", %s)')
            self.assertEqual(time_zone.params, ('UTC', ))
        finally:
            Flavor.set(Flavor())
Esempio n. 5
0
    def test_mapping(self):
        class MyAbs(Function):
            _function = 'MY_ABS'
            params = ('test', )

        abs_ = Abs(self.table.c1)
        flavor = Flavor(function_mapping={
            Abs: MyAbs,
        })
        Flavor.set(flavor)
        try:
            self.assertEqual(str(abs_), 'MY_ABS("c1")')
            self.assertEqual(abs_.params, ('test', ))
        finally:
            Flavor.set(Flavor())
 def test_no_boolean(self):
     true = Literal(True)
     false = Literal(False)
     self.assertEqual(str(true), '%s')
     self.assertEqual(true.params, (True, ))
     self.assertEqual(str(false), '%s')
     self.assertEqual(false.params, (False, ))
     try:
         Flavor.set(Flavor(no_boolean=True))
         self.assertEqual(str(true), '(1 = 1)')
         self.assertEqual(str(false), '(1 != 1)')
         self.assertEqual(true.params, ())
         self.assertEqual(false.params, ())
     finally:
         Flavor.set(Flavor())
Esempio n. 7
0
    def test_mapping(self):
        class MyAbs(Function):
            _function = 'MY_ABS'
            params = ('test', )

        class MyOverlay(FunctionKeyword):
            _function = 'MY_OVERLAY'
            _keywords = ('', 'PLACING', 'FROM', 'FOR')

        class MyCurrentTime(FunctionNotCallable):
            _function = 'MY_CURRENT_TIME'

        class MyTrim(Trim):
            _function = 'MY_TRIM'

        abs_ = Abs(self.table.c1)
        overlay = Overlay(self.table.c1, 'test', 2)
        current_time = CurrentTime()
        trim = Trim(' test ')
        flavor = Flavor(
            function_mapping={
                Abs: MyAbs,
                Overlay: MyOverlay,
                CurrentTime: MyCurrentTime,
                Trim: MyTrim,
            })
        Flavor.set(flavor)
        try:
            self.assertEqual(str(abs_), 'MY_ABS("c1")')
            self.assertEqual(abs_.params, ('test', ))

            self.assertEqual(str(overlay),
                             'MY_OVERLAY("c1" PLACING %s FROM %s)')
            self.assertEqual(overlay.params, ('test', 2))

            self.assertEqual(str(current_time), 'MY_CURRENT_TIME')
            self.assertEqual(current_time.params, ())

            self.assertEqual(str(trim), 'MY_TRIM(BOTH %s FROM %s)')
            self.assertEqual(trim.params, (
                ' ',
                ' test ',
            ))
        finally:
            Flavor.set(Flavor())
Esempio n. 8
0
    def test_mod_paramstyle(self):
        flavor = Flavor(paramstyle='format')
        Flavor.set(flavor)
        try:
            mod = Mod(self.table.c1, self.table.c2)
            self.assertEqual(str(mod), '("c1" %% "c2")')
            self.assertEqual(mod.params, ())
        finally:
            Flavor.set(Flavor())

        flavor = Flavor(paramstyle='qmark')
        Flavor.set(flavor)
        try:
            mod = Mod(self.table.c1, self.table.c2)
            self.assertEqual(str(mod), '("c1" % "c2")')
            self.assertEqual(mod.params, ())
        finally:
            Flavor.set(Flavor())
Esempio n. 9
0
    def test_not_ilike(self):
        flavor = Flavor(ilike=True)
        Flavor.set(flavor)
        try:
            for like in [NotILike(self.table.c1, 'foo'),
                    ~self.table.c1.ilike('foo')]:
                self.assertEqual(str(like), '("c1" NOT ILIKE %s)')
                self.assertEqual(like.params, ('foo',))
        finally:
            Flavor.set(Flavor())

        flavor = Flavor(ilike=False)
        Flavor.set(flavor)
        try:
            like = NotILike(self.table.c1, 'foo')
            self.assertEqual(str(like), '(UPPER("c1") NOT LIKE UPPER(%s))')
            self.assertEqual(like.params, ('foo',))
        finally:
            Flavor.set(Flavor())
Esempio n. 10
0
    def test_select_offset_fetch(self):
        try:
            Flavor.set(Flavor(limitstyle='fetch'))
            query = self.table.select(limit=50, offset=10)
            self.assertEqual(
                str(query), 'SELECT * FROM "t" AS "a" '
                'OFFSET (10) ROWS FETCH FIRST (50) ROWS ONLY')
            self.assertEqual(query.params, ())

            query.limit = None
            self.assertEqual(str(query),
                             'SELECT * FROM "t" AS "a" OFFSET (10) ROWS')
            self.assertEqual(query.params, ())

            query.offset = 0
            self.assertEqual(str(query), 'SELECT * FROM "t" AS "a"')
            self.assertEqual(query.params, ())
        finally:
            Flavor.set(Flavor())
Esempio n. 11
0
    def test_no_null_ordering(self):
        try:
            Flavor.set(Flavor(null_ordering=False))

            exp = NullsFirst(self.column)
            self.assertEqual(
                str(exp),
                'CASE WHEN ("c" IS NULL) THEN %s ELSE %s END ASC, "c"')
            self.assertEqual(exp.params, (0, 1))

            exp = NullsFirst(Desc(self.column))
            self.assertEqual(
                str(exp),
                'CASE WHEN ("c" IS NULL) THEN %s ELSE %s END ASC, "c" DESC')
            self.assertEqual(exp.params, (0, 1))

            exp = NullsLast(Literal(2))
            self.assertEqual(
                str(exp), 'CASE WHEN (%s IS NULL) THEN %s ELSE %s END ASC, %s')
            self.assertEqual(exp.params, (2, 1, 0, 2))
        finally:
            Flavor.set(Flavor())
Esempio n. 12
0
    def test_ilike(self):
        flavor = Flavor(ilike=True)
        Flavor.set(flavor)
        try:
            for like in [
                    ILike(self.table.c1, 'foo'),
                    self.table.c1.ilike('foo'),
                    ~NotILike(self.table.c1, 'foo')
            ]:
                self.assertEqual(str(like), '("c1" ILIKE %s)')
                self.assertEqual(like.params, ('foo', ))
        finally:
            Flavor.set(Flavor())

        flavor = Flavor(ilike=False)
        Flavor.set(flavor)
        try:
            like = ILike(self.table.c1, 'foo')
            self.assertEqual(str(like), '("c1" LIKE %s)')
            self.assertEqual(like.params, ('foo', ))
        finally:
            Flavor.set(Flavor())
Esempio n. 13
0
class Database(DatabaseInterface):

    _lock = RLock()
    _databases = {}
    _connpool = None
    _list_cache = None
    _list_cache_timestamp = None
    _version_cache = {}
    _search_path = None
    _current_user = None
    _has_returning = None
    flavor = Flavor(ilike=True)

    TYPES_MAPPING = {
        'INTEGER': SQLType('INT4', 'INT4'),
        'BIGINT': SQLType('INT8', 'INT8'),
        'FLOAT': SQLType('FLOAT8', 'FLOAT8'),
        'BLOB': SQLType('BYTEA', 'BYTEA'),
        'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP(0)'),
        'TIMESTAMP': SQLType('TIMESTAMP', 'TIMESTAMP(6)'),
    }

    def __new__(cls, name='template1'):
        with cls._lock:
            if name in cls._databases:
                return cls._databases[name]
            inst = DatabaseInterface.__new__(cls, name=name)

            logger.info('connect to "%s"', name)
            minconn = config.getint('database', 'minconn', default=1)
            maxconn = config.getint('database', 'maxconn', default=64)
            inst._connpool = ThreadedConnectionPool(minconn,
                                                    maxconn,
                                                    cls.dsn(name),
                                                    cursor_factory=PerfCursor)

            cls._databases[name] = inst
            return inst

    @classmethod
    def dsn(cls, name):
        uri = parse_uri(config.get('database', 'uri'))
        host = uri.hostname and "host=%s" % uri.hostname or ''
        port = uri.port and "port=%s" % uri.port or ''
        name = "dbname=%s" % name
        user = uri.username and "user=%s" % uri.username or ''
        password = ("password=%s" %
                    urllib.unquote_plus(uri.password) if uri.password else '')
        return '%s %s %s %s %s' % (host, port, name, user, password)

    def connect(self):
        return self

    def get_connection(self, autocommit=False, readonly=False):
        for count in range(config.getint('database', 'retry'), -1, -1):
            try:
                conn = self._connpool.getconn()
                break
            except PoolError:
                if count and not self._connpool.closed:
                    logger.info('waiting a connection')
                    time.sleep(1)
                    continue
                raise
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        if readonly:
            cursor = conn.cursor()
            cursor.execute('SET TRANSACTION READ ONLY')
        conn.cursor_factory = PerfCursor
        return conn

    def put_connection(self, connection, close=False):
        self._connpool.putconn(connection, close=close)

    def close(self):
        with self._lock:
            self._connpool.closeall()
            self._databases.pop(self.name)

    @classmethod
    def create(cls, connection, database_name, template='template0'):
        cursor = connection.cursor()
        cursor.execute('CREATE DATABASE "' + database_name + '" '
                       'TEMPLATE "' + template + '" ENCODING \'unicode\'')
        connection.commit()
        cls._list_cache = None

    def drop(self, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('DROP DATABASE "' + database_name + '"')
        self.__class__._list_cache = None

    def get_version(self, connection):
        if self.name not in self._version_cache:
            cursor = connection.cursor()
            cursor.execute('SHOW server_version_num')
            version, = cursor.fetchone()
            major, rest = divmod(int(version), 10000)
            minor, patch = divmod(rest, 100)
            self._version_cache[self.name] = (major, minor, patch)
        return self._version_cache[self.name]

    def list(self):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = self.__class__._list_cache
        if res and abs(self.__class__._list_cache_timestamp - now) < timeout:
            return res

        connection = self.get_connection()
        try:
            cursor = connection.cursor()
            cursor.execute('SELECT datname FROM pg_database '
                           'WHERE datistemplate = false ORDER BY datname')
            res = []
            for db_name, in cursor:
                try:
                    with connect(self.dsn(db_name)) as conn:
                        if self._test(conn):
                            res.append(db_name)
                except Exception:
                    continue
        finally:
            self.put_connection(connection)

        self.__class__._list_cache = res
        self.__class__._list_cache_timestamp = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        try:
            cursor = connection.cursor()
            sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
            with open(sql_file) as fp:
                for line in fp.read().split(';'):
                    if (len(line) > 0) and (not line.isspace()):
                        cursor.execute(line)

            for module in ('ir', 'res'):
                state = 'not activated'
                if module in ('ir', 'res'):
                    state = 'to activate'
                info = get_module_info(module)
                cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
                module_id = cursor.fetchone()[0]
                cursor.execute(
                    'INSERT INTO ir_module '
                    '(id, create_uid, create_date, name, state) '
                    'VALUES (%s, %s, now(), %s, %s)',
                    (module_id, 0, module, state))
                for dependency in info.get('depends', []):
                    cursor.execute(
                        'INSERT INTO ir_module_dependency '
                        '(create_uid, create_date, module, name) '
                        'VALUES (%s, now(), %s, %s)',
                        (0, module_id, dependency))

            connection.commit()
        finally:
            self.put_connection(connection)

    def test(self):
        connection = self.get_connection()
        is_tryton_database = self._test(connection)
        self.put_connection(connection)
        return is_tryton_database

    @classmethod
    def _test(cls, connection):
        cursor = connection.cursor()
        cursor.execute(
            'SELECT 1 FROM information_schema.tables '
            'WHERE table_name IN %s',
            (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
              'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
              'ir_translation', 'ir_lang'), ))
        return len(cursor.fetchall()) != 0

    def nextid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')")
        return cursor.fetchone()[0]

    def setnextid(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value)

    def currid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('SELECT last_value FROM "' + table + '_id_seq"')
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table)

    def has_constraint(self):
        return True

    def has_multirow_insert(self):
        return True

    def get_table_schema(self, connection, table_name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute(
                'SELECT 1 '
                'FROM information_schema.tables '
                'WHERE table_name = %s AND table_schema = %s',
                (table_name, schema))
            if cursor.rowcount:
                return schema

    @property
    def current_user(self):
        if self._current_user is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SELECT current_user')
                self._current_user = cursor.fetchone()[0]
            finally:
                self.put_connection(connection)
        return self._current_user

    @property
    def search_path(self):
        if self._search_path is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SHOW search_path')
                path, = cursor.fetchone()
                special_values = {
                    'user': self.current_user,
                }
                self._search_path = [
                    unescape_quote(
                        replace_special_values(p.strip(), **special_values))
                    for p in path.split(',')
                ]
            finally:
                self.put_connection(connection)
        return self._search_path

    def has_returning(self):
        if self._has_returning is None:
            connection = self.get_connection()
            try:
                # RETURNING clause is available since PostgreSQL 8.2
                self._has_returning = self.get_version(connection) >= (8, 2)
            finally:
                self.put_connection(connection)
        return self._has_returning

    def has_select_for(self):
        return True

    def has_window_functions(self):
        return True

    @classmethod
    def has_sequence(cls):
        return True

    def sql_type(self, type_):
        if type_ in self.TYPES_MAPPING:
            return self.TYPES_MAPPING[type_]
        if type_.startswith('VARCHAR'):
            return SQLType('VARCHAR', type_)
        return SQLType(type_, type_)

    def sql_format(self, type_, value):
        if type_ == 'BLOB':
            if value is not None:
                return Binary(value)
        return value

    def sequence_exist(self, connection, name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute(
                'SELECT 1 '
                'FROM information_schema.sequences '
                'WHERE sequence_name = %s AND sequence_schema = %s',
                (name, schema))
            if cursor.rowcount:
                return True
        return False

    def sequence_create(self,
                        connection,
                        name,
                        number_increment=1,
                        start_value=1):
        cursor = connection.cursor()

        param = self.flavor.param
        cursor.execute(
            'CREATE SEQUENCE "%s" '
            'INCREMENT BY %s '
            'START WITH %s' % (name, param, param),
            (number_increment, start_value))

    def sequence_update(self,
                        connection,
                        name,
                        number_increment=1,
                        start_value=1):
        cursor = connection.cursor()
        param = self.flavor.param
        cursor.execute(
            'ALTER SEQUENCE "%s" '
            'INCREMENT BY %s '
            'RESTART WITH %s' % (name, param, param),
            (number_increment, start_value))

    def sequence_rename(self, connection, old_name, new_name):
        cursor = connection.cursor()
        if (self.sequence_exist(connection, old_name)
                and not self.sequence_exist(connection, new_name)):
            cursor.execute('ALTER TABLE "%s" RENAME TO "%s"' %
                           (old_name, new_name))

    def sequence_delete(self, connection, name):
        cursor = connection.cursor()
        cursor.execute('DROP SEQUENCE "%s"' % name)

    def sequence_next_number(self, connection, name):
        cursor = connection.cursor()
        version = self.get_version(connection)
        if version >= (10, 0):
            cursor.execute(
                'SELECT increment_by '
                'FROM pg_sequences '
                'WHERE sequencename=%s ' % self.flavor.param, (name, ))
            increment, = cursor.fetchone()
            cursor.execute(
                'SELECT CASE WHEN NOT is_called THEN last_value '
                'ELSE last_value + %s '
                'END '
                'FROM "%s"' % (self.flavor.param, name), (increment, ))
        else:
            cursor.execute('SELECT CASE WHEN NOT is_called THEN last_value '
                           'ELSE last_value + increment_by '
                           'END '
                           'FROM "%s"' % name)
        return cursor.fetchone()[0]
Esempio n. 14
0
class Database(DatabaseInterface):

    _databases = {}
    _connpool = None
    _list_cache = None
    _list_cache_timestamp = None
    _version_cache = {}
    flavor = Flavor(ilike=True)

    def __new__(cls, database_name='template1'):
        if database_name in cls._databases:
            return cls._databases[database_name]
        return DatabaseInterface.__new__(cls, database_name=database_name)

    def __init__(self, database_name='template1'):
        super(Database, self).__init__(database_name=database_name)
        self._databases.setdefault(database_name, self)

    def connect(self):
        if self._connpool is not None:
            return self
        logger.info('connect to "%s"', self.database_name)
        uri = parse_uri(config.get('database', 'uri'))
        assert uri.scheme == 'postgresql'
        host = uri.hostname and "host=%s" % uri.hostname or ''
        port = uri.port and "port=%s" % uri.port or ''
        name = "dbname=%s" % self.database_name
        user = uri.username and "user=%s" % uri.username or ''
        password = ("password=%s" %
                    urllib.unquote_plus(uri.password) if uri.password else '')
        minconn = config.getint('database', 'minconn', default=1)
        maxconn = config.getint('database', 'maxconn', default=64)
        dsn = '%s %s %s %s %s' % (host, port, name, user, password)
        self._connpool = ThreadedConnectionPool(minconn, maxconn, dsn)
        return self

    def cursor(self, autocommit=False, readonly=False):
        if self._connpool is None:
            self.connect()
        conn = self._connpool.getconn()
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        cursor = Cursor(self._connpool, conn, self)
        if readonly:
            cursor.execute('SET TRANSACTION READ ONLY')
        return cursor

    def close(self):
        if self._connpool is None:
            return
        self._connpool.closeall()
        self._connpool = None

    @classmethod
    def create(cls, cursor, database_name):
        cursor.execute('CREATE DATABASE "' + database_name + '" '
                       'TEMPLATE template0 ENCODING \'unicode\'')
        cls._list_cache = None

    @classmethod
    def drop(cls, cursor, database_name):
        cursor.execute('DROP DATABASE "' + database_name + '"')
        cls._list_cache = None

    def get_version(self, cursor):
        if self.database_name not in self._version_cache:
            cursor.execute('SELECT version()')
            version, = cursor.fetchone()
            self._version_cache[self.database_name] = tuple(
                map(int,
                    RE_VERSION.search(version).groups()))
        return self._version_cache[self.database_name]

    @staticmethod
    def dump(database_name):
        from trytond.tools import exec_command_pipe

        cmd = ['pg_dump', '--format=c', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            # if db_password is set in configuration we should pass
            # an environment variable PGPASSWORD to our subprocess
            # see libpg documentation
            env['PGPASSWORD'] = uri.password
        cmd.append(database_name)

        pipe = exec_command_pipe(*tuple(cmd), env=env)
        pipe.stdin.close()
        data = pipe.stdout.read()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t dump database!')
        return data

    @staticmethod
    def restore(database_name, data):
        from trytond.tools import exec_command_pipe

        database = Database().connect()
        cursor = database.cursor(autocommit=True)
        database.create(cursor, database_name)
        cursor.commit()
        cursor.close()
        database.close()

        cmd = ['pg_restore', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            env['PGPASSWORD'] = uri.password
        cmd.append('--dbname=' + database_name)
        args2 = tuple(cmd)

        if os.name == "nt":
            tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam()
            with open(tmpfile, 'wb') as fp:
                fp.write(data)
            args2 = list(args2)
            args2.append(' ' + tmpfile)
            args2 = tuple(args2)

        pipe = exec_command_pipe(*args2, env=env)
        if not os.name == "nt":
            pipe.stdin.write(data)
        pipe.stdin.close()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t restore database')

        database = Database(database_name).connect()
        cursor = database.cursor()
        if not cursor.test():
            cursor.close()
            database.close()
            raise Exception('Couldn\'t restore database!')
        cursor.close()
        database.close()
        Database._list_cache = None
        return True

    @staticmethod
    def list(cursor):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = Database._list_cache
        if res and abs(Database._list_cache_timestamp - now) < timeout:
            return res
        uri = parse_uri(config.get('database', 'uri'))
        db_user = uri.username or os.environ.get('PGUSER')
        if not db_user and os.name == 'posix':
            db_user = pwd.getpwuid(os.getuid())[0]
        if db_user:
            cursor.execute(
                "SELECT datname "
                "FROM pg_database "
                "WHERE datdba = ("
                "SELECT usesysid "
                "FROM pg_user "
                "WHERE usename=%s) "
                "AND datname not in "
                "('template0', 'template1', 'postgres') "
                "ORDER BY datname", (db_user, ))
        else:
            cursor.execute("SELECT datname "
                           "FROM pg_database "
                           "WHERE datname not in "
                           "('template0', 'template1','postgres') "
                           "ORDER BY datname")
        res = []
        for db_name, in cursor.fetchall():
            db_name = db_name.encode('utf-8')
            try:
                database = Database(db_name).connect()
            except Exception:
                continue
            cursor2 = database.cursor()
            if cursor2.test():
                res.append(db_name)
                cursor2.close(close=True)
            else:
                cursor2.close(close=True)
                database.close()
        Database._list_cache = res
        Database._list_cache_timestamp = now
        return res

    @staticmethod
    def init(cursor):
        from trytond.modules import get_module_info
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        for module in ('ir', 'res', 'webdav'):
            state = 'uninstalled'
            if module in ('ir', 'res'):
                state = 'to install'
            info = get_module_info(module)
            cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
            module_id = cursor.fetchone()[0]
            cursor.execute(
                'INSERT INTO ir_module '
                '(id, create_uid, create_date, name, state) '
                'VALUES (%s, %s, now(), %s, %s)',
                (module_id, 0, module, state))
            for dependency in info.get('depends', []):
                cursor.execute(
                    'INSERT INTO ir_module_dependency '
                    '(create_uid, create_date, module, name) '
                    'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))
Esempio n. 15
0
class Database(DatabaseInterface):

    _lock = RLock()
    _databases = defaultdict(dict)
    _connpool = None
    _list_cache = {}
    _list_cache_timestamp = {}
    _search_path = None
    _current_user = None
    _has_returning = None
    _has_select_for_skip_locked = None
    _has_unaccent = {}
    flavor = Flavor(ilike=True)

    TYPES_MAPPING = {
        'INTEGER': SQLType('INT4', 'INT4'),
        'BIGINT': SQLType('INT8', 'INT8'),
        'FLOAT': SQLType('FLOAT8', 'FLOAT8'),
        'BLOB': SQLType('BYTEA', 'BYTEA'),
        'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP(0)'),
        'TIMESTAMP': SQLType('TIMESTAMP', 'TIMESTAMP(6)'),
    }

    def __new__(cls, name=_default_name):
        with cls._lock:
            now = datetime.now()
            databases = cls._databases[os.getpid()]
            for database in list(databases.values()):
                if ((now - database._last_use).total_seconds() > _timeout
                        and database.name != name
                        and not database._connpool._used):
                    database.close()
            if name in databases:
                inst = databases[name]
            else:
                if name == _default_name:
                    minconn = 0
                else:
                    minconn = _minconn
                inst = DatabaseInterface.__new__(cls, name=name)
                try:
                    inst._connpool = ThreadedConnectionPool(
                        minconn,
                        _maxconn,
                        **cls._connection_params(name),
                        cursor_factory=LoggingCursor)
                    logger.info('connected to "%s"', name)
                except Exception:
                    logger.error('connection to "%s" failed',
                                 name,
                                 exc_info=True)
                    raise
                else:
                    logger.info('connection to "%s" succeeded', name)
                databases[name] = inst
            inst._last_use = datetime.now()
            return inst

    def __init__(self, name=_default_name):
        super(Database, self).__init__(name)

    @classmethod
    def _connection_params(cls, name):
        uri = parse_uri(config.get('database', 'uri'))
        params = {
            'dbname': name,
        }
        if uri.username:
            params['user'] = uri.username
        if uri.password:
            params['password'] = urllib.parse.unquote_plus(uri.password)
        if uri.hostname:
            params['host'] = uri.hostname
        if uri.port:
            params['port'] = uri.port
        return params

    def _kill_session_query(self, database_name):
        return 'SELECT pg_terminate_backend(pg_stat_activity.pid) ' \
            'FROM pg_stat_activity WHERE pg_stat_activity.datname = \'%s\'' \
            ' AND pid <> pg_backend_pid();' % database_name

    def connect(self):
        return self

    def get_connection(self, autocommit=False, readonly=False):
        for count in range(config.getint('database', 'retry'), -1, -1):
            try:
                conn = self._connpool.getconn()
                break
            except PoolError:
                if count and not self._connpool.closed:
                    logger.info('waiting a connection')
                    time.sleep(1)
                    continue
                raise
            except Exception:
                logger.error('connection to "%s" failed',
                             self.name,
                             exc_info=True)
                raise
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        if readonly:
            cursor = conn.cursor()
            cursor.execute('SET TRANSACTION READ ONLY')
        conn.cursor_factory = PerfCursor
        return conn

    def put_connection(self, connection, close=False):
        self._connpool.putconn(connection, close=close)

    def close(self):
        with self._lock:
            logger.info('disconnection from "%s"', self.name)
            self._connpool.closeall()
            self._databases[os.getpid()].pop(self.name)

    @classmethod
    def create(cls, connection, database_name, template='template0'):
        cursor = connection.cursor()
        cursor.execute('CREATE DATABASE "' + database_name + '" '
                       'TEMPLATE "' + template + '" ENCODING \'unicode\'')
        connection.commit()
        cls._list_cache.clear()

    def drop(self, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('DROP DATABASE "' + database_name + '"')
        self.__class__._list_cache.clear()

    def get_version(self, connection):
        version = connection.server_version
        major, rest = divmod(int(version), 10000)
        minor, patch = divmod(rest, 100)
        return (major, minor, patch)

    def list(self, hostname=None):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = self.__class__._list_cache.get(hostname)
        timestamp = self.__class__._list_cache_timestamp.get(hostname, now)
        if res and abs(timestamp - now) < timeout:
            return res

        connection = self.get_connection()
        try:
            cursor = connection.cursor()
            cursor.execute('SELECT datname FROM pg_database '
                           'WHERE datistemplate = false ORDER BY datname')
            res = []
            for db_name, in cursor:
                try:
                    with connect(**self._connection_params(db_name)) as conn:
                        if self._test(conn, hostname=hostname):
                            res.append(db_name)
                        conn.close()
                except Exception:
                    logger.debug('Test failed for "%s"',
                                 db_name,
                                 exc_info=True)
                    continue
        finally:
            self.put_connection(connection)

        self.__class__._list_cache[hostname] = res
        self.__class__._list_cache_timestamp[hostname] = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        try:
            cursor = connection.cursor()
            sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
            with open(sql_file) as fp:
                for line in fp.read().split(';'):
                    if (len(line) > 0) and (not line.isspace()):
                        cursor.execute(line)

            for module in ('ir', 'res'):
                state = 'not activated'
                if module in ('ir', 'res'):
                    state = 'to activate'
                info = get_module_info(module)
                cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
                module_id = cursor.fetchone()[0]
                cursor.execute(
                    'INSERT INTO ir_module '
                    '(id, create_uid, create_date, name, state) '
                    'VALUES (%s, %s, now(), %s, %s)',
                    (module_id, 0, module, state))
                for dependency in info.get('depends', []):
                    cursor.execute(
                        'INSERT INTO ir_module_dependency '
                        '(create_uid, create_date, module, name) '
                        'VALUES (%s, now(), %s, %s)',
                        (0, module_id, dependency))

            connection.commit()
        finally:
            self.put_connection(connection)

    def test(self, hostname=None):
        try:
            connection = self.get_connection()
        except Exception:
            logger.debug('Test failed for "%s"', self.name, exc_info=True)
            return False
        try:
            return self._test(connection, hostname=hostname)
        finally:
            self.put_connection(connection)

    @classmethod
    def _test(cls, connection, hostname=None):
        cursor = connection.cursor()
        tables = ('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
                  'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
                  'ir_translation', 'ir_lang', 'ir_configuration')
        cursor.execute(
            'SELECT table_name FROM information_schema.tables '
            'WHERE table_name IN %s', (tables, ))
        if len(cursor.fetchall()) != len(tables):
            return False
        if hostname:
            try:
                cursor.execute('SELECT hostname FROM ir_configuration')
                hostnames = {h for h, in cursor.fetchall() if h}
                if hostnames and hostname not in hostnames:
                    return False
            except ProgrammingError:
                pass
        return True

    def nextid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')")
        return cursor.fetchone()[0]

    def setnextid(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value)

    def currid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('SELECT last_value FROM "' + table + '_id_seq"')
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table)

    def lock_id(self, id, timeout=None):
        if not timeout:
            return TryAdvisoryLock(id)
        else:
            return AdvisoryLock(id)

    def has_constraint(self, constraint):
        return True

    def has_multirow_insert(self):
        return True

    def get_table_schema(self, connection, table_name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute(
                'SELECT 1 '
                'FROM information_schema.tables '
                'WHERE table_name = %s AND table_schema = %s',
                (table_name, schema))
            if cursor.rowcount:
                return schema

    @property
    def current_user(self):
        if self._current_user is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SELECT current_user')
                self._current_user = cursor.fetchone()[0]
            finally:
                self.put_connection(connection)
        return self._current_user

    @property
    def search_path(self):
        if self._search_path is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SHOW search_path')
                path, = cursor.fetchone()
                special_values = {
                    'user': self.current_user,
                }
                self._search_path = [
                    unescape_quote(
                        replace_special_values(p.strip(), **special_values))
                    for p in path.split(',')
                ]
            finally:
                self.put_connection(connection)
        return self._search_path

    def has_returning(self):
        if self._has_returning is None:
            connection = self.get_connection()
            try:
                # RETURNING clause is available since PostgreSQL 8.2
                self._has_returning = self.get_version(connection) >= (8, 2)
            finally:
                self.put_connection(connection)
        return self._has_returning

    def has_select_for(self):
        return True

    def get_select_for_skip_locked(self):
        if self._has_select_for_skip_locked is None:
            connection = self.get_connection()
            try:
                # SKIP LOCKED clause is available since PostgreSQL 9.5
                self._has_select_for_skip_locked = (
                    self.get_version(connection) >= (9, 5))
            finally:
                self.put_connection(connection)
        if self._has_select_for_skip_locked:
            return ForSkipLocked
        else:
            return For

    def has_window_functions(self):
        return True

    @classmethod
    def has_sequence(cls):
        return True

    def has_unaccent(self):
        if self.name in self._has_unaccent:
            return self._has_unaccent[self.name]
        connection = self.get_connection()
        unaccent = False
        try:
            cursor = connection.cursor()
            cursor.execute("SELECT 1 FROM pg_proc WHERE proname=%s",
                           (Unaccent._function, ))
            unaccent = bool(cursor.rowcount)
        finally:
            self.put_connection(connection)
        self._has_unaccent[self.name] = unaccent
        return unaccent

    def sql_type(self, type_):
        if type_ in self.TYPES_MAPPING:
            return self.TYPES_MAPPING[type_]
        if type_.startswith('VARCHAR'):
            return SQLType('VARCHAR', type_)
        return SQLType(type_, type_)

    def sql_format(self, type_, value):
        if type_ == 'BLOB':
            if value is not None:
                return Binary(value)
        return value

    def unaccent(self, value):
        if self.has_unaccent():
            return Unaccent(value)
        return value

    def sequence_exist(self, connection, name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute(
                'SELECT 1 '
                'FROM information_schema.sequences '
                'WHERE sequence_name = %s AND sequence_schema = %s',
                (name, schema))
            if cursor.rowcount:
                return True
        return False

    def sequence_create(self,
                        connection,
                        name,
                        number_increment=1,
                        start_value=1):
        cursor = connection.cursor()

        param = self.flavor.param
        cursor.execute(
            'CREATE SEQUENCE "%s" '
            'INCREMENT BY %s '
            'START WITH %s' % (name, param, param),
            (number_increment, start_value))

    def sequence_update(self,
                        connection,
                        name,
                        number_increment=1,
                        start_value=1):
        cursor = connection.cursor()
        param = self.flavor.param
        cursor.execute(
            'ALTER SEQUENCE "%s" '
            'INCREMENT BY %s '
            'RESTART WITH %s' % (name, param, param),
            (number_increment, start_value))

    def sequence_rename(self, connection, old_name, new_name):
        cursor = connection.cursor()
        if (self.sequence_exist(connection, old_name)
                and not self.sequence_exist(connection, new_name)):
            cursor.execute('ALTER TABLE "%s" RENAME TO "%s"' %
                           (old_name, new_name))

    def sequence_delete(self, connection, name):
        cursor = connection.cursor()
        cursor.execute('DROP SEQUENCE "%s"' % name)

    def sequence_next_number(self, connection, name):
        cursor = connection.cursor()
        version = self.get_version(connection)
        if version >= (10, 0):
            cursor.execute(
                'SELECT increment_by '
                'FROM pg_sequences '
                'WHERE sequencename=%s ' % self.flavor.param, (name, ))
            increment, = cursor.fetchone()
            cursor.execute(
                'SELECT CASE WHEN NOT is_called THEN last_value '
                'ELSE last_value + %s '
                'END '
                'FROM "%s"' % (self.flavor.param, name), (increment, ))
        else:
            cursor.execute('SELECT CASE WHEN NOT is_called THEN last_value '
                           'ELSE last_value + increment_by '
                           'END '
                           'FROM "%s"' % name)
        return cursor.fetchone()[0]

    def has_channel(self):
        return True

    def json_get(self, column, key=None):
        column = Cast(column, 'jsonb')
        if key:
            column = JSONBExtractPath(column, key)
        return column

    def json_key_exists(self, column, key):
        return JSONKeyExists(Cast(column, 'jsonb'), key)

    def json_any_keys_exist(self, column, keys):
        return JSONAnyKeyExist(Cast(column, 'jsonb'), keys)

    def json_all_keys_exist(self, column, keys):
        return JSONAllKeyExist(Cast(column, 'jsonb'), keys)

    def json_contains(self, column, json):
        return JSONContains(Cast(column, 'jsonb'), Cast(json, 'jsonb'))
Esempio n. 16
0
class Database(DatabaseInterface):

    _local = threading.local()
    _conn = None
    flavor = Flavor(paramstyle='qmark', function_mapping=MAPPING)
    IN_MAX = 200

    def __new__(cls, name=':memory:'):
        if (name == ':memory:'
                and getattr(cls._local, 'memory_database', None)):
            return cls._local.memory_database
        return DatabaseInterface.__new__(cls, name=name)

    def __init__(self, name=':memory:'):
        super(Database, self).__init__(name=name)
        if name == ':memory:':
            Database._local.memory_database = self

    def connect(self):
        if self.name == ':memory:':
            path = ':memory:'
        else:
            db_filename = self.name + '.sqlite'
            path = os.path.join(config.get('database', 'path'), db_filename)
            if not os.path.isfile(path):
                raise IOError('Database "%s" doesn\'t exist!' % db_filename)
        if self._conn is not None:
            return self
        self._conn = sqlite.connect(path,
                                    detect_types=sqlite.PARSE_DECLTYPES
                                    | sqlite.PARSE_COLNAMES,
                                    factory=SQLiteConnection)
        self._conn.create_function('extract', 2, SQLiteExtract.extract)
        self._conn.create_function('date_trunc', 2, date_trunc)
        self._conn.create_function('split_part', 3, split_part)
        self._conn.create_function('position', 2, SQLitePosition.position)
        self._conn.create_function('overlay', 3, SQLiteOverlay.overlay)
        self._conn.create_function('overlay', 4, SQLiteOverlay.overlay)
        if sqlite.sqlite_version_info < (3, 3, 14):
            self._conn.create_function('replace', 3, replace)
        self._conn.create_function('now', 0, now)
        self._conn.create_function('sign', 1, sign)
        self._conn.create_function('greatest', -1, greatest)
        self._conn.create_function('least', -1, least)
        self._conn.execute('PRAGMA foreign_keys = ON')
        return self

    def get_connection(self, autocommit=False, readonly=False):
        if self._conn is None:
            self.connect()
        if autocommit:
            self._conn.isolation_level = None
        else:
            self._conn.isolation_level = 'IMMEDIATE'
        return self._conn

    def put_connection(self, connection=None, close=False):
        pass

    def close(self):
        if self.name == ':memory:':
            return
        if self._conn is None:
            return
        self._conn = None

    @classmethod
    def create(cls, connection, database_name):
        if database_name == ':memory:':
            path = ':memory:'
        else:
            if os.sep in database_name:
                return
            path = os.path.join(config.get('database', 'path'),
                                database_name + '.sqlite')
        with sqlite.connect(path) as conn:
            cursor = conn.cursor()
            cursor.close()

    def drop(self, connection, database_name):
        if database_name == ':memory:':
            self._local.memory_database._conn = None
            return
        if os.sep in database_name:
            return
        os.remove(
            os.path.join(config.get('database', 'path'),
                         database_name + '.sqlite'))

    def list(self):
        res = []
        listdir = [':memory:']
        try:
            listdir += os.listdir(config.get('database', 'path'))
        except OSError:
            pass
        for db_file in listdir:
            if db_file.endswith('.sqlite') or db_file == ':memory:':
                if db_file == ':memory:':
                    db_name = ':memory:'
                else:
                    db_name = db_file[:-7]
                try:
                    database = Database(db_name).connect()
                except Exception:
                    continue
                if database.test():
                    res.append(db_name)
                database.close()
        return res

    def init(self):
        from trytond.modules import get_module_info
        with self._conn as conn:
            cursor = conn.cursor()
            sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
            with open(sql_file) as fp:
                for line in fp.read().split(';'):
                    if (len(line) > 0) and (not line.isspace()):
                        cursor.execute(line)

            ir_module = Table('ir_module')
            ir_module_dependency = Table('ir_module_dependency')
            for module in ('ir', 'res'):
                state = 'uninstalled'
                if module in ('ir', 'res'):
                    state = 'to install'
                info = get_module_info(module)
                insert = ir_module.insert([
                    ir_module.create_uid, ir_module.create_date,
                    ir_module.name, ir_module.state
                ], [[0, CurrentTimestamp(), module, state]])
                cursor.execute(*insert)
                cursor.execute('SELECT last_insert_rowid()')
                module_id, = cursor.fetchone()
                for dependency in info.get('depends', []):
                    insert = ir_module_dependency.insert([
                        ir_module_dependency.create_uid,
                        ir_module_dependency.create_date,
                        ir_module_dependency.module,
                        ir_module_dependency.name,
                    ], [[0, CurrentTimestamp(), module_id, dependency]])
                    cursor.execute(*insert)
            conn.commit()

    def test(self):
        sqlite_master = Table('sqlite_master')
        select = sqlite_master.select(sqlite_master.name)
        select.where = sqlite_master.type == 'table'
        select.where &= sqlite_master.name.in_([
            'ir_model',
            'ir_model_field',
            'ir_ui_view',
            'ir_ui_menu',
            'res_user',
            'res_group',
            'ir_module',
            'ir_module_dependency',
            'ir_translation',
            'ir_lang',
        ])
        with self._conn as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(*select)
            except Exception:
                return False
            return len(cursor.fetchall()) != 0

    def lastid(self, cursor):
        # This call is not thread safe
        return cursor.lastrowid

    def lock(self, connection, table):
        pass

    def has_constraint(self):
        return False

    def has_multirow_insert(self):
        return True
Esempio n. 17
0
class Database(DatabaseInterface):

    _list_cache = None
    _list_cache_timestamp = None
    flavor = Flavor(max_limit=18446744073709551610, function_mapping=MAPPING)

    def connect(self):
        return self

    def get_connection(self, autocommit=False, readonly=False):
        conv = MySQLdb.converters.conversions.copy()
        conv[float] = lambda value, _: repr(value)
        conv[MySQLdb.constants.FIELD_TYPE.TIME] = MySQLdb.times.Time_or_None
        args = {
            'db': self.name,
            'sql_mode': 'traditional,postgresql',
            'use_unicode': True,
            'charset': 'utf8',
            'conv': conv,
        }
        uri = parse_uri(config.get('database', 'uri'))
        assert uri.scheme == 'mysql'
        if uri.hostname:
            args['host'] = uri.hostname
        if uri.port:
            args['port'] = uri.port
        if uri.username:
            args['user'] = uri.username
        if uri.password:
            args['passwd'] = urllib.unquote_plus(uri.password)
        conn = MySQLdb.connect(**args)
        cursor = conn.cursor()
        cursor.execute('SET time_zone = "+00:00"')
        return conn

    def put_connection(self, connection, close=False):
        connection.close()

    def close(self):
        return

    @classmethod
    def create(cls, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('CREATE DATABASE `' + database_name + '` '
            'DEFAULT CHARACTER SET = \'utf8\'')
        cls._list_cache = None

    @classmethod
    def drop(cls, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('DROP DATABASE `' + database_name + '`')
        cls._list_cache = None

    def list(self):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = Database._list_cache
        if res and abs(Database._list_cache_timestamp - now) < timeout:
            return res
        conn = self.get_connection()
        cursor = conn.cursor()
        cursor.execute('SHOW DATABASES')
        res = []
        for db_name, in cursor.fetchall():
            try:
                database = Database(db_name).connect()
            except Exception:
                continue
            if database.test():
                res.append(db_name)
            database.close()
        self.put_connection(conn)
        Database._list_cache = res
        Database._list_cache_timestamp = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        cursor = connection.cursor()
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        for module in ('ir', 'res'):
            state = 'not activated'
            if module in ('ir', 'res'):
                state = 'to activate'
            info = get_module_info(module)
            cursor.execute('INSERT INTO ir_module '
                '(create_uid, create_date, name, state) '
                'VALUES (%s, now(), %s, %s)',
                (0, module, state))
            cursor.execute('SELECT LAST_INSERT_ID()')
            module_id, = cursor.fetchone()
            for dependency in info.get('depends', []):
                cursor.execute('INSERT INTO ir_module_dependency '
                    '(create_uid, create_date, module, name) '
                    'VALUES (%s, now(), %s, %s)',
                    (0, module_id, dependency))

        connection.commit()
        self.put_connection(connection)

    def test(self):
        is_tryton_database = False
        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute("SHOW TABLES")
        for table, in cursor.fetchall():
            if table in (
                    'ir_model',
                    'ir_model_field',
                    'ir_ui_view',
                    'ir_ui_menu',
                    'res_user',
                    'res_group',
                    'ir_module',
                    'ir_module_dependency',
                    'ir_translation',
                    'ir_lang',
                    ):
                is_tryton_database = True
                break
        self.put_connection(connection)
        return is_tryton_database

    def lastid(self, cursor):
        # This call is not thread safe
        cursor.execute('SELECT LAST_INSERT_ID()')
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        # Lock of table doesn't work because MySQL require
        # that the session locks all tables that will be accessed
        # but 'FLUSH TABLES WITH READ LOCK' creates deadlock
        pass

    def has_constraint(self):
        return False

    def has_multirow_insert(self):
        return True

    def update_auto_increment(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute('ALTER TABLE `%s` AUTO_INCREMENT = %%s' % table,
                (value,))
Esempio n. 18
0
class Database(DatabaseInterface):

    _local = threading.local()
    _conn = None
    flavor = Flavor(paramstyle='qmark',
                    function_mapping=MAPPING,
                    null_ordering=False)
    IN_MAX = 200

    TYPES_MAPPING = {
        'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP'),
        'BIGINT': SQLType('INTEGER', 'INTEGER'),
        'BOOL': SQLType('BOOLEAN', 'BOOLEAN'),
    }

    def __new__(cls, name=':memory:'):
        if (name == ':memory:'
                and getattr(cls._local, 'memory_database', None)):
            return cls._local.memory_database
        return DatabaseInterface.__new__(cls, name=name)

    def __init__(self, name=':memory:'):
        super(Database, self).__init__(name=name)
        if name == ':memory:':
            Database._local.memory_database = self

    def connect(self):
        if self.name == ':memory:':
            path = ':memory:'
        else:
            db_filename = self.name + '.sqlite'
            path = os.path.join(config.get('database', 'path'), db_filename)
            if not os.path.isfile(path):
                raise IOError('Database "%s" doesn\'t exist!' % db_filename)
        if self._conn is not None:
            return self
        self._conn = sqlite.connect(path,
                                    detect_types=sqlite.PARSE_DECLTYPES
                                    | sqlite.PARSE_COLNAMES,
                                    factory=SQLiteConnection)
        self._conn.create_function('extract', 2, SQLiteExtract.extract)
        self._conn.create_function('date_trunc', 2, date_trunc)
        self._conn.create_function('split_part', 3, split_part)
        self._conn.create_function('position', 2, SQLitePosition.position)
        self._conn.create_function('to_char', 2, to_char)
        self._conn.create_function('overlay', 3, SQLiteOverlay.overlay)
        self._conn.create_function('overlay', 4, SQLiteOverlay.overlay)
        if sqlite.sqlite_version_info < (3, 3, 14):
            self._conn.create_function('replace', 3, replace)
        self._conn.create_function('now', 0, now)
        self._conn.create_function('sign', 1, sign)
        self._conn.create_function('greatest', -1, greatest)
        self._conn.create_function('least', -1, least)
        if (hasattr(self._conn, 'set_trace_callback')
                and logger.isEnabledFor(logging.DEBUG)):
            self._conn.set_trace_callback(logger.debug)
        self._conn.execute('PRAGMA foreign_keys = ON')
        return self

    def get_connection(self, autocommit=False, readonly=False):
        if self._conn is None:
            self.connect()
        if autocommit:
            self._conn.isolation_level = None
        else:
            self._conn.isolation_level = 'IMMEDIATE'
        return self._conn

    def put_connection(self, connection=None, close=False):
        pass

    def close(self):
        if self.name == ':memory:':
            return
        if self._conn is None:
            return
        self._conn = None

    @classmethod
    def create(cls, connection, database_name):
        if database_name == ':memory:':
            path = ':memory:'
        else:
            if os.sep in database_name:
                return
            path = os.path.join(config.get('database', 'path'),
                                database_name + '.sqlite')
        with sqlite.connect(path) as conn:
            cursor = conn.cursor()
            cursor.close()

    def drop(self, connection, database_name):
        if database_name == ':memory:':
            self._local.memory_database._conn = None
            return
        if os.sep in database_name:
            return
        os.remove(
            os.path.join(config.get('database', 'path'),
                         database_name + '.sqlite'))

    def list(self, hostname=None):
        res = []
        listdir = [':memory:']
        try:
            listdir += os.listdir(config.get('database', 'path'))
        except OSError:
            pass
        for db_file in listdir:
            if db_file.endswith('.sqlite') or db_file == ':memory:':
                if db_file == ':memory:':
                    db_name = ':memory:'
                else:
                    db_name = db_file[:-7]
                try:
                    database = Database(db_name).connect()
                except Exception:
                    continue
                if database.test(hostname=hostname):
                    res.append(db_name)
                database.close()
        return res

    def init(self):
        from trytond.modules import get_module_info
        Flavor.set(self.flavor)
        with self.get_connection() as conn:
            cursor = conn.cursor()
            sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
            with open(sql_file) as fp:
                for line in fp.read().split(';'):
                    if (len(line) > 0) and (not line.isspace()):
                        cursor.execute(line)

            ir_module = Table('ir_module')
            ir_module_dependency = Table('ir_module_dependency')
            for module in ('ir', 'res'):
                state = 'not activated'
                if module in ('ir', 'res'):
                    state = 'to activate'
                info = get_module_info(module)
                insert = ir_module.insert([
                    ir_module.create_uid, ir_module.create_date,
                    ir_module.name, ir_module.state
                ], [[0, CurrentTimestamp(), module, state]])
                cursor.execute(*insert)
                cursor.execute('SELECT last_insert_rowid()')
                module_id, = cursor.fetchone()
                for dependency in info.get('depends', []):
                    insert = ir_module_dependency.insert([
                        ir_module_dependency.create_uid,
                        ir_module_dependency.create_date,
                        ir_module_dependency.module,
                        ir_module_dependency.name,
                    ], [[0, CurrentTimestamp(), module_id, dependency]])
                    cursor.execute(*insert)
            conn.commit()

    def test(self, hostname=None):
        Flavor.set(self.flavor)
        tables = [
            'ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
            'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
            'ir_translation', 'ir_lang', 'ir_configuration'
        ]
        sqlite_master = Table('sqlite_master')
        select = sqlite_master.select(sqlite_master.name)
        select.where = sqlite_master.type == 'table'
        select.where &= sqlite_master.name.in_(tables)
        with self._conn as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(*select)
            except Exception:
                return False
            if len(cursor.fetchall()) != len(tables):
                return False
            if hostname:
                configuration = Table('ir_configuration')
                try:
                    cursor.execute(
                        *configuration.select(configuration.hostname))
                except Exception:
                    return False
                hostnames = {h for h, in cursor.fetchall() if h}
                if hostnames and hostname not in hostnames:
                    return False
        return True

    def lastid(self, cursor):
        # This call is not thread safe
        return cursor.lastrowid

    def lock(self, connection, table):
        pass

    def lock_id(self, id, timeout=None):
        return Literal(True)

    def has_constraint(self, constraint):
        return False

    def has_multirow_insert(self):
        return True

    def sql_type(self, type_):
        if type_ in self.TYPES_MAPPING:
            return self.TYPES_MAPPING[type_]
        if type_.startswith('VARCHAR'):
            return SQLType('VARCHAR', 'VARCHAR')
        return SQLType(type_, type_)

    def sql_format(self, type_, value):
        if type_ in ('INTEGER', 'BIGINT'):
            if (value is not None and not isinstance(value,
                                                     (Query, Expression))):
                value = int(value)
        return value

    def json_get(self, column, key):
        if key:
            column = JSONExtract(column, '$.%s' % key)
        return NullIf(JSONQuote(column), JSONQuote(Null))
Esempio n. 19
0
class Database(DatabaseInterface):

    _lock = RLock()
    _databases = {}
    _connpool = None
    _list_cache = None
    _list_cache_timestamp = None
    _version_cache = {}
    _search_path = None
    _current_user = None
    _has_returning = None
    flavor = Flavor(ilike=True)

    def __new__(cls, name='template1'):
        with cls._lock:
            if name in cls._databases:
                return cls._databases[name]
            inst = DatabaseInterface.__new__(cls, name=name)
            cls._databases[name] = inst

            logger.info('connect to "%s"', name)
            minconn = config.getint('database', 'minconn', default=1)
            maxconn = config.getint('database', 'maxconn', default=64)
            inst._connpool = ThreadedConnectionPool(
                minconn, maxconn, cls.dsn(name), cursor_factory=LoggingCursor)

            return inst

    @classmethod
    def dsn(cls, name):
        uri = parse_uri(config.get('database', 'uri'))
        host = uri.hostname and "host=%s" % uri.hostname or ''
        port = uri.port and "port=%s" % uri.port or ''
        name = "dbname=%s" % name
        user = uri.username and "user=%s" % uri.username or ''
        password = ("password=%s" %
                    urllib.unquote_plus(uri.password) if uri.password else '')
        return '%s %s %s %s %s' % (host, port, name, user, password)

    def connect(self):
        return self

    def get_connection(self, autocommit=False, readonly=False):
        for count in range(config.getint('database', 'retry'), -1, -1):
            try:
                conn = self._connpool.getconn()
                break
            except PoolError:
                if count and not self._connpool.closed:
                    logger.info('waiting a connection')
                    time.sleep(1)
                    continue
                raise
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        if readonly:
            cursor = conn.cursor()
            cursor.execute('SET TRANSACTION READ ONLY')
        return conn

    def put_connection(self, connection, close=False):
        self._connpool.putconn(connection, close=close)

    def close(self):
        with self._lock:
            self._connpool.closeall()
            self._databases.pop(self.name)

    @classmethod
    def create(cls, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('CREATE DATABASE "' + database_name + '" '
                       'TEMPLATE template0 ENCODING \'unicode\'')
        connection.commit()
        cls._list_cache = None

    def drop(self, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('DROP DATABASE "' + database_name + '"')
        Database._list_cache = None

    def get_version(self, connection):
        if self.name not in self._version_cache:
            cursor = connection.cursor()
            cursor.execute('SELECT version()')
            version, = cursor.fetchone()
            self._version_cache[self.name] = tuple(
                map(int,
                    RE_VERSION.search(version).groups()))
        return self._version_cache[self.name]

    def list(self):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = Database._list_cache
        if res and abs(Database._list_cache_timestamp - now) < timeout:
            return res

        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute('SELECT datname FROM pg_database '
                       'WHERE datistemplate = false ORDER BY datname')
        res = []
        for db_name, in cursor:
            try:
                with connect(self.dsn(db_name)) as conn:
                    if self._test(conn):
                        res.append(db_name)
            except Exception:
                continue
        self.put_connection(connection)

        Database._list_cache = res
        Database._list_cache_timestamp = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        cursor = connection.cursor()
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        for module in ('ir', 'res'):
            state = 'not activated'
            if module in ('ir', 'res'):
                state = 'to activate'
            info = get_module_info(module)
            cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
            module_id = cursor.fetchone()[0]
            cursor.execute(
                'INSERT INTO ir_module '
                '(id, create_uid, create_date, name, state) '
                'VALUES (%s, %s, now(), %s, %s)',
                (module_id, 0, module, state))
            for dependency in info.get('depends', []):
                cursor.execute(
                    'INSERT INTO ir_module_dependency '
                    '(create_uid, create_date, module, name) '
                    'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))

        connection.commit()
        self.put_connection(connection)

    def test(self):
        connection = self.get_connection()
        is_tryton_database = self._test(connection)
        self.put_connection(connection)
        return is_tryton_database

    @classmethod
    def _test(cls, connection):
        cursor = connection.cursor()
        cursor.execute(
            'SELECT 1 FROM information_schema.tables '
            'WHERE table_name IN %s',
            (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
              'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
              'ir_translation', 'ir_lang'), ))
        return len(cursor.fetchall()) != 0

    def nextid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')")
        return cursor.fetchone()[0]

    def setnextid(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value)

    def currid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('SELECT last_value FROM "' + table + '_id_seq"')
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table)

    def has_constraint(self):
        return True

    def has_multirow_insert(self):
        return True

    def get_table_schema(self, connection, table_name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute(
                'SELECT 1 '
                'FROM information_schema.tables '
                'WHERE table_name = %s AND table_schema = %s',
                (table_name, schema))
            if cursor.rowcount:
                return schema

    @property
    def current_user(self):
        if self._current_user is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SELECT current_user')
                self._current_user = cursor.fetchone()[0]
            finally:
                self.put_connection(connection)
        return self._current_user

    @property
    def search_path(self):
        if self._search_path is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SHOW search_path')
                path, = cursor.fetchone()
                special_values = {
                    'user': self.current_user,
                }
                self._search_path = [
                    unescape_quote(
                        replace_special_values(p.strip(), **special_values))
                    for p in path.split(',')
                ]
            finally:
                self.put_connection(connection)
        return self._search_path

    def has_returning(self):
        if self._has_returning is None:
            connection = self.get_connection()
            try:
                # RETURNING clause is available since PostgreSQL 8.2
                self._has_returning = self.get_version(connection) >= (8, 2)
            finally:
                self.put_connection(connection)
        return self._has_returning

    def has_select_for(self):
        return True
Esempio n. 20
0
 def _format(value):
     if isinstance(value, Expression):
         return str(value)
     else:
         return Flavor().get().param
Esempio n. 21
0
class Database(DatabaseInterface):

    _list_cache = None
    _list_cache_timestamp = None
    flavor = Flavor(max_limit=18446744073709551610, function_mapping=MAPPING)

    def connect(self):
        return self

    def cursor(self, autocommit=False, readonly=False):
        conv = MySQLdb.converters.conversions.copy()
        conv[float] = lambda value, _: repr(value)
        conv[MySQLdb.constants.FIELD_TYPE.TIME] = MySQLdb.times.Time_or_None
        args = {
            'db': self.database_name,
            'sql_mode': 'traditional,postgresql',
            'use_unicode': True,
            'charset': 'utf8',
            'conv': conv,
        }
        uri = parse_uri(config.get('database', 'uri'))
        assert uri.scheme == 'mysql'
        if uri.hostname:
            args['host'] = uri.hostname
        if uri.port:
            args['port'] = uri.port
        if uri.username:
            args['user'] = uri.username
        if uri.password:
            args['passwd'] = uri.password
        conn = MySQLdb.connect(**args)
        cursor = Cursor(conn, self.database_name)
        cursor.execute('SET time_zone = `UTC`')
        return cursor

    def close(self):
        return

    @classmethod
    def create(cls, cursor, database_name):
        cursor.execute('CREATE DATABASE `' + database_name + '` '
                       'DEFAULT CHARACTER SET = \'utf8\'')
        cls._list_cache = None

    @classmethod
    def drop(cls, cursor, database_name):
        cursor.execute('DROP DATABASE `' + database_name + '`')
        cls._list_cache = None

    @staticmethod
    def dump(database_name):
        from trytond.tools import exec_command_pipe

        cmd = ['mysqldump', '--no-create-db']
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--user='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            cmd.append('--password='******'Couldn\'t dump database!')
        return data

    @staticmethod
    def restore(database_name, data):
        from trytond.tools import exec_command_pipe

        database = Database().connect()
        cursor = database.cursor(autocommit=True)
        database.create(cursor, database_name)
        cursor.commit()
        cursor.close()

        cmd = ['mysql']
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--user='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            cmd.append('--password='******'wb+') as fd:
            fd.write(data)

        cmd.append('<')
        cmd.append(file_name)

        args2 = tuple(cmd)

        pipe = exec_command_pipe(*args2)
        pipe.stdin.close()
        res = pipe.wait()
        os.remove(file_name)
        if res:
            raise Exception('Couldn\'t restore database')

        database = Database(database_name).connect()
        cursor = database.cursor()
        if not cursor.test():
            cursor.close()
            database.close()
            raise Exception('Couldn\'t restore database!')
        cursor.close()
        database.close()
        Database._list_cache = None
        return True

    @staticmethod
    def list(cursor):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = Database._list_cache
        if res and abs(Database._list_cache_timestamp - now) < timeout:
            return res
        cursor.execute('SHOW DATABASES')
        res = []
        for db_name, in cursor.fetchall():
            try:
                database = Database(db_name).connect()
            except Exception:
                continue
            cursor2 = database.cursor()
            if cursor2.test():
                res.append(db_name)
                cursor2.close(close=True)
            else:
                cursor2.close()
                database.close()
        Database._list_cache = res
        Database._list_cache_timestamp = now
        return res

    @staticmethod
    def init(cursor):
        from trytond.modules import get_module_info
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        for module in ('ir', 'res', 'webdav'):
            state = 'uninstalled'
            if module in ('ir', 'res'):
                state = 'to install'
            info = get_module_info(module)
            cursor.execute(
                'INSERT INTO ir_module_module '
                '(create_uid, create_date, name, state) '
                'VALUES (%s, now(), %s, %s)', (0, module, state))
            cursor.execute('SELECT LAST_INSERT_ID()')
            module_id, = cursor.fetchone()
            for dependency in info.get('depends', []):
                cursor.execute(
                    'INSERT INTO ir_module_module_dependency '
                    '(create_uid, create_date, module, name) '
                    'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))
Esempio n. 22
0
class Database(DatabaseInterface):

    _databases = {}
    _connpool = None
    _list_cache = None
    _list_cache_timestamp = None
    _version_cache = {}
    flavor = Flavor(ilike=True)

    def __new__(cls, name='template1'):
        if name in cls._databases:
            return cls._databases[name]
        return DatabaseInterface.__new__(cls, name=name)

    def __init__(self, name='template1'):
        super(Database, self).__init__(name=name)
        self._databases.setdefault(name, self)
        self._search_path = None
        self._current_user = None

    @classmethod
    def dsn(cls, name):
        uri = parse_uri(config.get('database', 'uri'))
        assert uri.scheme == 'postgresql'
        host = uri.hostname and "host=%s" % uri.hostname or ''
        port = uri.port and "port=%s" % uri.port or ''
        name = "dbname=%s" % name
        user = uri.username and "user=%s" % uri.username or ''
        password = ("password=%s" %
                    urllib.unquote_plus(uri.password) if uri.password else '')
        return '%s %s %s %s %s' % (host, port, name, user, password)

    def connect(self):
        if self._connpool is not None:
            return self
        logger.info('connect to "%s"', self.name)
        minconn = config.getint('database', 'minconn', default=1)
        maxconn = config.getint('database', 'maxconn', default=64)
        self._connpool = ThreadedConnectionPool(minconn, maxconn,
                                                self.dsn(self.name))
        return self

    def get_connection(self, autocommit=False, readonly=False):
        if self._connpool is None:
            self.connect()
        for count in range(config.getint('database', 'retry'), -1, -1):
            try:
                conn = self._connpool.getconn()
                break
            except PoolError:
                if count and not self._connpool.closed:
                    logger.info('waiting a connection')
                    time.sleep(1)
                    continue
                raise
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        if readonly:
            cursor = conn.cursor()
            cursor.execute('SET TRANSACTION READ ONLY')
        return conn

    def put_connection(self, connection, close=False):
        self._connpool.putconn(connection, close=close)

    def close(self):
        if self._connpool is None:
            return
        self._connpool.closeall()
        self._connpool = None

    @classmethod
    def create(cls, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('CREATE DATABASE "' + database_name + '" '
                       'TEMPLATE template0 ENCODING \'unicode\'')
        connection.commit()
        cls._list_cache = None

    def drop(self, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('DROP DATABASE "' + database_name + '"')
        Database._list_cache = None

    def get_version(self, connection):
        if self.name not in self._version_cache:
            cursor = connection.cursor()
            cursor.execute('SELECT version()')
            version, = cursor.fetchone()
            self._version_cache[self.name] = tuple(
                map(int,
                    RE_VERSION.search(version).groups()))
        return self._version_cache[self.name]

    @staticmethod
    def dump(database_name):
        from trytond.tools import exec_command_pipe

        cmd = ['pg_dump', '--format=c', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            # if db_password is set in configuration we should pass
            # an environment variable PGPASSWORD to our subprocess
            # see libpg documentation
            env['PGPASSWORD'] = uri.password
        cmd.append(database_name)

        pipe = exec_command_pipe(*tuple(cmd), env=env)
        pipe.stdin.close()
        data = pipe.stdout.read()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t dump database!')
        return data

    @staticmethod
    def restore(database_name, data):
        from trytond.tools import exec_command_pipe

        database = Database().connect()
        connection = database.get_connection(autocommit=True)
        database.create(connection, database_name)
        database.close()

        cmd = ['pg_restore', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            env['PGPASSWORD'] = uri.password
        cmd.append('--dbname=' + database_name)
        args2 = tuple(cmd)

        if os.name == "nt":
            tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam()
            with open(tmpfile, 'wb') as fp:
                fp.write(data)
            args2 = list(args2)
            args2.append(' ' + tmpfile)
            args2 = tuple(args2)

        pipe = exec_command_pipe(*args2, env=env)
        if not os.name == "nt":
            pipe.stdin.write(data)
        pipe.stdin.close()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t restore database')

        database = Database(database_name).connect()
        cursor = database.get_connection().cursor()
        if not database.test():
            cursor.close()
            database.close()
            raise Exception('Couldn\'t restore database!')
        cursor.close()
        database.close()
        Database._list_cache = None
        return True

    def list(self):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = Database._list_cache
        if res and abs(Database._list_cache_timestamp - now) < timeout:
            return res

        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute('SELECT datname FROM pg_database '
                       'WHERE datistemplate = false ORDER BY datname')
        res = []
        for db_name, in cursor:
            try:
                with connect(self.dsn(db_name)) as conn:
                    if self._test(conn):
                        res.append(db_name)
            except Exception:
                continue
        self.put_connection(connection)

        Database._list_cache = res
        Database._list_cache_timestamp = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        cursor = connection.cursor()
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        for module in ('ir', 'res'):
            state = 'uninstalled'
            if module in ('ir', 'res'):
                state = 'to install'
            info = get_module_info(module)
            cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
            module_id = cursor.fetchone()[0]
            cursor.execute(
                'INSERT INTO ir_module '
                '(id, create_uid, create_date, name, state) '
                'VALUES (%s, %s, now(), %s, %s)',
                (module_id, 0, module, state))
            for dependency in info.get('depends', []):
                cursor.execute(
                    'INSERT INTO ir_module_dependency '
                    '(create_uid, create_date, module, name) '
                    'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))

        connection.commit()
        self.put_connection(connection)

    def test(self):
        connection = self.get_connection()
        is_tryton_database = self._test(connection)
        self.put_connection(connection)
        return is_tryton_database

    @classmethod
    def _test(cls, connection):
        cursor = connection.cursor()
        cursor.execute(
            'SELECT 1 FROM information_schema.tables '
            'WHERE table_name IN %s',
            (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
              'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
              'ir_translation', 'ir_lang'), ))
        return len(cursor.fetchall()) != 0

    def nextid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')")
        return cursor.fetchone()[0]

    def setnextid(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value)

    def currid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('SELECT last_value FROM "' + table + '_id_seq"')
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table)

    def has_constraint(self):
        return True

    def has_multirow_insert(self):
        return True

    def get_table_schema(self, connection, table_name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute(
                'SELECT 1 '
                'FROM information_schema.tables '
                'WHERE table_name = %s AND table_schema = %s',
                (table_name, schema))
            if cursor.rowcount:
                return schema

    @property
    def current_user(self):
        if self._current_user is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SELECT current_user')
                self._current_user = cursor.fetchone()[0]
            finally:
                self.put_connection(connection)
        return self._current_user

    @property
    def search_path(self):
        if self._search_path is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SHOW search_path')
                path, = cursor.fetchone()
                special_values = {
                    'user': self.current_user,
                }
                self._search_path = [
                    unescape_quote(
                        replace_special_values(p.strip(), **special_values))
                    for p in path.split(',')
                ]
            finally:
                self.put_connection(connection)
        return self._search_path
Esempio n. 23
0
class Database(DatabaseInterface):

    _lock = RLock()
    _databases = defaultdict(dict)
    _connpool = None
    _list_cache = {}
    _list_cache_timestamp = {}
    _search_path = None
    _current_user = None
    _has_returning = None
    _has_select_for_skip_locked = None
    _has_proc = defaultdict(dict)
    _search_full_text_languages = defaultdict(dict)
    flavor = Flavor(ilike=True)

    TYPES_MAPPING = {
        'INTEGER': SQLType('INT4', 'INT4'),
        'BIGINT': SQLType('INT8', 'INT8'),
        'FLOAT': SQLType('FLOAT8', 'FLOAT8'),
        'BLOB': SQLType('BYTEA', 'BYTEA'),
        'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP(0)'),
        'TIMESTAMP': SQLType('TIMESTAMP', 'TIMESTAMP(6)'),
        'FULLTEXT': SQLType('TSVECTOR', 'TSVECTOR'),
        }

    def __new__(cls, name=_default_name):
        with cls._lock:
            now = datetime.now()
            databases = cls._databases[os.getpid()]
            for database in list(databases.values()):
                if ((now - database._last_use).total_seconds() > _timeout
                        and database.name != name
                        and not database._connpool._used):
                    database.close()
            if name in databases:
                inst = databases[name]
            else:
                if name == _default_name:
                    minconn = 0
                else:
                    minconn = _minconn
                inst = DatabaseInterface.__new__(cls, name=name)
                try:
                    inst._connpool = ThreadedConnectionPool(
                        minconn, _maxconn, **cls._connection_params(name),
                        cursor_factory=LoggingCursor)
                except Exception:
                    logger.error(
                        'connection to "%s" failed', name, exc_info=True)
                    raise
                else:
                    logger.info('connection to "%s" succeeded', name)
                databases[name] = inst
            inst._last_use = datetime.now()
            return inst

    def __init__(self, name=_default_name):
        super(Database, self).__init__(name)

    @classmethod
    def _connection_params(cls, name):
        uri = parse_uri(config.get('database', 'uri'))
        if uri.path and uri.path != '/':
            warnings.warn("The path specified in the URI will be overridden")
        params = {
            'dsn': uri._replace(path=name).geturl(),
            }
        return params

    def connect(self):
        return self

    def get_connection(self, autocommit=False, readonly=False):
        for count in range(config.getint('database', 'retry'), -1, -1):
            try:
                conn = self._connpool.getconn()
                break
            except PoolError:
                if count and not self._connpool.closed:
                    logger.info('waiting a connection')
                    time.sleep(1)
                    continue
                raise
            except Exception:
                logger.error(
                    'connection to "%s" failed', self.name, exc_info=True)
                raise
        # We do not use set_session because psycopg2 < 2.7 and psycopg2cffi
        # change the default_transaction_* attributes which breaks external
        # pooling at the transaction level.
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        # psycopg2cffi does not have the readonly property
        if hasattr(conn, 'readonly'):
            conn.readonly = readonly
        elif not autocommit and readonly:
            cursor = conn.cursor()
            cursor.execute('SET TRANSACTION READ ONLY')
        return conn

    def put_connection(self, connection, close=False):
        self._connpool.putconn(connection, close=close)

    def close(self):
        with self._lock:
            logger.info('disconnection from "%s"', self.name)
            self._connpool.closeall()
            self._databases[os.getpid()].pop(self.name)

    @classmethod
    def create(cls, connection, database_name, template='template0'):
        cursor = connection.cursor()
        cursor.execute(
            SQL(
                "CREATE DATABASE {} TEMPLATE {} ENCODING 'unicode'")
            .format(
                Identifier(database_name),
                Identifier(template)))
        connection.commit()
        cls._list_cache.clear()

    def drop(self, connection, database_name):
        cursor = connection.cursor()
        cursor.execute(SQL("DROP DATABASE {}")
            .format(Identifier(database_name)))
        self.__class__._list_cache.clear()

    def get_version(self, connection):
        version = connection.server_version
        major, rest = divmod(int(version), 10000)
        minor, patch = divmod(rest, 100)
        return (major, minor, patch)

    def list(self, hostname=None):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = self.__class__._list_cache.get(hostname)
        timestamp = self.__class__._list_cache_timestamp.get(hostname, now)
        if res and abs(timestamp - now) < timeout:
            return res

        connection = self.get_connection()
        try:
            cursor = connection.cursor()
            cursor.execute('SELECT datname FROM pg_database '
                'WHERE datistemplate = false ORDER BY datname')
            res = []
            for db_name, in cursor:
                try:
                    conn = connect(**self._connection_params(db_name))
                    try:
                        with conn:
                            if self._test(conn, hostname=hostname):
                                res.append(db_name)
                    finally:
                        conn.close()
                except Exception:
                    logger.debug(
                        'Test failed for "%s"', db_name, exc_info=True)
                    continue
        finally:
            self.put_connection(connection)

        self.__class__._list_cache[hostname] = res
        self.__class__._list_cache_timestamp[hostname] = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        try:
            cursor = connection.cursor()
            sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
            with open(sql_file) as fp:
                for line in fp.read().split(';'):
                    if (len(line) > 0) and (not line.isspace()):
                        cursor.execute(line)

            for module in ('ir', 'res'):
                state = 'not activated'
                if module in ('ir', 'res'):
                    state = 'to activate'
                info = get_module_info(module)
                cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
                module_id = cursor.fetchone()[0]
                cursor.execute('INSERT INTO ir_module '
                    '(id, create_uid, create_date, name, state) '
                    'VALUES (%s, %s, now(), %s, %s)',
                    (module_id, 0, module, state))
                for dependency in info.get('depends', []):
                    cursor.execute('INSERT INTO ir_module_dependency '
                        '(create_uid, create_date, module, name) '
                        'VALUES (%s, now(), %s, %s)',
                        (0, module_id, dependency))

            connection.commit()
        finally:
            self.put_connection(connection)

    def test(self, hostname=None):
        try:
            connection = self.get_connection()
        except Exception:
            logger.debug('Test failed for "%s"', self.name, exc_info=True)
            return False
        try:
            return self._test(connection, hostname=hostname)
        finally:
            self.put_connection(connection)

    @classmethod
    def _test(cls, connection, hostname=None):
        cursor = connection.cursor()
        tables = ('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
            'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
            'ir_translation', 'ir_lang', 'ir_configuration')
        cursor.execute('SELECT table_name FROM information_schema.tables '
            'WHERE table_name IN %s', (tables,))
        if len(cursor.fetchall()) != len(tables):
            return False
        if hostname:
            try:
                cursor.execute(
                    'SELECT hostname FROM ir_configuration')
                hostnames = {h for h, in cursor if h}
                if hostnames and hostname not in hostnames:
                    return False
            except ProgrammingError:
                pass
        return True

    def nextid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute("SELECT NEXTVAL(%s)", (table + '_id_seq',))
        return cursor.fetchone()[0]

    def setnextid(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute("SELECT SETVAL(%s, %s)", (table + '_id_seq', value))

    def currid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute(SQL("SELECT last_value FROM {}").format(
                Identifier(table + '_id_seq')))
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        cursor = connection.cursor()
        cursor.execute(SQL('LOCK {} IN EXCLUSIVE MODE NOWAIT').format(
                Identifier(table)))

    def lock_id(self, id, timeout=None):
        if not timeout:
            return TryAdvisoryLock(id)
        else:
            return AdvisoryLock(id)

    def has_constraint(self, constraint):
        return True

    def has_multirow_insert(self):
        return True

    def get_table_schema(self, connection, table_name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute('SELECT 1 '
                'FROM information_schema.tables '
                'WHERE table_name = %s AND table_schema = %s',
                (table_name, schema))
            if cursor.rowcount:
                return schema

    @property
    def current_user(self):
        if self._current_user is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SELECT current_user')
                self._current_user = cursor.fetchone()[0]
            finally:
                self.put_connection(connection)
        return self._current_user

    @property
    def search_path(self):
        if self._search_path is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SHOW search_path')
                path, = cursor.fetchone()
                special_values = {
                    'user': self.current_user,
                }
                self._search_path = [
                    unescape_quote(replace_special_values(
                            p.strip(), **special_values))
                    for p in path.split(',')]
            finally:
                self.put_connection(connection)
        return self._search_path

    def has_returning(self):
        if self._has_returning is None:
            connection = self.get_connection()
            try:
                # RETURNING clause is available since PostgreSQL 8.2
                self._has_returning = self.get_version(connection) >= (8, 2)
            finally:
                self.put_connection(connection)
        return self._has_returning

    def has_select_for(self):
        return True

    def get_select_for_skip_locked(self):
        if self._has_select_for_skip_locked is None:
            connection = self.get_connection()
            try:
                # SKIP LOCKED clause is available since PostgreSQL 9.5
                self._has_select_for_skip_locked = (
                    self.get_version(connection) >= (9, 5))
            finally:
                self.put_connection(connection)
        if self._has_select_for_skip_locked:
            return ForSkipLocked
        else:
            return For

    def has_window_functions(self):
        return True

    @classmethod
    def has_sequence(cls):
        return True

    def has_proc(self, name):
        if name in self._has_proc[self.name]:
            return self._has_proc[self.name][name]
        connection = self.get_connection()
        result = False
        try:
            cursor = connection.cursor()
            cursor.execute(
                "SELECT 1 FROM pg_proc WHERE proname=%s",
                (name,))
            result = bool(cursor.rowcount)
        finally:
            self.put_connection(connection)
        self._has_proc[self.name][name] = result
        return result

    def has_unaccent(self):
        return self.has_proc(Unaccent._function)

    def has_similarity(self):
        return self.has_proc(Similarity._function)

    def similarity(self, column, value):
        return Similarity(column, value)

    def has_search_full_text(self):
        return True

    def _search_full_text_language(self, language):
        languages = self._search_full_text_languages[self.name]
        if language not in languages:
            lang = Table('ir_lang')
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute(*lang.select(
                        Coalesce(lang.pg_text_search, 'simple'),
                        where=lang.code == language,
                        limit=1))
                config_name, = cursor.fetchone()
            finally:
                self.put_connection(connection)
            languages[language] = config_name
        else:
            config_name = languages[language]
        return config_name

    def format_full_text(self, *documents, language=None):
        size = max(len(documents) // 4, 1)
        if len(documents) > 1:
            weights = chain(
                ['A'] * size, ['B'] * size, ['C'] * size, repeat('D'))
        else:
            weights = [None]
        expression = None
        if language:
            config_name = self._search_full_text_language(language)
        else:
            config_name = None
        for document, weight in zip(documents, weights):
            if not document:
                continue
            if config_name:
                ts_vector = ToTsvector(config_name, document)
            else:
                ts_vector = ToTsvector('simple', document)
            if weight:
                ts_vector = Setweight(ts_vector, weight)
            if expression is None:
                expression = ts_vector
            else:
                expression = Concat(expression, ts_vector)
        return expression

    def format_full_text_query(self, query, language=None):
        connection = self.get_connection()
        try:
            version = self.get_version(connection)
        finally:
            self.put_connection(connection)
        if version >= (11, 0):
            ToTsQuery = WebsearchToTsQuery
        else:
            ToTsQuery = PlainToTsQuery
        if language:
            config_name = self._search_full_text_language(language)
            if not isinstance(query, TsQuery):
                query = ToTsQuery(config_name, query)
        else:
            if not isinstance(query, TsQuery):
                query = ToTsQuery(query)
        return query

    def search_full_text(self, document, query):
        return Match(document, query)

    def rank_full_text(self, document, query, normalize=None):
        # TODO: weights and cover density
        norm_int = 0
        if normalize:
            values = {
                'document log': 1,
                'document': 2,
                'mean': 4,
                'word': 8,
                'word log': 16,
                'rank': 32,
                }
            for norm in normalize:
                norm_int |= values.get(norm, 0)
        return TsRank(document, query, norm_int)

    def sql_type(self, type_):
        if type_ in self.TYPES_MAPPING:
            return self.TYPES_MAPPING[type_]
        if type_.startswith('VARCHAR'):
            return SQLType('VARCHAR', type_)
        return SQLType(type_, type_)

    def sql_format(self, type_, value):
        if type_ == 'BLOB':
            if value is not None:
                return Binary(value)
        return value

    def unaccent(self, value):
        if self.has_unaccent():
            return Unaccent(value)
        return value

    def sequence_exist(self, connection, name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute('SELECT 1 '
                'FROM information_schema.sequences '
                'WHERE sequence_name = %s AND sequence_schema = %s',
                (name, schema))
            if cursor.rowcount:
                return True
        return False

    def sequence_create(
            self, connection, name, number_increment=1, start_value=1):
        cursor = connection.cursor()

        cursor.execute(
            SQL("CREATE SEQUENCE {} INCREMENT BY %s START WITH %s").format(
                Identifier(name)),
            (number_increment, start_value))

    def sequence_update(
            self, connection, name, number_increment=1, start_value=1):
        cursor = connection.cursor()
        cursor.execute(
            SQL("ALTER SEQUENCE {} INCREMENT BY %s RESTART WITH %s").format(
                Identifier(name)),
            (number_increment, start_value))

    def sequence_rename(self, connection, old_name, new_name):
        cursor = connection.cursor()
        if (self.sequence_exist(connection, old_name)
                and not self.sequence_exist(connection, new_name)):
            cursor.execute(
                SQL("ALTER TABLE {} RENAME TO {}").format(
                    Identifier(old_name),
                    Identifier(new_name)))

    def sequence_delete(self, connection, name):
        cursor = connection.cursor()
        cursor.execute(SQL("DROP SEQUENCE {}").format(
                Identifier(name)))

    def sequence_next_number(self, connection, name):
        cursor = connection.cursor()
        version = self.get_version(connection)
        if version >= (10, 0):
            cursor.execute(
                'SELECT increment_by '
                'FROM pg_sequences '
                'WHERE sequencename=%s',
                (name,))
            increment, = cursor.fetchone()
            cursor.execute(
                SQL(
                    'SELECT CASE WHEN NOT is_called THEN last_value '
                    'ELSE last_value + %s '
                    'END '
                    'FROM {}').format(Identifier(name)),
                (increment,))
        else:
            cursor.execute(
                SQL(
                    'SELECT CASE WHEN NOT is_called THEN last_value '
                    'ELSE last_value + increment_by '
                    'END '
                    'FROM {}').format(sequence=Identifier(name)))
        return cursor.fetchone()[0]

    def has_channel(self):
        return True

    def json_get(self, column, key=None):
        column = Cast(column, 'jsonb')
        if key:
            column = JSONBExtractPath(column, key)
        return column

    def json_key_exists(self, column, key):
        return JSONKeyExists(Cast(column, 'jsonb'), key)

    def json_any_keys_exist(self, column, keys):
        return JSONAnyKeyExist(Cast(column, 'jsonb'), keys)

    def json_all_keys_exist(self, column, keys):
        return JSONAllKeyExist(Cast(column, 'jsonb'), keys)

    def json_contains(self, column, json):
        return JSONContains(Cast(column, 'jsonb'), Cast(json, 'jsonb'))
Esempio n. 24
0
class Database(DatabaseInterface):

    _local = threading.local()
    _conn = None
    flavor = Flavor(paramstyle='qmark', function_mapping=MAPPING)

    def __new__(cls, database_name=':memory:'):
        if (database_name == ':memory:'
                and getattr(cls._local, 'memory_database', None)):
            return cls._local.memory_database
        return DatabaseInterface.__new__(cls, database_name=database_name)

    def __init__(self, database_name=':memory:'):
        super(Database, self).__init__(database_name=database_name)
        if database_name == ':memory:':
            Database._local.memory_database = self

    def connect(self):
        if self.database_name == ':memory:':
            path = ':memory:'
        else:
            db_filename = self.database_name + '.sqlite'
            path = os.path.join(config.get('database', 'path'), db_filename)
            if not os.path.isfile(path):
                raise IOError('Database "%s" doesn\'t exist!' % db_filename)
        if self._conn is not None:
            return self
        self._conn = sqlite.connect(path,
                                    detect_types=sqlite.PARSE_DECLTYPES
                                    | sqlite.PARSE_COLNAMES)
        self._conn.create_function('extract', 2, SQLiteExtract.extract)
        self._conn.create_function('date_trunc', 2, date_trunc)
        self._conn.create_function('split_part', 3, split_part)
        self._conn.create_function('position', 2, SQLitePosition.position)
        self._conn.create_function('overlay', 3, SQLiteOverlay.overlay)
        self._conn.create_function('overlay', 4, SQLiteOverlay.overlay)
        if sqlite.sqlite_version_info < (3, 3, 14):
            self._conn.create_function('replace', 3, replace)
        self._conn.create_function('now', 0, now)
        self._conn.create_function('sign', 1, sign)
        self._conn.create_function('greatest', -1, max)
        self._conn.create_function('least', -1, min)
        self._conn.execute('PRAGMA foreign_keys = ON')
        return self

    def cursor(self, autocommit=False, readonly=False):
        if self._conn is None:
            self.connect()
        if autocommit:
            self._conn.isolation_level = None
        else:
            self._conn.isolation_level = 'IMMEDIATE'
        return Cursor(self._conn, self.database_name)

    def close(self):
        if self.database_name == ':memory:':
            return
        if self._conn is None:
            return
        self._conn = None

    @staticmethod
    def create(cursor, database_name):
        if database_name == ':memory:':
            path = ':memory:'
        else:
            if os.sep in database_name:
                return
            path = os.path.join(config.get('database', 'path'),
                                database_name + '.sqlite')
        with sqlite.connect(path) as conn:
            cursor = conn.cursor()
            cursor.close()

    @classmethod
    def drop(cls, cursor, database_name):
        if database_name == ':memory:':
            cls._local.memory_database._conn = None
            return
        if os.sep in database_name:
            return
        os.remove(
            os.path.join(config.get('database', 'path'),
                         database_name + '.sqlite'))

    @staticmethod
    def dump(database_name):
        if database_name == ':memory:':
            raise Exception('Unable to dump memory database!')
        if os.sep in database_name:
            raise Exception('Wrong database name!')
        path = os.path.join(config.get('database', 'path'),
                            database_name + '.sqlite')
        with open(path, 'rb') as file_p:
            data = file_p.read()
        return data

    @staticmethod
    def restore(database_name, data):
        if database_name == ':memory:':
            raise Exception('Unable to restore memory database!')
        if os.sep in database_name:
            raise Exception('Wrong database name!')
        path = os.path.join(config.get('database', 'path'),
                            database_name + '.sqlite')
        if os.path.isfile(path):
            raise Exception('Database already exists!')
        with open(path, 'wb') as file_p:
            file_p.write(data)

    @staticmethod
    def list(cursor):
        res = []
        listdir = [':memory:']
        try:
            listdir += os.listdir(config.get('database', 'path'))
        except OSError:
            pass
        for db_file in listdir:
            if db_file.endswith('.sqlite') or db_file == ':memory:':
                if db_file == ':memory:':
                    db_name = ':memory:'
                else:
                    db_name = db_file[:-7]
                try:
                    database = Database(db_name)
                except Exception:
                    continue
                cursor2 = database.cursor()
                if cursor2.test():
                    res.append(db_name)
                cursor2.close()
        return res

    @staticmethod
    def init(cursor):
        from trytond.modules import get_module_info
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        ir_module = Table('ir_module')
        ir_module_dependency = Table('ir_module_dependency')
        for module in ('ir', 'res', 'webdav'):
            state = 'uninstalled'
            if module in ('ir', 'res'):
                state = 'to install'
            info = get_module_info(module)
            insert = ir_module.insert([
                ir_module.create_uid, ir_module.create_date, ir_module.name,
                ir_module.state
            ], [[0, CurrentTimestamp(), module, state]])
            cursor.execute(*insert)
            cursor.execute('SELECT last_insert_rowid()')
            module_id, = cursor.fetchone()
            for dependency in info.get('depends', []):
                insert = ir_module_dependency.insert([
                    ir_module_dependency.create_uid,
                    ir_module_dependency.create_date,
                    ir_module_dependency.module, ir_module_dependency.name
                ], [[0, CurrentTimestamp(), module_id, dependency]])
                cursor.execute(*insert)
Esempio n. 25
0
class Database(DatabaseInterface):

    _local = threading.local()
    _conn = None
    flavor = Flavor(paramstyle='qmark',
                    function_mapping=MAPPING,
                    null_ordering=False)
    IN_MAX = 200

    TYPES_MAPPING = {
        'DATETIME': SQLType('TIMESTAMP', 'TIMESTAMP'),
        'BIGINT': SQLType('INTEGER', 'INTEGER'),
        'BOOL': SQLType('BOOLEAN', 'BOOLEAN'),
        'FULLTEXT': SQLType('TEXT', 'TEXT'),
    }

    def __new__(cls, name=_default_name):
        if (name == ':memory:'
                and getattr(cls._local, 'memory_database', None)):
            return cls._local.memory_database
        return DatabaseInterface.__new__(cls, name=name)

    def __init__(self, name=_default_name):
        super(Database, self).__init__(name=name)
        if name == ':memory:':
            Database._local.memory_database = self

    def _kill_session_query(self, database_name):
        return 'SELECT 1'

    def connect(self):
        if self._conn is not None:
            return self
        self._conn = sqlite.connect(self._make_uri(),
                                    uri=True,
                                    detect_types=sqlite.PARSE_DECLTYPES
                                    | sqlite.PARSE_COLNAMES,
                                    factory=SQLiteConnection)
        self._conn.create_function('extract', 2, SQLiteExtract.extract)
        self._conn.create_function('date_trunc', 2, date_trunc)
        self._conn.create_function('split_part', 3, split_part)
        self._conn.create_function('to_char', 2, to_char)
        if sqlite.sqlite_version_info < (3, 3, 14):
            self._conn.create_function('replace', 3, replace)
        self._conn.create_function('now', 0, now)
        self._conn.create_function('greatest', -1, greatest)
        self._conn.create_function('least', -1, least)

        # Mathematical functions
        self._conn.create_function('cbrt', 1, cbrt)
        self._conn.create_function('ceil', 1, math.ceil)
        self._conn.create_function('degrees', 1, math.degrees)
        self._conn.create_function('div', 2, div)
        self._conn.create_function('exp', 1, math.exp)
        self._conn.create_function('floor', 1, math.floor)
        self._conn.create_function('ln', 1, math.log)
        self._conn.create_function('log', 1, math.log10)
        self._conn.create_function('mod', 2, math.fmod)
        self._conn.create_function('pi', 0, lambda: math.pi)
        self._conn.create_function('power', 2, math.pow)
        self._conn.create_function('radians', 1, math.radians)
        self._conn.create_function('sign', 1, sign)
        self._conn.create_function('sqrt', 1, math.sqrt)
        self._conn.create_function('trunc', 1, math.trunc)
        self._conn.create_function('trunc', 2, trunc)

        # Trigonomentric functions
        self._conn.create_function('acos', 1, math.acos)
        self._conn.create_function('asin', 1, math.asin)
        self._conn.create_function('atan', 1, math.atan)
        self._conn.create_function('atan2', 2, math.atan2)
        self._conn.create_function('cos', 1, math.cos)
        self._conn.create_function(
            'cot', 1, lambda x: 1 / math.tan(x) if x else math.inf)
        self._conn.create_function('sin', 1, math.sin)
        self._conn.create_function('tan', 1, math.tan)

        # Random functions
        self._conn.create_function('random', 0, random.random)
        self._conn.create_function('setseed', 1, random.seed)

        # String functions
        self._conn.create_function('overlay', 3, SQLiteOverlay.overlay)
        self._conn.create_function('overlay', 4, SQLiteOverlay.overlay)
        self._conn.create_function('position', 2, SQLitePosition.position)

        if (hasattr(self._conn, 'set_trace_callback')
                and logger.isEnabledFor(logging.DEBUG)):
            self._conn.set_trace_callback(logger.debug)
        self._conn.execute('PRAGMA foreign_keys = ON')
        return self

    def _make_uri(self):
        uri = config.get('database', 'uri')
        base_uri = parse_uri(uri)
        if base_uri.path and base_uri.path != '/':
            warnings.warn("The path specified in the URI will be overridden")

        if self.name == ':memory:':
            query_string = urllib.parse.parse_qs(base_uri.query)
            query_string['mode'] = 'memory'
            query = urllib.parse.urlencode(query_string, doseq=True)
            db_uri = base_uri._replace(netloc='', path='/', query=query)
        else:
            db_path = safe_join(config.get('database', 'path'),
                                self.name + '.sqlite')
            if not os.path.isfile(db_path):
                raise IOError("Database '%s' doesn't exist!" % db_path)
            db_uri = base_uri._replace(path=db_path)

        # Use unparse before replacing sqlite with file because SQLite accepts
        # a relative path URI like file:db/test.sqlite which doesn't conform to
        # RFC8089 which urllib follows and enforces when the scheme is 'file'
        db_uri = urllib.parse.urlunparse(db_uri)
        return db_uri.replace('sqlite', 'file', 1)

    def get_connection(self, autocommit=False, readonly=False):
        if self._conn is None:
            self.connect()
        if autocommit:
            self._conn.isolation_level = None
        else:
            self._conn.isolation_level = 'IMMEDIATE'
        return self._conn

    def put_connection(self, connection=None, close=False):
        pass

    def close(self):
        if self.name == ':memory:':
            return
        if self._conn is None:
            return
        self._conn = None

    @classmethod
    def create(cls, connection, database_name):
        if database_name == ':memory:':
            path = ':memory:'
        else:
            if os.sep in database_name:
                return
            path = os.path.join(config.get('database', 'path'),
                                database_name + '.sqlite')
        with sqlite.connect(path) as conn:
            cursor = conn.cursor()
            cursor.close()

    def drop(self, connection, database_name):
        if database_name == ':memory:':
            self._local.memory_database._conn = None
            return
        if os.sep in database_name:
            return
        os.remove(
            os.path.join(config.get('database', 'path'),
                         database_name + '.sqlite'))

    def _kill_session_query(self, database_name):
        # JMO : not necessary
        return 'select 1'

    def list(self, hostname=None):
        res = []
        listdir = [':memory:']
        try:
            listdir += os.listdir(config.get('database', 'path'))
        except OSError:
            pass
        for db_file in listdir:
            if db_file.endswith('.sqlite') or db_file == ':memory:':
                if db_file == ':memory:':
                    db_name = ':memory:'
                else:
                    db_name = db_file[:-7]
                try:
                    database = Database(db_name).connect()
                except Exception:
                    logger.debug('Test failed for "%s"',
                                 db_name,
                                 exc_info=True)
                    continue
                if database.test(hostname=hostname):
                    res.append(db_name)
                database.close()
        return res

    def init(self):
        from trytond.modules import get_module_info
        Flavor.set(self.flavor)
        with self.get_connection() as conn:
            cursor = conn.cursor()
            sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
            with open(sql_file) as fp:
                for line in fp.read().split(';'):
                    if (len(line) > 0) and (not line.isspace()):
                        cursor.execute(line)

            ir_module = Table('ir_module')
            ir_module_dependency = Table('ir_module_dependency')
            for module in ('ir', 'res'):
                state = 'not activated'
                if module in ('ir', 'res'):
                    state = 'to activate'
                info = get_module_info(module)
                insert = ir_module.insert([
                    ir_module.create_uid, ir_module.create_date,
                    ir_module.name, ir_module.state
                ], [[0, CurrentTimestamp(), module, state]])
                cursor.execute(*insert)
                cursor.execute('SELECT last_insert_rowid()')
                module_id, = cursor.fetchone()
                for dependency in info.get('depends', []):
                    insert = ir_module_dependency.insert([
                        ir_module_dependency.create_uid,
                        ir_module_dependency.create_date,
                        ir_module_dependency.module,
                        ir_module_dependency.name,
                    ], [[0, CurrentTimestamp(), module_id, dependency]])
                    cursor.execute(*insert)
            conn.commit()

    def test(self, hostname=None):
        Flavor.set(self.flavor)
        tables = [
            'ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
            'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
            'ir_translation', 'ir_lang', 'ir_configuration'
        ]
        sqlite_master = Table('sqlite_master')
        select = sqlite_master.select(sqlite_master.name)
        select.where = sqlite_master.type == 'table'
        select.where &= sqlite_master.name.in_(tables)
        with self._conn as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(*select)
            except Exception:
                return False
            if len(cursor.fetchall()) != len(tables):
                return False
            if hostname:
                configuration = Table('ir_configuration')
                try:
                    cursor.execute(
                        *configuration.select(configuration.hostname))
                except Exception:
                    return False
                hostnames = {h for h, in cursor if h}
                if hostnames and hostname not in hostnames:
                    return False
        return True

    def lastid(self, cursor):
        # This call is not thread safe
        return cursor.lastrowid

    def lock(self, connection, table):
        pass

    def lock_id(self, id, timeout=None):
        return Literal(True)

    def has_constraint(self, constraint):
        return False

    def has_multirow_insert(self):
        return True

    def sql_type(self, type_):
        if type_ in self.TYPES_MAPPING:
            return self.TYPES_MAPPING[type_]
        if type_.startswith('VARCHAR'):
            return SQLType('VARCHAR', 'VARCHAR')
        return SQLType(type_, type_)

    def sql_format(self, type_, value):
        if type_ in ('INTEGER', 'BIGINT'):
            if (value is not None and not isinstance(value,
                                                     (Query, Expression))):
                value = int(value)
        return value

    def json_get(self, column, key=None):
        if key:
            column = JSONExtract(column, '$.%s' % key)
        return NullIf(JSONQuote(column), JSONQuote(Null))
Esempio n. 26
0
    def test_select_rownum(self):
        try:
            Flavor.set(Flavor(limitstyle='rownum'))
            query = self.table.select(limit=50, offset=10)
            self.assertEqual(
                str(query), 'SELECT "a".* FROM ('
                'SELECT "b".*, ROWNUM AS "rnum" FROM ('
                'SELECT * FROM "t" AS "c") AS "b" '
                'WHERE (ROWNUM <= %s)) AS "a" '
                'WHERE ("rnum" > %s)')
            self.assertEqual(query.params, (60, 10))

            query = self.table.select(self.table.c1.as_('col1'),
                                      self.table.c2.as_('col2'),
                                      limit=50,
                                      offset=10)
            self.assertEqual(
                str(query), 'SELECT "a"."col1", "a"."col2" FROM ('
                'SELECT "b"."col1", "b"."col2", ROWNUM AS "rnum" FROM ('
                'SELECT "c"."c1" AS "col1", "c"."c2" AS "col2" '
                'FROM "t" AS "c") AS "b" '
                'WHERE (ROWNUM <= %s)) AS "a" '
                'WHERE ("rnum" > %s)')
            self.assertEqual(query.params, (60, 10))

            subquery = query.select(query.col1, query.col2)
            self.assertEqual(
                str(subquery), 'SELECT "a"."col1", "a"."col2" FROM ('
                'SELECT "b"."col1", "b"."col2" FROM ('
                'SELECT "a"."col1", "a"."col2", ROWNUM AS "rnum" '
                'FROM ('
                'SELECT "c"."c1" AS "col1", "c"."c2" AS "col2" '
                'FROM "t" AS "c") AS "a" '
                'WHERE (ROWNUM <= %s)) AS "b" '
                'WHERE ("rnum" > %s)) AS "a"')
            # XXX alias of query is reused but not a problem
            # as it is hidden in subquery
            self.assertEqual(query.params, (60, 10))

            query = self.table.select(limit=50,
                                      offset=10,
                                      order_by=[self.table.c])
            self.assertEqual(
                str(query), 'SELECT "a".* FROM ('
                'SELECT "b".*, ROWNUM AS "rnum" FROM ('
                'SELECT * FROM "t" AS "c" ORDER BY "c"."c") AS "b" '
                'WHERE (ROWNUM <= %s)) AS "a" '
                'WHERE ("rnum" > %s)')
            self.assertEqual(query.params, (60, 10))

            query = self.table.select(limit=50)
            self.assertEqual(
                str(query), 'SELECT "a".* FROM ('
                'SELECT * FROM "t" AS "b") AS "a" '
                'WHERE (ROWNUM <= %s)')
            self.assertEqual(query.params, (50, ))

            query = self.table.select(offset=10)
            self.assertEqual(
                str(query), 'SELECT "a".* FROM ('
                'SELECT "b".*, ROWNUM AS "rnum" FROM ('
                'SELECT * FROM "t" AS "c") AS "b") AS "a" '
                'WHERE ("rnum" > %s)')
            self.assertEqual(query.params, (10, ))

            query = self.table.select(self.table.c.as_('col'),
                                      where=self.table.c >= 20,
                                      limit=50,
                                      offset=10)
            self.assertEqual(
                str(query), 'SELECT "a"."col" FROM ('
                'SELECT "b"."col", ROWNUM AS "rnum" FROM ('
                'SELECT "c"."c" AS "col" FROM "t" AS "c" '
                'WHERE ("c"."c" >= %s)) AS "b" '
                'WHERE (ROWNUM <= %s)) AS "a" '
                'WHERE ("rnum" > %s)')
            self.assertEqual(query.params, (20, 60, 10))
        finally:
            Flavor.set(Flavor())