Exemple #1
0
 def convert_pseudotype(self, obj):
     reql_type = obj.get('$reql_type$')
     if reql_type is not None:
         if reql_type == 'TIME':
             time_format = self.reql_format_opts.get('time_format')
             if time_format is None or time_format == 'native':
                 # Convert to native python datetime object
                 return self.convert_time(obj)
             elif time_format != 'raw':
                 raise ReqlDriverError("Unknown time_format run option \"%s\"."
                                       % time_format)
         elif reql_type == 'GROUPED_DATA':
             group_format = self.reql_format_opts.get('group_format')
             if group_format is None or group_format == 'native':
                 return self.convert_grouped_data(obj)
             elif group_format != 'raw':
                 raise ReqlDriverError("Unknown group_format run option \"%s\"."
                                       % group_format)
         elif reql_type == 'GEOMETRY':
             # No special support for this. Just return the raw object
             return obj
         elif reql_type == 'BINARY':
             binary_format = self.reql_format_opts.get('binary_format')
             if binary_format is None or binary_format == 'native':
                 return self.convert_binary(obj)
             elif binary_format != 'raw':
                 raise ReqlDriverError("Unknown binary_format run option \"%s\"."
                                       % binary_format)
         else:
             raise ReqlDriverError("Unknown pseudo-type %s" % reql_type)
     # If there was no pseudotype, or the relevant format is raw, return
     # the original object
     return obj
    def connect(self, timeout):
        factory = DatabaseProtoFactory(timeout, self._handleResponse,
                                       self._parent.handshake)

        # We connect to the server, and send the handshake payload.
        pConnection = None
        try:
            pConnection = yield self._connectTimeout(factory, timeout)
        except Exception as e:
            raise ReqlDriverError(
                'Could not connect to {p.host}:{p.port}. Error: {exc}'.format(
                    p=self._parent, exc=str(e)))

        # Now, we need to wait for the handshake.
        try:
            yield pConnection.wait_for_handshake
        except ReqlAuthError as e:
            raise
        except ReqlTimeoutError as e:
            raise ReqlTimeoutError(self._parent.host, self._parent.port)
        except Exception as e:
            raise ReqlDriverError(
                'Connection interrupted during handshake with {p.host}:{p.port}. Error: {exc}'
                .format(p=self._parent, exc=str(e)))

        self._connection = pConnection

        returnValue(self._parent)
 def recvall(self, length):
     res = b'' if self._read_buffer is None else self._read_buffer
     while len(res) < length:
         while True:
             try:
                 chunk = self._socket.recv(length - len(res))
                 break
             except ReqlTimeoutError:
                 raise
             except IOError as ex:
                 if ex.errno == errno.ECONNRESET:
                     self.close()
                     raise ReqlDriverError("Connection is closed.")
                 elif ex.errno != errno.EINTR:
                     self.close()
                     raise ReqlDriverError(
                         'Connection interrupted receiving from %s:%s - %s'
                         % (self.host, self.port, str(ex)))
             except Exception as ex:
                 self.close()
                 raise ReqlDriverError('Error receiving from %s:%s - %s' %
                                       (self.host, self.port, str(ex)))
         if len(chunk) == 0:
             self.close()
             raise ReqlDriverError("Connection is closed.")
         res += chunk
     return res
 def next_message(self, response):
     if self._state == 0:
         if response is not None:
             raise ReqlDriverError("Unexpected response")
         self._state = 1
         return \
             struct.pack("<2L", self.VERSION, len(self._auth_key)) + \
             self._auth_key + \
             struct.pack("<L", self.PROTOCOL)
     elif self._state == 1:
         if response is None:
             raise ReqlDriverError("Expected response")
         self._state = 2
         if response != b"SUCCESS":
             # This is an error case, we attempt to decode `response` as UTF-8 with
             # with fallbacks to get something useful
             message = None
             try:
                 message = response.decode("utf-8", errors="ignore").strip()
             except TypeError:
                 try:
                     message = response.decode("utf-8").strip()
                 except UnicodeError:
                     message = repr(response).strip()
             if message == "ERROR: Incorrect authorization key.":
                 raise ReqlAuthError("Incorrect authentication key.",
                                     self._host, self._port)
             else:
                 raise ReqlDriverError(
                     "Server dropped connection with message: \"%s\"" %
                     (message, ))
         return None
     else:
         raise ReqlDriverError("Unexpected handshake state")
Exemple #5
0
    def run(self, c=None, **global_optargs):
        if c is None:
            c = Repl.get()
            if c is None:
                if Repl.repl_active:
                    raise ReqlDriverError(
                        "RqlQuery.run must be given a connection to run on. A default connection has been set with "
                        "`repl()` on another thread, but not this one."
                    )
                else:
                    raise ReqlDriverError("RqlQuery.run must be given a connection to run on.")

        return c._start(self, **global_optargs)
    def close(self, noreply_wait=False, token=None, exception=None):
        d = defer.succeed(None)
        self._closing = True
        error_message = "Connection is closed"
        if exception is not None:
            error_message = "Connection is closed (reason: {exc})".format(
                exc=str(exception))

        for cursor in list(self._cursor_cache.values()):
            cursor._error(error_message)

        for query, deferred in iter(self._user_queries.values()):
            if not deferred.called:
                deferred.errback(fail=ReqlDriverError(error_message))

        self._user_queries = {}
        self._cursor_cache = {}

        if noreply_wait:
            noreply = Query(pQuery.NOREPLY_WAIT, token, None, None)
            d = self.run_query(noreply, False)

        def closeConnection(res):
            self._connection.transport.loseConnection()
            return res

        return d.addBoth(closeConnection)
 def _handleResponse(self, token, data):
     try:
         cursor = self._cursor_cache.get(token)
         if cursor is not None:
             cursor._extend(data)
         elif token in self._user_queries:
             query, deferred = self._user_queries[token]
             res = Response(token, data,
                            self._parent._get_json_decoder(query))
             if res.type == pResponse.SUCCESS_ATOM:
                 deferred.callback(maybe_profile(res.data[0], res))
             elif res.type in (pResponse.SUCCESS_SEQUENCE,
                               pResponse.SUCCESS_PARTIAL):
                 cursor = TwistedCursor(self, query, res)
                 deferred.callback(maybe_profile(cursor, res))
             elif res.type == pResponse.WAIT_COMPLETE:
                 deferred.callback(None)
             elif res.type == pResponse.SERVER_INFO:
                 deferred.callback(res.data[0])
             else:
                 deferred.errback(res.make_error(query))
             del self._user_queries[token]
         elif not self._closing:
             raise ReqlDriverError("Unexpected response received.")
     except Exception as e:
         if not self._closing:
             self.close(exception=e)
    def close(self, noreply_wait=False, token=None, exception=None):
        self._closing = True
        if exception is not None:
            err_message = "Connection is closed (%s)." % str(exception)
        else:
            err_message = "Connection is closed."

        # Cursors may remove themselves when errored, so copy a list of them
        for cursor in list(self._cursor_cache.values()):
            cursor._error(err_message)

        for query, future in iter(self._user_queries.values()):
            future.set_exception(ReqlDriverError(err_message))

        self._user_queries = {}
        self._cursor_cache = {}

        if noreply_wait:
            noreply = Query(pQuery.NOREPLY_WAIT, token, None, None)
            yield self.run_query(noreply, False)

        try:
            self._stream.close()
        except iostream.StreamClosedError:
            pass
        raise gen.Return(None)
Exemple #9
0
 def _wait_to_timeout(wait):
     if isinstance(wait, bool):
         return None if wait else 0
     elif isinstance(wait, numbers.Real) and wait >= 0:
         return wait
     else:
         raise ReqlDriverError("Invalid wait timeout '%s'" % str(wait))
    def close(self, noreply_wait=False, token=None, exception=None):
        self._closing = True
        if exception is not None:
            err_message = "Connection is closed (%s)." % str(exception)
        else:
            err_message = "Connection is closed."

        # Cursors may remove themselves when errored, so copy a list of them
        for cursor in list(self._cursor_cache.values()):
            cursor._error(err_message)

        for query, future in iter(self._user_queries.values()):
            if not future.done():
                future.set_exception(ReqlDriverError(err_message))

        self._user_queries = {}
        self._cursor_cache = {}

        if noreply_wait:
            noreply = Query(pQuery.NOREPLY_WAIT, token, None, None)
            yield from self.run_query(noreply, False)

        self._streamwriter.close()
        # We must not wait for the _reader_task if we got an exception, because that
        # means that we were called from it. Waiting would lead to a deadlock.
        if self._reader_task and exception is None:
            yield from self._reader_task

        return None
 def __init__(self, *args, **kwargs):
     super(Connection, self).__init__(ConnectionInstance, *args, **kwargs)
     try:
         self.port = int(self.port)
     except ValueError:
         raise ReqlDriverError("Could not convert port %s to an integer." %
                               self.port)
 def dataReceived(self, data):
     try:
         if self._open:
             self._handlers[self.state](data)
     except Exception as e:
         self.transport.loseConnection()
         raise ReqlDriverError('Driver failed to handle received data.'
                               'Error: {exc}. Dropping the connection.'.format(exc=str(e)))
Exemple #13
0
 def convert_grouped_data(self, obj):
     if 'data' not in obj:
         raise ReqlDriverError(
             ('pseudo-type GROUPED_DATA object' +
              ' %s does not have the expected field "data".') %
             json.dumps(obj))
     return dict([(recursively_make_hashable(k), v)
                  for k, v in obj['data']])
 def sendall(self, data):
     offset = 0
     while offset < len(data):
         try:
             offset += self._socket.send(data[offset:])
         except IOError as ex:
             if ex.errno == errno.ECONNRESET:
                 self.close()
                 raise ReqlDriverError("Connection is closed.")
             elif ex.errno != errno.EINTR:
                 self.close()
                 raise ReqlDriverError(
                     ('Connection interrupted ' + 'sending to %s:%s - %s') %
                     (self.host, self.port, str(ex)))
         except Exception as ex:
             self.close()
             raise ReqlDriverError('Error sending to %s:%s - %s' %
                                   (self.host, self.port, str(ex)))
Exemple #15
0
 def make_error(self, query):
     if self.type == pResponse.CLIENT_ERROR:
         return ReqlDriverError(self.data[0], query.term, self.backtrace)
     elif self.type == pResponse.COMPILE_ERROR:
         return ReqlServerCompileError(self.data[0], query.term, self.backtrace)
     elif self.type == pResponse.RUNTIME_ERROR:
         return {
             pErrorType.INTERNAL: ReqlInternalError,
             pErrorType.RESOURCE_LIMIT: ReqlResourceLimitError,
             pErrorType.QUERY_LOGIC: ReqlQueryLogicError,
             pErrorType.NON_EXISTENCE: ReqlNonExistenceError,
             pErrorType.OP_FAILED: ReqlOpFailedError,
             pErrorType.OP_INDETERMINATE: ReqlOpIndeterminateError,
             pErrorType.USER: ReqlUserError,
             pErrorType.PERMISSION_ERROR: ReqlPermissionError
         }.get(self.error_type, ReqlRuntimeError)(
             self.data[0], query.term, self.backtrace)
     return ReqlDriverError(("Unknown Response type %d encountered" +
                             " in a response.") % self.type)
Exemple #16
0
    def convert_time(self, obj):
        if 'epoch_time' not in obj:
            raise ReqlDriverError(('pseudo-type TIME object %s does not ' +
                                   'have expected field "epoch_time".')
                                  % json.dumps(obj))

        if 'timezone' in obj:
            return datetime.datetime.fromtimestamp(obj['epoch_time'],
                                                   RqlTzinfo(obj['timezone']))
        else:
            return datetime.datetime.utcfromtimestamp(obj['epoch_time'])
Exemple #17
0
    def __init__(self, conn_type, host, port, db, auth_key, user, password, timeout, ssl, _handshake_version, **kwargs):
        self.db = db

        self.host = host
        try:
            self.port = int(port)
        except ValueError:
            raise ReqlDriverError("Could not convert port %r to an integer." % port)

        self.connect_timeout = timeout

        self.ssl = ssl

        self._conn_type = conn_type
        self._child_kwargs = kwargs
        self._instance = None
        self._next_token = 0

        if 'json_encoder' in kwargs:
            self._json_encoder = kwargs.pop('json_encoder')
        if 'json_decoder' in kwargs:
            self._json_decoder = kwargs.pop('json_decoder')

        if auth_key is None and password is None:
            auth_key = password = ''
        elif auth_key is None and password is not None:
            auth_key = password
        elif auth_key is not None and password is None:
            password = auth_key
        else:
            # auth_key is not None and password is not None
            raise ReqlDriverError("`auth_key` and `password` are both set.")

        if _handshake_version == 4:
            self.handshake = HandshakeV0_4(self.host, self.port, auth_key)
        else:
            self.handshake = HandshakeV1_0(
                self._json_decoder(), self._json_encoder(), self.host, self.port, user, password)
Exemple #18
0
    def recvall(self, length, deadline):
        res = b'' if self._read_buffer is None else self._read_buffer
        timeout = None if deadline is None else max(0, deadline - time.time())
        self._socket.settimeout(timeout)
        while len(res) < length:
            while True:
                try:
                    chunk = self._socket.recv(length - len(res))
                    self._socket.settimeout(None)
                    break
                except socket.timeout:
                    self._read_buffer = res
                    self._socket.settimeout(None)
                    raise ReqlTimeoutError(self.host, self.port)
                except IOError as ex:
                    if ex.errno == errno.ECONNRESET:
                        self.close()
                        raise ReqlDriverError("Connection is closed.")
                    elif ex.errno == errno.EWOULDBLOCK:
                        self.close()
                        # This should only happen with a timeout of 0
                        raise ReqlTimeoutError(self.host, self.port)
                    elif ex.errno != errno.EINTR:
                        raise ReqlDriverError(('Connection interrupted ' +
                                               'receiving from %s:%s - %s') %
                                              (self.host, self.port, str(ex)))
                except Exception as ex:
                    self.close()
                    raise ReqlDriverError('Error receiving from %s:%s - %s' %
                                          (self.host, self.port, str(ex)))

            if len(chunk) == 0:
                self.close()
                raise ReqlDriverError("Connection is closed.")
            res += chunk
        return res
Exemple #19
0
    def _read_response(self, query, deadline=None):
        token = query.token
        # We may get an async continue result, in which case we save
        # it and read the next response
        while True:
            try:
                # The first 8 bytes give the corresponding query token
                # of this response.  The next 4 bytes give the
                # expected length of this response.
                if self._header_in_progress is None:
                    self._header_in_progress \
                        = self._socket.recvall(12, deadline)
                (res_token, res_len,) \
                    = struct.unpack("<qL", self._header_in_progress)
                res_buf = self._socket.recvall(res_len, deadline)
                self._header_in_progress = None
            except KeyboardInterrupt as ex:
                # Cancel outstanding queries by dropping this connection,
                # then create a new connection for the user's convenience.
                self._parent.reconnect(noreply_wait=False)
                raise ex

            res = None

            cursor = self._cursor_cache.get(res_token)
            if cursor is not None:
                # Construct response
                cursor._extend(res_buf)
                if res_token == token:
                    return res
            elif res_token == token:
                return Response(
                    res_token, res_buf,
                    self._parent._get_json_decoder(query))
            elif not self._closing:
                # This response is corrupted or not intended for us
                self.close()
                raise ReqlDriverError("Unexpected response received.")
    def _reader(self):
        try:
            while True:
                buf = yield self._stream.read_bytes(12)
                (
                    token,
                    length,
                ) = struct.unpack("<qL", buf)
                buf = yield self._stream.read_bytes(length)

                cursor = self._cursor_cache.get(token)
                if cursor is not None:
                    cursor._extend(buf)
                elif token in self._user_queries:
                    # Do not pop the query from the dict until later, so
                    # we don't lose track of it in case of an exception
                    query, future = self._user_queries[token]
                    res = Response(token, buf,
                                   self._parent._get_json_decoder(query))
                    if res.type == pResponse.SUCCESS_ATOM:
                        future.set_result(maybe_profile(res.data[0], res))
                    elif res.type in (pResponse.SUCCESS_SEQUENCE,
                                      pResponse.SUCCESS_PARTIAL):
                        cursor = TornadoCursor(self, query, res)
                        future.set_result(maybe_profile(cursor, res))
                    elif res.type == pResponse.WAIT_COMPLETE:
                        future.set_result(None)
                    elif res.type == pResponse.SERVER_INFO:
                        future.set_result(res.data[0])
                    else:
                        future.set_exception(res.make_error(query))
                    del self._user_queries[token]
                elif not self._closing:
                    raise ReqlDriverError("Unexpected response received.")
        except Exception as ex:
            if not self._closing:
                yield self.close(exception=ex)
Exemple #21
0
 def check_open(self):
     if self._instance is None or not self._instance.is_open():
         raise ReqlDriverError('Connection is closed.')
    def connect(self, timeout):
        deadline = None if timeout is None else self._io_loop.time() + timeout

        try:
            if len(self._parent.ssl) > 0:
                ssl_options = {}
                if self._parent.ssl["ca_certs"]:
                    ssl_options['ca_certs'] = self._parent.ssl["ca_certs"]
                    ssl_options['cert_reqs'] = 2  # ssl.CERT_REQUIRED
                stream_future = TCPClient().connect(self._parent.host,
                                                    self._parent.port,
                                                    ssl_options=ssl_options)
            else:
                stream_future = TCPClient().connect(self._parent.host,
                                                    self._parent.port)

            self._stream = yield with_absolute_timeout(
                deadline,
                stream_future,
                io_loop=self._io_loop,
                quiet_exceptions=(iostream.StreamClosedError))
        except Exception as err:
            raise ReqlDriverError(
                'Could not connect to %s:%s. Error: %s' %
                (self._parent.host, self._parent.port, str(err)))

        self._stream.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY,
                                       1)
        self._stream.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE,
                                       1)

        try:
            self._parent.handshake.reset()
            response = None
            while True:
                request = self._parent.handshake.next_message(response)
                if request is None:
                    break
                # This may happen in the `V1_0` protocol where we send two requests as
                # an optimization, then need to read each separately
                if request is not "":
                    self._stream.write(request)

                response = yield with_absolute_timeout(
                    deadline,
                    self._stream.read_until(b'\0'),
                    io_loop=self._io_loop,
                    quiet_exceptions=(iostream.StreamClosedError))
                response = response[:-1]
        except ReqlAuthError:
            try:
                self._stream.close()
            except iostream.StreamClosedError:
                pass
            raise
        except ReqlTimeoutError:
            try:
                self._stream.close()
            except iostream.StreamClosedError:
                pass
            raise ReqlTimeoutError(self._parent.host, self._parent.port)
        except Exception as err:
            try:
                self._stream.close()
            except iostream.StreamClosedError:
                pass
            raise ReqlDriverError(
                'Connection interrupted during handshake with %s:%s. Error: %s'
                % (self._parent.host, self._parent.port, str(err)))

        # Start a parallel function to perform reads
        self._io_loop.add_callback(self._reader)
        raise gen.Return(self._parent)
    def __init__(self, parent):
        self.host = parent._parent.host
        self.port = parent._parent.port
        self._read_buffer = None
        self._socket = None
        self.ssl = parent._parent.ssl

        try:
            self._socket = socket.create_connection((self.host, self.port))
            self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
            self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

            if len(self.ssl) > 0:
                try:
                    if hasattr(ssl, 'SSLContext'
                               ):  # Python2.7 and 3.2+, or backports.ssl
                        ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
                        if hasattr(ssl_context, "options"):
                            ssl_context.options |= getattr(
                                ssl, "OP_NO_SSLv2", 0)
                            ssl_context.options |= getattr(
                                ssl, "OP_NO_SSLv3", 0)
                        self.ssl_context.verify_mode = ssl.CERT_REQUIRED
                        self.ssl_context.check_hostname = True  # redundant with match_hostname
                        self.ssl_context.load_verify_locations(
                            self.ssl["ca_certs"])
                        self._socket = ssl_context.wrap_socket(
                            self._socket, server_hostname=self.host)
                    else:  # this does not disable SSLv2 or SSLv3
                        self._socket = ssl.wrap_socket(
                            self._socket,
                            cert_reqs=ssl.CERT_REQUIRED,
                            ssl_version=ssl.PROTOCOL_SSLv23,
                            ca_certs=self.ssl["ca_certs"])
                except IOError as exc:
                    self._socket.close()
                    raise ReqlDriverError(
                        "SSL handshake failed (see server log for more information): %s"
                        % str(exc))
                try:
                    ssl.match_hostname(self._socket.getpeercert(),
                                       hostname=self.host)
                except ssl.CertificateError:
                    self._socket.close()
                    raise

            parent._parent.handshake.reset()
            response = None
            while True:
                request = parent._parent.handshake.next_message(response)
                if request is None:
                    break
                # This may happen in the `V1_0` protocol where we send two requests as
                # an optimization, then need to read each separately
                if request is not "":
                    self.sendall(request)

                # The response from the server is a null-terminated string
                response = b''
                while True:
                    char = self.recvall(1)
                    if char == b'\0':
                        break
                    response += char
        except (ReqlAuthError, ReqlTimeoutError):
            self.close()
            raise
        except ReqlDriverError as ex:
            self.close()
            error = str(ex) \
                .replace('receiving from', 'during handshake with') \
                .replace('sending to', 'during handshake with')
            raise ReqlDriverError(error)
        except Exception as ex:
            self.close()
            raise ReqlDriverError("Could not connect to %s:%s. Error: %s" %
                                  (self.host, self.port, ex))
Exemple #24
0
    def __init__(self, parent, timeout):
        self.host = parent._parent.host
        self.port = parent._parent.port
        self._read_buffer = None
        self._socket = None
        self.ssl = parent._parent.ssl

        deadline = time.time() + timeout

        try:
            self._socket = socket.create_connection((self.host, self.port), timeout)
            self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
            self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

            if len(self.ssl) > 0:
                try:
                    if hasattr(ssl, 'SSLContext'):  # Python2.7 and 3.2+, or backports.ssl
                        ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
                        if hasattr(ssl_context, "options"):
                            ssl_context.options |= getattr(ssl, "OP_NO_SSLv2", 0)
                            ssl_context.options |= getattr(ssl, "OP_NO_SSLv3", 0)
                        ssl_context.verify_mode = ssl.CERT_REQUIRED
                        ssl_context.check_hostname = True  # redundant with match_hostname
                        ssl_context.load_verify_locations(self.ssl["ca_certs"])
                        self._socket = ssl_context.wrap_socket(self._socket, server_hostname=self.host)
                    else:  # this does not disable SSLv2 or SSLv3
                        self._socket = ssl.wrap_socket(
                            self._socket, cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_SSLv23,
                            ca_certs=self.ssl["ca_certs"])
                except IOError as err:
                    self._socket.close()

                    if 'EOF occurred in violation of protocol' in str(
                            err) or 'sslv3 alert handshake failure' in str(err):
                        # probably on an older version of OpenSSL
                        raise ReqlDriverError(
                            "SSL handshake failed, likely because Python is linked against an old version of OpenSSL "
                            "that does not support either TLSv1.2 or any of the allowed ciphers. This can be worked "
                            "around by lowering the security setting on the server with the options "
                            "`--tls-min-protocol TLSv1 --tls-ciphers "
                            "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:AES256-SHA` (see server log for more "
                            "information): %s" % str(err)
                        )
                    else:
                        raise ReqlDriverError(
                            "SSL handshake failed (see server log for more information): %s" %
                            str(err))
                try:
                    match_hostname(self._socket.getpeercert(), hostname=self.host)
                except CertificateError:
                    self._socket.close()
                    raise

            parent._parent.handshake.reset()
            response = None
            while True:
                request = parent._parent.handshake.next_message(response)
                if request is None:
                    break
                # This may happen in the `V1_0` protocol where we send two requests as
                # an optimization, then need to read each separately
                if request is not "":
                    self.sendall(request)

                # The response from the server is a null-terminated string
                response = b''
                while True:
                    char = self.recvall(1, deadline)
                    if char == b'\0':
                        break
                    response += char
        except (ReqlAuthError, ReqlTimeoutError):
            self.close()
            raise
        except ReqlDriverError as ex:
            self.close()
            error = str(ex)\
                .replace('receiving from', 'during handshake with')\
                .replace('sending to', 'during handshake with')
            raise ReqlDriverError(error)
        except socket.timeout as ex:
            self.close()
            raise ReqlTimeoutError(self.host, self.port)
        except Exception as ex:
            self.close()
            raise ReqlDriverError("Could not connect to %s:%s. Error: %s" %
                                  (self.host, self.port, str(ex)))
Exemple #25
0
 def convert_binary(obj):
     if 'data' not in obj:
         raise ReqlDriverError(('pseudo-type BINARY object %s does not have ' +
                                'the expected field "data".')
                               % json.dumps(obj))
     return RqlBinary(base64.b64decode(obj['data'].encode('utf-8')))
    def connect(self, timeout):
        try:
            ssl_context = None
            if len(self._parent.ssl) > 0:
                ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
                if hasattr(ssl_context, "options"):
                    ssl_context.options |= getattr(ssl, "OP_NO_SSLv2", 0)
                    ssl_context.options |= getattr(ssl, "OP_NO_SSLv3", 0)
                ssl_context.verify_mode = ssl.CERT_REQUIRED
                ssl_context.check_hostname = True  # redundant with match_hostname
                ssl_context.load_verify_locations(self._parent.ssl["ca_certs"])

            self._streamreader, self._streamwriter = yield from asyncio.open_connection(
                self._parent.host,
                self._parent.port,
                loop=self._io_loop,
                ssl=ssl_context)
            self._streamwriter.get_extra_info('socket').setsockopt(
                socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
            self._streamwriter.get_extra_info('socket').setsockopt(
                socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
        except Exception as err:
            raise ReqlDriverError(
                'Could not connect to %s:%s. Error: %s' %
                (self._parent.host, self._parent.port, str(err)))

        try:
            self._parent.handshake.reset()
            response = None
            with translate_timeout_errors():
                while True:
                    request = self._parent.handshake.next_message(response)
                    if request is None:
                        break
                    # This may happen in the `V1_0` protocol where we send two requests as
                    # an optimization, then need to read each separately
                    if request is not "":
                        self._streamwriter.write(request)

                    response = yield from asyncio.wait_for(
                        _read_until(self._streamreader, b'\0'),
                        timeout,
                        loop=self._io_loop,
                    )
                    response = response[:-1]
        except ReqlAuthError:
            yield from self.close()
            raise
        except ReqlTimeoutError as err:
            yield from self.close()
            raise ReqlDriverError(
                'Connection interrupted during handshake with %s:%s. Error: %s'
                % (self._parent.host, self._parent.port, str(err)))
        except Exception as err:
            yield from self.close()
            raise ReqlDriverError(
                'Could not connect to %s:%s. Error: %s' %
                (self._parent.host, self._parent.port, str(err)))

        # Start a parallel function to perform reads
        #  store a reference to it so it doesn't get destroyed
        self._reader_task = asyncio.ensure_future(self._reader(),
                                                  loop=self._io_loop)
        return self._parent
Exemple #27
0
 def __iter__(*args, **kwargs):
     raise ReqlDriverError(
         "__iter__ called on an RqlQuery object.\n"
         "To iterate over the results of a query, call run first.\n"
         "To iterate inside a query, use map or for_each.")
    def next_message(self, response):
        if self._state == 0:
            if response is not None:
                raise ReqlDriverError("Unexpected response")

            # Using base64 encoding for printable characters
            self._r = base64.standard_b64encode(
                bytes(bytearray(
                    self._random.getrandbits(8) for i in range(18))))

            self._client_first_message_bare = b"n=" + self._username + b",r=" + self._r

            # Here we send the version as well as the initial JSON as an optimization
            self._state = 1
            return struct.pack("<L", self.VERSION) + \
                self._json_encoder.encode({
                    "protocol_version": self._protocol_version,
                    "authentication_method": "SCRAM-SHA-256",
                    "authentication":
                        (b"n,," + self._client_first_message_bare).decode("ascii")
                }).encode("utf-8") + \
                b'\0'
        elif self._state == 1:
            response = response.decode("utf-8")
            if response.startswith("ERROR"):
                raise ReqlDriverError(
                    "Received an unexpected reply. You may be attempting to connect to a RethinkDB server that is too "
                    "old for this driver.  The minimum supported server version is 2.3.0."
                )
            json = self._json_decoder.decode(response)
            try:
                if json["success"] is False:
                    if 10 <= json["error_code"] <= 20:
                        raise ReqlAuthError(json["error"], self._host,
                                            self._port)
                    else:
                        raise ReqlDriverError(json["error"])

                min = json["min_protocol_version"]
                max = json["max_protocol_version"]
                if not min <= self._protocol_version <= max:
                    raise ReqlDriverError(
                        "Unsupported protocol version %d, expected between %d and %d"
                        % (self._protocol_version, min, max))
            except KeyError as key_error:
                raise ReqlDriverError("Missing key: %s" % (key_error, ))

            # We've already sent the initial JSON above, and only support a single
            # protocol version at the moment thus we simply read the next response
            self._state = 2
            return ""
        elif self._state == 2:
            json = self._json_decoder.decode(response.decode("utf-8"))
            server_first_message = r = salt = i = None
            try:
                if json["success"] is False:
                    if 10 <= json["error_code"] <= 20:
                        raise ReqlAuthError(json["error"], self._host,
                                            self._port)
                    else:
                        raise ReqlDriverError(json["error"])

                server_first_message = json["authentication"].encode("ascii")
                authentication = dict(
                    x.split(b"=", 1) for x in server_first_message.split(b","))

                r = authentication[b"r"]
                if not r.startswith(self._r):
                    raise ReqlAuthError("Invalid nonce from server",
                                        self._host, self._port)
                salt = base64.standard_b64decode(authentication[b"s"])
                i = int(authentication[b"i"])
            except KeyError as key_error:
                raise ReqlDriverError("Missing key: %s" % (key_error, ))

            client_final_message_without_proof = b"c=biws,r=" + r

            # SaltedPassword := Hi(Normalize(password), salt, i)
            salted_password = self._pbkdf2_hmac("sha256", self._password, salt,
                                                i)

            # ClientKey := HMAC(SaltedPassword, "Client Key")
            client_key = hmac.new(salted_password, b"Client Key",
                                  hashlib.sha256).digest()

            # StoredKey := H(ClientKey)
            stored_key = hashlib.sha256(client_key).digest()

            # AuthMessage := client-first-message-bare + "," +
            #                server-first-message + "," +
            #                client-final-message-without-proof
            auth_message = b",".join(
                (self._client_first_message_bare, server_first_message,
                 client_final_message_without_proof))

            # ClientSignature := HMAC(StoredKey, AuthMessage)
            client_signature = hmac.new(stored_key, auth_message,
                                        hashlib.sha256).digest()

            # ClientProof := ClientKey XOR ClientSignature
            client_proof = struct.pack(
                "32B",
                *(l ^ r
                  for l, r in zip(struct.unpack("32B", client_key),
                                  struct.unpack("32B", client_signature))))

            # ServerKey := HMAC(SaltedPassword, "Server Key")
            server_key = hmac.new(salted_password, b"Server Key",
                                  hashlib.sha256).digest()

            # ServerSignature := HMAC(ServerKey, AuthMessage)
            self._server_signature = hmac.new(server_key, auth_message,
                                              hashlib.sha256).digest()

            self._state = 3
            return self._json_encoder.encode({
                "authentication": (
                    client_final_message_without_proof +
                    b",p=" + base64.standard_b64encode(client_proof)
                ).decode("ascii")
            }).encode("utf-8") + \
                b'\0'
        elif self._state == 3:
            json = self._json_decoder.decode(response.decode("utf-8"))
            v = None
            try:
                if json["success"] is False:
                    if 10 <= json["error_code"] <= 20:
                        raise ReqlAuthError(json["error"], self._host,
                                            self._port)
                    else:
                        raise ReqlDriverError(json["error"])

                authentication = dict(
                    x.split(b"=", 1) for x in json["authentication"].encode(
                        "ascii").split(b","))

                v = base64.standard_b64decode(authentication[b"v"])
            except KeyError as key_error:
                raise ReqlDriverError("Missing key: %s" % (key_error, ))

            if not self._compare_digest(v, self._server_signature):
                raise ReqlAuthError("Invalid server signature", self._host,
                                    self._port)

            self._state = 4
            return None
        else:
            raise ReqlDriverError("Unexpected handshake state")