Exemple #1
0
    async def create_connection(self, with_db: bool) -> None:
        if charset_by_name(self.charset) is None:  # type: ignore
            raise DBConnectionError(f"Unknown charset {self.charset}")
        self._template = {
            "host": self.host,
            "port": self.port,
            "user": self.user,
            "db": self.database if with_db else None,
            "autocommit": True,
            "charset": self.charset,
            "minsize": self.pool_minsize,
            "maxsize": self.pool_maxsize,
            **self.extra,
        }
        try:
            self._pool = await aiomysql.create_pool(password=self.password, **self._template)

            if isinstance(self._pool, aiomysql.Pool):
                async with self.acquire_connection() as connection:
                    async with connection.cursor() as cursor:
                        if self.storage_engine:
                            await cursor.execute(
                                f"SET default_storage_engine='{self.storage_engine}';"
                            )
                            if self.storage_engine.lower() != "innodb":  # pragma: nobranch
                                self.capabilities.__dict__["supports_transactions"] = False
                        hours = timezone.now().utcoffset().seconds / 3600  # type: ignore
                        tz = "{:+d}:{:02d}".format(int(hours), int((hours % 1) * 60))
                        await cursor.execute(f"SET SESSION time_zone='{tz}';")
            self.log.debug("Created connection %s pool with params: %s", self._pool, self._template)
        except pymysql.err.OperationalError:
            raise DBConnectionError(f"Can't connect to MySQL server: {self._template}")
Exemple #2
0
 def set_charset(self, charset):
     """Sets the character set for the current connection"""
     # Make sure charset is supported.
     encoding = charset_by_name(charset).encoding
     yield from self._execute_command(COM_QUERY, "SET NAMES %s" % self.escape(charset))
     yield from self._read_packet()
     self._charset = charset
     self._encoding = encoding
Exemple #3
0
 async def set_charset(self, charset):
     """Sets the character set for the current connection"""
     # Make sure charset is supported.
     encoding = charset_by_name(charset).encoding
     await self._execute_command(COMMAND.COM_QUERY,
                                 "SET NAMES %s" % self.escape(charset))
     await self._read_packet()
     self._charset = charset
     self._encoding = encoding
Exemple #4
0
    def _request_authentication(self):
        self.client_flag |= CAPABILITIES
        if self.server_version.startswith('5'):
            self.client_flag |= MULTI_RESULTS

        if self._user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self._charset).id
        user = self._user
        if isinstance(self._user, str):
            user = self._user.encode(self._encoding)

        data_init = (struct.pack('<i', self.client_flag) +
                     struct.pack("<I", 1) + int2byte(charset_id) +
                     int2byte(0) * 23)

        next_packet = 1

        data = data_init + user + b'\0' + _scramble(
            self._password.encode('latin1'), self.salt)

        if self._db:
            db = self._db
            if isinstance(self._db, str):
                db = self._db.encode(self._encoding)
            data += db + int2byte(0)

        data = pack_int24(len(data)) + int2byte(next_packet) + data
        next_packet += 2
        # logger.debug(dump_packet(data))
        self._write_bytes(data)

        auth_packet = yield from self._read_packet()

        # if old_passwords is enabled the packet will be 1 byte long and
        # have the octet 254

        if auth_packet.is_eof_packet():
            # send legacy handshake
            data = _scramble_323(self._password.encode('latin1'),
                                 self.salt) + b'\0'
            data = pack_int24(len(data)) + int2byte(next_packet) + data
            self._write_bytes(data)
            auth_packet = self._read_packet()
Exemple #5
0
    def _request_authentication(self):
        self.client_flag |= CAPABILITIES
        if int(self.server_version.split('.', 1)[0]) >= 5:
            self.client_flag |= MULTI_RESULTS

        if self._user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self._charset).id
        user = self._user
        if isinstance(self._user, str):
            user = self._user.encode(self._encoding)

        data_init = struct.pack('<iIB23s', self.client_flag, 1,
                                charset_id, b'')

        next_packet = 1

        data = data_init + user + b'\0' + _scramble(
            self._password.encode('latin1'), self.salt)

        if self._db:
            db = self._db
            if isinstance(self._db, str):
                db = self._db.encode(self._encoding)
            data += db + int2byte(0)

        data = pack_int24(len(data)) + int2byte(next_packet) + data
        next_packet += 2
        # logger.debug(dump_packet(data))
        self._write_bytes(data)

        auth_packet = yield from self._read_packet()

        # if old_passwords is enabled the packet will be 1 byte long and
        # have the octet 254

        if auth_packet.is_eof_packet():
            # send legacy handshake
            data = _scramble_323(self._password.encode('latin1'),
                                 self.salt) + b'\0'
            data = pack_int24(len(data)) + int2byte(next_packet) + data
            self._write_bytes(data)
            auth_packet = self._read_packet()
Exemple #6
0
    def _request_authentication(self):
        self.client_flag |= CLIENT.CAPABILITIES
        if int(self.server_version.split('.', 1)[0]) >= 5:
            self.client_flag |= CLIENT.MULTI_RESULTS

        if self._user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self._charset).id
        user = self._user
        if isinstance(self._user, str):
            user = self._user.encode(self._encoding)

        data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id,
                                b'')

        data = data_init + user + b'\0' + _scramble(
            self._password.encode('latin1'), self.salt)

        # TODO: sync auth code with changes in PyMySQL
        if self._db:
            db = self._db
            if isinstance(self._db, str):
                db = self._db.encode(self._encoding)
            data += db + int2byte(0)

        self.write_packet(data)
        auth_packet = yield from self._read_packet()

        # if old_passwords is enabled the packet will be 1 byte long and
        # have the octet 254

        if auth_packet.is_eof_packet():
            # send legacy handshake
            data = _scramble_323(self._password.encode('latin1'),
                                 self.salt) + b'\0'
            self.write_packet(data)
            auth_packet = self._read_packet()
Exemple #7
0
    def _read_packet(self, packet_type=FilePacket):
        """Read an entire "mysql packet" in its entirety from the network
        and return a MysqlPacket type that represents the results.
        :raise OperationalError: If the connection to the MySQL server is lost.
        :raise InternalError: If the packet sequence number is wrong.
        """
        buff = bytearray()

        type_code, event_length, timestamp = self.read_header()
        if DEBUG:
            print(type_code, event_length, timestamp)
        self._rfile.seek(self.log_pos)
        if event_length:
            buff += b'\0'
            buff += self._read_bytes(event_length)
        else:
            buff += b'\xfe'
            buff += b'\x00\x00\x02\x00'

        packet = packet_type(bytes(buff), charset_by_name('utf8').encoding)
        if packet.is_error_packet():
            packet.raise_for_error()
        return packet
Exemple #8
0
    def __init__(self, host="localhost", user=None, password="",
                 db=None, port=3306, unix_socket=None,
                 charset='', sql_mode=None,
                 read_default_file=None, conv=decoders, use_unicode=None,
                 client_flag=0, cursorclass=Cursor, init_command=None,
                 connect_timeout=None, read_default_group=None,
                 no_delay=None, autocommit=False, echo=False,
                 local_infile=False, loop=None):
        """
        Establish a connection to the MySQL database. Accepts several
        arguments:

        :param host: Host where the database server is located
        :param user: Username to log in as
        :param password: Password to use.
        :param db: Database to use, None to not use a particular one.
        :param port: MySQL port to use, default is usually OK.
        :param unix_socket: Optionally, you can use a unix socket rather
        than TCP/IP.
        :param charset: Charset you want to use.
        :param sql_mode: Default SQL_MODE to use.
        :param read_default_file: Specifies  my.cnf file to read these
            parameters from under the [client] section.
        :param conv: Decoders dictionary to use instead of the default one.
            This is used to provide custom marshalling of types.
            See converters.
        :param use_unicode: Whether or not to default to unicode strings.
        :param  client_flag: Custom flags to send to MySQL. Find
            potential values in constants.CLIENT.
        :param cursorclass: Custom cursor class to use.
        :param init_command: Initial SQL statement to run when connection is
            established.
        :param connect_timeout: Timeout before throwing an exception
            when connecting.
        :param read_default_group: Group to read from in the configuration
            file.
        :param no_delay: Disable Nagle's algorithm on the socket
        :param autocommit: Autocommit mode. None means use server default.
            (default: False)
        :param local_infile: boolean to enable the use of LOAD DATA LOCAL
            command. (default: False)
        :param loop: asyncio loop
        """
        self._loop = loop or asyncio.get_event_loop()

        if use_unicode is None and sys.version_info[0] > 2:
            use_unicode = True

        if read_default_file:
            if not read_default_group:
                read_default_group = "client"
            cfg = configparser.RawConfigParser()
            cfg.read(os.path.expanduser(read_default_file))
            _config = partial(cfg.get, read_default_group)

            user = _config("user", fallback=user)
            password = _config("password", fallback=password)
            host = _config("host", fallback=host)
            db = _config("database", fallback=db)
            unix_socket = _config("socket", fallback=unix_socket)
            port = int(_config("port", fallback=port))
            charset = _config("default-character-set", fallback=charset)

        # pymysql port
        if no_delay is not None:
            warnings.warn("no_delay option is deprecated", DeprecationWarning)
            no_delay = bool(no_delay)
        else:
            no_delay = True

        self._host = host
        self._port = port
        self._user = user or DEFAULT_USER
        self._password = password or ""
        self._db = db
        self._no_delay = no_delay
        self._echo = echo

        self._unix_socket = unix_socket
        if charset:
            self._charset = charset
            self.use_unicode = True
        else:
            self._charset = DEFAULT_CHARSET
            self.use_unicode = False

        if use_unicode is not None:
            self.use_unicode = use_unicode

        self._encoding = charset_by_name(self._charset).encoding

        if local_infile:
            client_flag |= CLIENT.LOCAL_FILES

        client_flag |= CLIENT.CAPABILITIES
        client_flag |= CLIENT.MULTI_STATEMENTS
        if self._db:
            client_flag |= CLIENT.CONNECT_WITH_DB
        self.client_flag = client_flag

        self.cursorclass = cursorclass
        self.connect_timeout = connect_timeout

        self._result = None
        self._affected_rows = 0
        self.host_info = "Not connected"

        #: specified autocommit mode. None means use server default.
        self.autocommit_mode = autocommit

        self.encoders = encoders  # Need for MySQLdb compatibility.
        self.decoders = conv
        self.sql_mode = sql_mode
        self.init_command = init_command

        # asyncio StreamReader, StreamWriter
        self._reader = None
        self._writer = None
Exemple #9
0
    def _request_authentication(self):
        # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
        if int(self.server_version.split('.', 1)[0]) >= 5:
            self.client_flag |= CLIENT.MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, str):
            _user = self.user.encode(self.encoding)

        data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id,
                                b'')

        data = data_init + _user + b'\0'

        authresp = b''
        if self._auth_plugin_name in ('', 'mysql_native_password'):
            authresp = _scramble(self._password.encode('latin1'), self.salt)

        if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
            data += lenenc_int(len(authresp)) + authresp
        elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
            data += struct.pack('B', len(authresp)) + authresp
        else:  # pragma: no cover
            # not testing against servers without secure auth (>=5.0)
            data += authresp + b'\0'

        if self._db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:

            if isinstance(self._db, str):
                db = self._db.encode(self.encoding)
            else:
                db = self._db
            data += db + b'\0'

        if self.server_capabilities & CLIENT.PLUGIN_AUTH:
            name = self._auth_plugin_name
            if isinstance(name, str):
                name = name.encode('ascii')
            data += name + b'\0'

        self.write_packet(data)
        auth_packet = yield from self._read_packet()

        # if authentication method isn't accepted the first byte
        # will have the octet 254
        if auth_packet.is_auth_switch_request():
            # https://dev.mysql.com/doc/internals/en/
            # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
            auth_packet.read_uint8()  # 0xfe packet identifier
            plugin_name = auth_packet.read_string()
            if (self.server_capabilities & CLIENT.PLUGIN_AUTH
                    and plugin_name is not None):
                auth_packet = self._process_auth(plugin_name, auth_packet)
            else:
                # send legacy handshake
                data = _scramble_323(self._password.encode('latin1'),
                                     self.salt) + b'\0'
                self.write_packet(data)
                auth_packet = yield from self._read_packet()
 def __read_string(self, size, column):
     string = self.packet.read_length_coded_pascal_string(size)
     if column.character_set_name is not None:
         string = string.decode(charset_by_name(column.character_set_name).encoding)
     return string
Exemple #11
0
#All functions modified from https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/
#This is all modified to my specific needs, sorry if it gives you any form of sickness.
#Note for me: you changed pymysql execute function too, removed escaping (mog)
from pymysql.converters import encoders
from pymysql.converters import _escape_unicode as escape_string
from pymysql.charset import charset_by_name
from functools import partial

encoding = charset_by_name('utf8mb4').encoding
charset = 'utf8mb4'

def literal(obj):
	return escape(obj, encoders)

def _ensure_bytes(x, encoding=None):
	if isinstance(x, str):
		x = x.encode(encoding)
	elif isinstance(x, (tuple, list)):
		x = type(x)(_ensure_bytes(v, encoding=encoding) for v in x)
	return x

def _escape_args(args):
	ensure_bytes = partial(_ensure_bytes, encoding=encoding)
	if isinstance(args, (tuple, list)):
		return tuple(literal(arg) for arg in args)
	elif isinstance(args, dict):
		return dict((key, literal(val)) for (key, val) in args.items())
	else:
		return escape(args)

def escape(obj, mapping=None):
Exemple #12
0
    def _request_authentication(self):
        self.client_flag |= CLIENT.CAPABILITIES
        if int(self.server_version.split(".", 1)[0]) >= 5:
            self.client_flag |= CLIENT.MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, text_type):
            self.user = self.user.encode(self.encoding)

        data_init = struct.pack("<iIB23s", self.client_flag, 1, charset_id, b"")

        next_packet = 1

        if self.ssl:
            data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
            next_packet += 1

            self._write_bytes(data)

            child_gr = greenlet.getcurrent()
            main = child_gr.parent
            assert main is not None, "Execut must be running in child greenlet"

            def finish(future):
                try:
                    stream = future.result()
                    child_gr.switch(stream)
                except Exception as e:
                    child_gr.throw(e)

            cert_reqs = ssl.CERT_NONE if self.ca is None else ssl.CERT_REQUIRED
            future = self.socket.start_tls(
                None,
                {
                    "keyfile": self.key,
                    "certfile": self.cert,
                    "ssl_version": ssl.PROTOCOL_TLSv1,
                    "cert_reqs": ssl.CERT_REQUIRED,
                    "ca_certs": cert_reqs,
                },
            )
            IOLoop.current().add_future(future, finish)
            self.socket = main.switch()
            self._rfile = self.socket

        data = data_init + self.user + b"\0" + _scramble(self.password.encode("latin1"), self.salt)

        if self.db:
            if isinstance(self.db, text_type):
                self.db = self.db.encode(self.encoding)
            data += self.db + int2byte(0)

        data = pack_int24(len(data)) + int2byte(next_packet) + data
        next_packet += 2

        if DEBUG:
            dump_packet(data)

        self._write_bytes(data)

        auth_packet = self._read_packet()

        # if old_passwords is enabled the packet will be 1 byte long and
        # have the octet 254

        if auth_packet.is_eof_packet():
            # send legacy handshake
            data = _scramble_323(self.password.encode("latin1"), self.salt) + b"\0"
            data = pack_int24(len(data)) + int2byte(next_packet) + data
            self._write_bytes(data)
            auth_packet = self._read_packet()
 def charset_to_encoding(name):
     charset = charset_by_name(name)
     return charset.encoding if charset else name
Exemple #14
0
    def __init__(self,
                 host="localhost",
                 user=None,
                 password="",
                 db=None,
                 port=3306,
                 unix_socket=None,
                 charset='',
                 sql_mode=None,
                 read_default_file=None,
                 conv=decoders,
                 use_unicode=None,
                 client_flag=0,
                 cursorclass=Cursor,
                 init_command=None,
                 connect_timeout=None,
                 read_default_group=None,
                 no_delay=False,
                 autocommit=False,
                 echo=False,
                 loop=None):
        """
        Establish a connection to the MySQL database. Accepts several
        arguments:

        :param host: Host where the database server is located
        :param user: Username to log in as
        :param password: Password to use.
        :param db: Database to use, None to not use a particular one.
        :param port: MySQL port to use, default is usually OK.
        :param unix_socket: Optionally, you can use a unix socket rather
        than TCP/IP.
        :param charset: Charset you want to use.
        :param sql_mode: Default SQL_MODE to use.
        :param read_default_file: Specifies  my.cnf file to read these
            parameters from under the [client] section.
        :param conv: Decoders dictionary to use instead of the default one.
            This is used to provide custom marshalling of types.
            See converters.
        :param use_unicode: Whether or not to default to unicode strings.
        :param  client_flag: Custom flags to send to MySQL. Find
            potential values in constants.CLIENT.
        :param cursorclass: Custom cursor class to use.
        :param init_command: Initial SQL statement to run when connection is
            established.
        :param connect_timeout: Timeout before throwing an exception
            when connecting.
        :param read_default_group: Group to read from in the configuration
            file.
        :param no_delay: Disable Nagle's algorithm on the socket
        :param autocommit: Autocommit mode. None means use server default.
            (default: False)
        :param loop: asyncio loop
        """
        self._loop = loop or asyncio.get_event_loop()

        if use_unicode is None and sys.version_info[0] > 2:
            use_unicode = True

        if read_default_file:
            if not read_default_group:
                read_default_group = "client"
            cfg = configparser.RawConfigParser()
            cfg.read(os.path.expanduser(read_default_file))
            _config = partial(cfg.get, read_default_group)

            user = _config("user", fallback=user)
            password = _config("password", fallback=password)
            host = _config("host", fallback=host)
            db = _config("database", fallback=db)
            unix_socket = _config("socket", fallback=unix_socket)
            port = int(_config("port", fallback=port))
            charset = _config("default-character-set", fallback=charset)

        self._host = host
        self._port = port
        self._user = user or DEFAULT_USER
        self._password = password or ""
        self._db = db
        self._no_delay = no_delay
        self._echo = echo

        self._unix_socket = unix_socket
        if charset:
            self._charset = charset
            self.use_unicode = True
        else:
            self._charset = DEFAULT_CHARSET
            self.use_unicode = False

        if use_unicode is not None:
            self.use_unicode = use_unicode

        self._encoding = charset_by_name(self._charset).encoding

        client_flag |= CAPABILITIES
        client_flag |= MULTI_STATEMENTS
        if self._db:
            client_flag |= CONNECT_WITH_DB
        self.client_flag = client_flag

        self.cursorclass = cursorclass
        self.connect_timeout = connect_timeout

        self._result = None
        self._affected_rows = 0
        self.host_info = "Not connected"

        #: specified autocommit mode. None means use server default.
        self.autocommit_mode = autocommit

        self.encoders = encoders  # Need for MySQLdb compatibility.
        self.decoders = conv
        self.sql_mode = sql_mode
        self.init_command = init_command

        # asyncio StreamReader, StreamWriter
        self._reader = None
        self._writer = None
Exemple #15
0
    def _request_authentication(self):
        if int(self.server_version.split('.', 1)[0]) >= 5:
            self.client_flag |= CLIENT.MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, text_type):
            self.user = self.user.encode(self.encoding)

        data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')

        if self.ssl and self.server_capabilities & CLIENT.SSL:
            self.write_packet(data_init)

            child_gr = greenlet.getcurrent()
            main = child_gr.parent
            assert main is not None, "Execut must be running in child greenlet"

            def finish(future):
                if future._exc_info is not None:
                    child_gr.throw(future.exception())
                else:
                    child_gr.switch(future.result())

            future = self.socket.start_tls(False, self.ctx, server_hostname=self.host)
            self._loop.add_future(future, finish)
            self._rfile = self.socket = main.switch()

        data = data_init + self.user + b'\0'

        authresp = b''
        if self._auth_plugin_name in ('', 'mysql_native_password'):
            authresp = _scramble(self.password.encode('latin1'), self.salt)

        if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
            data += lenenc_int(len(authresp)) + authresp
        elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
            data += struct.pack('B', len(authresp)) + authresp
        else:  # pragma: no cover - not testing against servers without secure auth (>=5.0)
            data += authresp + b'\0'

        if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
            if isinstance(self.db, text_type):
                self.db = self.db.encode(self.encoding)
            data += self.db + b'\0'

        if self.server_capabilities & CLIENT.PLUGIN_AUTH:
            name = self._auth_plugin_name
            if isinstance(name, text_type):
                name = name.encode('ascii')
            data += name + b'\0'

        self.write_packet(data)
        auth_packet = self._read_packet()

        # if authentication method isn't accepted the first byte
        # will have the octet 254
        if auth_packet.is_auth_switch_request():
            # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
            auth_packet.read_uint8() # 0xfe packet identifier
            plugin_name = auth_packet.read_string()
            if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None:
                auth_packet = self._process_auth(plugin_name, auth_packet)
            else:
                # send legacy handshake
                data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
                self.write_packet(data)
                auth_packet = self._read_packet()
Exemple #16
0
    def _request_authentication(self):
        # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
        if int(self.server_version.split('.', 1)[0]) >= 5:
            self.client_flag |= CLIENT.MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, str):
            _user = self.user.encode(self.encoding)

        data_init = struct.pack('<iIB23s', self.client_flag, 1,
                                charset_id, b'')

        data = data_init + _user + b'\0'

        authresp = b''
        if self._auth_plugin_name in ('', 'mysql_native_password'):
            authresp = _scramble(self._password.encode('latin1'), self.salt)

        if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
            data += lenenc_int(len(authresp)) + authresp
        elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
            data += struct.pack('B', len(authresp)) + authresp
        else:  # pragma: no cover
            # not testing against servers without secure auth (>=5.0)
            data += authresp + b'\0'

        if self._db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:

            if isinstance(self._db, str):
                db = self._db.encode(self.encoding)
            else:
                db = self._db
            data += db + b'\0'

        if self.server_capabilities & CLIENT.PLUGIN_AUTH:
            name = self._auth_plugin_name
            if isinstance(name, str):
                name = name.encode('ascii')
            data += name + b'\0'

        self.write_packet(data)
        auth_packet = yield from self._read_packet()

        # if authentication method isn't accepted the first byte
        # will have the octet 254
        if auth_packet.is_auth_switch_request():
            # https://dev.mysql.com/doc/internals/en/
            # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
            auth_packet.read_uint8()  # 0xfe packet identifier
            plugin_name = auth_packet.read_string()
            if (self.server_capabilities & CLIENT.PLUGIN_AUTH and
                    plugin_name is not None):
                auth_packet = self._process_auth(plugin_name, auth_packet)
            else:
                # send legacy handshake
                data = _scramble_323(self._password.encode('latin1'),
                                     self.salt) + b'\0'
                self.write_packet(data)
                auth_packet = yield from self._read_packet()
Exemple #17
0
    def __init__(self,
                 host="localhost",
                 user=None,
                 password="",
                 db=None,
                 port=3306,
                 unix_socket=None,
                 charset='',
                 sql_mode=None,
                 read_default_file=None,
                 conv=decoders,
                 use_unicode=None,
                 client_flag=0,
                 cursorclass=Cursor,
                 init_command=None,
                 connect_timeout=None,
                 read_default_group=None,
                 no_delay=None,
                 autocommit=False,
                 echo=False,
                 local_infile=False,
                 loop=None,
                 ssl=None,
                 auth_plugin='',
                 program_name='',
                 server_public_key=None):
        """
        Establish a connection to the MySQL database. Accepts several
        arguments:

        :param host: Host where the database server is located
        :param user: Username to log in as
        :param password: Password to use.
        :param db: Database to use, None to not use a particular one.
        :param port: MySQL port to use, default is usually OK.
        :param unix_socket: Optionally, you can use a unix socket rather
        than TCP/IP.
        :param charset: Charset you want to use.
        :param sql_mode: Default SQL_MODE to use.
        :param read_default_file: Specifies  my.cnf file to read these
            parameters from under the [client] section.
        :param conv: Decoders dictionary to use instead of the default one.
            This is used to provide custom marshalling of types.
            See converters.
        :param use_unicode: Whether or not to default to unicode strings.
        :param  client_flag: Custom flags to send to MySQL. Find
            potential values in constants.CLIENT.
        :param cursorclass: Custom cursor class to use.
        :param init_command: Initial SQL statement to run when connection is
            established.
        :param connect_timeout: Timeout before throwing an exception
            when connecting.
        :param read_default_group: Group to read from in the configuration
            file.
        :param no_delay: Disable Nagle's algorithm on the socket
        :param autocommit: Autocommit mode. None means use server default.
            (default: False)
        :param local_infile: boolean to enable the use of LOAD DATA LOCAL
            command. (default: False)
        :param ssl: Optional SSL Context to force SSL
        :param auth_plugin: String to manually specify the authentication
            plugin to use, i.e you will want to use mysql_clear_password
            when using IAM authentication with Amazon RDS.
            (default: Server Default)
        :param program_name: Program name string to provide when
            handshaking with MySQL. (default: sys.argv[0])
        :param server_public_key: SHA256 authentication plugin public
            key value.
        :param loop: asyncio loop
        """
        self._loop = loop or asyncio.get_event_loop()

        if use_unicode is None and sys.version_info[0] > 2:
            use_unicode = True

        if read_default_file:
            if not read_default_group:
                read_default_group = "client"
            cfg = configparser.RawConfigParser()
            cfg.read(os.path.expanduser(read_default_file))
            _config = partial(cfg.get, read_default_group)

            user = _config("user", fallback=user)
            password = _config("password", fallback=password)
            host = _config("host", fallback=host)
            db = _config("database", fallback=db)
            unix_socket = _config("socket", fallback=unix_socket)
            port = int(_config("port", fallback=port))
            charset = _config("default-character-set", fallback=charset)

        # pymysql port
        if no_delay is not None:
            warnings.warn("no_delay option is deprecated", DeprecationWarning)
            no_delay = bool(no_delay)
        else:
            no_delay = True

        self._host = host
        self._port = port
        self._user = user or DEFAULT_USER
        self._password = password or ""
        self._db = db
        self._no_delay = no_delay
        self._echo = echo
        self._last_usage = self._loop.time()
        self._client_auth_plugin = auth_plugin
        self._server_auth_plugin = ""
        self._auth_plugin_used = ""
        self.server_public_key = server_public_key
        self.salt = None

        # TODO somehow import version from __init__.py
        self._connect_attrs = {
            '_client_name': 'aiomysql',
            '_pid': str(os.getpid()),
            '_client_version': '0.0.16',
        }
        if program_name:
            self._connect_attrs["program_name"] = program_name
        elif sys.argv:
            self._connect_attrs["program_name"] = sys.argv[0]

        self._unix_socket = unix_socket
        if charset:
            self._charset = charset
            self.use_unicode = True
        else:
            self._charset = DEFAULT_CHARSET
            self.use_unicode = False

        if use_unicode is not None:
            self.use_unicode = use_unicode

        self._ssl_context = ssl
        if ssl:
            client_flag |= CLIENT.SSL

        self._encoding = charset_by_name(self._charset).encoding

        if local_infile:
            client_flag |= CLIENT.LOCAL_FILES

        client_flag |= CLIENT.CAPABILITIES
        client_flag |= CLIENT.MULTI_STATEMENTS
        if self._db:
            client_flag |= CLIENT.CONNECT_WITH_DB
        self.client_flag = client_flag

        self.cursorclass = cursorclass
        self.connect_timeout = connect_timeout

        self._result = None
        self._affected_rows = 0
        self.host_info = "Not connected"

        #: specified autocommit mode. None means use server default.
        self.autocommit_mode = autocommit

        self.encoders = encoders  # Need for MySQLdb compatibility.
        self.decoders = conv
        self.sql_mode = sql_mode
        self.init_command = init_command

        # asyncio StreamReader, StreamWriter
        self._reader = None
        self._writer = None
        # If connection was closed for specific reason, we should show that to
        # user
        self._close_reason = None
    def _request_authentication(self):
        self.client_flag |= CAPABILITIES
        if self.server_version.startswith('5'):
            self.client_flag |= MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, text_type):
            self.user = self.user.encode(self.encoding)

        data_init = struct.pack('<i', self.client_flag) + struct.pack("<I", 1) + \
                     int2byte(charset_id) + int2byte(0)*23

        next_packet = 1

        if self.ssl:
            data = pack_int24(
                len(data_init)) + int2byte(next_packet) + data_init
            next_packet += 1

            if DEBUG: dump_packet(data)

            self._write_bytes(data)
            self.socket = ssl.wrap_socket(self.socket,
                                          keyfile=self.key,
                                          certfile=self.cert,
                                          ssl_version=ssl.PROTOCOL_TLSv1,
                                          cert_reqs=ssl.CERT_REQUIRED,
                                          ca_certs=self.ca)

        data = data_init + self.user + b'\0' + self.forwarded_auth_response

        if self.db:
            if isinstance(self.db, text_type):
                self.db = self.db.encode(self.encoding)
            data += self.db + int2byte(0)

        data = pack_int24(len(data)) + int2byte(next_packet) + data
        next_packet += 2

        if DEBUG: dump_packet(data)

        self._write_bytes(data)

        auth_packet = MysqlPacket(self)
        auth_packet.check_error()
        if DEBUG: auth_packet.dump()

        # if old_passwords is enabled the packet will be 1 byte long and
        # have the octet 254

        if auth_packet.is_eof_packet():
            # send legacy handshake
            #raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table."
            # TODO: is this the correct charset?
            data = _scramble_323(self.password.encode(self.encoding),
                                 self.salt.encode(self.encoding)) + b'\0'
            data = pack_int24(len(data)) + int2byte(next_packet) + data

            self._write_bytes(data)
            auth_packet = MysqlPacket(self)
            auth_packet.check_error()
            if DEBUG: auth_packet.dump()
Exemple #19
0
    async def _request_authentication(self):
        # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
        if self.server_version.split('.', 1)[0] == 'Inception2':
            self.client_flag |= CLIENT.MULTI_RESULTS
        elif int(self.server_version.split('.', 1)[0]) >= 5:
            self.client_flag |= CLIENT.MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        if self._ssl_context:
            # capablities, max packet, charset
            data = struct.pack('<IIB', self.client_flag, 16777216, 33)
            data += b'\x00' * (32 - len(data))

            self.write_packet(data)

            # Stop sending events to data_received
            self._writer.transport.pause_reading()

            # Get the raw socket from the transport
            raw_sock = self._writer.transport.get_extra_info('socket',
                                                             default=None)
            if raw_sock is None:
                raise RuntimeError("Transport does not expose socket instance")

            raw_sock = raw_sock.dup()
            self._writer.transport.close()
            # MySQL expects TLS negotiation to happen in the middle of a
            # TCP connection not at start. Passing in a socket to
            # open_connection will cause it to negotiate TLS on an existing
            # connection not initiate a new one.
            self._reader, self._writer = await asyncio.open_connection(
                sock=raw_sock,
                ssl=self._ssl_context,
                loop=self._loop,
                server_hostname=self._host)

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, str):
            _user = self.user.encode(self.encoding)
        else:
            _user = self.user

        data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
                                charset_id, b'')

        data = data_init + _user + b'\0'

        authresp = b''

        auth_plugin = self._client_auth_plugin
        if not self._client_auth_plugin:
            # Contains the auth plugin from handshake
            auth_plugin = self._server_auth_plugin

        if auth_plugin in ('', 'mysql_native_password'):
            authresp = _auth.scramble_native_password(
                self._password.encode('latin1'), self.salt)
        elif auth_plugin == 'caching_sha2_password':
            if self._password:
                authresp = _auth.scramble_caching_sha2(
                    self._password.encode('latin1'), self.salt)
            # Else: empty password
        elif auth_plugin == 'sha256_password':
            if self._ssl_context and self.server_capabilities & CLIENT.SSL:
                authresp = self._password.encode('latin1') + b'\0'
            elif self._password:
                authresp = b'\1'  # request public key
            else:
                authresp = b'\0'  # empty password

        elif auth_plugin in ('', 'mysql_clear_password'):
            authresp = self._password.encode('latin1') + b'\0'

        if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
            data += lenenc_int(len(authresp)) + authresp
        elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
            data += struct.pack('B', len(authresp)) + authresp
        else:  # pragma: no cover
            # not testing against servers without secure auth (>=5.0)
            data += authresp + b'\0'

        if self._db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:

            if isinstance(self._db, str):
                db = self._db.encode(self.encoding)
            else:
                db = self._db
            data += db + b'\0'

        if self.server_capabilities & CLIENT.PLUGIN_AUTH:
            name = auth_plugin
            if isinstance(name, str):
                name = name.encode('ascii')
            data += name + b'\0'

        self._auth_plugin_used = auth_plugin

        # Sends the server a few pieces of client info
        if self.server_capabilities & CLIENT.CONNECT_ATTRS:
            connect_attrs = b''
            for k, v in self._connect_attrs.items():
                k, v = k.encode('utf8'), v.encode('utf8')
                connect_attrs += struct.pack('B', len(k)) + k
                connect_attrs += struct.pack('B', len(v)) + v
            data += struct.pack('B', len(connect_attrs)) + connect_attrs

        self.write_packet(data)
        auth_packet = await self._read_packet()

        # if authentication method isn't accepted the first byte
        # will have the octet 254
        if auth_packet.is_auth_switch_request():
            # https://dev.mysql.com/doc/internals/en/
            # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
            auth_packet.read_uint8()  # 0xfe packet identifier
            plugin_name = auth_packet.read_string()
            if (self.server_capabilities & CLIENT.PLUGIN_AUTH
                    and plugin_name is not None):
                await self._process_auth(plugin_name, auth_packet)
            else:
                # send legacy handshake
                data = _auth.scramble_old_password(
                    self._password.encode('latin1'),
                    auth_packet.read_all()) + b'\0'
                self.write_packet(data)
                await self._read_packet()
        elif auth_packet.is_extra_auth_data():
            if auth_plugin == "caching_sha2_password":
                await self.caching_sha2_password_auth(auth_packet)
            elif auth_plugin == "sha256_password":
                await self.sha256_password_auth(auth_packet)
            else:
                raise OperationalError(
                    "Received extra packet "
                    "for auth method %r", auth_plugin)
    def _request_authentication(self):
        self.client_flag |= CAPABILITIES
        if self.server_version.startswith('5'):
            self.client_flag |= MULTI_RESULTS

        if self.user is None:
            raise ValueError("Did not specify a username")

        charset_id = charset_by_name(self.charset).id
        if isinstance(self.user, text_type):
            self.user = self.user.encode(self.encoding)

        data_init = struct.pack('<i', self.client_flag) + struct.pack("<I", 1) + \
                     int2byte(charset_id) + int2byte(0)*23

        next_packet = 1

        if self.ssl:
            data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
            next_packet += 1

            if DEBUG: dump_packet(data)

            self._write_bytes(data)
            self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
                                          certfile=self.cert,
                                          ssl_version=ssl.PROTOCOL_TLSv1,
                                          cert_reqs=ssl.CERT_REQUIRED,
                                          ca_certs=self.ca)

        data = data_init + self.user + b'\0' + self.forwarded_auth_response

        if self.db:
            if isinstance(self.db, text_type):
                self.db = self.db.encode(self.encoding)
            data += self.db + int2byte(0)

        data = pack_int24(len(data)) + int2byte(next_packet) + data
        next_packet += 2

        if DEBUG: dump_packet(data)

        self._write_bytes(data)

        auth_packet = MysqlPacket(self)
        auth_packet.check_error()
        if DEBUG: auth_packet.dump()

        # if old_passwords is enabled the packet will be 1 byte long and
        # have the octet 254

        if auth_packet.is_eof_packet():
            # send legacy handshake
            #raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table."
            # TODO: is this the correct charset?
            data = _scramble_323(self.password.encode(self.encoding), self.salt.encode(self.encoding)) + b'\0'
            data = pack_int24(len(data)) + int2byte(next_packet) + data

            self._write_bytes(data)
            auth_packet = MysqlPacket(self)
            auth_packet.check_error()
            if DEBUG: auth_packet.dump()