Esempio n. 1
0
 def _baseline_1_create_tables(self):
     Table(
         "Zoo",
         self.metadata,
         Column(
             "ID",
             Integer,
             Sequence("zoo_id_seq"),
             primary_key=True,
             index=True,
         ),
         Column("Name", Unicode(255)),
         Column("Founded", Date),
         Column("Opens", Time),
         Column("LastEscape", DateTime),
         Column("Admission", Float),
     )
     Table(
         "Animal",
         self.metadata,
         Column("ID", Integer, Sequence("animal_id_seq"), primary_key=True),
         Column("ZooID", Integer, ForeignKey("Zoo.ID"), index=True),
         Column("Name", Unicode(100)),
         Column("Species", Unicode(100)),
         Column("Legs", Integer, default=4),
         Column("LastEscape", DateTime),
         Column("Lifespan", Float(4)),
         Column("MotherID", Integer, ForeignKey("Animal.ID")),
         Column("PreferredFoodID", Integer),
         Column("AlternateFoodID", Integer),
     )
     self.metadata.create_all()
Esempio n. 2
0
    def test_no_pk(self):
        metadata = self.metadata

        Table(
            "sometable",
            metadata,
            Column("id_a", Unicode(255)),
            Column("id_b", Unicode(255)),
            Index("pk_idx_1", "id_a", "id_b", unique=True),
            Index("pk_idx_2", "id_b", "id_a", unique=True),
        )
        metadata.create_all()

        insp = inspect(testing.db)
        eq_(
            insp.get_indexes("sometable"),
            [
                {
                    "name": "pk_idx_1",
                    "column_names": ["id_a", "id_b"],
                    "dialect_options": {},
                    "unique": True,
                },
                {
                    "name": "pk_idx_2",
                    "column_names": ["id_b", "id_a"],
                    "dialect_options": {},
                    "unique": True,
                },
            ],
        )
Esempio n. 3
0
 def test_string_types(self):
     specs = [
         (String(1), mysql.MSString(1)),
         (String(3), mysql.MSString(3)),
         (Text(), mysql.MSText()),
         (Unicode(1), mysql.MSString(1)),
         (Unicode(3), mysql.MSString(3)),
         (UnicodeText(), mysql.MSText()),
         (mysql.MSChar(1), mysql.MSChar(1)),
         (mysql.MSChar(3), mysql.MSChar(3)),
         (NCHAR(2), mysql.MSChar(2)),
         (mysql.MSNChar(2), mysql.MSChar(2)),
         (mysql.MSNVarChar(22), mysql.MSString(22)),
     ]
     self._run_test(specs, ["length"])
    def test_varchar_raise(self):
        for type_ in (
                String,
                VARCHAR,
                String(),
                VARCHAR(),
                Unicode,
                Unicode(),
        ):
            type_ = sqltypes.to_instance(type_)
            assert_raises_message(
                exc.CompileError,
                "VARCHAR requires a length on dialect firebird",
                type_.compile,
                dialect=firebird.dialect(),
            )

            t1 = Table("sometable", MetaData(), Column("somecolumn", type_))
            assert_raises_message(
                exc.CompileError,
                r"\(in table 'sometable', column 'somecolumn'\)\: "
                r"(?:N)?VARCHAR requires a length on dialect firebird",
                schema.CreateTable(t1).compile,
                dialect=firebird.dialect(),
            )
Esempio n. 5
0
    def test_unicode_warnings(self):
        metadata = MetaData(self.engine)
        table1 = Table(
            "mytable",
            metadata,
            Column(
                "col1",
                Integer,
                primary_key=True,
                test_needs_autoincrement=True,
            ),
            Column("col2", Unicode(30)),
        )
        metadata.create_all()
        i = [1]

        # the times here is cranked way up so that we can see
        # pysqlite clearing out its internal buffer and allow
        # the test to pass
        @testing.emits_warning()
        @profile_memory()
        def go():

            # execute with a non-unicode object. a warning is emitted,
            # this warning shouldn't clog up memory.

            self.engine.execute(
                table1.select().where(table1.c.col2 == "foo%d" % i[0]))
            i[0] += 1

        try:
            go()
        finally:
            metadata.drop_all()
    def _baseline_1_create_tables(self):
        zoo = Table(
            "Zoo",
            self.metadata,
            Column(
                "ID",
                Integer,
                Sequence("zoo_id_seq"),
                primary_key=True,
                index=True,
            ),
            Column("Name", Unicode(255)),
            Column("Founded", Date),
            Column("Opens", Time),
            Column("LastEscape", DateTime),
            Column("Admission", Float),
        )
        animal = Table(
            "Animal",
            self.metadata,
            Column("ID", Integer, Sequence("animal_id_seq"), primary_key=True),
            Column("ZooID", Integer, ForeignKey("Zoo.ID"), index=True),
            Column("Name", Unicode(100)),
            Column("Species", Unicode(100)),
            Column("Legs", Integer, default=4),
            Column("LastEscape", DateTime),
            Column("Lifespan", Float(4)),
            Column("MotherID", Integer, ForeignKey("Animal.ID")),
            Column("PreferredFoodID", Integer),
            Column("AlternateFoodID", Integer),
        )
        self.metadata.create_all()
        global Zoo, Animal

        class Zoo(object):
            def __init__(self, **kwargs):
                for k, v in kwargs.items():
                    setattr(self, k, v)

        class Animal(object):
            def __init__(self, **kwargs):
                for k, v in kwargs.items():
                    setattr(self, k, v)

        mapper(Zoo, zoo)
        mapper(Animal, animal)
Esempio n. 7
0
    def test_use_nchar(self):
        dialect = self._dialect((10, 2, 5), use_nchar_for_unicode=True)

        dialect.initialize(Mock())
        assert dialect._use_nchar_for_unicode

        self.assert_compile(String(50), "VARCHAR2(50 CHAR)", dialect=dialect)
        self.assert_compile(Unicode(50), "NVARCHAR2(50)", dialect=dialect)
        self.assert_compile(UnicodeText(), "NCLOB", dialect=dialect)
Esempio n. 8
0
    def test_ora10_flags(self):
        dialect = self._dialect((10, 2, 5))

        dialect.initialize(Mock())
        assert dialect._supports_char_length
        assert not dialect._use_nchar_for_unicode
        assert dialect.use_ansi
        self.assert_compile(String(50), "VARCHAR2(50 CHAR)", dialect=dialect)
        self.assert_compile(Unicode(50), "VARCHAR2(50 CHAR)", dialect=dialect)
        self.assert_compile(UnicodeText(), "CLOB", dialect=dialect)
Esempio n. 9
0
    def test_default_flags(self):
        """test with no initialization or server version info"""

        dialect = self._dialect(None)

        assert dialect._supports_char_length
        assert not dialect._use_nchar_for_unicode
        assert dialect.use_ansi
        self.assert_compile(String(50), "VARCHAR2(50 CHAR)", dialect=dialect)
        self.assert_compile(Unicode(50), "VARCHAR2(50 CHAR)", dialect=dialect)
        self.assert_compile(UnicodeText(), "CLOB", dialect=dialect)
Esempio n. 10
0
 def test_nonunicode_default(self):
     default = b("foo")
     assert_raises_message(
         sa.exc.SAWarning,
         "Unicode column 'foobar' has non-unicode "
         "default value b?'foo' specified.",
         Column,
         "foobar",
         Unicode(32),
         default=default,
     )
Esempio n. 11
0
    def test_include_indexes_resembling_pk(self, explicit_pk):
        metadata = self.metadata

        t = Table(
            "sometable",
            metadata,
            Column("id_a", Unicode(255), primary_key=True),
            Column("id_b", Unicode(255), primary_key=True),
            Column("group", Unicode(255), primary_key=True),
            Column("col", Unicode(255)),
            # Oracle won't let you do this unless the indexes have
            # the columns in different order
            Index("pk_idx_1", "id_b", "id_a", "group", unique=True),
            Index("pk_idx_2", "id_b", "group", "id_a", unique=True),
        )
        if explicit_pk:
            t.append_constraint(
                PrimaryKeyConstraint("id_a",
                                     "id_b",
                                     "group",
                                     name="some_primary_key"))
        metadata.create_all()

        insp = inspect(testing.db)
        eq_(
            insp.get_indexes("sometable"),
            [
                {
                    "name": "pk_idx_1",
                    "column_names": ["id_b", "id_a", "group"],
                    "dialect_options": {},
                    "unique": True,
                },
                {
                    "name": "pk_idx_2",
                    "column_names": ["id_b", "group", "id_a"],
                    "dialect_options": {},
                    "unique": True,
                },
            ],
        )
Esempio n. 12
0
    def test_quoted_column_non_unicode(self):
        metadata = self.metadata
        table = Table(
            "atable",
            metadata,
            Column("_underscorecolumn", Unicode(255), primary_key=True),
        )
        metadata.create_all()

        table.insert().execute({"_underscorecolumn": u("’é")})
        result = testing.db.execute(table.select().where(
            table.c._underscorecolumn == u("’é"))).scalar()
        eq_(result, u("’é"))
Esempio n. 13
0
    def test_quoted_column_unicode(self):
        metadata = self.metadata
        table = Table(
            "atable",
            metadata,
            Column(u("méil"), Unicode(255), primary_key=True),
        )
        metadata.create_all()

        table.insert().execute({u("méil"): u("’é")})
        result = testing.db.execute(
            table.select().where(table.c[u("méil")] == u("’é"))).scalar()
        eq_(result, u("’é"))
Esempio n. 14
0
 def setup_class(cls):
     global t, t2, metadata
     metadata = MetaData(testing.db)
     t = Table(
         "table1", metadata, *[
             Column("field%d" % fnum, String(50))
             for fnum in range(NUM_FIELDS)
         ])
     t2 = Table(
         "table2", metadata, *[
             Column("field%d" % fnum, Unicode(50))
             for fnum in range(NUM_FIELDS)
         ])
Esempio n. 15
0
    def test_ora8_flags(self):
        dialect = self._dialect((8, 2, 5))

        # before connect, assume modern DB
        assert dialect._supports_char_length
        assert dialect.use_ansi
        assert not dialect._use_nchar_for_unicode

        dialect.initialize(Mock())
        assert not dialect.implicit_returning
        assert not dialect._supports_char_length
        assert not dialect.use_ansi
        self.assert_compile(String(50), "VARCHAR2(50)", dialect=dialect)
        self.assert_compile(Unicode(50), "VARCHAR2(50)", dialect=dialect)
        self.assert_compile(UnicodeText(), "CLOB", dialect=dialect)

        dialect = self._dialect((8, 2, 5), implicit_returning=True)
        dialect.initialize(testing.db.connect())
        assert dialect.implicit_returning
Esempio n. 16
0
 def test_no_default(self):
     Column(Unicode(32))
Esempio n. 17
0
 def test_unicode_literal_binds(self):
     self.assert_compile(column("x", Unicode()) == "foo",
                         "x = N'foo'",
                         literal_binds=True)
Esempio n. 18
0
    def test_basic(self):
        metadata = self.metadata

        s_table = Table(
            "sometable",
            metadata,
            Column("id_a", Unicode(255), primary_key=True),
            Column("id_b", Unicode(255), primary_key=True, unique=True),
            Column("group", Unicode(255), primary_key=True),
            Column("col", Unicode(255)),
            UniqueConstraint("col", "group"),
        )

        # "group" is a keyword, so lower case
        normalind = Index("tableind", s_table.c.id_b, s_table.c.group)
        Index("compress1",
              s_table.c.id_a,
              s_table.c.id_b,
              oracle_compress=True)
        Index(
            "compress2",
            s_table.c.id_a,
            s_table.c.id_b,
            s_table.c.col,
            oracle_compress=1,
        )

        metadata.create_all()

        mirror = MetaData(testing.db)
        mirror.reflect()

        metadata.drop_all()
        mirror.create_all()

        inspect = MetaData(testing.db)
        inspect.reflect()

        def obj_definition(obj):
            return (
                obj.__class__,
                tuple([c.name for c in obj.columns]),
                getattr(obj, "unique", None),
            )

        # find what the primary k constraint name should be
        primaryconsname = testing.db.scalar(
            text("""SELECT constraint_name
               FROM all_constraints
               WHERE table_name = :table_name
               AND owner = :owner
               AND constraint_type = 'P' """),
            table_name=s_table.name.upper(),
            owner=testing.db.dialect.default_schema_name.upper(),
        )

        reflectedtable = inspect.tables[s_table.name]

        # make a dictionary of the reflected objects:

        reflected = dict([
            (obj_definition(i), i)
            for i in reflectedtable.indexes | reflectedtable.constraints
        ])

        # assert we got primary key constraint and its name, Error
        # if not in dict

        assert (reflected[(PrimaryKeyConstraint, ("id_a", "id_b", "group"),
                           None)].name.upper() == primaryconsname.upper())

        # Error if not in dict

        eq_(reflected[(Index, ("id_b", "group"), False)].name, normalind.name)
        assert (Index, ("id_b", ), True) in reflected
        assert (Index, ("col", "group"), True) in reflected

        idx = reflected[(Index, ("id_a", "id_b"), False)]
        assert idx.dialect_options["oracle"]["compress"] == 2

        idx = reflected[(Index, ("id_a", "id_b", "col"), False)]
        assert idx.dialect_options["oracle"]["compress"] == 1

        eq_(len(reflectedtable.constraints), 1)
        eq_(len(reflectedtable.indexes), 5)
Esempio n. 19
0
 def test_unicode_default(self):
     default = u("foo")
     Column(Unicode(32), default=default)
Esempio n. 20
0
class DialectTypesTest(fixtures.TestBase, AssertsCompiledSQL):
    __dialect__ = oracle.OracleDialect()

    def test_no_clobs_for_string_params(self):
        """test that simple string params get a DBAPI type of
        VARCHAR, not CLOB. This is to prevent setinputsizes
        from setting up cx_oracle.CLOBs on
        string-based bind params [ticket:793]."""

        class FakeDBAPI(object):
            def __getattr__(self, attr):
                return attr

        dialect = oracle.OracleDialect()
        dbapi = FakeDBAPI()

        b = bindparam("foo", "hello world!")
        eq_(b.type.dialect_impl(dialect).get_dbapi_type(dbapi), "STRING")

        b = bindparam("foo", "hello world!")
        eq_(b.type.dialect_impl(dialect).get_dbapi_type(dbapi), "STRING")

    def test_long(self):
        self.assert_compile(oracle.LONG(), "LONG")

    @testing.combinations(
        (Date(), cx_oracle._OracleDate),
        (oracle.OracleRaw(), cx_oracle._OracleRaw),
        (String(), String),
        (VARCHAR(), cx_oracle._OracleString),
        (DATE(), cx_oracle._OracleDate),
        (oracle.DATE(), oracle.DATE),
        (String(50), cx_oracle._OracleString),
        (Unicode(), cx_oracle._OracleUnicodeStringCHAR),
        (Text(), cx_oracle._OracleText),
        (UnicodeText(), cx_oracle._OracleUnicodeTextCLOB),
        (CHAR(), cx_oracle._OracleChar),
        (NCHAR(), cx_oracle._OracleNChar),
        (NVARCHAR(), cx_oracle._OracleUnicodeStringNCHAR),
        (oracle.RAW(50), cx_oracle._OracleRaw),
    )
    def test_type_adapt(self, start, test):
        dialect = cx_oracle.dialect()

        assert isinstance(
            start.dialect_impl(dialect), test
        ), "wanted %r got %r" % (test, start.dialect_impl(dialect))

    @testing.combinations(
        (String(), String),
        (VARCHAR(), cx_oracle._OracleString),
        (String(50), cx_oracle._OracleString),
        (Unicode(), cx_oracle._OracleUnicodeStringNCHAR),
        (Text(), cx_oracle._OracleText),
        (UnicodeText(), cx_oracle._OracleUnicodeTextNCLOB),
        (NCHAR(), cx_oracle._OracleNChar),
        (NVARCHAR(), cx_oracle._OracleUnicodeStringNCHAR),
    )
    def test_type_adapt_nchar(self, start, test):
        dialect = cx_oracle.dialect(use_nchar_for_unicode=True)

        assert isinstance(
            start.dialect_impl(dialect), test
        ), "wanted %r got %r" % (test, start.dialect_impl(dialect))

    def test_raw_compile(self):
        self.assert_compile(oracle.RAW(), "RAW")
        self.assert_compile(oracle.RAW(35), "RAW(35)")

    def test_char_length(self):
        self.assert_compile(VARCHAR(50), "VARCHAR(50 CHAR)")

        oracle8dialect = oracle.dialect()
        oracle8dialect.server_version_info = (8, 0)
        self.assert_compile(VARCHAR(50), "VARCHAR(50)", dialect=oracle8dialect)

        self.assert_compile(NVARCHAR(50), "NVARCHAR2(50)")
        self.assert_compile(CHAR(50), "CHAR(50)")

    @testing.combinations(
        (String(50), "VARCHAR2(50 CHAR)"),
        (Unicode(50), "VARCHAR2(50 CHAR)"),
        (NVARCHAR(50), "NVARCHAR2(50)"),
        (VARCHAR(50), "VARCHAR(50 CHAR)"),
        (oracle.NVARCHAR2(50), "NVARCHAR2(50)"),
        (oracle.VARCHAR2(50), "VARCHAR2(50 CHAR)"),
        (String(), "VARCHAR2"),
        (Unicode(), "VARCHAR2"),
        (NVARCHAR(), "NVARCHAR2"),
        (VARCHAR(), "VARCHAR"),
        (oracle.NVARCHAR2(), "NVARCHAR2"),
        (oracle.VARCHAR2(), "VARCHAR2"),
    )
    def test_varchar_types(self, typ, exp):
        dialect = oracle.dialect()
        self.assert_compile(typ, exp, dialect=dialect)

    @testing.combinations(
        (String(50), "VARCHAR2(50 CHAR)"),
        (Unicode(50), "NVARCHAR2(50)"),
        (NVARCHAR(50), "NVARCHAR2(50)"),
        (VARCHAR(50), "VARCHAR(50 CHAR)"),
        (oracle.NVARCHAR2(50), "NVARCHAR2(50)"),
        (oracle.VARCHAR2(50), "VARCHAR2(50 CHAR)"),
        (String(), "VARCHAR2"),
        (Unicode(), "NVARCHAR2"),
        (NVARCHAR(), "NVARCHAR2"),
        (VARCHAR(), "VARCHAR"),
        (oracle.NVARCHAR2(), "NVARCHAR2"),
        (oracle.VARCHAR2(), "VARCHAR2"),
    )
    def test_varchar_use_nchar_types(self, typ, exp):
        dialect = oracle.dialect(use_nchar_for_unicode=True)
        self.assert_compile(typ, exp, dialect=dialect)

    @testing.combinations(
        (oracle.INTERVAL(), "INTERVAL DAY TO SECOND"),
        (oracle.INTERVAL(day_precision=3), "INTERVAL DAY(3) TO SECOND"),
        (oracle.INTERVAL(second_precision=5), "INTERVAL DAY TO SECOND(5)"),
        (
            oracle.INTERVAL(day_precision=2, second_precision=5),
            "INTERVAL DAY(2) TO SECOND(5)",
        ),
    )
    def test_interval(self, type_, expected):
        self.assert_compile(type_, expected)
Esempio n. 21
0
class SetInputSizesTest(fixtures.TestBase):
    __only_on__ = "oracle+cx_oracle"
    __backend__ = True

    @testing.combinations(
        (SmallInteger, 25, int, False),
        (Integer, 25, int, False),
        (Numeric(10, 8), decimal.Decimal("25.34534"), None, False),
        (Float(15), 25.34534, None, False),
        (oracle.BINARY_DOUBLE, 25.34534, "NATIVE_FLOAT", False),
        (oracle.BINARY_FLOAT, 25.34534, "NATIVE_FLOAT", False),
        (oracle.DOUBLE_PRECISION, 25.34534, None, False),
        (Unicode(30), u("test"), "NCHAR", True),
        (UnicodeText(), u("test"), "NCLOB", True),
        (Unicode(30), u("test"), None, False),
        (UnicodeText(), u("test"), "CLOB", False),
        (String(30), "test", None, False),
        (CHAR(30), "test", "FIXED_CHAR", False),
        (NCHAR(30), u("test"), "FIXED_NCHAR", False),
        (oracle.LONG(), "test", None, False),
    )
    @testing.provide_metadata
    def test_setinputsizes(
        self, datatype, value, sis_value_text, set_nchar_flag
    ):
        if isinstance(sis_value_text, str):
            sis_value = getattr(testing.db.dialect.dbapi, sis_value_text)
        else:
            sis_value = sis_value_text

        class TestTypeDec(TypeDecorator):
            impl = NullType()

            def load_dialect_impl(self, dialect):
                if dialect.name == "oracle":
                    return dialect.type_descriptor(datatype)
                else:
                    return self.impl

        m = self.metadata
        # Oracle can have only one column of type LONG so we make three
        # tables rather than one table w/ three columns
        t1 = Table("t1", m, Column("foo", datatype))
        t2 = Table(
            "t2", m, Column("foo", NullType().with_variant(datatype, "oracle"))
        )
        t3 = Table("t3", m, Column("foo", TestTypeDec()))
        m.create_all()

        class CursorWrapper(object):
            # cx_oracle cursor can't be modified so we have to
            # invent a whole wrapping scheme

            def __init__(self, connection_fairy):
                self.cursor = connection_fairy.connection.cursor()
                self.mock = mock.Mock()
                connection_fairy.info["mock"] = self.mock

            def setinputsizes(self, *arg, **kw):
                self.mock.setinputsizes(*arg, **kw)
                self.cursor.setinputsizes(*arg, **kw)

            def __getattr__(self, key):
                return getattr(self.cursor, key)

        if set_nchar_flag:
            engine = testing_engine(options={"use_nchar_for_unicode": True})
        else:
            engine = testing.db

        with engine.connect() as conn:
            connection_fairy = conn.connection
            for tab in [t1, t2, t3]:
                with mock.patch.object(
                    connection_fairy,
                    "cursor",
                    lambda: CursorWrapper(connection_fairy),
                ):
                    conn.execute(tab.insert(), {"foo": value})

                if sis_value:
                    eq_(
                        conn.info["mock"].mock_calls,
                        [mock.call.setinputsizes(foo=sis_value)],
                    )
                else:
                    eq_(
                        conn.info["mock"].mock_calls,
                        [mock.call.setinputsizes()],
                    )

    def test_event_no_native_float(self):
        def _remove_type(inputsizes, cursor, statement, parameters, context):
            for param, dbapitype in list(inputsizes.items()):
                if dbapitype is testing.db.dialect.dbapi.NATIVE_FLOAT:
                    del inputsizes[param]

        event.listen(testing.db, "do_setinputsizes", _remove_type)
        try:
            self.test_setinputsizes(oracle.BINARY_FLOAT, 25.34534, None, False)
        finally:
            event.remove(testing.db, "do_setinputsizes", _remove_type)
Esempio n. 22
0
    def test_conflicting_backref_one(self):
        """test that conflicting backrefs raises an exception"""

        metadata = MetaData(testing.db)

        order = Table(
            "orders",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("type", Unicode(16)),
        )

        product = Table("products", metadata,
                        Column("id", Integer, primary_key=True))

        orderproduct = Table(
            "orderproducts",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("order_id",
                   Integer,
                   ForeignKey("orders.id"),
                   nullable=False),
            Column(
                "product_id",
                Integer,
                ForeignKey("products.id"),
                nullable=False,
            ),
        )

        class Order(object):
            pass

        class Product(object):
            pass

        class OrderProduct(object):
            pass

        order_join = order.select().alias("pjoin")

        mapper(
            Order,
            order,
            with_polymorphic=("*", order_join),
            polymorphic_on=order_join.c.type,
            polymorphic_identity="order",
            properties={
                "orderproducts":
                relationship(OrderProduct, lazy="select", backref="product")
            },
        )

        mapper(
            Product,
            product,
            properties={
                "orderproducts":
                relationship(OrderProduct, lazy="select", backref="product")
            },
        )

        mapper(OrderProduct, orderproduct)

        assert_raises_message(sa_exc.ArgumentError, "Error creating backref",
                              configure_mappers)
Esempio n. 23
0
class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
    """Tests MySQL-dialect specific compilation."""

    __dialect__ = mysql.dialect()

    def test_precolumns(self):
        dialect = self.__dialect__

        def gen(distinct=None, prefixes=None):
            kw = {}
            if distinct is not None:
                kw["distinct"] = distinct
            if prefixes is not None:
                kw["prefixes"] = prefixes
            return str(select([column("q")], **kw).compile(dialect=dialect))

        eq_(gen(None), "SELECT q")
        eq_(gen(True), "SELECT DISTINCT q")

        eq_(gen(prefixes=["ALL"]), "SELECT ALL q")
        eq_(gen(prefixes=["DISTINCTROW"]), "SELECT DISTINCTROW q")

        # Interaction with MySQL prefix extensions
        eq_(gen(None, ["straight_join"]), "SELECT straight_join q")
        eq_(
            gen(False, ["HIGH_PRIORITY", "SQL_SMALL_RESULT", "ALL"]),
            "SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL q",
        )
        eq_(
            gen(True, ["high_priority", sql.text("sql_cache")]),
            "SELECT high_priority sql_cache DISTINCT q",
        )

    def test_backslash_escaping(self):
        self.assert_compile(
            sql.column("foo").like("bar", escape="\\"),
            "foo LIKE %s ESCAPE '\\\\'",
        )

        dialect = mysql.dialect()
        dialect._backslash_escapes = False
        self.assert_compile(
            sql.column("foo").like("bar", escape="\\"),
            "foo LIKE %s ESCAPE '\\'",
            dialect=dialect,
        )

    def test_limit(self):
        t = sql.table("t", sql.column("col1"), sql.column("col2"))

        self.assert_compile(
            select([t]).limit(10).offset(20),
            "SELECT t.col1, t.col2 FROM t  LIMIT %s, %s",
            {
                "param_1": 20,
                "param_2": 10
            },
        )
        self.assert_compile(
            select([t]).limit(10),
            "SELECT t.col1, t.col2 FROM t  LIMIT %s",
            {"param_1": 10},
        )

        self.assert_compile(
            select([t]).offset(10),
            "SELECT t.col1, t.col2 FROM t  LIMIT %s, 18446744073709551615",
            {"param_1": 10},
        )

    @testing.combinations(
        (String, ),
        (VARCHAR, ),
        (String(), ),
        (VARCHAR(), ),
        (NVARCHAR(), ),
        (Unicode, ),
        (Unicode(), ),
    )
    def test_varchar_raise(self, type_):
        type_ = sqltypes.to_instance(type_)
        assert_raises_message(
            exc.CompileError,
            "VARCHAR requires a length on dialect mysql",
            type_.compile,
            dialect=mysql.dialect(),
        )

        t1 = Table("sometable", MetaData(), Column("somecolumn", type_))
        assert_raises_message(
            exc.CompileError,
            r"\(in table 'sometable', column 'somecolumn'\)\: "
            r"(?:N)?VARCHAR requires a length on dialect mysql",
            schema.CreateTable(t1).compile,
            dialect=mysql.dialect(),
        )

    def test_update_limit(self):
        t = sql.table("t", sql.column("col1"), sql.column("col2"))

        self.assert_compile(t.update(values={"col1": 123}),
                            "UPDATE t SET col1=%s")
        self.assert_compile(
            t.update(values={"col1": 123}, mysql_limit=5),
            "UPDATE t SET col1=%s LIMIT 5",
        )
        self.assert_compile(
            t.update(values={"col1": 123}, mysql_limit=None),
            "UPDATE t SET col1=%s",
        )
        self.assert_compile(
            t.update(t.c.col2 == 456, values={"col1": 123}, mysql_limit=1),
            "UPDATE t SET col1=%s WHERE t.col2 = %s LIMIT 1",
        )

    def test_utc_timestamp(self):
        self.assert_compile(func.utc_timestamp(), "utc_timestamp()")

    def test_utc_timestamp_fsp(self):
        self.assert_compile(
            func.utc_timestamp(5),
            "utc_timestamp(%s)",
            checkparams={"utc_timestamp_1": 5},
        )

    def test_sysdate(self):
        self.assert_compile(func.sysdate(), "SYSDATE()")

    m = mysql

    @testing.combinations(
        (Integer, "CAST(t.col AS SIGNED INTEGER)"),
        (INT, "CAST(t.col AS SIGNED INTEGER)"),
        (m.MSInteger, "CAST(t.col AS SIGNED INTEGER)"),
        (m.MSInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"),
        (SmallInteger, "CAST(t.col AS SIGNED INTEGER)"),
        (m.MSSmallInteger, "CAST(t.col AS SIGNED INTEGER)"),
        (m.MSTinyInteger, "CAST(t.col AS SIGNED INTEGER)"),
        # 'SIGNED INTEGER' is a bigint, so this is ok.
        (m.MSBigInteger, "CAST(t.col AS SIGNED INTEGER)"),
        (m.MSBigInteger(unsigned=False), "CAST(t.col AS SIGNED INTEGER)"),
        (m.MSBigInteger(unsigned=True), "CAST(t.col AS UNSIGNED INTEGER)"),
        # this is kind of sucky.  thank you default arguments!
        (NUMERIC, "CAST(t.col AS DECIMAL)"),
        (DECIMAL, "CAST(t.col AS DECIMAL)"),
        (Numeric, "CAST(t.col AS DECIMAL)"),
        (m.MSNumeric, "CAST(t.col AS DECIMAL)"),
        (m.MSDecimal, "CAST(t.col AS DECIMAL)"),
        (TIMESTAMP, "CAST(t.col AS DATETIME)"),
        (DATETIME, "CAST(t.col AS DATETIME)"),
        (DATE, "CAST(t.col AS DATE)"),
        (TIME, "CAST(t.col AS TIME)"),
        (DateTime, "CAST(t.col AS DATETIME)"),
        (Date, "CAST(t.col AS DATE)"),
        (Time, "CAST(t.col AS TIME)"),
        (DateTime, "CAST(t.col AS DATETIME)"),
        (Date, "CAST(t.col AS DATE)"),
        (m.MSTime, "CAST(t.col AS TIME)"),
        (m.MSTimeStamp, "CAST(t.col AS DATETIME)"),
        (String, "CAST(t.col AS CHAR)"),
        (Unicode, "CAST(t.col AS CHAR)"),
        (UnicodeText, "CAST(t.col AS CHAR)"),
        (VARCHAR, "CAST(t.col AS CHAR)"),
        (NCHAR, "CAST(t.col AS CHAR)"),
        (CHAR, "CAST(t.col AS CHAR)"),
        (m.CHAR(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"),
        (CLOB, "CAST(t.col AS CHAR)"),
        (TEXT, "CAST(t.col AS CHAR)"),
        (m.TEXT(charset="utf8"), "CAST(t.col AS CHAR CHARACTER SET utf8)"),
        (String(32), "CAST(t.col AS CHAR(32))"),
        (Unicode(32), "CAST(t.col AS CHAR(32))"),
        (CHAR(32), "CAST(t.col AS CHAR(32))"),
        (m.MSString, "CAST(t.col AS CHAR)"),
        (m.MSText, "CAST(t.col AS CHAR)"),
        (m.MSTinyText, "CAST(t.col AS CHAR)"),
        (m.MSMediumText, "CAST(t.col AS CHAR)"),
        (m.MSLongText, "CAST(t.col AS CHAR)"),
        (m.MSNChar, "CAST(t.col AS CHAR)"),
        (m.MSNVarChar, "CAST(t.col AS CHAR)"),
        (LargeBinary, "CAST(t.col AS BINARY)"),
        (BLOB, "CAST(t.col AS BINARY)"),
        (m.MSBlob, "CAST(t.col AS BINARY)"),
        (m.MSBlob(32), "CAST(t.col AS BINARY)"),
        (m.MSTinyBlob, "CAST(t.col AS BINARY)"),
        (m.MSMediumBlob, "CAST(t.col AS BINARY)"),
        (m.MSLongBlob, "CAST(t.col AS BINARY)"),
        (m.MSBinary, "CAST(t.col AS BINARY)"),
        (m.MSBinary(32), "CAST(t.col AS BINARY)"),
        (m.MSVarBinary, "CAST(t.col AS BINARY)"),
        (m.MSVarBinary(32), "CAST(t.col AS BINARY)"),
        (Interval, "CAST(t.col AS DATETIME)"),
    )
    def test_cast(self, type_, expected):
        t = sql.table("t", sql.column("col"))
        self.assert_compile(cast(t.c.col, type_), expected)

    def test_cast_type_decorator(self):
        class MyInteger(sqltypes.TypeDecorator):
            impl = Integer

        type_ = MyInteger()
        t = sql.table("t", sql.column("col"))
        self.assert_compile(cast(t.c.col, type_),
                            "CAST(t.col AS SIGNED INTEGER)")

    def test_cast_literal_bind(self):
        expr = cast(column("foo", Integer) + 5, Integer())

        self.assert_compile(expr,
                            "CAST(foo + 5 AS SIGNED INTEGER)",
                            literal_binds=True)

    def test_unsupported_cast_literal_bind(self):
        expr = cast(column("foo", Integer) + 5, Float)

        with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
            self.assert_compile(expr, "(foo + 5)", literal_binds=True)

        dialect = mysql.MySQLDialect()
        dialect.server_version_info = (3, 9, 8)
        with expect_warnings("Current MySQL version does not support CAST"):
            eq_(
                str(
                    expr.compile(dialect=dialect,
                                 compile_kwargs={"literal_binds": True})),
                "(foo + 5)",
            )

    m = mysql

    @testing.combinations(
        (m.MSBit, "t.col"),
        (FLOAT, "t.col"),
        (Float, "t.col"),
        (m.MSFloat, "t.col"),
        (m.MSDouble, "t.col"),
        (m.MSReal, "t.col"),
        (m.MSYear, "t.col"),
        (m.MSYear(2), "t.col"),
        (Boolean, "t.col"),
        (BOOLEAN, "t.col"),
        (m.MSEnum, "t.col"),
        (m.MSEnum("1", "2"), "t.col"),
        (m.MSSet, "t.col"),
        (m.MSSet("1", "2"), "t.col"),
    )
    def test_unsupported_casts(self, type_, expected):

        t = sql.table("t", sql.column("col"))
        with expect_warnings("Datatype .* does not support CAST on MySQL;"):
            self.assert_compile(cast(t.c.col, type_), expected)

    def test_no_cast_pre_4(self):
        self.assert_compile(cast(Column("foo", Integer), String),
                            "CAST(foo AS CHAR)")
        dialect = mysql.dialect()
        dialect.server_version_info = (3, 2, 3)
        with expect_warnings("Current MySQL version does not support CAST;"):
            self.assert_compile(cast(Column("foo", Integer), String),
                                "foo",
                                dialect=dialect)

    def test_cast_grouped_expression_non_castable(self):
        with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
            self.assert_compile(cast(sql.column("x") + sql.column("y"), Float),
                                "(x + y)")

    def test_cast_grouped_expression_pre_4(self):
        dialect = mysql.dialect()
        dialect.server_version_info = (3, 2, 3)
        with expect_warnings("Current MySQL version does not support CAST;"):
            self.assert_compile(
                cast(sql.column("x") + sql.column("y"), Integer),
                "(x + y)",
                dialect=dialect,
            )

    def test_extract(self):
        t = sql.table("t", sql.column("col1"))

        for field in "year", "month", "day":
            self.assert_compile(
                select([extract(field, t.c.col1)]),
                "SELECT EXTRACT(%s FROM t.col1) AS anon_1 FROM t" % field,
            )

        # millsecondS to millisecond
        self.assert_compile(
            select([extract("milliseconds", t.c.col1)]),
            "SELECT EXTRACT(millisecond FROM t.col1) AS anon_1 FROM t",
        )

    def test_too_long_index(self):
        exp = "ix_zyrenian_zyme_zyzzogeton_zyzzogeton_zyrenian_zyme_zyz_5cd2"
        tname = "zyrenian_zyme_zyzzogeton_zyzzogeton"
        cname = "zyrenian_zyme_zyzzogeton_zo"

        t1 = Table(tname, MetaData(), Column(cname, Integer, index=True))
        ix1 = list(t1.indexes)[0]

        self.assert_compile(
            schema.CreateIndex(ix1),
            "CREATE INDEX %s "
            "ON %s (%s)" % (exp, tname, cname),
        )

    def test_innodb_autoincrement(self):
        t1 = Table(
            "sometable",
            MetaData(),
            Column("assigned_id",
                   Integer(),
                   primary_key=True,
                   autoincrement=False),
            Column("id", Integer(), primary_key=True, autoincrement=True),
            mysql_engine="InnoDB",
        )
        self.assert_compile(
            schema.CreateTable(t1),
            "CREATE TABLE sometable (assigned_id "
            "INTEGER NOT NULL, id INTEGER NOT NULL "
            "AUTO_INCREMENT, PRIMARY KEY (id, assigned_id)"
            ")ENGINE=InnoDB",
        )

        t1 = Table(
            "sometable",
            MetaData(),
            Column("assigned_id",
                   Integer(),
                   primary_key=True,
                   autoincrement=True),
            Column("id", Integer(), primary_key=True, autoincrement=False),
            mysql_engine="InnoDB",
        )
        self.assert_compile(
            schema.CreateTable(t1),
            "CREATE TABLE sometable (assigned_id "
            "INTEGER NOT NULL AUTO_INCREMENT, id "
            "INTEGER NOT NULL, PRIMARY KEY "
            "(assigned_id, id))ENGINE=InnoDB",
        )

    def test_innodb_autoincrement_reserved_word_column_name(self):
        t1 = Table(
            "sometable",
            MetaData(),
            Column("id", Integer(), primary_key=True, autoincrement=False),
            Column("order", Integer(), primary_key=True, autoincrement=True),
            mysql_engine="InnoDB",
        )
        self.assert_compile(
            schema.CreateTable(t1),
            "CREATE TABLE sometable ("
            "id INTEGER NOT NULL, "
            "`order` INTEGER NOT NULL AUTO_INCREMENT, "
            "PRIMARY KEY (`order`, id)"
            ")ENGINE=InnoDB",
        )

    def test_create_table_with_partition(self):
        t1 = Table(
            "testtable",
            MetaData(),
            Column("id", Integer(), primary_key=True, autoincrement=True),
            Column("other_id",
                   Integer(),
                   primary_key=True,
                   autoincrement=False),
            mysql_partitions="2",
            mysql_partition_by="KEY(other_id)",
        )
        self.assert_compile(
            schema.CreateTable(t1),
            "CREATE TABLE testtable ("
            "id INTEGER NOT NULL AUTO_INCREMENT, "
            "other_id INTEGER NOT NULL, "
            "PRIMARY KEY (id, other_id)"
            ")PARTITION BY KEY(other_id) PARTITIONS 2",
        )

    def test_create_table_with_subpartition(self):
        t1 = Table(
            "testtable",
            MetaData(),
            Column("id", Integer(), primary_key=True, autoincrement=True),
            Column("other_id",
                   Integer(),
                   primary_key=True,
                   autoincrement=False),
            mysql_partitions="2",
            mysql_partition_by="KEY(other_id)",
            mysql_subpartition_by="HASH(some_expr)",
            mysql_subpartitions="2",
        )
        self.assert_compile(
            schema.CreateTable(t1),
            "CREATE TABLE testtable ("
            "id INTEGER NOT NULL AUTO_INCREMENT, "
            "other_id INTEGER NOT NULL, "
            "PRIMARY KEY (id, other_id)"
            ")PARTITION BY KEY(other_id) PARTITIONS 2 "
            "SUBPARTITION BY HASH(some_expr) SUBPARTITIONS 2",
        )

    def test_create_table_with_partition_hash(self):
        t1 = Table(
            "testtable",
            MetaData(),
            Column("id", Integer(), primary_key=True, autoincrement=True),
            Column("other_id",
                   Integer(),
                   primary_key=True,
                   autoincrement=False),
            mysql_partitions="2",
            mysql_partition_by="HASH(other_id)",
        )
        self.assert_compile(
            schema.CreateTable(t1),
            "CREATE TABLE testtable ("
            "id INTEGER NOT NULL AUTO_INCREMENT, "
            "other_id INTEGER NOT NULL, "
            "PRIMARY KEY (id, other_id)"
            ")PARTITION BY HASH(other_id) PARTITIONS 2",
        )

    def test_create_table_with_partition_and_other_opts(self):
        t1 = Table(
            "testtable",
            MetaData(),
            Column("id", Integer(), primary_key=True, autoincrement=True),
            Column("other_id",
                   Integer(),
                   primary_key=True,
                   autoincrement=False),
            mysql_stats_sample_pages="2",
            mysql_partitions="2",
            mysql_partition_by="HASH(other_id)",
        )
        self.assert_compile(
            schema.CreateTable(t1),
            "CREATE TABLE testtable ("
            "id INTEGER NOT NULL AUTO_INCREMENT, "
            "other_id INTEGER NOT NULL, "
            "PRIMARY KEY (id, other_id)"
            ")STATS_SAMPLE_PAGES=2 PARTITION BY HASH(other_id) PARTITIONS 2",
        )

    def test_create_table_with_collate(self):
        # issue #5411
        t1 = Table(
            "testtable",
            MetaData(),
            Column("id", Integer(), primary_key=True, autoincrement=True),
            mysql_engine="InnoDB",
            mysql_collate="utf8_icelandic_ci",
            mysql_charset="utf8",
        )
        first_part = ("CREATE TABLE testtable ("
                      "id INTEGER NOT NULL AUTO_INCREMENT, "
                      "PRIMARY KEY (id))")
        try:
            self.assert_compile(
                schema.CreateTable(t1),
                first_part +
                "ENGINE=InnoDB CHARSET=utf8 COLLATE utf8_icelandic_ci",
            )
        except AssertionError:
            self.assert_compile(
                schema.CreateTable(t1),
                first_part +
                "CHARSET=utf8 ENGINE=InnoDB COLLATE utf8_icelandic_ci",
            )

    def test_inner_join(self):
        t1 = table("t1", column("x"))
        t2 = table("t2", column("y"))

        self.assert_compile(t1.join(t2, t1.c.x == t2.c.y),
                            "t1 INNER JOIN t2 ON t1.x = t2.y")

    def test_outer_join(self):
        t1 = table("t1", column("x"))
        t2 = table("t2", column("y"))

        self.assert_compile(
            t1.outerjoin(t2, t1.c.x == t2.c.y),
            "t1 LEFT OUTER JOIN t2 ON t1.x = t2.y",
        )

    def test_full_outer_join(self):
        t1 = table("t1", column("x"))
        t2 = table("t2", column("y"))

        self.assert_compile(
            t1.outerjoin(t2, t1.c.x == t2.c.y, full=True),
            "t1 FULL OUTER JOIN t2 ON t1.x = t2.y",
        )
Esempio n. 24
0
    def test_with_polymorphic(self):
        metadata = MetaData(testing.db)

        order = Table(
            "orders",
            metadata,
            Column("id", Integer, primary_key=True),
            Column(
                "employee_id",
                Integer,
                ForeignKey("employees.id"),
                nullable=False,
            ),
            Column("type", Unicode(16)),
        )

        employee = Table(
            "employees",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("name", Unicode(16), unique=True, nullable=False),
        )

        product = Table("products", metadata,
                        Column("id", Integer, primary_key=True))

        orderproduct = Table(
            "orderproducts",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("order_id",
                   Integer,
                   ForeignKey("orders.id"),
                   nullable=False),
            Column(
                "product_id",
                Integer,
                ForeignKey("products.id"),
                nullable=False,
            ),
        )

        class Order(object):
            pass

        class Employee(object):
            pass

        class Product(object):
            pass

        class OrderProduct(object):
            pass

        order_join = order.select().alias("pjoin")

        mapper(
            Order,
            order,
            with_polymorphic=("*", order_join),
            polymorphic_on=order_join.c.type,
            polymorphic_identity="order",
            properties={
                "orderproducts":
                relationship(OrderProduct, lazy="select", backref="order")
            },
        )

        mapper(
            Product,
            product,
            properties={
                "orderproducts":
                relationship(OrderProduct, lazy="select", backref="product")
            },
        )

        mapper(
            Employee,
            employee,
            properties={
                "orders": relationship(Order,
                                       lazy="select",
                                       backref="employee")
            },
        )

        mapper(OrderProduct, orderproduct)

        # this requires that the compilation of order_mapper's "surrogate
        # mapper" occur after the initial setup of MapperProperty objects on
        # the mapper.
        configure_mappers()