def recv_plain(self): """Receive packets from the MySQL server""" packet = '' try: # Read the header of the MySQL packet, 4 bytes packet = self.sock.recv(1) while len(packet) < 4: chunk = self.sock.recv(1) if not chunk: raise errors.InterfaceError(errno=2013) packet += chunk # Save the packet number and total packet length from header self._packet_number = ord(packet[3]) packet_totlen = struct.unpack("<I", packet[0:3] + '\x00')[0] + 4 # Read the rest of the packet rest = packet_totlen - len(packet) while rest > 0: chunk = self.sock.recv(rest) if not chunk: raise errors.InterfaceError(errno=2013) packet += chunk rest = packet_totlen - len(packet) return packet except socket.timeout, err: raise errors.InterfaceError(errno=2013)
def fetchall(self): if not self._have_unread_result(): raise errors.InterfaceError("No result set to fetch from.") (rows, eof) = self._connection.get_rows() self._rowcount = len(rows) self._handle_eof(eof) return rows
class MySQLUnixSocket(BaseMySQLSocket): """MySQL socket class using UNIX sockets Opens a connection through the UNIX socket of the MySQL Server. """ def __init__(self, unix_socket='/tmp/mysql.sock'): super(MySQLUnixSocket, self).__init__() self._unix_socket = unix_socket def get_address(self): return self._unix_socket def open_connection(self): try: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.settimeout(self._connection_timeout) self.sock.connect(self._unix_socket) except socket.error, err: try: msg = err.errno if msg is None: msg = str(err) except AttributeError: msg = str(err) raise errors.InterfaceError(errno=2002, values=(self.get_address(), msg)) except StandardError, err: raise errors.InterfaceError('%s' % err)
class MySQLTCPSocket(BaseMySQLSocket): """MySQL socket class using TCP/IP Opens a TCP/IP connection to the MySQL Server. """ def __init__(self, host='127.0.0.1', port=3306): super(MySQLTCPSocket, self).__init__() self.server_host = host self.server_port = port def get_address(self): return "%s:%s" % (self.server_host, self.server_port) def open_connection(self): """Open the TCP/IP connection to the MySQL server """ # Detect address family. try: inet_pton(socket.AF_INET6, self.server_host.split('%')[0]) family = socket.AF_INET6 except (socket.error, socket.gaierror), err: family = socket.AF_INET try: (family, socktype, proto, canonname, sockaddr) = socket.getaddrinfo(self.server_host, self.server_port, family, socket.SOCK_STREAM)[0] self.sock = socket.socket(family, socktype, proto) self.sock.settimeout(self._connection_timeout) self.sock.connect(sockaddr) except socket.gaierror, err: raise errors.InterfaceError(errno=2003, values=(self.server_host, err[1]))
def parse_ok(self, packet): """Parse a MySQL OK-packet""" if not packet[4] == '\x00': raise errors.InterfaceError("Failed parsing OK packet.") ok = {} try: (packet, ok['field_count']) = utils.read_int(packet[4:], 1) (packet, ok['affected_rows']) = utils.read_lc_int(packet) (packet, ok['insert_id']) = utils.read_lc_int(packet) (packet, ok['server_status']) = utils.read_int(packet, 2) (packet, ok['warning_count']) = utils.read_int(packet, 2) if packet: (packet, ok['info_msg']) = utils.read_lc_string(packet) except ValueError: raise errors.InterfaceError("Failed parsing OK packet.") return ok
def fetchall(self): if self._rows is None: raise errors.InterfaceError("No result set to fetch from.") res = [] for row in self._rows: res.append(self._row_to_python(row)) self._next_row = len(self._rows) return res
def switch_to_ssl(self, ca, cert, key): """Switch the socket to use SSL""" if not self.sock: raise errors.InterfaceError(errno=2048) try: self.sock = ssl.wrap_socket(self.sock, keyfile=key, certfile=cert, ca_certs=ca, cert_reqs=ssl.CERT_NONE, do_handshake_on_connect=False, ssl_version=ssl.PROTOCOL_TLSv1) self.sock.do_handshake() except NameError: raise errors.NotSupportedError( "Python installation has no SSL support") except ssl.SSLError, err: raise errors.InterfaceError("SSL error: %s" % err)
def fetchall(self): if not self._have_unread_result(): raise errors.InterfaceError("No result set to fetch from.") res = [] (rows, eof) = self._connection.get_rows() self._rowcount = len(rows) for i in xrange(0, self.rowcount): res.append(self._row_to_python(rows[i])) self._handle_eof(eof) return res
def parse_eof(self, packet): """Parse a MySQL EOF-packet""" if not (packet[4] == '\xfe' and len(packet) <= 9): raise errors.InterfaceError("Failed parsing EOF packet.") res = {} packet = packet[5:] # disregard the first checking byte (packet, res['warning_count']) = utils.read_int(packet, 2) (packet, res['status_flag']) = utils.read_int(packet, 2) return res
def _row_to_python(self, rowdata, desc=None): res = () try: if not desc: desc = self.description for idx, v in enumerate(rowdata): flddsc = desc[idx] res += (self._connection.converter.to_python(flddsc, v), ) except StandardError, e: raise errors.InterfaceError( "Failed converting row to Python types; %s" % e)
def recv_compressed(self): """Receive compressed packets from the MySQL server""" import utils try: return self._packet_queue.popleft() except IndexError: pass header = '' packets = [] try: abyte = self.sock.recv(1) while abyte and len(header) < 7: header += abyte abyte = self.sock.recv(1) while header: if len(header) < 7: raise errors.InterfaceError(errno=2013) zip_payload_length = struct.unpack("<I", header[0:3] + '\x00')[0] payload_length = struct.unpack("<I", header[4:7] + '\x00')[0] zip_payload = abyte while len(zip_payload) < zip_payload_length: chunk = self.sock.recv(zip_payload_length - len(zip_payload)) if len(chunk) == 0: raise errors.InterfaceError(errno=2013) zip_payload = zip_payload + chunk if payload_length == 0: self._split_zipped_payload(zip_payload) return self._packet_queue.popleft() packets.append(header + zip_payload) if payload_length != 16384: break header = '' abyte = self.sock.recv(1) while abyte and len(header) < 7: header += abyte abyte = self.sock.recv(1) except socket.timeout, err: raise errors.InterfaceError(errno=2013)
def parse_statistics(self, packet): """Parse the statistics packet""" errmsg = "Failed getting COM_STATISTICS information" res = {} # Information is separated by 2 spaces pairs = packet[4:].split('\x20\x20') for pair in pairs: try: (lbl, val) = [ v.strip() for v in pair.split(':', 2) ] except: raise errors.InterfaceError(errmsg) # It's either an integer or a decimal try: res[lbl] = long(val) except: try: res[lbl] = Decimal(val) except: raise errors.InterfaceError( "%s (%s:%s)." % (errmsg, lbl, val)) return res
def _scramble_password(self, passwd, seed): """Scramble a password ready to send to MySQL""" hash4 = None try: hash1 = sha1(passwd).digest() hash2 = sha1(hash1).digest() # Password as found in mysql.user() hash3 = sha1(seed + hash2).digest() xored = [ utils.intread(h1) ^ utils.intread(h3) for (h1,h3) in zip(hash1, hash3) ] hash4 = struct.pack('20B', *xored) except Exception, err: raise errors.InterfaceError( 'Failed scrambling password; %s' % err)
def _handle_result(self, result): """ Handle the result after a command was send. The result can be either an OK-packet or a dictionary containing column/eof information. Raises InterfaceError when result is not a dict() or result is invalid. """ if not isinstance(result, dict): raise errors.InterfaceError('Result was not a dict()') if 'columns' in result: # Weak test, must be column/eof information self._description = result['columns'] self._connection.unread_result = True self._handle_resultset() elif 'affected_rows' in result: # Weak test, must be an OK-packet self._connection.unread_result = False self._handle_noresultset(result) else: raise errors.InterfaceError('Invalid result')
def executemany(self, operation, seq_params): """Execute the given operation multiple times The executemany() method will execute the operation iterating over the list of parameters in seq_params. Example: Inserting 3 new employees and their phone number data = [ ('Jane','555-001'), ('Joe', '555-001'), ('John', '555-003') ] stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')" cursor.executemany(stmt, data) INSERT statements are optimized by batching the data, that is using the MySQL multiple rows syntax. Results are discarded. If they are needed, consider looping over data using the execute() method. """ if not operation: return if self._have_unread_result(): raise errors.InternalError("Unread result found.") elif len(RE_SQL_SPLIT_STMTS.split(operation)) > 1: raise errors.InternalError( "executemany() does not support multiple statements") # Optimize INSERTs by batching them if re.match(RE_SQL_INSERT_STMT, operation): opnocom = re.sub(RE_SQL_COMMENT, '', operation) m = re.search(RE_SQL_INSERT_VALUES, opnocom) fmt = m.group(1) values = [] for params in seq_params: values.append(fmt % self._process_params(params)) operation = operation.replace(m.group(1), ','.join(values), 1) return self.execute(operation) rowcnt = 0 try: for params in seq_params: self.execute(operation, params) if self.with_rows and self._have_unread_result(): self.fetchall() rowcnt += self._rowcount except (ValueError, TypeError), err: raise errors.InterfaceError("Failed executing the operation; %s" % err)
def open_connection(self): try: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.settimeout(self._connection_timeout) self.sock.connect(self._unix_socket) except socket.error, err: try: msg = err.errno if msg is None: msg = str(err) except AttributeError: msg = str(err) raise errors.InterfaceError(errno=2002, values=(self.get_address(), msg))
def _fetch_warnings(self): """ Fetch warnings doing a SHOW WARNINGS. Can be called after getting the result. Returns a result set or None when there were no warnings. """ res = [] try: c = self._connection.cursor() cnt = c.execute("SHOW WARNINGS") res = c.fetchall() c.close() except StandardError, e: raise errors.InterfaceError, errors.InterfaceError( "Failed getting warnings; %s" % e), sys.exc_info()[2]
if not chunk: raise errors.InterfaceError(errno=2013) packet += chunk rest = packet_totlen - len(packet) return packet except socket.timeout, err: raise errors.InterfaceError(errno=2013) except socket.error, err: try: msg = err.errno if msg is None: msg = str(err) except AttributeError: msg = str(err) raise errors.InterfaceError(errno=2055, values=(self.get_address(), msg)) recv = recv_plain def _split_zipped_payload(self, packet_bunch): """Split compressed payload""" while packet_bunch: payload_length = struct.unpack("<I", packet_bunch[0:3] + '\x00')[0] self._packet_queue.append(packet_bunch[0:payload_length + 4]) packet_bunch = packet_bunch[payload_length + 4:] def recv_compressed(self): """Receive compressed packets from the MySQL server""" import utils try: return self._packet_queue.popleft()
def fetchall(self): if self._rows is None: raise errors.InterfaceError("No result set to fetch from.") return [r for r in self._rows]
def callproc(self, procname, args=()): """Calls a stored procedue with the given arguments The arguments will be set during this session, meaning they will be called like _<procname>__arg<nr> where <nr> is an enumeration (+1) of the arguments. Coding Example: 1) Definining the Stored Routine in MySQL: CREATE PROCEDURE multiply(IN pFac1 INT, IN pFac2 INT, OUT pProd INT) BEGIN SET pProd := pFac1 * pFac2; END 2) Executing in Python: args = (5,5,0) # 0 is to hold pprod cursor.callproc('multiply', args) print cursor.fetchone() The last print should output ('5', '5', 25L) Does not return a value, but a result set will be available when the CALL-statement execute successfully. Raises exceptions when something is wrong. """ argfmt = "@_%s_arg%d" self._stored_results = [] results = [] try: procargs = self._process_params(args) argnames = [] for idx, arg in enumerate(procargs): argname = argfmt % (procname, idx + 1) argnames.append(argname) setquery = "SET %s=%%s" % argname self.execute(setquery, (arg, )) call = "CALL %s(%s)" % (procname, ','.join(argnames)) for result in self._connection.cmd_query_iter(call): if 'columns' in result: tmp = MySQLCursorBuffered(self._connection._get_self()) tmp._handle_result(result) results.append(tmp) if argnames: select = "SELECT %s" % ','.join(argnames) self.execute(select) self._stored_results = results return self.fetchone() else: self._stored_results = results return () except errors.Error: raise except StandardError, e: raise errors.InterfaceError("Failed calling stored routine; %s" % e)
raise errors.ProgrammingError( "Wrong number of arguments during string formatting") else: stmt = operation if multi: self._executed = stmt self._executed_list = [] return self._execute_iter(self._connection.cmd_query_iter(stmt)) else: self._executed = stmt try: self._handle_result(self._connection.cmd_query(stmt)) except errors.InterfaceError, err: if self._connection._have_next_result: raise errors.InterfaceError( "Use multi=True when executing multiple statements") raise return None def executemany(self, operation, seq_params): """Execute the given operation multiple times The executemany() method will execute the operation iterating over the list of parameters in seq_params. Example: Inserting 3 new employees and their phone number data = [ ('Jane','555-001'), ('Joe', '555-001'), ('John', '555-003')
def _set_connection(self, connection): try: self._connection = weakref.proxy(connection) self._connection._protocol except (AttributeError, TypeError): raise errors.InterfaceError(errno=2048)