コード例 #1
0
    def recv_plain(self):
        """Receive packets from the MySQL server"""
        packet = ''
        try:
            # Read the header of the MySQL packet, 4 bytes
            packet = self.sock.recv(1)
            while len(packet) < 4:
                chunk = self.sock.recv(1)
                if not chunk:
                    raise errors.InterfaceError(errno=2013)
                packet += chunk

            # Save the packet number and total packet length from header
            self._packet_number = ord(packet[3])
            packet_totlen = struct.unpack("<I", packet[0:3] + '\x00')[0] + 4

            # Read the rest of the packet
            rest = packet_totlen - len(packet)
            while rest > 0:
                chunk = self.sock.recv(rest)
                if not chunk:
                    raise errors.InterfaceError(errno=2013)
                packet += chunk
                rest = packet_totlen - len(packet)

            return packet
        except socket.timeout, err:
            raise errors.InterfaceError(errno=2013)
コード例 #2
0
 def fetchall(self):
     if not self._have_unread_result():
         raise errors.InterfaceError("No result set to fetch from.")
     (rows, eof) = self._connection.get_rows()
     self._rowcount = len(rows)
     self._handle_eof(eof)
     return rows
コード例 #3
0
class MySQLUnixSocket(BaseMySQLSocket):
    """MySQL socket class using UNIX sockets

    Opens a connection through the UNIX socket of the MySQL Server.
    """
    def __init__(self, unix_socket='/tmp/mysql.sock'):
        super(MySQLUnixSocket, self).__init__()
        self._unix_socket = unix_socket

    def get_address(self):
        return self._unix_socket

    def open_connection(self):
        try:
            self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            self.sock.settimeout(self._connection_timeout)
            self.sock.connect(self._unix_socket)
        except socket.error, err:
            try:
                msg = err.errno
                if msg is None:
                    msg = str(err)
            except AttributeError:
                msg = str(err)
            raise errors.InterfaceError(errno=2002,
                                        values=(self.get_address(), msg))
        except StandardError, err:
            raise errors.InterfaceError('%s' % err)
コード例 #4
0
class MySQLTCPSocket(BaseMySQLSocket):
    """MySQL socket class using TCP/IP

    Opens a TCP/IP connection to the MySQL Server.
    """
    def __init__(self, host='127.0.0.1', port=3306):
        super(MySQLTCPSocket, self).__init__()
        self.server_host = host
        self.server_port = port

    def get_address(self):
        return "%s:%s" % (self.server_host, self.server_port)

    def open_connection(self):
        """Open the TCP/IP connection to the MySQL server
        """
        # Detect address family.
        try:
            inet_pton(socket.AF_INET6, self.server_host.split('%')[0])
            family = socket.AF_INET6
        except (socket.error, socket.gaierror), err:
            family = socket.AF_INET
        try:
            (family, socktype, proto, canonname,
             sockaddr) = socket.getaddrinfo(self.server_host, self.server_port,
                                            family, socket.SOCK_STREAM)[0]
            self.sock = socket.socket(family, socktype, proto)
            self.sock.settimeout(self._connection_timeout)
            self.sock.connect(sockaddr)
        except socket.gaierror, err:
            raise errors.InterfaceError(errno=2003,
                                        values=(self.server_host, err[1]))
コード例 #5
0
    def parse_ok(self, packet):
        """Parse a MySQL OK-packet"""
        if not packet[4] == '\x00':
            raise errors.InterfaceError("Failed parsing OK packet.")

        ok = {}
        try:
            (packet, ok['field_count']) = utils.read_int(packet[4:], 1)
            (packet, ok['affected_rows']) = utils.read_lc_int(packet)
            (packet, ok['insert_id']) = utils.read_lc_int(packet)
            (packet, ok['server_status']) = utils.read_int(packet, 2)
            (packet, ok['warning_count']) = utils.read_int(packet, 2)
            if packet:
                (packet, ok['info_msg']) = utils.read_lc_string(packet)
        except ValueError:
            raise errors.InterfaceError("Failed parsing OK packet.")
        return ok
コード例 #6
0
 def fetchall(self):
     if self._rows is None:
         raise errors.InterfaceError("No result set to fetch from.")
     res = []
     for row in self._rows:
         res.append(self._row_to_python(row))
     self._next_row = len(self._rows)
     return res
コード例 #7
0
    def switch_to_ssl(self, ca, cert, key):
        """Switch the socket to use SSL"""
        if not self.sock:
            raise errors.InterfaceError(errno=2048)

        try:
            self.sock = ssl.wrap_socket(self.sock,
                                        keyfile=key,
                                        certfile=cert,
                                        ca_certs=ca,
                                        cert_reqs=ssl.CERT_NONE,
                                        do_handshake_on_connect=False,
                                        ssl_version=ssl.PROTOCOL_TLSv1)
            self.sock.do_handshake()
        except NameError:
            raise errors.NotSupportedError(
                "Python installation has no SSL support")
        except ssl.SSLError, err:
            raise errors.InterfaceError("SSL error: %s" % err)
コード例 #8
0
 def fetchall(self):
     if not self._have_unread_result():
         raise errors.InterfaceError("No result set to fetch from.")
     res = []
     (rows, eof) = self._connection.get_rows()
     self._rowcount = len(rows)
     for i in xrange(0, self.rowcount):
         res.append(self._row_to_python(rows[i]))
     self._handle_eof(eof)
     return res
コード例 #9
0
    def parse_eof(self, packet):
        """Parse a MySQL EOF-packet"""
        if not (packet[4] == '\xfe' and len(packet) <= 9):
            raise errors.InterfaceError("Failed parsing EOF packet.")

        res = {}
        packet = packet[5:] # disregard the first checking byte
        (packet, res['warning_count']) = utils.read_int(packet, 2)
        (packet, res['status_flag']) = utils.read_int(packet, 2)
        return res
コード例 #10
0
 def _row_to_python(self, rowdata, desc=None):
     res = ()
     try:
         if not desc:
             desc = self.description
         for idx, v in enumerate(rowdata):
             flddsc = desc[idx]
             res += (self._connection.converter.to_python(flddsc, v), )
     except StandardError, e:
         raise errors.InterfaceError(
             "Failed converting row to Python types; %s" % e)
コード例 #11
0
    def recv_compressed(self):
        """Receive compressed packets from the MySQL server"""
        import utils
        try:
            return self._packet_queue.popleft()
        except IndexError:
            pass

        header = ''
        packets = []
        try:
            abyte = self.sock.recv(1)
            while abyte and len(header) < 7:
                header += abyte
                abyte = self.sock.recv(1)
            while header:
                if len(header) < 7:
                    raise errors.InterfaceError(errno=2013)
                zip_payload_length = struct.unpack("<I",
                                                   header[0:3] + '\x00')[0]
                payload_length = struct.unpack("<I", header[4:7] + '\x00')[0]
                zip_payload = abyte
                while len(zip_payload) < zip_payload_length:
                    chunk = self.sock.recv(zip_payload_length -
                                           len(zip_payload))
                    if len(chunk) == 0:
                        raise errors.InterfaceError(errno=2013)
                    zip_payload = zip_payload + chunk
                if payload_length == 0:
                    self._split_zipped_payload(zip_payload)
                    return self._packet_queue.popleft()
                packets.append(header + zip_payload)
                if payload_length != 16384:
                    break
                header = ''
                abyte = self.sock.recv(1)
                while abyte and len(header) < 7:
                    header += abyte
                    abyte = self.sock.recv(1)
        except socket.timeout, err:
            raise errors.InterfaceError(errno=2013)
コード例 #12
0
    def parse_statistics(self, packet):
        """Parse the statistics packet"""
        errmsg = "Failed getting COM_STATISTICS information"
        res = {}
         # Information is separated by 2 spaces
        pairs = packet[4:].split('\x20\x20')
        for pair in pairs:
            try:
                (lbl, val) = [ v.strip() for v in pair.split(':', 2) ]
            except:
                raise errors.InterfaceError(errmsg)

            # It's either an integer or a decimal
            try:
                res[lbl] = long(val)
            except:
                try:
                    res[lbl] = Decimal(val)
                except:
                    raise errors.InterfaceError(
                        "%s (%s:%s)." % (errmsg, lbl, val))
        return res
コード例 #13
0
 def _scramble_password(self, passwd, seed):
     """Scramble a password ready to send to MySQL"""
     hash4 = None
     try:
         hash1 = sha1(passwd).digest()
         hash2 = sha1(hash1).digest() # Password as found in mysql.user()
         hash3 = sha1(seed + hash2).digest()
         xored = [ utils.intread(h1) ^ utils.intread(h3)
             for (h1,h3) in zip(hash1, hash3) ]
         hash4 = struct.pack('20B', *xored)
     except Exception, err:
         raise errors.InterfaceError(
             'Failed scrambling password; %s' % err)
コード例 #14
0
    def _handle_result(self, result):
        """
        Handle the result after a command was send. The result can be either
        an OK-packet or a dictionary containing column/eof information.
        
        Raises InterfaceError when result is not a dict() or result is
        invalid.
        """
        if not isinstance(result, dict):
            raise errors.InterfaceError('Result was not a dict()')

        if 'columns' in result:
            # Weak test, must be column/eof information
            self._description = result['columns']
            self._connection.unread_result = True
            self._handle_resultset()
        elif 'affected_rows' in result:
            # Weak test, must be an OK-packet
            self._connection.unread_result = False
            self._handle_noresultset(result)
        else:
            raise errors.InterfaceError('Invalid result')
コード例 #15
0
    def executemany(self, operation, seq_params):
        """Execute the given operation multiple times
        
        The executemany() method will execute the operation iterating
        over the list of parameters in seq_params.
        
        Example: Inserting 3 new employees and their phone number
        
        data = [
            ('Jane','555-001'),
            ('Joe', '555-001'),
            ('John', '555-003')
            ]
        stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')"
        cursor.executemany(stmt, data)
        
        INSERT statements are optimized by batching the data, that is
        using the MySQL multiple rows syntax.
        
        Results are discarded. If they are needed, consider looping over
        data using the execute() method.
        """
        if not operation:
            return
        if self._have_unread_result():
            raise errors.InternalError("Unread result found.")
        elif len(RE_SQL_SPLIT_STMTS.split(operation)) > 1:
            raise errors.InternalError(
                "executemany() does not support multiple statements")

        # Optimize INSERTs by batching them
        if re.match(RE_SQL_INSERT_STMT, operation):
            opnocom = re.sub(RE_SQL_COMMENT, '', operation)
            m = re.search(RE_SQL_INSERT_VALUES, opnocom)
            fmt = m.group(1)
            values = []
            for params in seq_params:
                values.append(fmt % self._process_params(params))
            operation = operation.replace(m.group(1), ','.join(values), 1)
            return self.execute(operation)

        rowcnt = 0
        try:
            for params in seq_params:
                self.execute(operation, params)
                if self.with_rows and self._have_unread_result():
                    self.fetchall()
                rowcnt += self._rowcount
        except (ValueError, TypeError), err:
            raise errors.InterfaceError("Failed executing the operation; %s" %
                                        err)
コード例 #16
0
 def open_connection(self):
     try:
         self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
         self.sock.settimeout(self._connection_timeout)
         self.sock.connect(self._unix_socket)
     except socket.error, err:
         try:
             msg = err.errno
             if msg is None:
                 msg = str(err)
         except AttributeError:
             msg = str(err)
         raise errors.InterfaceError(errno=2002,
                                     values=(self.get_address(), msg))
コード例 #17
0
    def _fetch_warnings(self):
        """
        Fetch warnings doing a SHOW WARNINGS. Can be called after getting
        the result.

        Returns a result set or None when there were no warnings.
        """
        res = []
        try:
            c = self._connection.cursor()
            cnt = c.execute("SHOW WARNINGS")
            res = c.fetchall()
            c.close()
        except StandardError, e:
            raise errors.InterfaceError, errors.InterfaceError(
                "Failed getting warnings; %s" % e), sys.exc_info()[2]
コード例 #18
0
                if not chunk:
                    raise errors.InterfaceError(errno=2013)
                packet += chunk
                rest = packet_totlen - len(packet)

            return packet
        except socket.timeout, err:
            raise errors.InterfaceError(errno=2013)
        except socket.error, err:
            try:
                msg = err.errno
                if msg is None:
                    msg = str(err)
            except AttributeError:
                msg = str(err)
            raise errors.InterfaceError(errno=2055,
                                        values=(self.get_address(), msg))

    recv = recv_plain

    def _split_zipped_payload(self, packet_bunch):
        """Split compressed payload"""
        while packet_bunch:
            payload_length = struct.unpack("<I", packet_bunch[0:3] + '\x00')[0]
            self._packet_queue.append(packet_bunch[0:payload_length + 4])
            packet_bunch = packet_bunch[payload_length + 4:]

    def recv_compressed(self):
        """Receive compressed packets from the MySQL server"""
        import utils
        try:
            return self._packet_queue.popleft()
コード例 #19
0
 def fetchall(self):
     if self._rows is None:
         raise errors.InterfaceError("No result set to fetch from.")
     return [r for r in self._rows]
コード例 #20
0
    def callproc(self, procname, args=()):
        """Calls a stored procedue with the given arguments

        The arguments will be set during this session, meaning
        they will be called like  _<procname>__arg<nr> where
        <nr> is an enumeration (+1) of the arguments.

        Coding Example:
          1) Definining the Stored Routine in MySQL:
          CREATE PROCEDURE multiply(IN pFac1 INT, IN pFac2 INT, OUT pProd INT)
          BEGIN
            SET pProd := pFac1 * pFac2;
          END

          2) Executing in Python:
          args = (5,5,0) # 0 is to hold pprod
          cursor.callproc('multiply', args)
          print cursor.fetchone()

          The last print should output ('5', '5', 25L)

        Does not return a value, but a result set will be
        available when the CALL-statement execute successfully.
        Raises exceptions when something is wrong.
        """
        argfmt = "@_%s_arg%d"
        self._stored_results = []

        results = []
        try:
            procargs = self._process_params(args)
            argnames = []

            for idx, arg in enumerate(procargs):
                argname = argfmt % (procname, idx + 1)
                argnames.append(argname)
                setquery = "SET %s=%%s" % argname
                self.execute(setquery, (arg, ))

            call = "CALL %s(%s)" % (procname, ','.join(argnames))

            for result in self._connection.cmd_query_iter(call):
                if 'columns' in result:
                    tmp = MySQLCursorBuffered(self._connection._get_self())
                    tmp._handle_result(result)
                    results.append(tmp)

            if argnames:
                select = "SELECT %s" % ','.join(argnames)
                self.execute(select)
                self._stored_results = results
                return self.fetchone()
            else:
                self._stored_results = results
                return ()

        except errors.Error:
            raise
        except StandardError, e:
            raise errors.InterfaceError("Failed calling stored routine; %s" %
                                        e)
コード例 #21
0
                raise errors.ProgrammingError(
                    "Wrong number of arguments during string formatting")
        else:
            stmt = operation

        if multi:
            self._executed = stmt
            self._executed_list = []
            return self._execute_iter(self._connection.cmd_query_iter(stmt))
        else:
            self._executed = stmt
            try:
                self._handle_result(self._connection.cmd_query(stmt))
            except errors.InterfaceError, err:
                if self._connection._have_next_result:
                    raise errors.InterfaceError(
                        "Use multi=True when executing multiple statements")
                raise
            return None

    def executemany(self, operation, seq_params):
        """Execute the given operation multiple times
        
        The executemany() method will execute the operation iterating
        over the list of parameters in seq_params.
        
        Example: Inserting 3 new employees and their phone number
        
        data = [
            ('Jane','555-001'),
            ('Joe', '555-001'),
            ('John', '555-003')
コード例 #22
0
 def _set_connection(self, connection):
     try:
         self._connection = weakref.proxy(connection)
         self._connection._protocol
     except (AttributeError, TypeError):
         raise errors.InterfaceError(errno=2048)