Пример #1
0
 async def _read_bytes(self, num_bytes):
     try:
         data = await self._reader.readexactly(num_bytes)
     except asyncio.streams.IncompleteReadError as e:
         msg = "Lost connection to MySQL server during query"
         raise OperationalError(2013, msg) from e
     except (IOError, OSError) as e:
         msg = "Lost connection to MySQL server during query (%s)" % (e, )
         raise OperationalError(2013, msg) from e
     return data
Пример #2
0
 def _write_bytes(self, data):
     try:
         return self._writer.write(data)
     except RuntimeError as e:
         self.close()
         msg = "Lost connection to MySQL server during query (%s)" % (e, )
         raise OperationalError(2013, msg) from e
    async def _process_auth(self, plugin_name, auth_packet):
        if plugin_name == b"mysql_native_password":
            # https://dev.mysql.com/doc/internals/en/
            # secure-password-authentication.html#packet-Authentication::
            # Native41
            data = _scramble(self._password.encode('latin1'),
                             auth_packet.read_all())
        elif plugin_name == b"mysql_old_password":
            # https://dev.mysql.com/doc/internals/en/
            # old-password-authentication.html
            data = _scramble_323(self._password.encode('latin1'),
                                 auth_packet.read_all()) + b'\0'
        elif plugin_name == b"mysql_clear_password":
            # https://dev.mysql.com/doc/internals/en/
            # clear-text-authentication.html
            data = self._password.encode('latin1') + b'\0'
        else:
            raise OperationalError(
                2059, "Authentication plugin '%s' not configured" % plugin_name
            )

        self.write_packet(data)
        pkt = await self._read_packet()
        pkt.check_error()

        self._auth_plugin_used = plugin_name

        return pkt
Пример #4
0
    def _read_packet(self, packet_type=MysqlPacket):
        """Read an entire "mysql packet" in its entirety from the network
        and return a MysqlPacket type that represents the results.
        """
        buff = b''
        try:
            while True:
                packet_header = yield from self._reader.readexactly(4)
                btrl, btrh, packet_number = struct.unpack(
                    '<HBB', packet_header)
                bytes_to_read = btrl + (btrh << 16)

                # Outbound and inbound packets are numbered sequentialy, so
                # we increment in both write_packet and read_packet. The count
                # is reset at new COMMAND PHASE.
                if packet_number != self._next_seq_id:
                    raise InternalError(
                        "Packet sequence number wrong - got %d expected %d" %
                        (packet_number, self._next_seq_id))
                self._next_seq_id = (self._next_seq_id + 1) % 256

                recv_data = yield from self._reader.readexactly(bytes_to_read)
                buff += recv_data
                if bytes_to_read < MAX_PACKET_LEN:
                    break
        except (OSError, EOFError) as exc:
            msg = "MySQL server has gone away (%s)"
            raise OperationalError(2006, msg % (exc, )) from exc
        packet = packet_type(buff, self._encoding)
        packet.check_error()
        return packet
Пример #5
0
 def opener(filename):
     try:
         self._file_object = open(filename, 'rb')
     except IOError:
         raise OperationalError(
             1017, "Can't find file"
             " '{0}'".format(filename))
Пример #6
0
 async def _read_ok_packet(self):
     pkt = await self._read_packet()
     if not pkt.is_ok_packet():
         raise OperationalError(2014, "Command Out of Sync")
     ok = OKPacketWrapper(pkt)
     self.server_status = ok.server_status
     return True
Пример #7
0
    def post_auth_routine(self):
        """
        Anything that was initialized in a PyMySQL connection
        after a successful authentication
        """
        try:
            if self.sql_mode is not None:
                c = self.cursor()
                c.execute("SET sql_mode=%s", (self.sql_mode, ))

            if self.init_command is not None:
                c = self.cursor()
                c.execute(self.init_command)
                self.commit()

            if self.autocommit_mode is not None:
                self.autocommit(self.autocommit_mode)
        except Exception as e:
            self._rfile = None
            if sock is not None:
                try:
                    sock.close()
                except socket.error:
                    pass
            raise OperationalError(
                2003,
                "Can't connect to MySQL server on %r (%s)" % (self.host, e))
Пример #8
0
    def _read_packet(self, packet_type=MysqlPacket):
        """Read an entire "mysql packet" in its entirety from the network
        and return a MysqlPacket type that represents the results.
        """
        buff = b''
        try:
            while True:
                packet_header = yield from self._reader.readexactly(4)
                # logger.debug(_convert_to_str(packet_header))
                packet_length_bin = packet_header[:3]

                # TODO: check sequence id
                #  packet_number
                byte2int(packet_header[3])
                # pad little-endian number
                bin_length = packet_length_bin + b'\0'
                bytes_to_read = struct.unpack('<I', bin_length)[0]
                recv_data = yield from self._reader.readexactly(bytes_to_read)
                # logger.debug(dump_packet(recv_data))
                buff += recv_data
                if bytes_to_read < MAX_PACKET_LEN:
                    break
        except (OSError, EOFError) as exc:
            msg = "MySQL server has gone away (%s)"
            raise OperationalError(2006, msg % (exc, )) from exc
        packet = packet_type(buff, self._encoding)
        packet.check_error()
        return packet
Пример #9
0
def test_error_on_file_read(cursor, table_local_file):

    with patch.object(builtins, 'open') as open_mocked:
        m = MagicMock()
        m.read.side_effect = OperationalError(1024, 'Error reading file')
        m.close.return_value = None
        open_mocked.return_value = m

        with pytest.raises(OperationalError):
            yield from cursor.execute("LOAD DATA LOCAL INFILE 'some.txt'"
                                      " INTO TABLE test_load_local fields "
                                      "terminated by ','")
Пример #10
0
    async def _connect(self):
        # TODO: Set close callback
        # raise OperationalError(2006,
        # "MySQL server has gone away (%r)" % (e,))
        try:
            if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
                self._reader, self._writer = await \
                    asyncio.wait_for(
                        asyncio.open_unix_connection(
                            self._unix_socket,
                            loop=self._loop),
                        timeout=self.connect_timeout)
                self.host_info = "Localhost via UNIX socket: " + \
                                 self._unix_socket
            else:
                self._reader, self._writer = await \
                    asyncio.wait_for(
                        asyncio.open_connection(
                            self._host,
                            self._port,
                            loop=self._loop),
                        timeout=self.connect_timeout)
                self._set_keep_alive()
                self.host_info = "socket %s:%d" % (self._host, self._port)

            # do not set no delay in case of unix_socket
            if self._no_delay and not self._unix_socket:
                self._set_nodelay(True)

            self._next_seq_id = 0

            await self._get_server_information()
            await self._request_authentication()

            self.connected_time = self._loop.time()

            if self.sql_mode is not None:
                await self.query("SET sql_mode=%s" % (self.sql_mode,))

            if self.init_command is not None:
                await self.query(self.init_command)
                await self.commit()

            if self.autocommit_mode is not None:
                await self.autocommit(self.autocommit_mode)
        except Exception as e:
            if self._writer:
                self._writer.transport.close()
            self._reader = None
            self._writer = None
            raise OperationalError(2003,
                                   "Can't connect to MySQL server on %r" %
                                   self._host) from e
Пример #11
0
    def _read_load_local_packet(self, first_packet):
        load_packet = LoadLocalPacketWrapper(first_packet)
        local_packet = LoadLocalFile(load_packet.filename, self.connection)
        self.filename = load_packet.filename
        yield from local_packet.send_data()

        ok_packet = yield from self.connection._read_packet()
        if not ok_packet.is_ok_packet():
            raise OperationalError(2014, "Commands Out of Sync")
        self._read_ok_packet(ok_packet)

        if self.warning_count > 0:
            yield from self._print_warnings()
        self.filename = None
Пример #12
0
        def freader(chunk_size):
            try:
                chunk = self._file_object.read(chunk_size)

                if not chunk:
                    self._file_object.close()
                    self._file_object = None

            except Exception as e:
                self._file_object.close()
                self._file_object = None
                msg = "Error reading file {}".format(self.filename)
                raise OperationalError(1024, msg) from e
            return chunk
Пример #13
0
    async def _read_load_local_packet(self, first_packet):
        load_packet = LoadLocalPacketWrapper(first_packet)
        sender = LoadLocalFile(load_packet.filename, self.connection)
        try:
            await sender.send_data()
        except Exception:
            # Skip ok packet
            await self.connection._read_packet()
            raise

        ok_packet = await self.connection._read_packet()
        if not ok_packet.is_ok_packet():
            raise OperationalError(2014, "Commands Out of Sync")
        self._read_ok_packet(ok_packet)
Пример #14
0
    async def _process_auth(self, plugin_name, auth_packet):
        # These auth plugins do their own packet handling
        if plugin_name == b"caching_sha2_password":
            await self.caching_sha2_password_auth(auth_packet)
            self._auth_plugin_used = plugin_name.decode()
        elif plugin_name == b"sha256_password":
            await self.sha256_password_auth(auth_packet)
            self._auth_plugin_used = plugin_name.decode()
        else:

            if plugin_name == b"mysql_native_password":
                # https://dev.mysql.com/doc/internals/en/
                # secure-password-authentication.html#packet-Authentication::
                # Native41
                data = _auth.scramble_native_password(
                    self._password.encode('latin1'),
                    auth_packet.read_all())
            elif plugin_name == b"mysql_old_password":
                # https://dev.mysql.com/doc/internals/en/
                # old-password-authentication.html
                data = _auth.scramble_old_password(
                    self._password.encode('latin1'),
                    auth_packet.read_all()
                ) + b'\0'
            elif plugin_name == b"mysql_clear_password":
                # https://dev.mysql.com/doc/internals/en/
                # clear-text-authentication.html
                data = self._password.encode('latin1') + b'\0'
            else:
                raise OperationalError(
                    2059, "Authentication plugin '{0}'"
                          " not configured".format(plugin_name)
                )

            self.write_packet(data)
            pkt = await self._read_packet()
            pkt.check_error()

            self._auth_plugin_used = plugin_name.decode()

            return pkt
Пример #15
0
    async def sha256_password_auth(self, pkt):
        if self._ssl_context:
            logger.debug("sha256: Sending plain password")
            data = self._password.encode('latin1') + b'\0'
            self.write_packet(data)
            pkt = await self._read_packet()
            pkt.check_error()
            return pkt

        if pkt.is_auth_switch_request():
            self.salt = pkt.read_all()
            if not self.server_public_key and self._password:
                # Request server public key
                logger.debug("sha256: Requesting server public key")
                self.write_packet(b'\1')
                pkt = await self._read_packet()
                pkt.check_error()

        if pkt.is_extra_auth_data():
            self.server_public_key = pkt._data[1:]
            logger.debug(
                "Received public key:\n",
                self.server_public_key.decode('ascii')
            )

        if self._password:
            if not self.server_public_key:
                raise OperationalError("Couldn't receive server's public key")

            data = _auth.sha2_rsa_encrypt(
                self._password.encode('latin1'), self.salt,
                self.server_public_key
            )
        else:
            data = b''

        self.write_packet(data)
        pkt = await self._read_packet()
        pkt.check_error()
        return pkt
Пример #16
0
    def _connect(self, **kwargs):
        """
        Filthy shim to stop a full handshake from the actual
        pymysql library so that we can intercept the connection
        and grab the salt.
        """
        sock = None
        try:
            if self.unix_socket and self.host in ('localhost', '127.0.0.1'):
                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                t = sock.gettimeout()
                sock.settimeout(self.connect_timeout)
                sock.connect(self.unix_socket)
                sock.settimeout(t)
                self.host_info = "Localhost via UNIX socket"
                if DEBUG: print('connected using unix_socket')
            else:
                sock = socket.create_connection((self.host, self.port),
                                                self.connect_timeout)
                self.host_info = "socket %s:%d" % (self.host, self.port)
                if DEBUG: print('connected using socket')
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
            if self.no_delay:
                sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
            self.socket = sock
            self._rfile = _makefile(sock, 'rb')
            self._get_server_information()

            # we'll be doing this ourselves
            #self._request_authentication()
        except Exception as e:
            self._rfile = None
            if sock is not None:
                try:
                    sock.close()
                except socket.error:
                    pass
            raise OperationalError(
                2003,
                "Can't connect to MySQL server on %r (%s)" % (self.host, e))
Пример #17
0
    def connect(self):
        # TODO: Set close callback
        # raise OperationalError(2006,
        # "MySQL server has gone away (%r)" % (e,))
        try:
            if self._unix_socket and self._host in ('localhost', '127.0.0.1'):
                self._reader, self._writer = yield from \
                    asyncio.open_unix_connection(self._unix_socket,
                                                 loop=self._loop)
                self.host_info = "Localhost via UNIX socket: " + \
                                 self._unix_socket
            else:
                self._reader, self._writer = yield from \
                    asyncio.open_connection(self._host, self._port,
                                            loop=self._loop)
                self.host_info = "socket %s:%d" % (self._host, self._port)

            if self._no_delay:
                self._set_nodelay(True)

            yield from self._get_server_information()
            yield from self._request_authentication()

            self.connected_time = self._loop.time()

            if self.sql_mode is not None:
                yield from self.query("SET sql_mode=%s" % (self.sql_mode, ))

            if self.init_command is not None:
                yield from self.query(self.init_command)
                yield from self.commit()

            if self.autocommit_mode is not None:
                yield from self.autocommit(self.autocommit_mode)
        except OSError as e:
            self._reader, self._writer = None, None
            raise OperationalError(
                2003,
                "Can't connect to MySQL server on %r (%s)" % (self._host, e))
Пример #18
0
 def _read_packet(self, packet_type=MysqlPacket):
     """Read an entire "mysql packet" in its entirety from the network
     and return a MysqlPacket type that represents the results.
     """
     buff = b''
     try:
         while True:
             packet_header = yield from self._reader.readexactly(4)
             btrl, btrh, packet_number = struct.unpack(
                 '<HBB', packet_header)
             bytes_to_read = btrl + (btrh << 16)
             # TODO: check sequence id
             recv_data = yield from self._reader.readexactly(bytes_to_read)
             buff += recv_data
             if bytes_to_read < MAX_PACKET_LEN:
                 break
     except (OSError, EOFError) as exc:
         msg = "MySQL server has gone away (%s)"
         raise OperationalError(2006, msg % (exc, )) from exc
     packet = packet_type(buff, self._encoding)
     packet.check_error()
     return packet
Пример #19
0
    async def caching_sha2_password_auth(self, pkt):
        # No password fast path
        if not self._password:
            self.write_packet(b'')
            pkt = await self._read_packet()
            pkt.check_error()
            return pkt

        if pkt.is_auth_switch_request():
            # Try from fast auth
            logger.debug("caching sha2: Trying fast path")
            self.salt = pkt.read_all()
            scrambled = _auth.scramble_caching_sha2(
                self._password.encode('latin1'), self.salt)

            self.write_packet(scrambled)
            pkt = await self._read_packet()
            pkt.check_error()

        # else: fast auth is tried in initial handshake

        if not pkt.is_extra_auth_data():
            raise OperationalError("caching sha2: Unknown packet "
                                   "for fast auth: {0}".format(pkt._data[:1]))

        # magic numbers:
        # 2 - request public key
        # 3 - fast auth succeeded
        # 4 - need full auth

        pkt.advance(1)
        n = pkt.read_uint8()

        if n == 3:
            logger.debug("caching sha2: succeeded by fast path.")
            pkt = await self._read_packet()
            pkt.check_error()  # pkt must be OK packet
            return pkt

        if n != 4:
            raise OperationalError("caching sha2: Unknown "
                                   "result for fast auth: {0}".format(n))

        logger.debug("caching sha2: Trying full auth...")

        if self._ssl_context:
            logger.debug("caching sha2: Sending plain "
                         "password via secure connection")
            self.write_packet(self._password.encode('latin1') + b'\0')
            pkt = await self._read_packet()
            pkt.check_error()
            return pkt

        if not self.server_public_key:
            self.write_packet(b'\x02')
            pkt = await self._read_packet()  # Request public key
            pkt.check_error()

            if not pkt.is_extra_auth_data():
                raise OperationalError("caching sha2: Unknown packet "
                                       "for public key: {0}".format(
                                           pkt._data[:1]))

            self.server_public_key = pkt._data[1:]
            logger.debug(self.server_public_key.decode('ascii'))

        data = _auth.sha2_rsa_encrypt(self._password.encode('latin1'),
                                      self.salt, self.server_public_key)
        self.write_packet(data)
        pkt = await self._read_packet()
        pkt.check_error()
Пример #20
0
    async def _request_authentication(self):
        # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
        if self.server_version.split('.', 1)[0] == 'Inception2':
            self.client_flag |= CLIENT.MULTI_RESULTS
        elif int(self.server_version.split('.', 1)[0]) >= 5:
            self.client_flag |= CLIENT.MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        if self._ssl_context:
            # capablities, max packet, charset
            data = struct.pack('<IIB', self.client_flag, 16777216, 33)
            data += b'\x00' * (32 - len(data))

            self.write_packet(data)

            # Stop sending events to data_received
            self._writer.transport.pause_reading()

            # Get the raw socket from the transport
            raw_sock = self._writer.transport.get_extra_info('socket',
                                                             default=None)
            if raw_sock is None:
                raise RuntimeError("Transport does not expose socket instance")

            raw_sock = raw_sock.dup()
            self._writer.transport.close()
            # MySQL expects TLS negotiation to happen in the middle of a
            # TCP connection not at start. Passing in a socket to
            # open_connection will cause it to negotiate TLS on an existing
            # connection not initiate a new one.
            self._reader, self._writer = await asyncio.open_connection(
                sock=raw_sock,
                ssl=self._ssl_context,
                loop=self._loop,
                server_hostname=self._host)

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, str):
            _user = self.user.encode(self.encoding)
        else:
            _user = self.user

        data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
                                charset_id, b'')

        data = data_init + _user + b'\0'

        authresp = b''

        auth_plugin = self._client_auth_plugin
        if not self._client_auth_plugin:
            # Contains the auth plugin from handshake
            auth_plugin = self._server_auth_plugin

        if auth_plugin in ('', 'mysql_native_password'):
            authresp = _auth.scramble_native_password(
                self._password.encode('latin1'), self.salt)
        elif auth_plugin == 'caching_sha2_password':
            if self._password:
                authresp = _auth.scramble_caching_sha2(
                    self._password.encode('latin1'), self.salt)
            # Else: empty password
        elif auth_plugin == 'sha256_password':
            if self._ssl_context and self.server_capabilities & CLIENT.SSL:
                authresp = self._password.encode('latin1') + b'\0'
            elif self._password:
                authresp = b'\1'  # request public key
            else:
                authresp = b'\0'  # empty password

        elif auth_plugin in ('', 'mysql_clear_password'):
            authresp = self._password.encode('latin1') + b'\0'

        if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
            data += lenenc_int(len(authresp)) + authresp
        elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
            data += struct.pack('B', len(authresp)) + authresp
        else:  # pragma: no cover
            # not testing against servers without secure auth (>=5.0)
            data += authresp + b'\0'

        if self._db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:

            if isinstance(self._db, str):
                db = self._db.encode(self.encoding)
            else:
                db = self._db
            data += db + b'\0'

        if self.server_capabilities & CLIENT.PLUGIN_AUTH:
            name = auth_plugin
            if isinstance(name, str):
                name = name.encode('ascii')
            data += name + b'\0'

        self._auth_plugin_used = auth_plugin

        # Sends the server a few pieces of client info
        if self.server_capabilities & CLIENT.CONNECT_ATTRS:
            connect_attrs = b''
            for k, v in self._connect_attrs.items():
                k, v = k.encode('utf8'), v.encode('utf8')
                connect_attrs += struct.pack('B', len(k)) + k
                connect_attrs += struct.pack('B', len(v)) + v
            data += struct.pack('B', len(connect_attrs)) + connect_attrs

        self.write_packet(data)
        auth_packet = await self._read_packet()

        # if authentication method isn't accepted the first byte
        # will have the octet 254
        if auth_packet.is_auth_switch_request():
            # https://dev.mysql.com/doc/internals/en/
            # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
            auth_packet.read_uint8()  # 0xfe packet identifier
            plugin_name = auth_packet.read_string()
            if (self.server_capabilities & CLIENT.PLUGIN_AUTH
                    and plugin_name is not None):
                await self._process_auth(plugin_name, auth_packet)
            else:
                # send legacy handshake
                data = _auth.scramble_old_password(
                    self._password.encode('latin1'),
                    auth_packet.read_all()) + b'\0'
                self.write_packet(data)
                await self._read_packet()
        elif auth_packet.is_extra_auth_data():
            if auth_plugin == "caching_sha2_password":
                await self.caching_sha2_password_auth(auth_packet)
            elif auth_plugin == "sha256_password":
                await self.sha256_password_auth(auth_packet)
            else:
                raise OperationalError(
                    "Received extra packet "
                    "for auth method %r", auth_plugin)
Пример #21
0
 def test_sql_or_recoverable_error_sql(self, mock_rollback):
     e = OperationalError(2013, 'meow')
     assert retry.sql_or_recoverable_error(e)
     mock_rollback.assert_called()
Пример #22
0
 def opener(filename):
     try:
         self._file_object = open(filename, 'rb')
     except IOError as e:
         msg = "Can't find file '{0}'".format(filename)
         raise OperationalError(1017, msg) from e