示例#1
0
    def test_wrapper(self):
        with self.connect() as connection:
            with connection.cursor() as cursor:
                for value in (b'1234', unicode_('1234'), None):
                    wrapper = ctds.SqlVarChar(value)
                    self.assertEqual(id(value), id(wrapper.value))
                    self.assertEqual(wrapper.size,
                                     len(value) if value is not None else 1)
                    row = self.parameter_type(cursor, wrapper)
                    self.assertEqual(row.Type,
                                     'varchar' if value is not None else None)
                    expected = wrapper.value.decode() if isinstance(
                        value, bytes) else wrapper.value
                    self.assertEqual(row.Value, expected)
                    self.assertEqual(row.MaxLength,
                                     len(value) if value is not None else None)

                    self.assertEqual(
                        repr(wrapper),
                        'ctds.SqlVarChar({0!r}, size={1})'.format(
                            value, wrapper.size))
                    self.assertEqual(repr(wrapper), str(wrapper))

                for expected in (unicode_(b'*', encoding='utf-8') * 54321, ):
                    cursor.execute('SELECT :0', (ctds.SqlVarChar(expected), ))
                    actual = cursor.fetchone()[0]
                    self.assertEqual(len(expected), len(actual))
                    self.assertEqual(expected, actual)
示例#2
0
    def test_insert_dict(self):
        with self.connect(autocommit=False) as connection:
            try:
                with connection.cursor() as cursor:
                    cursor.execute('''
                        CREATE TABLE {0}
                        (
                            PrimaryKey INT NOT NULL PRIMARY KEY,
                            Date       DATETIME NULL,
                            /*
                                FreeTDS' bulk insert implementation doesn't seem to work
                                properly with *VARCHAR(MAX) columns.
                            */
                            String     VARCHAR(1000) COLLATE SQL_Latin1_General_CP1_CI_AS NOT NULL,
                            Unicode    NVARCHAR(100) NULL,
                            Bytes      VARBINARY(1000) NULL,
                            Decimal    DECIMAL(7,3) NOT NULL
                        )
                        '''.format(self.test_insert_dict.__name__))

                rows = 100
                inserted = connection.bulk_insert(
                    self.test_insert_dict.__name__, ({
                        'Bytes':
                        bytes(ix + 1),
                        'Date':
                        datetime.datetime(2000 + ix, 1, 1) if ix % 2 else None,
                        'Decimal':
                        Decimal(str(ix + .125)),
                        'PrimaryKey':
                        ix,
                        'String':
                        ctds.SqlVarChar(
                            unicode_(b'this is row {0} \xc2\xbd',
                                     encoding='utf-8').format(ix).encode(
                                         'latin-1')),
                        'Unicode':
                        ctds.SqlVarChar(
                            (unicode_(b'\xe3\x83\x9b', encoding='utf-8') *
                             100).encode('utf-16le')),
                    } for ix in range(0, rows)))

                self.assertEqual(inserted, rows)

                with connection.cursor() as cursor:
                    cursor.execute('SELECT * FROM {0}'.format(
                        self.test_insert_dict.__name__))
                    self.assertEqual(
                        [tuple(row) for row in cursor.fetchall()],
                        [(ix, datetime.datetime(2000 + ix, 1, 1) if ix %
                          2 else None,
                          unicode_(b'this is row {0} \xc2\xbd',
                                   encoding='utf-8').format(ix),
                          unicode_(b'\xe3\x83\x9b', encoding='utf-8') * 100,
                          bytes(ix + 1), Decimal(str(ix + .125)))
                         for ix in range(0, rows)])

            finally:
                connection.rollback()
示例#3
0
    def test_insert_empty_string(self):
        with self.connect(autocommit=False) as connection:
            try:
                with connection.cursor() as cursor:
                    cursor.execute('''
                        CREATE TABLE {0}
                        (
                            String VARCHAR(1000) COLLATE SQL_Latin1_General_CP1_CI_AS
                        )
                        '''.format(self.test_insert_empty_string.__name__))

                with warnings.catch_warnings(record=True) as warns:
                    connection.bulk_insert(
                        self.test_insert_empty_string.__name__, (
                            (ctds.SqlVarChar(
                                unicode_(
                                    b'\xc2\xbd',
                                    encoding='utf-8').encode('latin-1')), ),
                            (ctds.SqlVarChar(
                                unicode_('').encode('latin-1')), ),
                        ))
                if self.bcp_empty_string_supported:
                    self.assertEqual([str(warn.message) for warn in warns], [
                        '''\
"" converted to NULL for compatibility with FreeTDS. \
Please update to a recent version of FreeTDS. \
'''
                    ] * len(warns))

                with connection.cursor() as cursor:
                    cursor.execute('''
                        SELECT
                            String,
                            CONVERT(VARBINARY(1000), String) AS Bytes
                        FROM
                            {0}
                        '''.format(self.test_insert_empty_string.__name__))
                    self.assertEqual(
                        [tuple(row) for row in cursor.fetchall()], [
                            (
                                unicode_(b'\xc2\xbd', encoding='utf-8'),
                                b'\xbd',
                            ),
                            (
                                unicode_('')
                                if self.bcp_empty_string_supported else None,
                                None,
                            ),
                        ])

            finally:
                connection.rollback()
示例#4
0
    def test_insert_invalid_encoding(self):
        with self.connect(autocommit=False) as connection:
            try:
                with connection.cursor() as cursor:
                    cursor.execute('''
                        CREATE TABLE {0}
                        (
                            String VARCHAR(1000) COLLATE SQL_Latin1_General_CP1_CI_AS
                        )
                        '''.format(self.test_insert_invalid_encoding.__name__))

                connection.bulk_insert(
                    self.test_insert_invalid_encoding.__name__,
                    ((ctds.SqlVarChar(
                        unicode_(b'\xc2\xbd',
                                 encoding='utf-8').encode('utf-8')), ), ))

                with connection.cursor() as cursor:
                    cursor.execute('''
                        SELECT
                            String,
                            CONVERT(VARBINARY(1000), String) AS Bytes
                        FROM
                            {0}
                        '''.format(self.test_insert_invalid_encoding.__name__))
                    self.assertEqual(
                        [tuple(row) for row in cursor.fetchall()], [(
                            unicode_(b'\xc2\xbd', encoding='latin-1'),
                            b'\xc2\xbd',
                        )])

            finally:
                connection.rollback()
示例#5
0
 def test_size(self):
     with self.connect() as connection:
         with connection.cursor() as cursor:
             # The parameter_type method does not work with NVARCHAR(MAX) and
             # will fail with "Operand type clash: varchar(max) is incompatible with sql_variant"
             # Therefore, limit input sizes to 8000 or less.
             for value, size in (
                 (b'1234', 14),
                 (b'1234', 1),
                 (unicode_('*'), 5000),
                 (unicode_('*' * 5000), 5000),
                 (unicode_('*' * 8000), 8000),
                 (None, 14),
             ):
                 wrapper = ctds.SqlVarChar(value, size=size)
                 self.assertEqual(id(value), id(wrapper.value))
                 self.assertEqual(wrapper.size, size)
                 row = self.parameter_type(cursor, wrapper)
                 self.assertEqual(row.Type,
                                  'varchar' if value is not None else None)
                 expected = wrapper.value.decode() if isinstance(
                     value, bytes) else wrapper.value
                 if value is not None:
                     self.assertEqual(len(row.Value), min(size, len(value)))
                 self.assertEqual(
                     row.Value,
                     expected[:size] if value is not None else None)
                 self.assertEqual(row.MaxLength,
                                  size if value is not None else None)
示例#6
0
    def test_varchar(self):
        with self.connect() as connection:
            with connection.cursor() as cursor:
                for value in (unicode_(''), None, unicode_(' '),
                              unicode_('one'),
                              unicode_(b'hola \xc2\xa9', encoding='utf-8')):
                    for size in (None, 1, 3, 500):
                        kwargs = {}
                        if size is not None:
                            kwargs['size'] = size
                            expected_size = size
                        else:
                            expected_size = 1 if value is None else max(
                                1, len(value.encode('utf-8')))

                        varchar = ctds.SqlVarChar(value, **kwargs)
                        self.assertEqual(id(varchar.value), id(value))
                        self.assertEqual(varchar.size, expected_size)
                        self.assertEqual(varchar.tdstype, ctds.VARCHAR)

                        cursor.execute(
                            '''
                            SELECT :0
                            ''', (varchar, ))

                        # $future: fix this once supported by FreeTDS
                        # Currently FreeTDS (really the db-lib API) will
                        # turn empty string to NULL
                        if value == '' and self.use_sp_executesql:
                            value = None
                        self.assertEqual(
                            [tuple(row) for row in cursor.fetchall()],
                            [(value
                              if value is None else value[0:expected_size], )])
示例#7
0
    def _encode_result(self, rows, target_fields=None):
        result = [(ctds.SqlVarChar(col.encode(self.dest_character_encoding))
                   if isinstance(col, str) else col for col in tuple(row))
                  for row in rows]

        if self.bulk_insert_dict_rows:
            result = [dict(zip(target_fields, tuple(row))) for row in result]

        return result
示例#8
0
    def test_insert_tablock(self):
        with self.connect(autocommit=False) as connection:
            try:
                with connection.cursor() as cursor:
                    cursor.execute('''
                        CREATE TABLE {0}
                        (
                            PrimaryKey  INT NOT NULL PRIMARY KEY,
                            Date        DATETIME,
                            String      VARCHAR(1000),
                            Unicode     NVARCHAR(1000),
                            Bytes       VARBINARY(1000),
                            Decimal     DECIMAL(7,3)
                        )
                        '''.format(self.test_insert.__name__))

                rows = 100
                inserted = connection.bulk_insert(
                    self.test_insert.__name__,
                    ((ix, datetime.datetime(2000 +
                                            ix, 1, 1) if ix < 1000 else None,
                      ctds.SqlVarChar(
                          unicode_(
                              b'this is row {0} \xc2\xbd',
                              encoding='utf-8').format(ix).encode('latin-1')),
                      ctds.SqlVarChar(
                          (unicode_(b'\xe3\x83\x9b',
                                    encoding='utf-8')).encode('utf-16le')),
                      bytes(ix + 1), Decimal(str(ix + .125)))
                     for ix in range(0, rows)),
                    tablock=True)
                self.assertEqual(inserted, rows)

                with connection.cursor() as cursor:
                    cursor.execute('SELECT COUNT(1) FROM {0}'.format(
                        self.test_insert.__name__))
                    self.assertEqual(cursor.fetchone()[0], rows)

            finally:
                connection.rollback()
 def test_wrapper(self):
     with self.connect() as connection:
         with connection.cursor() as cursor:
             for value in (b'1234', unicode_('1234'), None):
                 wrapper = ctds.SqlVarChar(value)
                 self.assertEqual(id(value), id(wrapper.value))
                 self.assertEqual(wrapper.size,
                                  len(value) if value is not None else 1)
                 row = self.parameter_type(cursor, wrapper)
                 self.assertEqual(row.Type,
                                  'varchar' if value is not None else None)
                 expected = wrapper.value.decode() if isinstance(
                     value, bytes) else wrapper.value
                 self.assertEqual(row.Value, expected)
                 self.assertEqual(row.MaxLength,
                                  len(value) if value is not None else None)
示例#10
0
    def bulk_insert_rows_ctds(self,
                              table,
                              rows,
                              target_fields,
                              commit_every=5000):
        """
        ;param table: Name of the target table
        ;type  table: str
        ;param rows: The rows to insert into the table, data types being correct is important
        ;type  rows: iterable of tuples
        ;param target_fields: The names of the columns to fill in the table
        ;type  target_fields: iterable of strings
        ;param commit_every: An optional batch size.
        ;type  commit_every: int
        """

        if len(rows):
            with closing(self.get_ctds_conn()) as conn:
                encoded_rows = [(ctds.SqlVarChar(col.encode('utf-16le'))
                                 if isinstance(col, basestring) else col
                                 for col in tuple(row)) for row in rows]
                data = [
                    dict(zip(target_fields, tuple(row)))
                    for row in encoded_rows
                ]

                try:
                    rows_saved = conn.bulk_insert(table=table,
                                                  rows=data,
                                                  batch_size=commit_every,
                                                  tablock=True)
                    if rows_saved != len(rows):
                        self.log.error('Table: {}'.format(table))
                        pprint(data)
                        raise AirflowException(
                            'ERROR bulk_insert only = {} should have been {}'.
                            format(rows_saved, len(rows)))
                except _tds.DatabaseError:
                    self.log.error('Table: {}'.format(table))
                    pprint(data)
                    raise AirflowException('ERROR DatabaseError: '.format(
                        rows_saved, len(rows)))
示例#11
0
    def test_nvarchar(self):
        with self.connect() as connection:
            with connection.cursor() as cursor:
                sproc = self.test_nvarchar.__name__
                with self.stored_procedure(
                        cursor, sproc, '''
                        @pVarChar NVARCHAR(256),
                        @pVarCharOut NVARCHAR(256) OUTPUT
                    AS
                        SET @pVarCharOut = @pVarChar;
                    '''):

                    format_ = (unichr_(191) + unicode_(' 8 ') + unichr_(247) +
                               unicode_(' 2 = 4 ? {0} {1} {2}'))
                    snowman = unichr_(9731)

                    # Python must be built with UCS4 support to test the large codepoints.
                    catface = (unichr_(128568) if self.UCS4_SUPPORTED else
                               self.UNICODE_REPLACEMENT)
                    flower = (unichr_(127802) if self.UCS4_SUPPORTED else
                              self.UNICODE_REPLACEMENT)

                    # Older versions of SQL server don't support passing codepoints outside
                    # of the server's code page. SQL Server defaults to latin-1, so assume
                    # non-latin-1 codepoints won't be supported.
                    if not self.use_sp_executesql:  # pragma: nocover
                        catface = unicode_('?')
                        snowman = unicode_('?')
                        flower = unicode_('?')

                    inputs = (
                        format_.format(snowman, catface, flower),
                        ctds.Parameter(ctds.SqlVarChar(None, size=256),
                                       output=True),
                    )

                    # If the connection supports UTF-16, unicode codepoints outside of the UCS-2
                    # range are supported and not replaced by ctds.
                    if self.use_utf16:
                        outputs = cursor.callproc(sproc, inputs)
                        self.assertEqual(inputs[0], outputs[1])
                    else:  # pragma: nocover
                        # The catface is not representable in UCS-2, and therefore is replaced.
                        with warnings.catch_warnings(record=True) as warns:
                            outputs = cursor.callproc(sproc, inputs)
                            if ord(catface) > 2**16:
                                self.assertEqual(len(warns), 2)
                                msg = unicode_('Unicode codepoint U+{0:08X} is not representable in UCS-2; replaced with U+{1:04X}')  # pylint: disable=line-too-long
                                self.assertEqual(
                                    [str(warn.message) for warn in warns], [
                                        msg.format(
                                            ord(char),
                                            ord(self.UNICODE_REPLACEMENT))
                                        for char in (catface, flower)
                                    ])
                                self.assertEqual(warns[0].category,
                                                 ctds.Warning)
                            else:
                                self.assertEqual(len(warns),
                                                 0)  # pragma: nocover

                        self.assertEqual(
                            format_.format(
                                snowman, self.UNICODE_REPLACEMENT
                                if self.use_sp_executesql else unicode_('?'),
                                self.UNICODE_REPLACEMENT
                                if self.use_sp_executesql else unicode_('?')),
                            outputs[1])

                    self.assertEqual(id(inputs[0]), id(outputs[0]))
                    self.assertNotEqual(id(inputs[1]), id(outputs[1]))