Example #1
0
    def test_reserved_words_mysql_vs_mariadb(self,
                                             mysql_mariadb_reserved_words):
        """test #7167 - real backend level

        We want to make sure that the "is mariadb" flag as well as the
        correct identifier preparer are set up for dialects no matter how they
        determine their "is_mariadb" flag.

        """

        dialect = testing.db.dialect
        expect_mariadb = testing.only_on("mariadb").enabled

        table, expected_mysql, expected_mdb = mysql_mariadb_reserved_words
        self.assert_compile(
            select(table),
            expected_mdb if expect_mariadb else expected_mysql,
            dialect=dialect,
        )
Example #2
0
class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults,
                        ComparesTables):
    __only_on__ = "mssql"

    __backend__ = True

    def test_decimal_notation(self, metadata, connection):
        numeric_table = Table(
            "numeric_table",
            metadata,
            Column(
                "id",
                Integer,
                Sequence("numeric_id_seq", optional=True),
                primary_key=True,
            ),
            Column("numericcol", Numeric(precision=38,
                                         scale=20,
                                         asdecimal=True)),
        )
        metadata.create_all(connection)
        test_items = [
            decimal.Decimal(d) for d in (
                "1500000.00000000000000000000",
                "-1500000.00000000000000000000",
                "1500000",
                "0.0000000000000000002",
                "0.2",
                "-0.0000000000000000002",
                "-2E-2",
                "156666.458923543",
                "-156666.458923543",
                "1",
                "-1",
                "-1234",
                "1234",
                "2E-12",
                "4E8",
                "3E-6",
                "3E-7",
                "4.1",
                "1E-1",
                "1E-2",
                "1E-3",
                "1E-4",
                "1E-5",
                "1E-6",
                "1E-7",
                "1E-1",
                "1E-8",
                "0.2732E2",
                "-0.2432E2",
                "4.35656E2",
                "-02452E-2",
                "45125E-2",
                "1234.58965E-2",
                "1.521E+15",
                # previously, these were at -1E-25, which were inserted
                # cleanly however we only got back 20 digits of accuracy.
                # pyodbc as of 4.0.22 now disallows the silent truncation.
                "-1E-20",
                "1E-20",
                "1254E-20",
                "-1203E-20",
                "0",
                "-0.00",
                "-0",
                "4585E12",
                "000000000000000000012",
                "000000000000.32E12",
                "00000000000000.1E+12",
                # these are no longer accepted by pyodbc 4.0.22 but it seems
                # they were not actually round-tripping correctly before that
                # in any case
                # '-1E-25',
                # '1E-25',
                # '1254E-25',
                # '-1203E-25',
                # '000000000000.2E-32',
            )
        ]

        for value in test_items:
            result = connection.execute(numeric_table.insert(),
                                        dict(numericcol=value))
            primary_key = result.inserted_primary_key
            returned = connection.scalar(
                select(numeric_table.c.numericcol).where(
                    numeric_table.c.id == primary_key[0]))
            eq_(value, returned)

    def test_float(self, metadata, connection):

        float_table = Table(
            "float_table",
            metadata,
            Column(
                "id",
                Integer,
                Sequence("numeric_id_seq", optional=True),
                primary_key=True,
            ),
            Column("floatcol", Float()),
        )

        metadata.create_all(connection)
        test_items = [
            float(d) for d in (
                "1500000.00000000000000000000",
                "-1500000.00000000000000000000",
                "1500000",
                "0.0000000000000000002",
                "0.2",
                "-0.0000000000000000002",
                "156666.458923543",
                "-156666.458923543",
                "1",
                "-1",
                "1234",
                "2E-12",
                "4E8",
                "3E-6",
                "3E-7",
                "4.1",
                "1E-1",
                "1E-2",
                "1E-3",
                "1E-4",
                "1E-5",
                "1E-6",
                "1E-7",
                "1E-8",
            )
        ]
        for value in test_items:
            result = connection.execute(float_table.insert(),
                                        dict(floatcol=value))
            primary_key = result.inserted_primary_key
            returned = connection.scalar(
                select(float_table.c.floatcol).where(
                    float_table.c.id == primary_key[0]))
            eq_(value, returned)

    def test_dates(self, metadata, connection):
        "Exercise type specification for date types."

        columns = [
            # column type, args, kwargs, expected ddl
            (mssql.MSDateTime, [], {}, "DATETIME", []),
            (types.DATE, [], {}, "DATE", [">=", (10, )]),
            (types.Date, [], {}, "DATE", [">=", (10, )]),
            (types.Date, [], {}, "DATETIME", ["<", (10, )], mssql.MSDateTime),
            (mssql.MSDate, [], {}, "DATE", [">=", (10, )]),
            (mssql.MSDate, [], {}, "DATETIME", ["<",
                                                (10, )], mssql.MSDateTime),
            (types.TIME, [], {}, "TIME", [">=", (10, )]),
            (types.Time, [], {}, "TIME", [">=", (10, )]),
            (mssql.MSTime, [], {}, "TIME", [">=", (10, )]),
            (mssql.MSTime, [1], {}, "TIME(1)", [">=", (10, )]),
            (types.Time, [], {}, "DATETIME", ["<", (10, )], mssql.MSDateTime),
            (mssql.MSTime, [], {}, "TIME", [">=", (10, )]),
            (mssql.MSSmallDateTime, [], {}, "SMALLDATETIME", []),
            (mssql.MSDateTimeOffset, [], {}, "DATETIMEOFFSET", [">=", (10, )]),
            (
                mssql.MSDateTimeOffset,
                [1],
                {},
                "DATETIMEOFFSET(1)",
                [">=", (10, )],
            ),
            (mssql.MSDateTime2, [], {}, "DATETIME2", [">=", (10, )]),
            (mssql.MSDateTime2, [0], {}, "DATETIME2(0)", [">=", (10, )]),
            (mssql.MSDateTime2, [1], {}, "DATETIME2(1)", [">=", (10, )]),
        ]

        table_args = ["test_mssql_dates", metadata]
        for index, spec in enumerate(columns):
            type_, args, kw, res, requires = spec[0:5]
            if (requires and testing._is_excluded("mssql", *requires)
                    or not requires):
                c = Column("c%s" % index, type_(*args, **kw), nullable=None)
                connection.dialect.type_descriptor(c.type)
                table_args.append(c)
        dates_table = Table(*table_args)
        gen = connection.dialect.ddl_compiler(connection.dialect,
                                              schema.CreateTable(dates_table))
        for col in dates_table.c:
            index = int(col.name[1:])
            testing.eq_(
                gen.get_column_specification(col),
                "%s %s" % (col.name, columns[index][3]),
            )
            self.assert_(repr(col))
        dates_table.create(connection)
        reflected_dates = Table("test_mssql_dates",
                                MetaData(),
                                autoload_with=connection)
        for col in reflected_dates.c:
            self.assert_types_base(col, dates_table.c[col.key])

    @testing.metadata_fixture()
    def date_fixture(self, metadata):
        t = Table(
            "test_dates",
            metadata,
            Column("adate", Date),
            Column("atime1", Time),
            Column("atime2", Time),
            Column("adatetime", DateTime),
            Column("adatetimeoffset", DATETIMEOFFSET),
            Column("adatetimewithtimezone", DateTime(timezone=True)),
        )

        d1 = datetime.date(2007, 10, 30)
        t1 = datetime.time(11, 2, 32)
        d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
        d3 = datetime.datetime(
            2007,
            10,
            30,
            11,
            2,
            32,
            123456,
            util.timezone(datetime.timedelta(hours=-5)),
        )
        return t, (d1, t1, d2, d3)

    def test_date_roundtrips(self, date_fixture, connection):
        t, (d1, t1, d2, d3) = date_fixture
        connection.execute(
            t.insert(),
            dict(
                adate=d1,
                adatetime=d2,
                atime1=t1,
                atime2=d2,
                adatetimewithtimezone=d3,
            ),
        )

        row = connection.execute(t.select()).first()
        eq_(
            (
                row.adate,
                row.adatetime,
                row.atime1,
                row.atime2,
                row.adatetimewithtimezone,
            ),
            (d1, d2, t1, d2.time(), d3),
        )

    @testing.combinations(
        (datetime.datetime(
            2007,
            10,
            30,
            11,
            2,
            32,
            tzinfo=util.timezone(datetime.timedelta(hours=-5)),
        ), ),
        (datetime.datetime(2007, 10, 30, 11, 2, 32)),
        argnames="date",
    )
    def test_tz_present_or_non_in_dates(self, date_fixture, connection, date):
        t, (d1, t1, d2, d3) = date_fixture
        connection.execute(
            t.insert(),
            dict(
                adatetime=date,
                adatetimewithtimezone=date,
            ),
        )

        row = connection.execute(
            select(t.c.adatetime, t.c.adatetimewithtimezone)).first()

        if not date.tzinfo:
            eq_(row, (date, date.replace(tzinfo=util.timezone.utc)))
        else:
            eq_(row, (date.replace(tzinfo=None), date))

    @testing.metadata_fixture()
    def datetimeoffset_fixture(self, metadata):
        t = Table(
            "test_dates",
            metadata,
            Column("adatetimeoffset", DATETIMEOFFSET),
        )

        return t

    @testing.combinations(
        ("dto_param_none", lambda: None, None, False),
        (
            "dto_param_datetime_aware_positive",
            lambda: datetime.datetime(
                2007,
                10,
                30,
                11,
                2,
                32,
                123456,
                util.timezone(datetime.timedelta(hours=1)),
            ),
            1,
            False,
        ),
        (
            "dto_param_datetime_aware_negative",
            lambda: datetime.datetime(
                2007,
                10,
                30,
                11,
                2,
                32,
                123456,
                util.timezone(datetime.timedelta(hours=-5)),
            ),
            -5,
            False,
        ),
        (
            "dto_param_datetime_aware_seconds_frac_fail",
            lambda: datetime.datetime(
                2007,
                10,
                30,
                11,
                2,
                32,
                123456,
                util.timezone(datetime.timedelta(seconds=4000)),
            ),
            None,
            True,
        ),
        (
            "dto_param_datetime_naive",
            lambda: datetime.datetime(2007, 10, 30, 11, 2, 32, 123456),
            0,
            False,
        ),
        (
            "dto_param_string_one",
            lambda: "2007-10-30 11:02:32.123456 +01:00",
            1,
            False,
        ),
        # wow
        (
            "dto_param_string_two",
            lambda: "October 30, 2007 11:02:32.123456",
            0,
            False,
        ),
        ("dto_param_string_invalid", lambda: "this is not a date", 0, True),
        id_="iaaa",
        argnames="dto_param_value, expected_offset_hours, should_fail",
    )
    def test_datetime_offset(
        self,
        datetimeoffset_fixture,
        dto_param_value,
        expected_offset_hours,
        should_fail,
        connection,
    ):
        t = datetimeoffset_fixture
        dto_param_value = dto_param_value()

        if should_fail:
            assert_raises(
                sa.exc.DBAPIError,
                connection.execute,
                t.insert(),
                dict(adatetimeoffset=dto_param_value),
            )
            return

        connection.execute(
            t.insert(),
            dict(adatetimeoffset=dto_param_value),
        )

        row = connection.execute(t.select()).first()

        if dto_param_value is None:
            is_(row.adatetimeoffset, None)
        else:
            eq_(
                row.adatetimeoffset,
                datetime.datetime(
                    2007,
                    10,
                    30,
                    11,
                    2,
                    32,
                    123456,
                    util.timezone(
                        datetime.timedelta(hours=expected_offset_hours)),
                ),
            )

    @testing.combinations(
        ("legacy_large_types", False),
        ("sql2012_large_types", True, lambda: testing.only_on("mssql >= 11")),
        id_="ia",
        argnames="deprecate_large_types",
    )
    def test_binary_reflection(self, metadata, deprecate_large_types):
        "Exercise type specification for binary types."

        columns = [
            # column type, args, kwargs, expected ddl from reflected
            (mssql.MSBinary, [], {}, "BINARY(1)"),
            (mssql.MSBinary, [10], {}, "BINARY(10)"),
            (types.BINARY, [], {}, "BINARY(1)"),
            (types.BINARY, [10], {}, "BINARY(10)"),
            (mssql.MSVarBinary, [], {}, "VARBINARY(max)"),
            (mssql.MSVarBinary, [10], {}, "VARBINARY(10)"),
            (types.VARBINARY, [10], {}, "VARBINARY(10)"),
            (types.VARBINARY, [], {}, "VARBINARY(max)"),
            (mssql.MSImage, [], {}, "IMAGE"),
            (mssql.IMAGE, [], {}, "IMAGE"),
            (
                types.LargeBinary,
                [],
                {},
                "IMAGE" if not deprecate_large_types else "VARBINARY(max)",
            ),
        ]

        engine = engines.testing_engine(
            options={"deprecate_large_types": deprecate_large_types})
        with engine.begin() as conn:
            table_args = ["test_mssql_binary", metadata]
            for index, spec in enumerate(columns):
                type_, args, kw, res = spec
                table_args.append(
                    Column("c%s" % index, type_(*args, **kw), nullable=None))
            binary_table = Table(*table_args)
            metadata.create_all(conn)
            reflected_binary = Table("test_mssql_binary",
                                     MetaData(),
                                     autoload_with=conn)
            for col, spec in zip(reflected_binary.c, columns):
                eq_(
                    col.type.compile(dialect=mssql.dialect()),
                    spec[3],
                    "column %s %s != %s" % (
                        col.key,
                        col.type.compile(dialect=conn.dialect),
                        spec[3],
                    ),
                )
                c1 = conn.dialect.type_descriptor(col.type).__class__
                c2 = conn.dialect.type_descriptor(
                    binary_table.c[col.name].type).__class__
                assert issubclass(
                    c1,
                    c2), "column %s: %r is not a subclass of %r" % (col.key,
                                                                    c1, c2)
                if binary_table.c[col.name].type.length:
                    testing.eq_(col.type.length,
                                binary_table.c[col.name].type.length)

    def test_autoincrement(self, metadata, connection):
        Table(
            "ai_1",
            metadata,
            Column("int_y", Integer, primary_key=True, autoincrement=True),
            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
        )
        Table(
            "ai_2",
            metadata,
            Column("int_y", Integer, primary_key=True, autoincrement=True),
            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
        )
        Table(
            "ai_3",
            metadata,
            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
            Column("int_y", Integer, primary_key=True, autoincrement=True),
        )

        Table(
            "ai_4",
            metadata,
            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
            Column("int_n2", Integer, DefaultClause("0"), primary_key=True),
        )
        Table(
            "ai_5",
            metadata,
            Column("int_y", Integer, primary_key=True, autoincrement=True),
            Column("int_n", Integer, DefaultClause("0"), primary_key=True),
        )
        Table(
            "ai_6",
            metadata,
            Column("o1", String(1), DefaultClause("x"), primary_key=True),
            Column("int_y", Integer, primary_key=True, autoincrement=True),
        )
        Table(
            "ai_7",
            metadata,
            Column("o1", String(1), DefaultClause("x"), primary_key=True),
            Column("o2", String(1), DefaultClause("x"), primary_key=True),
            Column("int_y", Integer, autoincrement=True, primary_key=True),
        )
        Table(
            "ai_8",
            metadata,
            Column("o1", String(1), DefaultClause("x"), primary_key=True),
            Column("o2", String(1), DefaultClause("x"), primary_key=True),
        )
        metadata.create_all(connection)

        table_names = [
            "ai_1",
            "ai_2",
            "ai_3",
            "ai_4",
            "ai_5",
            "ai_6",
            "ai_7",
            "ai_8",
        ]
        mr = MetaData()

        for name in table_names:
            tbl = Table(name, mr, autoload_with=connection)
            tbl = metadata.tables[name]

            # test that the flag itself reflects appropriately
            for col in tbl.c:
                if "int_y" in col.name:
                    is_(col.autoincrement, True)
                    is_(tbl._autoincrement_column, col)
                else:
                    eq_(col.autoincrement, "auto")
                    is_not(tbl._autoincrement_column, col)

            eng = [
                engines.testing_engine(options={"implicit_returning": False}),
                engines.testing_engine(options={"implicit_returning": True}),
            ]

            for counter, engine in enumerate(eng):
                connection.execute(tbl.insert())
                if "int_y" in tbl.c:
                    eq_(
                        connection.execute(select(tbl.c.int_y)).scalar(),
                        counter + 1,
                    )
                    assert (list(connection.execute(
                        tbl.select()).first()).count(counter + 1) == 1)
                else:
                    assert 1 not in list(
                        connection.execute(tbl.select()).first())
                connection.execute(tbl.delete())