Exemple #1
0
class MySQLDataAPIDialect(DataAPIDialectMixin, MySQLDialect, DataAPIDialect):
    def get_primary_keys(self, connection, table_name, schema=None, **kw):  # type: ignore
        pass

    def get_temp_table_names(self, connection, schema=None, **kw):  # type: ignore
        pass

    def get_temp_view_names(self, connection, schema=None, **kw):  # type: ignore
        pass

    def has_sequence(self, connection, sequence_name, schema=None):  # type: ignore
        pass

    # https://github.com/sqlalchemy/sqlalchemy/blob/master/lib/sqlalchemy/dialects/mysql/mysqldb.py
    def _extract_error_code(self, exception: Exception) -> Any:  # pragma: no cover
        return exception.args[0]

    def _detect_charset(self, connection: Any) -> Any:  # pragma: no cover
        return connection.execute(
            "show variables like 'character_set_client'"
        ).fetchone()[1]

    name = "mysql"
    default_paramstyle = "named"

    colspecs = util.update_copy(
        MySQLDialect.colspecs,
        {
            TIMESTAMP: DataAPITimestamp,
            DATE: DataAPIDate,
            TIME: DataAPITime,
            DATETIME: DataAPIDateTime,
        },
    )
Exemple #2
0
class PostgreSQLDataAPIDialect(DataAPIDialectMixin, PGDialect, DataAPIDialect):
    def get_primary_keys(self, connection, table_name, schema=None, **kw):  # type: ignore
        pass

    def get_temp_table_names(self, connection, schema=None, **kw):  # type: ignore
        pass

    def get_temp_view_names(self, connection, schema=None, **kw):  # type: ignore
        pass

    @classmethod
    def dbapi(cls) -> Type[Connection]:
        return Connection

    name = "postgresql"
    default_paramstyle = "named"
    supports_alter = True
    max_identifier_length = 63
    supports_sane_rowcount = True
    isolation_level = None

    colspecs = util.update_copy(
        PGDialect.colspecs,
        {TIMESTAMP: DataAPITimestamp, DATE: DataAPIDate, TIME: DataAPITime},
    )
Exemple #3
0
class SQLiteCompiler(compiler.SQLCompiler):
    extract_map = util.update_copy(
        compiler.SQLCompiler.extract_map,
        {
        'month': '%m',
        'day': '%d',
        'year': '%Y',
        'second': '%S',
        'hour': '%H',
        'doy': '%j',
        'minute': '%M',
        'epoch': '%s',
        'dow': '%w',
        'week': '%W'
    })

    def visit_now_func(self, fn, **kw):
        return "CURRENT_TIMESTAMP"

    def visit_localtimestamp_func(self, func, **kw):
        return 'DATETIME(CURRENT_TIMESTAMP, "localtime")'

    def visit_true(self, expr, **kw):
        return '1'

    def visit_false(self, expr, **kw):
        return '0'

    def visit_char_length_func(self, fn, **kw):
        return "length%s" % self.function_argspec(fn)

    def visit_cast(self, cast, **kwargs):
        if self.dialect.supports_cast:
            return super(SQLiteCompiler, self).visit_cast(cast)
        else:
            return self.process(cast.clause)

    def visit_extract(self, extract, **kw):
        try:
            return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
                self.extract_map[extract.field], self.process(extract.expr, **kw))
        except KeyError:
            raise exc.CompileError(
                "%s is not a valid extract argument." % extract.field)

    def limit_clause(self, select):
        text = ""
        if select._limit is not None:
            text +=  "\n LIMIT " + self.process(sql.literal(select._limit))
        if select._offset is not None:
            if select._limit is None:
                text += "\n LIMIT " + self.process(sql.literal(-1))
            text += " OFFSET " + self.process(sql.literal(select._offset))
        else:
            text += " OFFSET " + self.process(sql.literal(0))
        return text

    def for_update_clause(self, select):
        # sqlite has no "FOR UPDATE" AFAICT
        return ''
Exemple #4
0
class PGDialect_pg8000(PGDialect):
    driver = 'pg8000'

    supports_unicode_statements = True

    supports_unicode_binds = True

    default_paramstyle = 'format'
    supports_sane_multi_rowcount = False
    execution_ctx_cls = PGExecutionContext_pg8000
    statement_compiler = PGCompiler_pg8000
    preparer = PGIdentifierPreparer_pg8000
    description_encoding = 'use_encoding'

    colspecs = util.update_copy(PGDialect.colspecs, {
        sqltypes.Numeric: _PGNumericNoBind,
        sqltypes.Float: _PGNumeric
    })

    @classmethod
    def dbapi(cls):
        return __import__('pg8000').dbapi

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='******')
        if 'port' in opts:
            opts['port'] = int(opts['port'])
        opts.update(url.query)
        return ([], opts)

    def is_disconnect(self, e, connection, cursor):
        return "connection is closed" in str(e)
Exemple #5
0
class SybaseSQLCompiler(compiler.SQLCompiler):
    ansi_bind_rules = True

    extract_map = util.update_copy(
        compiler.SQLCompiler.extract_map,
        {
            'doy': 'dayofyear',
            'dow': 'weekday',
            'milliseconds': 'millisecond'
        })

    def get_select_precolumns(self, select):
        s = select._distinct and "DISTINCT " or ""
        # TODO: don't think Sybase supports
        # bind params for FIRST / TOP
        limit = select._limit
        if limit:
            # if select._limit == 1:
                # s += "FIRST "
            # else:
                # s += "TOP %s " % (select._limit,)
            s += "TOP %s " % (limit,)
        offset = select._offset
        if offset:
            if not limit:
                # FIXME: sybase doesn't allow an offset without a limit
                # so use a huge value for TOP here
                s += "TOP 1000000 "
            s += "START AT %s " % (offset + 1,)
        return s

    def get_from_hint_text(self, table, text):
        return text

    def limit_clause(self, select, **kw):
        # Limit in sybase is after the select keyword
        return ""

    def visit_extract(self, extract, **kw):
        field = self.extract_map.get(extract.field, extract.field)
        return 'DATEPART("%s", %s)' % (
            field, self.process(extract.expr, **kw))

    def visit_now_func(self, fn, **kw):
        return "GETDATE()"

    def for_update_clause(self, select):
        # "FOR UPDATE" is only allowed on "DECLARE CURSOR"
        # which SQLAlchemy doesn't use
        return ''

    def order_by_clause(self, select, **kw):
        kw['literal_binds'] = True
        order_by = self.process(select._order_by_clause, **kw)

        # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
        if order_by and (not self.is_subquery() or select._limit):
            return " ORDER BY " + order_by
        else:
            return ""
Exemple #6
0
class SybaseSQLCompiler(compiler.SQLCompiler):
    ansi_bind_rules = True

    extract_map = util.update_copy(
        compiler.SQLCompiler.extract_map,
        {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"},
    )

    def get_from_hint_text(self, table, text):
        return text

    def limit_clause(self, select, **kw):
        text = ""
        if select._limit_clause is not None:
            text += " ROWS LIMIT " + self.process(select._limit_clause, **kw)
        if select._offset_clause is not None:
            if select._limit_clause is None:
                text += " ROWS"
            text += " OFFSET " + self.process(select._offset_clause, **kw)
        return text

    def visit_extract(self, extract, **kw):
        field = self.extract_map.get(extract.field, extract.field)
        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))

    def visit_now_func(self, fn, **kw):
        return "GETDATE()"

    def for_update_clause(self, select):
        # "FOR UPDATE" is only allowed on "DECLARE CURSOR"
        # which SQLAlchemy doesn't use
        return ""

    def order_by_clause(self, select, **kw):
        kw["literal_binds"] = True
        order_by = self.process(select._order_by_clause, **kw)

        # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
        if order_by and (not self.is_subquery() or select._limit):
            return " ORDER BY " + order_by
        else:
            return ""

    def delete_table_clause(self, delete_stmt, from_table, extra_froms):
        """If we have extra froms make sure we render any alias as hint."""
        ashint = False
        if extra_froms:
            ashint = True
        return from_table._compiler_dispatch(
            self, asfrom=True, iscrud=True, ashint=ashint
        )

    def delete_extra_from_clause(
        self, delete_stmt, from_table, extra_froms, from_hints, **kw
    ):
        """Render the DELETE .. FROM clause specific to Sybase."""
        return "FROM " + ", ".join(
            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
            for t in [from_table] + extra_froms
        )
Exemple #7
0
class PGDialect_psycopg2(PGDialect):
    driver = 'psycopg2'
    supports_unicode_statements = False
    default_paramstyle = 'pyformat'
    supports_sane_multi_rowcount = False
    execution_ctx_cls = PGExecutionContext_psycopg2
    statement_compiler = PGCompiler_psycopg2
    preparer = PGIdentifierPreparer_psycopg2

    colspecs = util.update_copy(
        PGDialect.colspecs,
        {
            sqltypes.Numeric : _PGNumeric,
            ENUM : _PGEnum, # needs force_unicode
            sqltypes.Enum : _PGEnum, # needs force_unicode
            ARRAY : _PGArray, # needs force_unicode
        }
    )

    def __init__(self, server_side_cursors=False, use_native_unicode=True, **kwargs):
        PGDialect.__init__(self, **kwargs)
        self.server_side_cursors = server_side_cursors
        self.use_native_unicode = use_native_unicode
        self.supports_unicode_binds = use_native_unicode
        
    @classmethod
    def dbapi(cls):
        psycopg = __import__('psycopg2')
        return psycopg
    
    def on_connect(self):
        base_on_connect = super(PGDialect_psycopg2, self).on_connect()
        if self.dbapi and self.use_native_unicode:
            extensions = __import__('psycopg2.extensions').extensions
            def connect(conn):
                extensions.register_type(extensions.UNICODE, conn)
                if base_on_connect:
                    base_on_connect(conn)
            return connect
        else:
            return base_on_connect

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='******')
        if 'port' in opts:
            opts['port'] = int(opts['port'])
        opts.update(url.query)
        return ([], opts)

    def is_disconnect(self, e):
        if isinstance(e, self.dbapi.OperationalError):
            return 'closed the connection' in str(e) or 'connection not open' in str(e)
        elif isinstance(e, self.dbapi.InterfaceError):
            return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
        elif isinstance(e, self.dbapi.ProgrammingError):
            # yes, it really says "losed", not "closed"
            return "losed the connection unexpectedly" in str(e)
        else:
            return False
Exemple #8
0
class AccessDialect_pyodbc(PyODBCConnector, AccessDialect):

    execution_ctx_cls = AccessExecutionContext_pyodbc

    pyodbc_driver_name = 'Microsoft Access'

    colspecs = util.update_copy(AccessDialect.colspecs,
                                {sqltypes.Numeric: _AccessNumeric_pyodbc})
Exemple #9
0
class MSDialect_pymssql(MSDialect):
    supports_sane_rowcount = False
    max_identifier_length = 30
    driver = 'pymssql'

    colspecs = util.update_copy(
        MSDialect.colspecs,
        {
            sqltypes.Numeric:_MSNumeric_pymssql,
            sqltypes.Float:sqltypes.Float,
        }
    )
    @classmethod
    def dbapi(cls):
        module = __import__('pymssql')
        # pymmsql doesn't have a Binary method.  we use string
        # TODO: monkeypatching here is less than ideal
        module.Binary = str

        client_ver = tuple(int(x) for x in module.__version__.split("."))
        if client_ver < (1, ):
            util.warn("The pymssql dialect expects at least "
                            "the 1.0 series of the pymssql DBAPI.")
        return module

    def __init__(self, **params):
        super(MSDialect_pymssql, self).__init__(**params)
        self.use_scope_identity = True

    def _get_server_version_info(self, connection):
        vers = connection.scalar("select @@version")
        m = re.match(
            r"Microsoft SQL Server.*? - (\d+).(\d+).(\d+).(\d+)", vers)
        if m:
            return tuple(int(x) for x in m.group(1, 2, 3, 4))
        else:
            return None

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='******')
        opts.update(url.query)
        port = opts.pop('port', None)
        if port and 'host' in opts:
            opts['host'] = "%s:%s" % (opts['host'], port)
        return [[], opts]

    def is_disconnect(self, e):
        for msg in (
            "Error 10054",
            "Not connected to any MS SQL server",
            "Connection is closed"
        ):
            if msg in str(e):
                return True
        else:
            return False
Exemple #10
0
class SQLiteDialect_rqlite(SQLiteDialect):
    default_paramstyle = 'qmark'

    colspecs = util.update_copy(
        SQLiteDialect.colspecs,
        {
            sqltypes.Date: _SQLite_rqliteDate,
            sqltypes.TIMESTAMP: _SQLite_rqliteTimeStamp,
        }
    )

    if not util.py2k:
        description_encoding = None

    driver = 'pyrqlite'

    # pylint: disable=method-hidden
    @classmethod
    def dbapi(cls):
        try:
            # pylint: disable=no-name-in-module
            from pyrqlite import dbapi2 as sqlite
            #from sqlite3 import dbapi2 as sqlite  # try 2.5+ stdlib name.
        except ImportError:
            #raise e
            raise
        return sqlite

    @classmethod
    def get_pool_class(cls, url):
        if url.database and url.database != ':memory:':
            return pool.NullPool
        else:
            return pool.SingletonThreadPool

    def create_connect_args(self, url):
        if url.username or url.password:
            raise exc.ArgumentError(
                "Invalid RQLite URL: %s\n"
                "Valid RQLite URL forms are:\n"
                " rqlite+pyrqlite://host:port/[?params]" % (url,))

        opts = url.query.copy()
        util.coerce_kw_type(opts, 'connect_timeout', float)
        util.coerce_kw_type(opts, 'detect_types', int)
        util.coerce_kw_type(opts, 'max_redirects', int)
        opts['port'] = url.port
        opts['host'] = url.host

        return ([], opts)

    def is_disconnect(self, e, connection, cursor):
        return False
Exemple #11
0
class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect):
    jdbc_db_name = 'mysql'
    jdbc_driver_name = 'com.mysql.jdbc.Driver'

    execution_ctx_cls = MySQLExecutionContext_zxjdbc

    colspecs = util.update_copy(
        MySQLDialect.colspecs,
        {
            sqltypes.Time: sqltypes.Time,
            BIT: _ZxJDBCBit
        }
    )

    def _detect_charset(self, connection):
        """Sniff out the character set in use for connection results."""
        # Prefer 'character_set_results' for the current connection over the
        # value in the driver.  SET NAMES or individual variable SETs will
        # change the charset without updating the driver's view of the world.
        #
        # If it's decided that issuing that sort of SQL leaves you SOL, then
        # this can prefer the driver value.
        rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
        opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
        for key in ('character_set_connection', 'character_set'):
            if opts.get(key, None):
                return opts[key]

        util.warn("Could not detect the connection character set.  Assuming latin1.")
        return 'latin1'

    def _driver_kwargs(self):
        """return kw arg dict to be sent to connect()."""
        return dict(characterEncoding='UTF-8', yearIsDateType='false')

    def _extract_error_code(self, exception):
        # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
        # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
        m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.args))
        c = m.group(1)
        if c:
            return int(c)

    def _get_server_version_info(self,connection):
        dbapi_con = connection.connection
        version = []
        r = re.compile('[.\-]')
        for n in r.split(dbapi_con.dbversion):
            try:
                version.append(int(n))
            except ValueError:
                version.append(n)
        return tuple(version)
Exemple #12
0
class SQLAnySQLCompiler(compiler.SQLCompiler):
    ansi_bind_rules = True

    extract_map = util.update_copy(compiler.SQLCompiler.extract_map, {
        'doy': 'dayofyear',
        'dow': 'weekday',
        'milliseconds': 'millisecond'
    })

    def get_select_precolumns(self, select, **kw):
        s = "DISTINCT " if select._distinct else ""
        if select._limit:
            if select._limit == 1:
                s += "FIRST "
            else:
                s += "TOP %s " % select._limit
        if select._offset:
            if not select._limit:
                # SQL Anywhere doesn't allow "start at" without "top n"
                s += "TOP ALL "
            s += "START AT %s " % (select._offset + 1, )
        if s != '':
            return s
        return compiler.SQLCompiler.get_select_precolumns(self, select, **kw)

    def get_from_hint_text(self, table, text):
        return text

    def limit_clause(self, select, **kw):
        # Limit in sybase is after the select keyword
        return ""

    def visit_extract(self, extract, **kw):
        field = self.extract_map.get(extract.field, extract.field)
        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))

    def visit_now_func(self, fn, **kw):
        return "NOW()"

    def for_update_clause(self, select):
        # "FOR UPDATE" is only allowed on "DECLARE CURSOR"
        # which SQLAlchemy doesn't use
        return ''

    def order_by_clause(self, select, **kw):
        kw['literal_binds'] = True
        order_by = self.process(select._order_by_clause, **kw)

        if order_by:
            return " ORDER BY " + order_by
        else:
            return ""
Exemple #13
0
class MSDialect_pyodbc(PyODBCConnector, MSDialect):

    execution_ctx_cls = MSExecutionContext_pyodbc

    pyodbc_driver_name = 'SQL Server'

    colspecs = util.update_copy(MSDialect.colspecs,
                                {sqltypes.Numeric: _MSNumeric_pyodbc})

    def __init__(self, description_encoding='latin-1', **params):
        super(MSDialect_pyodbc, self).__init__(**params)
        self.description_encoding = description_encoding
        self.use_scope_identity = self.dbapi and \
                        hasattr(self.dbapi.Cursor, 'nextset')
Exemple #14
0
class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect):
    jdbc_db_name = 'oracle'
    jdbc_driver_name = 'oracle.jdbc.OracleDriver'

    statement_compiler = OracleCompiler_zxjdbc
    execution_ctx_cls = OracleExecutionContext_zxjdbc

    colspecs = util.update_copy(
        OracleDialect.colspecs,
        {
            sqltypes.Date: _ZxJDBCDate,
            sqltypes.Numeric: _ZxJDBCNumeric
        }
    )

    def __init__(self, *args, **kwargs):
        super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs)
        global SQLException, zxJDBC
        from java.sql import SQLException
        from com.ziclix.python.sql import zxJDBC
        from com.ziclix.python.sql.handler import OracleDataHandler

        class OracleReturningDataHandler(OracleDataHandler):
            """zxJDBC DataHandler that specially handles ReturningParam."""

            def setJDBCObject(self, statement, index, object, dbtype=None):
                if type(object) is ReturningParam:
                    statement.registerReturnParameter(index, object.type)
                elif dbtype is None:
                    OracleDataHandler.setJDBCObject(
                        self, statement, index, object)
                else:
                    OracleDataHandler.setJDBCObject(
                        self, statement, index, object, dbtype)
        self.DataHandler = OracleReturningDataHandler

    def initialize(self, connection):
        super(OracleDialect_zxjdbc, self).initialize(connection)
        self.implicit_returning = \
            connection.connection.driverversion >= '10.2'

    def _create_jdbc_url(self, url):
        return 'jdbc:oracle:thin:@%s:%s:%s' % (
            url.host, url.port or 1521, url.database)

    def _get_server_version_info(self, connection):
        version = re.search(
            r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
        return tuple(int(x) for x in version.split('.'))
Exemple #15
0
class AuroraPostgresDataAPIDialect(PGDialect):
    colspecs = util.update_copy(
        PGDialect.colspecs, {
            SA_JSON: _ADA_SA_JSON,
            JSON: _ADA_JSON,
            JSONB: _ADA_JSONB,
            UUID: _ADA_UUID,
            DATE: _ADA_DATE,
            TIME: _ADA_TIME,
            TIMESTAMP: _ADA_TIMESTAMP,
            ARRAY: _ADA_ARRAY
        })

    @classmethod
    def dbapi(cls):
        return aurora_data_api
Exemple #16
0
class MSDialect_adodbapi(MSDialect):
    supports_sane_rowcount = True
    supports_sane_multi_rowcount = True
    supports_unicode = sys.maxunicode == 65535
    supports_unicode_statements = True
    driver = 'adodbapi'

    @classmethod
    def import_dbapi(cls):
        import adodbapi as module
        return module

    colspecs = util.update_copy(
        MSDialect.colspecs,
        {
            sqltypes.DateTime: MSDateTime_adodbapi
        }
    )

    def create_connect_args(self, url):
        def check_quote(token):
            if ";" in str(token):
                token = "'%s'" % token
            return token

        keys = dict(
            (k, check_quote(v)) for k, v in url.query.items()
        )

        connectors = ["Provider=SQLOLEDB"]
        if 'port' in keys:
            connectors.append("Data Source=%s, %s" %
                              (keys.get("host"), keys.get("port")))
        else:
            connectors.append("Data Source=%s" % keys.get("host"))
        connectors.append("Initial Catalog=%s" % keys.get("database"))
        user = keys.get("user")
        if user:
            connectors.append("User Id=%s" % user)
            connectors.append("Password=%s" % keys.get("password", ""))
        else:
            connectors.append("Integrated Security=SSPI")
        return [[";".join(connectors)], {}]

    def is_disconnect(self, e, connection, cursor):
        return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \
            "'connection failure'" in str(e)
class AuroraPostgresDataAPIDialect(PGDialect):
    # See https://docs.sqlalchemy.org/en/13/core/internals.html#sqlalchemy.engine.interfaces.Dialect
    driver = "aurora_data_api"
    default_schema_name = None
    colspecs = util.update_copy(PGDialect.colspecs, {
        sqltypes.JSON: _ADA_SA_JSON,
        JSON: _ADA_JSON,
        JSONB: _ADA_JSONB,
        UUID: _ADA_UUID,
        sqltypes.Date: _ADA_DATE,
        sqltypes.Time: _ADA_TIME,
        sqltypes.DateTime: _ADA_TIMESTAMP,
        sqltypes.Enum: _ADA_ENUM,
        ARRAY: _ADA_ARRAY
    })
    @classmethod
    def dbapi(cls):
        return aurora_data_api
Exemple #18
0
class EXACompiler(compiler.SQLCompiler):
    extract_map = util.update_copy(
        compiler.SQLCompiler.extract_map, {
            'month': '%m',
            'day': '%d',
            'year': '%Y',
            'second': '%S',
            'hour': '%H',
            'doy': '%j',
            'minute': '%M',
            'epoch': '%s',
            'dow': '%w',
            'week': '%W'
        })

    def visit_now_func(self, fn, **kw):
        return "CURRENT_TIMESTAMP"

    def visit_char_length_func(self, fn, **kw):
        return "length%s" % self.function_argspec(fn)

    def limit_clause(self, select, **kw):
        text = ""
        if select._limit is not None:
            text += "\n LIMIT %d" % int(select._limit)
        if select._offset is not None:
            text += "\n OFFSET %d" % int(select._offset)

        return text

    def for_update_clause(self, select):
        # Exasol has no "FOR UPDATE"
        util.warn("EXASolution does not support SELECT ... FOR UPDATE")
        return ''

    def default_from(self):
        """Called when a ``SELECT`` statement has no froms,
        and no ``FROM`` clause is to be appended.
        """
        return " FROM DUAL"

    def visit_empty_set_expr(self, type_):
        return "SELECT 1 FROM DUAL WHERE 1!=1"
class AuroraMySQLDataAPIDialect(MySQLDialect):
    # See https://docs.sqlalchemy.org/en/13/core/internals.html#sqlalchemy.engine.interfaces.Dialect
    driver = "aurora_data_api"
    default_schema_name = None
    colspecs = util.update_copy(MySQLDialect.colspecs, {
        sqltypes.Date: _ADA_DATE,
        sqltypes.Time: _ADA_TIME,
        sqltypes.DateTime: _ADA_TIMESTAMP,
    })

    @classmethod
    def dbapi(cls):
        return aurora_data_api

    def _detect_charset(self, connection):
        return connection.execute("SHOW VARIABLES LIKE 'character_set_client'").fetchone()[1]

    def _extract_error_code(self, exception):
        return exception.args[0].value
class AccessDialect_pyodbc(PyODBCConnector, AccessDialect):

    supports_statement_cache = True

    execution_ctx_cls = AccessExecutionContext_pyodbc

    pyodbc_driver_name = "Microsoft Access"

    colspecs = util.update_copy(AccessDialect.colspecs,
                                {sqltypes.Numeric: _AccessNumeric_pyodbc})

    @classmethod
    def dbapi(cls):
        import pyodbc as module

        module.pooling = (
            False  # required for Access databases with ODBC linked tables
        )
        return module
class SQLiteDialect_pysqlite(SQLiteDialect):
    default_paramstyle = 'qmark'

    colspecs = util.update_copy(
        SQLiteDialect.colspecs, {
            sqltypes.Date: _SQLite_pysqliteDate,
            sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
        })

    # Py3K
    #description_encoding = None

    driver = 'pysqlite'

    def __init__(self, **kwargs):
        SQLiteDialect.__init__(self, **kwargs)

        if self.dbapi is not None:
            sqlite_ver = self.dbapi.version_info
            if sqlite_ver < (2, 1, 3):
                util.warn(
                    ("The installed version of pysqlite2 (%s) is out-dated "
                     "and will cause errors in some cases.  Version 2.1.3 "
                     "or greater is recommended.") %
                    '.'.join([str(subver) for subver in sqlite_ver]))

    @classmethod
    def dbapi(cls):
        try:
            print "Importing pysqlite2"
            from pysqlite2 import dbapi2 as sqlite
        except ImportError, e:
            try:
                print "Importing sqlite3"
                from sqlite3 import dbapi2 as sqlite  # try 2.5+ stdlib name.
            except ImportError:
                print "Re-raising exception"
                raise e
        return sqlite
class PGDialect_pypostgresql(PGDialect):
    driver = 'pypostgresql'

    supports_unicode_statements = True
    supports_unicode_binds = True
    description_encoding = None
    default_paramstyle = 'pyformat'

    # requires trunk version to support sane rowcounts
    # TODO: use dbapi version information to set this flag appropariately
    supports_sane_rowcount = True
    supports_sane_multi_rowcount = False

    execution_ctx_cls = PGExecutionContext_pypostgresql
    colspecs = util.update_copy(
        PGDialect.colspecs,
        {
            sqltypes.Numeric: PGNumeric,
            sqltypes.Float:
            sqltypes.Float,  # prevents PGNumeric from being used
        })

    @classmethod
    def dbapi(cls):
        from postgresql.driver import dbapi20
        return dbapi20

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='******')
        if 'port' in opts:
            opts['port'] = int(opts['port'])
        else:
            opts['port'] = 5432
        opts.update(url.query)
        return ([], opts)

    def is_disconnect(self, e):
        return "connection is closed" in str(e)
class OcientDbDialect_pyodbc(PyODBCConnector, OcientDbDialect):

    execution_ctx_cls = OcientDbExecutionContext_pyodbc

    pyodbc_driver_name = 'OcientDB'
    supports_unicode_statements = False
    supports_unicode_binds = False

    colspecs = util.update_copy(OcientDbDialect.colspecs,
                                {sqltypes.Numeric: _OcientDbNumeric_pyodbc})

    def connect(self, *cargs, **cparams):
        # Get connection
        conn = super(OcientDbDialect_pyodbc, self).connect(*cargs, **cparams)

        # Set up encodings
        conn.setdecoding(self.dbapi.SQL_CHAR, encoding='utf-8')
        conn.setdecoding(self.dbapi.SQL_WCHAR, encoding='utf-8')
        #conn.setdecoding(self.dbapi.SQL_WMETADATA, encoding='utf-8')
        conn.setencoding(str, encoding='utf-8')
        conn.setencoding(unicode, encoding='utf-8')

        # Return connection
        return conn
Exemple #24
0
class SybaseSQLCompiler(compiler.SQLCompiler):
    ansi_bind_rules = True

    extract_map = util.update_copy(
        compiler.SQLCompiler.extract_map,
        {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"},
    )

    def get_select_precolumns(self, select, **kw):
        s = select._distinct and "DISTINCT " or ""
        # TODO: don't think Sybase supports
        # bind params for FIRST / TOP
        limit = select._limit
        if limit:
            # if select._limit == 1:
            # s += "FIRST "
            # else:
            # s += "TOP %s " % (select._limit,)
            s += "TOP %s " % (limit,)
        offset = select._offset
        if offset:
            raise NotImplementedError("Sybase ASE does not support OFFSET")
        return s

    def get_from_hint_text(self, table, text):
        return text

    def limit_clause(self, select, **kw):
        # Limit in sybase is after the select keyword
        return ""

    def visit_extract(self, extract, **kw):
        field = self.extract_map.get
        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))

    def visit_now_func(self, fn, **kw):
        return "GETDATE()"

    def for_update_clause(self, select):
        # "FOR UPDATE" is only allowed on "DECLARE CURSOR"
        # which SQLAlchemy doesn't use
        return ""

    def order_by_clause(self, select, **kw):
        kw["literal_binds"] = True
        order_by = self.process(select._order_by_clause, **kw)

        # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
        if order_by and (not self.is_subquery() or select._limit):
            return " ORDER BY " + order_by
        else:
            return ""

    def delete_table_clause(self, delete_stmt, from_table, extra_froms):
        """If we have extra froms make sure we render any alias as hint."""
        ashint = False
        if extra_froms:
            ashint = True
        return from_table._compiler_dispatch(
            self, asfrom=True, iscrud=True, ashint=ashint
        )

    def delete_extra_from_clause(
        self, delete_stmt, from_table, extra_froms, from_hints, **kw
    ):
        """Render the DELETE .. FROM clause specific to Sybase."""
        return "FROM " + ", ".join(
            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
            for t in [from_table] + extra_froms
        )
Exemple #25
0
class DremioDialect_pyodbc(PyODBCConnector, DremioDialect):
    execution_ctx_cls = DremioExecutionContext_pyodbc
    driver_for_platf = {
        'Linux 64bit': '/opt/dremio-odbc/lib64/libdrillodbc_sb64.so',
        'Linux 32bit': '/opt/dremio-odbc/lib32/libdrillodbc_sb32.so',
        'Windows': 'Dremio Connector',
        'Darwin': 'Dremio Connector'
    }
    platf = platform.system() + (' ' + platform.architecture()[0]
                                 if platform.system() == 'Linux' else '')
    drv = driver_for_platf[platf]
    pyodbc_driver_name = drv
    colspecs = util.update_copy(DremioDialect.colspecs,
                                {sqltypes.Numeric: _DremioNumeric_pyodbc})

    def __init__(self, **kw):
        kw.setdefault('convert_unicode', True)
        super(DremioDialect_pyodbc, self).__init__(**kw)

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='******')
        opts.update(url.query)

        keys = opts

        query = url.query

        connect_args = {}
        for param in ('ansi', 'unicode_results', 'autocommit'):
            if param in keys:
                connect_args[param.upper()] = util.asbool(keys.pop(param))

        if 'odbc_connect' in keys:
            connectors = [util.unquote_plus(keys.pop('odbc_connect'))]
        else:

            def check_quote(token):
                if ";" in str(token):
                    token = "'%s'" % token
                return token

            keys = dict((k.lower(), check_quote(v)) for k, v in keys.items())

            dsn_connection = 'dsn' in keys or \
                ('host' in keys and 'database' not in keys)
            if dsn_connection:
                connectors = [
                    'DSN=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))
                ]
                connectors.extend(['HOST=', 'PORT=', 'Schema='])
            else:
                port = ''
                if 'port' in keys and 'port' not in query:
                    port = '%d' % int(keys.pop('port'))

                connectors = []
                driver = keys.pop('driver', self.pyodbc_driver_name)
                if driver is None:
                    util.warn("No driver name specified; "
                              "this is expected by PyODBC when using "
                              "DSN-less connections")
                else:
                    connectors.append("DRIVER={%s}" % driver)
                connectors.extend([
                    'HOST=%s' % keys.pop('host', ''),
                    'PORT=%s' % port,
                    'Schema=%s' % keys.pop('database', '')
                ])

            user = keys.pop("user", None)
            if user and 'password' in keys:
                connectors.append("UID=%s" % user)
                connectors.append("PWD=%s" % keys.pop('password', ''))
            elif user and 'password' not in keys:
                pass
            else:
                connectors.append("Trusted_Connection=Yes")

            # if set to 'Yes', the ODBC layer will try to automagically
            # convert textual data from your database encoding to your
            # client encoding.  This should obviously be set to 'No' if
            # you query a cp1253 encoded database from a latin1 client...
            if 'odbc_autotranslate' in keys:
                connectors.append("AutoTranslate=%s" %
                                  keys.pop("odbc_autotranslate"))

            connectors.append('INTTYPESINRESULTSIFPOSSIBLE=y')
            connectors.extend(['%s=%s' % (k, v) for k, v in keys.items()])
        return [[";".join(connectors)], connect_args]

    def is_disconnect(self, e, connection, cursor):
        if isinstance(e, self.dbapi.Error):
            error_codes = {
                '40004',  # Connection lost.
                '40009',  # Connection lost after internal server error.
                '40018',  # Connection lost after system running out of memory.
                '40020',  # Connection lost after system running out of memory.
            }
            dremio_error_codes = {
                'HY000': (  # Generic dremio error code
                    re.compile(six.u(r'operation timed out'), re.IGNORECASE),
                    re.compile(six.u(r'connection lost'), re.IGNORECASE),
                    re.compile(six.u(r'Socket closed by peer'), re.IGNORECASE),
                )
            }

            error_code, error_msg = e.args[:2]

            # import pdb; pdb.set_trace()
            if error_code in dremio_error_codes:
                # Check dremio error
                for msg_re in dremio_error_codes[error_code]:
                    if msg_re.search(error_msg):
                        return True

                return False

            # Check Pyodbc error
            return error_code in error_codes

        return super(DremioDialect_pyodbc,
                     self).is_disconnect(e, connection, cursor)
Exemple #26
0
class SQLiteDialect_pysqlite(SQLiteDialect):
    default_paramstyle = 'qmark'

    colspecs = util.update_copy(
        SQLiteDialect.colspecs, {
            sqltypes.Date: _SQLite_pysqliteDate,
            sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
        })

    if not util.py2k:
        description_encoding = None

    driver = 'pysqlite'

    def __init__(self, **kwargs):
        SQLiteDialect.__init__(self, **kwargs)

        if self.dbapi is not None:
            sqlite_ver = self.dbapi.version_info
            if sqlite_ver < (2, 1, 3):
                util.warn(
                    ("The installed version of pysqlite2 (%s) is out-dated "
                     "and will cause errors in some cases.  Version 2.1.3 "
                     "or greater is recommended.") %
                    '.'.join([str(subver) for subver in sqlite_ver]))

    @classmethod
    def dbapi(cls):
        try:
            from pysqlite2 import dbapi2 as sqlite
        except ImportError as e:
            try:
                from sqlite3 import dbapi2 as sqlite  # try 2.5+ stdlib name.
            except ImportError:
                raise e
        return sqlite

    @classmethod
    def get_pool_class(cls, url):
        if url.database and url.database != ':memory:':
            return pool.NullPool
        else:
            return pool.SingletonThreadPool

    def _get_server_version_info(self, connection):
        return self.dbapi.sqlite_version_info

    def create_connect_args(self, url):
        if url.username or url.password or url.host or url.port:
            raise exc.ArgumentError("Invalid SQLite URL: %s\n"
                                    "Valid SQLite URL forms are:\n"
                                    " sqlite:///:memory: (or, sqlite://)\n"
                                    " sqlite:///relative/path/to/file.db\n"
                                    " sqlite:////absolute/path/to/file.db" %
                                    (url, ))
        filename = url.database or ':memory:'
        if filename != ':memory:':
            filename = os.path.abspath(filename)

        opts = url.query.copy()
        util.coerce_kw_type(opts, 'timeout', float)
        util.coerce_kw_type(opts, 'isolation_level', str)
        util.coerce_kw_type(opts, 'detect_types', int)
        util.coerce_kw_type(opts, 'check_same_thread', bool)
        util.coerce_kw_type(opts, 'cached_statements', int)

        return ([filename], opts)

    def is_disconnect(self, e, connection, cursor):
        return isinstance(e, self.dbapi.ProgrammingError) and \
            "Cannot operate on a closed database." in str(e)
Exemple #27
0
class MSSQLCompiler(compiler.SQLCompiler):
    returning_precedes_values = True

    extract_map = util.update_copy(
        compiler.SQLCompiler.extract_map, {
            'doy': 'dayofyear',
            'dow': 'weekday',
            'milliseconds': 'millisecond',
            'microseconds': 'microsecond'
        })

    def __init__(self, *args, **kwargs):
        super(MSSQLCompiler, self).__init__(*args, **kwargs)
        self.tablealiases = {}

    def visit_now_func(self, fn, **kw):
        return "CURRENT_TIMESTAMP"

    def visit_current_date_func(self, fn, **kw):
        return "GETDATE()"

    def visit_length_func(self, fn, **kw):
        return "LEN%s" % self.function_argspec(fn, **kw)

    def visit_char_length_func(self, fn, **kw):
        return "LEN%s" % self.function_argspec(fn, **kw)

    def visit_concat_op(self, binary, **kw):
        return "%s + %s" % \
                (self.process(binary.left, **kw),
                self.process(binary.right, **kw))

    def visit_match_op(self, binary, **kw):
        return "CONTAINS (%s, %s)" % (self.process(
            binary.left, **kw), self.process(binary.right, **kw))

    def get_select_precolumns(self, select):
        """ MS-SQL puts TOP, it's version of LIMIT here """
        if select._distinct or select._limit:
            s = select._distinct and "DISTINCT " or ""

            if select._limit:
                if not select._offset:
                    s += "TOP %s " % (select._limit, )
            return s
        return compiler.SQLCompiler.get_select_precolumns(self, select)

    def limit_clause(self, select):
        # Limit in mssql is after the select keyword
        return ""

    def visit_select(self, select, **kwargs):
        """Look for ``LIMIT`` and OFFSET in a select statement, and if
        so tries to wrap it in a subquery with ``row_number()`` criterion.

        """
        if not getattr(select, '_mssql_visit', None) and select._offset:
            # to use ROW_NUMBER(), an ORDER BY is required.
            orderby = self.process(select._order_by_clause)
            if not orderby:
                raise exc.InvalidRequestError(
                    'MSSQL requires an order_by when '
                    'using an offset.')

            _offset = select._offset
            _limit = select._limit
            select._mssql_visit = True
            select = select.column(
                sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" \
                % orderby).label("mssql_rn")
                                   ).order_by(None).alias()

            limitselect = sql.select(
                [c for c in select.c if c.key != 'mssql_rn'])
            limitselect.append_whereclause("mssql_rn>%d" % _offset)
            if _limit is not None:
                limitselect.append_whereclause("mssql_rn<=%d" %
                                               (_limit + _offset))
            return self.process(limitselect, iswrapper=True, **kwargs)
        else:
            return compiler.SQLCompiler.visit_select(self, select, **kwargs)

    def _schema_aliased_table(self, table):
        if getattr(table, 'schema', None) is not None:
            if table not in self.tablealiases:
                self.tablealiases[table] = table.alias()
            return self.tablealiases[table]
        else:
            return None

    def visit_table(self, table, mssql_aliased=False, **kwargs):
        if mssql_aliased:
            return super(MSSQLCompiler, self).visit_table(table, **kwargs)

        # alias schema-qualified tables
        alias = self._schema_aliased_table(table)
        if alias is not None:
            return self.process(alias, mssql_aliased=True, **kwargs)
        else:
            return super(MSSQLCompiler, self).visit_table(table, **kwargs)

    def visit_alias(self, alias, **kwargs):
        # translate for schema-qualified table aliases
        self.tablealiases[alias.original] = alias
        kwargs['mssql_aliased'] = True
        return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)

    def visit_extract(self, extract, **kw):
        field = self.extract_map.get(extract.field, extract.field)
        return 'DATEPART("%s", %s)' % \
                        (field, self.process(extract.expr, **kw))

    def visit_rollback_to_savepoint(self, savepoint_stmt):
        return ("ROLLBACK TRANSACTION %s" %
                self.preparer.format_savepoint(savepoint_stmt))

    def visit_column(self, column, result_map=None, **kwargs):
        if column.table is not None and \
            (not self.isupdate and not self.isdelete) or self.is_subquery():
            # translate for schema-qualified table aliases
            t = self._schema_aliased_table(column.table)
            if t is not None:
                converted = expression._corresponding_column_or_error(
                    t, column)

                if result_map is not None:
                    result_map[column.name.lower()] = \
                                    (column.name, (column, ),
                                                    column.type)

                return super(MSSQLCompiler, self).\
                                visit_column(converted,
                                            result_map=None, **kwargs)

        return super(MSSQLCompiler, self).visit_column(column,
                                                       result_map=result_map,
                                                       **kwargs)

    def visit_binary(self, binary, **kwargs):
        """Move bind parameters to the right-hand side of an operator, where
        possible.

        """
        if (isinstance(binary.left, expression._BindParamClause)
                and binary.operator == operator.eq
                and not isinstance(binary.right, expression._BindParamClause)):
            return self.process(
                expression._BinaryExpression(binary.right, binary.left,
                                             binary.operator), **kwargs)
        else:
            if ((binary.operator is operator.eq
                 or binary.operator is operator.ne) and
                ((isinstance(binary.left, expression._FromGrouping) and
                  isinstance(binary.left.element, expression._ScalarSelect)) or
                 (isinstance(binary.right, expression._FromGrouping) and
                  isinstance(binary.right.element, expression._ScalarSelect))
                 or isinstance(binary.left, expression._ScalarSelect)
                 or isinstance(binary.right, expression._ScalarSelect))):
                op = binary.operator == operator.eq and "IN" or "NOT IN"
                return self.process(
                    expression._BinaryExpression(binary.left, binary.right,
                                                 op), **kwargs)
            return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)

    def returning_clause(self, stmt, returning_cols):

        if self.isinsert or self.isupdate:
            target = stmt.table.alias("inserted")
        else:
            target = stmt.table.alias("deleted")

        adapter = sql_util.ClauseAdapter(target)

        def col_label(col):
            adapted = adapter.traverse(col)
            if isinstance(col, expression._Label):
                return adapted.label(c.key)
            else:
                return self.label_select_column(None, adapted, asfrom=False)

        columns = [
            self.process(col_label(c),
                         within_columns_clause=True,
                         result_map=self.result_map)
            for c in expression._select_iterables(returning_cols)
        ]
        return 'OUTPUT ' + ', '.join(columns)

    def label_select_column(self, select, column, asfrom):
        if isinstance(column, expression.Function):
            return column.label(None)
        else:
            return super(MSSQLCompiler, self).\
                            label_select_column(select, column, asfrom)

    def for_update_clause(self, select):
        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which
        # SQLAlchemy doesn't use
        return ''

    def order_by_clause(self, select, **kw):
        order_by = self.process(select._order_by_clause, **kw)

        # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
        if order_by and (not self.is_subquery() or select._limit):
            return " ORDER BY " + order_by
        else:
            return ""
Exemple #28
0
class MySQLDialect_oursql(MySQLDialect):
    driver = 'oursql'
# Py3K
#    description_encoding = None
# Py2K
    supports_unicode_binds = True
    supports_unicode_statements = True
# end Py2K

    supports_native_decimal = True

    supports_sane_rowcount = True
    supports_sane_multi_rowcount = True
    execution_ctx_cls = MySQLExecutionContext_oursql

    colspecs = util.update_copy(
        MySQLDialect.colspecs,
        {
            sqltypes.Time: sqltypes.Time,
            BIT: _oursqlBIT,
        }
    )

    @classmethod
    def dbapi(cls):
        return __import__('oursql')

    def do_execute(self, cursor, statement, parameters, context=None):
        """Provide an implementation of *cursor.execute(statement, parameters)*."""

        if context and context.plain_query:
            cursor.execute(statement, plain_query=True)
        else:
            cursor.execute(statement, parameters)

    def do_begin(self, connection):
        connection.cursor().execute('BEGIN', plain_query=True)

    def _xa_query(self, connection, query, xid):
# Py2K
        arg = connection.connection._escape_string(xid)
# end Py2K
# Py3K
#        charset = self._connection_charset
#        arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
        connection.execution_options(_oursql_plain_query=True).execute(query % arg)

    # Because mysql is bad, these methods have to be 
    # reimplemented to use _PlainQuery. Basically, some queries
    # refuse to return any data if they're run through 
    # the parameterized query API, or refuse to be parameterized
    # in the first place.
    def do_begin_twophase(self, connection, xid):
        self._xa_query(connection, 'XA BEGIN "%s"', xid)

    def do_prepare_twophase(self, connection, xid):
        self._xa_query(connection, 'XA END "%s"', xid)
        self._xa_query(connection, 'XA PREPARE "%s"', xid)

    def do_rollback_twophase(self, connection, xid, is_prepared=True,
                             recover=False):
        if not is_prepared:
            self._xa_query(connection, 'XA END "%s"', xid)
        self._xa_query(connection, 'XA ROLLBACK "%s"', xid)

    def do_commit_twophase(self, connection, xid, is_prepared=True,
                           recover=False):
        if not is_prepared:
            self.do_prepare_twophase(connection, xid)
        self._xa_query(connection, 'XA COMMIT "%s"', xid)

    # Q: why didn't we need all these "plain_query" overrides earlier ?
    # am i on a newer/older version of OurSQL ?
    def has_table(self, connection, table_name, schema=None):
        return MySQLDialect.has_table(self, 
                                        connection.connect().\
                                            execution_options(_oursql_plain_query=True),
                                        table_name, schema)

    def get_table_options(self, connection, table_name, schema=None, **kw):
        return MySQLDialect.get_table_options(self,
                                            connection.connect().\
                                                execution_options(_oursql_plain_query=True),
                                            table_name,
                                            schema = schema,
                                            **kw
        )


    def get_columns(self, connection, table_name, schema=None, **kw):
        return MySQLDialect.get_columns(self,
                                        connection.connect().\
                                                    execution_options(_oursql_plain_query=True),
                                        table_name,
                                        schema=schema,
                                        **kw
        )

    def get_view_names(self, connection, schema=None, **kw):
        return MySQLDialect.get_view_names(self,
                                            connection.connect().\
                                                    execution_options(_oursql_plain_query=True),
                                            schema=schema,
                                            **kw
        )

    def get_table_names(self, connection, schema=None, **kw):
        return MySQLDialect.get_table_names(self,
                            connection.connect().\
                                        execution_options(_oursql_plain_query=True),
                            schema
        )

    def get_schema_names(self, connection, **kw):
        return MySQLDialect.get_schema_names(self,
                                    connection.connect().\
                                                execution_options(_oursql_plain_query=True),
                                    **kw
        )

    def initialize(self, connection):
        return MySQLDialect.initialize(
                            self, 
                            connection.execution_options(_oursql_plain_query=True)
                            )

    def _show_create_table(self, connection, table, charset=None,
                           full_name=None):
        return MySQLDialect._show_create_table(self,
                                connection.contextual_connect(close_with_result=True).
                                execution_options(_oursql_plain_query=True),
                                table, charset, full_name)

    def is_disconnect(self, e):
        if isinstance(e, self.dbapi.ProgrammingError):
            return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
        else:
            return e.errno in (2006, 2013, 2014, 2045, 2055)

    def create_connect_args(self, url):
        opts = url.translate_connect_args(database='db', username='******',
                                          password='******')
        opts.update(url.query)

        util.coerce_kw_type(opts, 'port', int)
        util.coerce_kw_type(opts, 'compress', bool)
        util.coerce_kw_type(opts, 'autoping', bool)

        util.coerce_kw_type(opts, 'default_charset', bool)
        if opts.pop('default_charset', False):
            opts['charset'] = None
        else:
            util.coerce_kw_type(opts, 'charset', str)
        opts['use_unicode'] = opts.get('use_unicode', True)
        util.coerce_kw_type(opts, 'use_unicode', bool)

        # FOUND_ROWS must be set in CLIENT_FLAGS to enable
        # supports_sane_rowcount.
        opts.setdefault('found_rows', True)

        ssl = {}
        for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 
                        'ssl_capath', 'ssl_cipher']:
            if key in opts:
                ssl[key[4:]] = opts[key]
                util.coerce_kw_type(ssl, key[4:], str)
                del opts[key]
        if ssl:
            opts['ssl'] = ssl

        return [[], opts]

    def _get_server_version_info(self, connection):
        dbapi_con = connection.connection
        version = []
        r = re.compile('[.\-]')
        for n in r.split(dbapi_con.server_info):
            try:
                version.append(int(n))
            except ValueError:
                version.append(n)
        return tuple(version)

    def _extract_error_code(self, exception):
        return exception.errno

    def _detect_charset(self, connection):
        """Sniff out the character set in use for connection results."""

        return connection.connection.charset

    def _compat_fetchall(self, rp, charset=None):
        """oursql isn't super-broken like MySQLdb, yaaay."""
        return rp.fetchall()

    def _compat_fetchone(self, rp, charset=None):
        """oursql isn't super-broken like MySQLdb, yaaay."""
        return rp.fetchone()

    def _compat_first(self, rp, charset=None):
        return rp.first()
Exemple #29
0
class OracleCompiler(compiler.SQLCompiler):
    """Oracle compiler modifies the lexical structure of Select
    statements to work under non-ANSI configured Oracle databases, if
    the use_ansi flag is False.
    """

    compound_keywords = util.update_copy(
        compiler.SQLCompiler.compound_keywords,
        {expression.CompoundSelect.EXCEPT: 'MINUS'})

    def __init__(self, *args, **kwargs):
        self.__wheres = {}
        self._quoted_bind_names = {}
        super(OracleCompiler, self).__init__(*args, **kwargs)

    def visit_mod_binary(self, binary, operator, **kw):
        return "mod(%s, %s)" % (self.process(
            binary.left, **kw), self.process(binary.right, **kw))

    def visit_now_func(self, fn, **kw):
        return "CURRENT_TIMESTAMP"

    def visit_char_length_func(self, fn, **kw):
        return "LENGTH" + self.function_argspec(fn, **kw)

    def visit_match_op_binary(self, binary, operator, **kw):
        return "CONTAINS (%s, %s)" % (self.process(
            binary.left), self.process(binary.right))

    def visit_true(self, expr, **kw):
        return '1'

    def visit_false(self, expr, **kw):
        return '0'

    def get_select_hint_text(self, byfroms):
        return " ".join("/*+ %s */" % text for table, text in byfroms.items())

    def function_argspec(self, fn, **kw):
        if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS:
            return compiler.SQLCompiler.function_argspec(self, fn, **kw)
        else:
            return ""

    def default_from(self):
        """Called when a ``SELECT`` statement has no froms,
        and no ``FROM`` clause is to be appended.

        The Oracle compiler tacks a "FROM DUAL" to the statement.
        """

        return " FROM DUAL"

    def visit_join(self, join, **kwargs):
        if self.dialect.use_ansi:
            return compiler.SQLCompiler.visit_join(self, join, **kwargs)
        else:
            kwargs['asfrom'] = True
            if isinstance(join.right, expression.FromGrouping):
                right = join.right.element
            else:
                right = join.right
            return self.process(join.left, **kwargs) + \
                        ", " + self.process(right, **kwargs)

    def _get_nonansi_join_whereclause(self, froms):
        clauses = []

        def visit_join(join):
            if join.isouter:

                def visit_binary(binary):
                    if binary.operator == sql_operators.eq:
                        if join.right.is_derived_from(binary.left.table):
                            binary.left = _OuterJoinColumn(binary.left)
                        elif join.right.is_derived_from(binary.right.table):
                            binary.right = _OuterJoinColumn(binary.right)

                clauses.append(
                    visitors.cloned_traverse(join.onclause, {},
                                             {'binary': visit_binary}))
            else:
                clauses.append(join.onclause)

            for j in join.left, join.right:
                if isinstance(j, expression.Join):
                    visit_join(j)
                elif isinstance(j, expression.FromGrouping):
                    visit_join(j.element)

        for f in froms:
            if isinstance(f, expression.Join):
                visit_join(f)

        if not clauses:
            return None
        else:
            return sql.and_(*clauses)

    def visit_outer_join_column(self, vc):
        return self.process(vc.column) + "(+)"

    def visit_sequence(self, seq):
        return self.dialect.identifier_preparer.format_sequence(
            seq) + ".nextval"

    def visit_alias(self, alias, asfrom=False, ashint=False, **kwargs):
        """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""

        if asfrom or ashint:
            alias_name = isinstance(alias.name, expression._truncated_label) and \
                            self._truncated_identifier("alias", alias.name) or alias.name

        if ashint:
            return alias_name
        elif asfrom:
            return self.process(alias.original, asfrom=asfrom, **kwargs) + \
                            " " + self.preparer.format_alias(alias, alias_name)
        else:
            return self.process(alias.original, **kwargs)

    def returning_clause(self, stmt, returning_cols):
        columns = []
        binds = []
        for i, column in enumerate(
                expression._select_iterables(returning_cols)):
            if column.type._has_column_expression:
                col_expr = column.type.column_expression(column)
            else:
                col_expr = column
            outparam = sql.outparam("ret_%d" % i, type_=column.type)
            self.binds[outparam.key] = outparam
            binds.append(
                self.bindparam_string(self._truncate_bindparam(outparam)))
            columns.append(self.process(col_expr, within_columns_clause=False))
            self.result_map[outparam.key] = (outparam.key,
                                             (column,
                                              getattr(column, 'name', None),
                                              getattr(column, 'key',
                                                      None)), column.type)

        return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)

    def _TODO_visit_compound_select(self, select):
        """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
        pass

    def visit_select(self, select, **kwargs):
        """Look for ``LIMIT`` and OFFSET in a select statement, and if
        so tries to wrap it in a subquery with ``rownum`` criterion.
        """

        if not getattr(select, '_oracle_visit', None):
            if not self.dialect.use_ansi:
                froms = self._display_froms_for_select(
                    select, kwargs.get('asfrom', False))
                whereclause = self._get_nonansi_join_whereclause(froms)
                if whereclause is not None:
                    select = select.where(whereclause)
                    select._oracle_visit = True

            if select._limit is not None or select._offset is not None:
                # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
                #
                # Generalized form of an Oracle pagination query:
                #   select ... from (
                #     select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
                #         select distinct ... where ... order by ...
                #     ) where ROWNUM <= :limit+:offset
                #   ) where ora_rn > :offset
                # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0

                # TODO: use annotations instead of clone + attr set ?
                select = select._generate()
                select._oracle_visit = True

                # Wrap the middle select and add the hint
                limitselect = sql.select([c for c in select.c])
                if select._limit and self.dialect.optimize_limits:
                    limitselect = limitselect.prefix_with(
                        "/*+ FIRST_ROWS(%d) */" % select._limit)

                limitselect._oracle_visit = True
                limitselect._is_wrapper = True

                # If needed, add the limiting clause
                if select._limit is not None:
                    max_row = select._limit
                    if select._offset is not None:
                        max_row += select._offset
                    if not self.dialect.use_binds_for_limits:
                        max_row = sql.literal_column("%d" % max_row)
                    limitselect.append_whereclause(
                        sql.literal_column("ROWNUM") <= max_row)

                # If needed, add the ora_rn, and wrap again with offset.
                if select._offset is None:
                    limitselect.for_update = select.for_update
                    select = limitselect
                else:
                    limitselect = limitselect.column(
                        sql.literal_column("ROWNUM").label("ora_rn"))
                    limitselect._oracle_visit = True
                    limitselect._is_wrapper = True

                    offsetselect = sql.select(
                        [c for c in limitselect.c if c.key != 'ora_rn'])
                    offsetselect._oracle_visit = True
                    offsetselect._is_wrapper = True

                    offset_value = select._offset
                    if not self.dialect.use_binds_for_limits:
                        offset_value = sql.literal_column("%d" % offset_value)
                    offsetselect.append_whereclause(
                        sql.literal_column("ora_rn") > offset_value)

                    offsetselect.for_update = select.for_update
                    select = offsetselect

        kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
        return compiler.SQLCompiler.visit_select(self, select, **kwargs)

    def limit_clause(self, select):
        return ""

    def for_update_clause(self, select):
        if self.is_subquery():
            return ""
        elif select.for_update == "nowait":
            return " FOR UPDATE NOWAIT"
        else:
            return super(OracleCompiler, self).for_update_clause(select)
Exemple #30
0
class ctsqlCompiler(compiler.SQLCompiler):
    ansi_bind_rules = True
    compound_keywords = util.update_copy(
        compiler.SQLCompiler.compound_keywords,
        {expression.CompoundSelect.EXCEPT: 'MINUS'})

    extract_map = util.update_copy(
        compiler.SQLCompiler.extract_map, {
            'month': '%m',
            'day': '%d',
            'year': '%Y',
            'second': '%S',
            'hour': '%H',
            'doy': '%j',
            'minute': '%M',
            'epoch': '%s',
            'dow': '%w',
            'week': '%W',
        })

    def visit_now_func(self, fn, **kw):
        return "CURRENT_TIMESTAMP"

    def visit_localtimestamp_func(self, func, **kw):
        return 'DATETIME(CURRENT_TIMESTAMP, "localtime")'

    def visit_null(self, expr, **kw):
        return 'NULL'

    def visit_char_length_func(self, fn, **kw):
        return "length%s" % self.function_argspec(fn)

    def visit_cast(self, cast, **kwargs):
        if self.dialect.supports_cast:
            return super(ctsqlCompiler, self).visit_cast(cast, **kwargs)
        else:
            return self.process(cast.clause, **kwargs)

    def get_select_precolumns(self, select, **kw):
        """ c-tree SQL puts TOP, it's version of LIMIT here """

        s = ""
        if select._distinct:
            s += "DISTINCT "

        if select._simple_int_limit:
            s += "TOP %d " % select._limit
        if select._offset:
            s += "SKIP %d " % select._offset

        if s:
            return s
        else:
            return compiler.SQLCompiler.get_select_precolumns(
                self, select, **kw)

    def visit_extract(self, extract, **kw):
        try:
            return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (self.extract_map[
                extract.field], self.process(extract.expr, **kw))
        except KeyError:
            raise exc.CompileError("%s is not a valid extract argument." %
                                   extract.field)

    def limit_clause(self, select, **kw):
        # Limit is after the select keyword
        return ""