def next_message(self, response): if self._state == 0: if response is not None: raise ReqlDriverError("Unexpected response") self._state = 1 return \ struct.pack("<2L", self.VERSION, len(self._auth_key)) + \ self._auth_key + \ struct.pack("<L", self.PROTOCOL) elif self._state == 1: if response is None: raise ReqlDriverError("Expected response") self._state = 2 if response != b"SUCCESS": # This is an error case, we attempt to decode `response` as UTF-8 with # with fallbacks to get something useful message = None try: message = response.decode("utf-8", errors="ignore").strip() except TypeError: try: message = response.decode("utf-8").strip() except UnicodeError: message = repr(response).strip() if message == "ERROR: Incorrect authorization key.": raise ReqlAuthError("Incorrect authentication key.", self._host, self._port) else: raise ReqlDriverError( "Server dropped connection with message: \"%s\"" % (message, )) return None else: raise ReqlDriverError("Unexpected handshake state")
def recvall(self, length, deadline): res = b"" if self._read_buffer is None else self._read_buffer timeout = None if deadline is None else max(0, deadline - time.time()) self._socket.settimeout(timeout) while len(res) < length: while True: try: chunk = self._socket.recv(length - len(res)) self._socket.settimeout(None) break except socket.timeout: self._read_buffer = res self._socket.settimeout(None) raise ReqlTimeoutError(self.host, self.port) except IOError as ex: if ex.errno == errno.ECONNRESET: self.close() raise ReqlDriverError("Connection is closed.") elif ex.errno == errno.EWOULDBLOCK: # This should only happen with a timeout of 0 raise ReqlTimeoutError(self.host, self.port) elif ex.errno != errno.EINTR: raise ReqlDriverError(("Connection interrupted " + "receiving from %s:%s - %s") % (self.host, self.port, str(ex))) except Exception as ex: self.close() raise ReqlDriverError("Error receiving from %s:%s - %s" % (self.host, self.port, str(ex))) if len(chunk) == 0: self.close() raise ReqlDriverError("Connection is closed.") res += chunk return res
def connect(self, timeout): factory = DatabaseProtoFactory(timeout, self._handleResponse, self._parent.handshake) # We connect to the server, and send the handshake payload. pConnection = None try: pConnection = yield self._connectTimeout(factory, timeout) except Exception as e: raise ReqlDriverError( 'Could not connect to {p.host}:{p.port}. Error: {exc}'.format( p=self._parent, exc=str(e))) # Now, we need to wait for the handshake. try: yield pConnection.wait_for_handshake except ReqlAuthError as e: raise except ReqlTimeoutError as e: raise ReqlTimeoutError(self._parent.host, self._parent.port) except Exception as e: raise ReqlDriverError( 'Connection interrupted during handshake with {p.host}:{p.port}. Error: {exc}' .format(p=self._parent, exc=str(e))) self._connection = pConnection returnValue(self._parent)
def __init__(self, host, port, db, auth_key, user, password, timeout, ssl, _handshake_version, **kwargs): self.host = host try: self.port = int(port) except ValueError: raise ReqlDriverError('Could not convert port %r to an integer.' % port) self.db = db if 'json_encoder' in kwargs: self.json_encoder = kwargs.pop('json_encoder') else: self.json_encoder = ReQLEncoder() if 'json_decoder' in kwargs: self.json_decoder = kwargs.pop('json_decoder') else: self.json_decoder = ReQLDecoder() if password is None: password = '' if auth_key is not None: raise ReqlDriverError('`auth_key` is not supported') if _handshake_version != 10: raise ReqlDriverError('only support handshake version 1.0') self._socket = SocketWrapper( host=self.host, port=self.port, ssl=ssl, timeout=timeout, user=user, password=password, json_encoder=self.json_encoder, json_decoder=self.json_decoder, ) self._next_token = 0
def recvall(self, length): res = b"" if self._read_buffer is None else self._read_buffer while len(res) < length: while True: try: chunk = self._socket.recv(length - len(res)) break except ReqlTimeoutError: raise except IOError as ex: if ex.errno == errno.ECONNRESET: self.close() raise ReqlDriverError("Connection is closed.") elif ex.errno != errno.EINTR: self.close() raise ReqlDriverError( "Connection interrupted receiving from %s:%s - %s" % (self.host, self.port, str(ex))) except Exception as ex: self.close() raise ReqlDriverError("Error receiving from %s:%s - %s" % (self.host, self.port, str(ex))) if len(chunk) == 0: self.close() raise ReqlDriverError("Connection is closed.") res += chunk return res
def convert_pseudotype(self, obj): reql_type = obj.get('$reql_type$') if reql_type is not None: if reql_type == 'TIME': time_format = self.reql_format_opts.get('time_format') if time_format is None or time_format == 'native': # Convert to native python datetime object return self.convert_time(obj) elif time_format != 'raw': raise ReqlDriverError( "Unknown time_format run option \"%s\"." % time_format) elif reql_type == 'GROUPED_DATA': group_format = self.reql_format_opts.get('group_format') if group_format is None or group_format == 'native': return self.convert_grouped_data(obj) elif group_format != 'raw': raise ReqlDriverError( "Unknown group_format run option \"%s\"." % group_format) elif reql_type == 'GEOMETRY': # No special support for this. Just return the raw object return obj elif reql_type == 'BINARY': binary_format = self.reql_format_opts.get('binary_format') if binary_format is None or binary_format == 'native': return self.convert_binary(obj) elif binary_format != 'raw': raise ReqlDriverError( "Unknown binary_format run option \"%s\"." % binary_format) else: raise ReqlDriverError("Unknown pseudo-type %s" % reql_type) # If there was no pseudotype, or the relevant format is raw, return # the original object return obj
def convert_pseudotype(self, obj): reql_type = obj.get("$reql_type$") if reql_type is not None: if reql_type == "TIME": time_format = self.reql_format_opts.get("time_format") if time_format is None or time_format == "native": # Convert to native python datetime object return self.convert_time(obj) elif time_format != "raw": raise ReqlDriverError( 'Unknown time_format run option "%s".' % time_format) elif reql_type == "GROUPED_DATA": group_format = self.reql_format_opts.get("group_format") if group_format is None or group_format == "native": return self.convert_grouped_data(obj) elif group_format != "raw": raise ReqlDriverError( 'Unknown group_format run option "%s".' % group_format) elif reql_type == "GEOMETRY": # No special support for this. Just return the raw object return obj elif reql_type == "BINARY": binary_format = self.reql_format_opts.get("binary_format") if binary_format is None or binary_format == "native": return self.convert_binary(obj) elif binary_format != "raw": raise ReqlDriverError( 'Unknown binary_format run option "%s".' % binary_format) else: raise ReqlDriverError("Unknown pseudo-type %s" % reql_type) # If there was no pseudotype, or the relevant format is raw, return # the original object return obj
async def connect(self, timeout): try: ssl_context = None if len(self._parent.ssl) > 0: ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) if hasattr(ssl_context, "options"): ssl_context.options |= getattr(ssl, "OP_NO_SSLv2", 0) ssl_context.options |= getattr(ssl, "OP_NO_SSLv3", 0) ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.check_hostname = True # redundant with match_hostname ssl_context.load_verify_locations(self._parent.ssl["ca_certs"]) if ssl_context: self._stream = await trio.open_ssl_over_tcp_stream( self._parent.host, self._parent.port, ssl_context=ssl_context) socket_ = self._stream.transport_stream.socket else: self._stream = await trio.open_tcp_stream( self._parent.host, self._parent.port) socket_ = self._stream.socket self._sockname = socket_.getsockname() socket_.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) except Exception as err: raise ReqlDriverError( 'Could not connect to %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) try: self._parent.handshake.reset() response = None while True: request = self._parent.handshake.next_message(response) if request is None: break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately if request is not "": await self._send(request) with _reql_timeout(timeout): response = await self._read_until(b'\0') response = response[:-1] except ReqlAuthError: await self.close() raise except ReqlTimeoutError as err: await self.close() raise ReqlDriverError( 'Connection interrupted during handshake with %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) except Exception as err: await self.close() raise ReqlDriverError( 'Could not connect to %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) # Start a parallel function to perform reads self._nursery.start_soon(self._reader_task) return self._parent
def run(self, c=None, **global_optargs): if c is None: c = Repl.get() if c is None: if Repl.repl_active: raise ReqlDriverError( "RqlQuery.run must be given a connection to run on. A default connection has been set with " "`repl()` on another thread, but not this one.") else: raise ReqlDriverError( "RqlQuery.run must be given a connection to run on.") return c._start(self, **global_optargs)
async def close(self, noreply_wait=False, token=None, exception=None): self._closing = True if exception is not None: err_message = "Connection is closed (%s)." % str(exception) else: err_message = "Connection is closed." # Cursors may remove themselves when errored, so copy a list of them for cursor in list(self._cursor_cache.values()): cursor._error(err_message) for _, future in self._user_queries.values(): if not future.done(): future.set_exception(ReqlDriverError(err_message)) self._user_queries = {} self._cursor_cache = {} if noreply_wait: noreply = Query(P_QUERY.NOREPLY_WAIT, token, None, None) await self.run_query(noreply, False) try: await self._stream.aclose() except (trio.ClosedResourceError, trio.BrokenResourceError): pass # We must not wait for the _reader_task if we got an exception, because that # means that we were called from it. Waiting would lead to a deadlock. if self._reader_ended_event: await self._reader_ended_event.wait() return None
def _init_connection(self, response): """ Prepare initial connection message. We send the version as well as the initial JSON as an optimization. :param response: Response from the database :raises: ReqlDriverError :return: Initial message which will be sent to the DB """ if response is not None: raise ReqlDriverError('Unexpected response') self._random_nonce = base64.standard_b64encode( bytes(bytearray(SystemRandom().getrandbits(8) for i in range(18)))) self._first_client_message = chain_to_bytes('n=', self._username, ',r=', self._random_nonce) initial_message = chain_to_bytes( struct.pack('<L', self.VERSION), self._json_encoder.encode({ 'protocol_version': self._protocol_version, 'authentication_method': 'SCRAM-SHA-256', 'authentication': chain_to_bytes('n,,', self._first_client_message).decode('ascii') }).encode('utf-8'), b'\0') self._next_state() return initial_message
def next_message(self, raw_response: Optional[bytes]) -> Optional[bytes]: """ Handle the next message to send or receive. """ response: str = "" message: Optional[bytes] = None if raw_response is not None: response = raw_response.decode("utf-8") if not self.is_valid_state(self.state): raise InvalidHandshakeStateError("Unexpected handshake state") if self.state == HandshakeState.INITIAL_CONNECTION: if raw_response is not None: raise ReqlDriverError("Unexpected response") message = self.__initialize_connection() if self.state == HandshakeState.INITIAL_RESPONSE: self.__read_response(response) if self.state == HandshakeState.AUTH_REQUEST: message = self.__prepare_auth_request(response) if self.state == HandshakeState.AUTH_RESPONSE: self.__read_auth_response(response) return message
def _wait_to_timeout(wait): if isinstance(wait, bool): return None if wait else 0 elif isinstance(wait, numbers.Real) and wait >= 0: return wait else: raise ReqlDriverError("Invalid wait timeout '%s'" % str(wait))
def _reader(self): try: while True: buf = yield from self._streamreader.readexactly(12) (token, length,) = struct.unpack("<qL", buf) buf = yield from self._streamreader.readexactly(length) cursor = self._cursor_cache.get(token) if cursor is not None: cursor._extend(buf) elif token in self._user_queries: # Do not pop the query from the dict until later, so # we don't lose track of it in case of an exception query, future = self._user_queries[token] res = Response(token, buf, self._parent._get_json_decoder(query)) if res.type == pResponse.SUCCESS_ATOM: future.set_result(maybe_profile(res.data[0], res)) elif res.type in ( pResponse.SUCCESS_SEQUENCE, pResponse.SUCCESS_PARTIAL, ): cursor = AsyncioCursor(self, query, res) future.set_result(maybe_profile(cursor, res)) elif res.type == pResponse.WAIT_COMPLETE: future.set_result(None) elif res.type == pResponse.SERVER_INFO: future.set_result(res.data[0]) else: future.set_exception(res.make_error(query)) del self._user_queries[token] elif not self._closing: raise ReqlDriverError("Unexpected response received.") except Exception as ex: if not self._closing: yield from self.close(exception=ex)
def _read_response(self, response): """ Read response of the server. Due to we've already sent the initial JSON, and only support a single protocol version at the moment thus we simply read the next response and return an empty string as a message. :param response: Response from the database :raises: ReqlDriverError | ReqlAuthError :return: An empty string """ json_response = self._decode_json_response(response) min_protocol_version = json_response["min_protocol_version"] max_protocol_version = json_response["max_protocol_version"] if not min_protocol_version <= self._protocol_version <= max_protocol_version: raise ReqlDriverError( "Unsupported protocol version {version}, expected between {min} and {max}".format( version=self._protocol_version, min=min_protocol_version, max=max_protocol_version, ) ) self._next_state() return ""
async def connect(self): # ssl_context = ssl.create_default_context() # ssl_context.verify_mode = ssl.CERT_REQUIRED # ssl_context.check_hostname = True # redundant with match_hostname # self._socket = await open_connection(self.host, self.port, ssl=ssl) self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) await self._socket.connect((self.host, self.port)) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) try: self._handshake.reset() response = None while True: request = self._handshake.next_message(response) if request is None: break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately if request: await self.sendall(request) # The response from the server is a null-terminated string response = (await self.read_until(b'\0'))[:-1] except (ReqlAuthError, ReqlTimeoutError): await self.close() raise except ReqlDriverError as ex: await self.close() error = str(ex)\ .replace('receiving from', 'during handshake with')\ .replace('sending to', 'during handshake with') raise ReqlDriverError(error) except socket.timeout as ex: await self.close() raise ReqlTimeoutError(self.host, self.port)
def close(self, noreply_wait=False, token=None, exception=None): self._closing = True if exception is not None: err_message = "Connection is closed (%s)." % str(exception) else: err_message = "Connection is closed." # Cursors may remove themselves when errored, so copy a list of them for cursor in list(self._cursor_cache.values()): cursor._error(err_message) for query, future in iter(self._user_queries.values()): if not future.done(): future.set_exception(ReqlDriverError(err_message)) self._user_queries = {} self._cursor_cache = {} if noreply_wait: noreply = Query(pQuery.NOREPLY_WAIT, token, None, None) yield from self.run_query(noreply, False) self._streamwriter.close() # We must not wait for the _reader_task if we got an exception, because that # means that we were called from it. Waiting would lead to a deadlock. if self._reader_task and exception is None: yield from self._reader_task return None
def close(self, noreply_wait=False, token=None, exception=None): d = defer.succeed(None) self._closing = True error_message = "Connection is closed" if exception is not None: error_message = "Connection is closed (reason: {exc})".format( exc=str(exception)) for cursor in list(self._cursor_cache.values()): cursor._error(error_message) for query, deferred in iter(self._user_queries.values()): if not deferred.called: deferred.errback(fail=ReqlDriverError(error_message)) self._user_queries = {} self._cursor_cache = {} if noreply_wait: noreply = Query(pQuery.NOREPLY_WAIT, token, None, None) d = self.run_query(noreply, False) def closeConnection(res): self._connection.transport.loseConnection() return res return d.addBoth(closeConnection)
def close(self, noreply_wait=False, token=None, exception=None): self._closing = True if exception is not None: err_message = "Connection is closed (%s)." % str(exception) else: err_message = "Connection is closed." # Cursors may remove themselves when errored, so copy a list of them for cursor in list(self._cursor_cache.values()): cursor._error(err_message) for query, future in iter(self._user_queries.values()): future.set_exception(ReqlDriverError(err_message)) self._user_queries = {} self._cursor_cache = {} if noreply_wait: noreply = Query(pQuery.NOREPLY_WAIT, token, None, None) yield self.run_query(noreply, False) try: self._stream.close() except iostream.StreamClosedError: pass raise gen.Return(None)
def _handleResponse(self, token, data): try: cursor = self._cursor_cache.get(token) if cursor is not None: cursor._extend(data) elif token in self._user_queries: query, deferred = self._user_queries[token] res = Response(token, data, self._parent._get_json_decoder(query)) if res.type == pResponse.SUCCESS_ATOM: deferred.callback(maybe_profile(res.data[0], res)) elif res.type in (pResponse.SUCCESS_SEQUENCE, pResponse.SUCCESS_PARTIAL): cursor = TwistedCursor(self, query, res) deferred.callback(maybe_profile(cursor, res)) elif res.type == pResponse.WAIT_COMPLETE: deferred.callback(None) elif res.type == pResponse.SERVER_INFO: deferred.callback(res.data[0]) else: deferred.errback(res.make_error(query)) del self._user_queries[token] elif not self._closing: raise ReqlDriverError("Unexpected response received.") except Exception as e: if not self._closing: self.close(exception=e)
def convert_pseudo_type(self, obj: Dict[str, Any]) -> Any: """ Convert pseudo-type objects using the given converter. :raises: ReqlDriverError """ reql_type = obj.get("$reql_type$") # If there was no pseudo_type, or the relevant format is raw, return # the original object if reql_type is None: return obj if reql_type == "TIME": self.__convert_pseudo_type(obj, "time_format", self.convert_time) elif reql_type == "GROUPED_DATA": self.__convert_pseudo_type(obj, "group_format", self.convert_grouped_data) elif reql_type == "BINARY": self.__convert_pseudo_type(obj, "binary_format", self.convert_binary) elif reql_type == "GEOMETRY": # No special support for this, just return the raw object return obj raise ReqlDriverError(f'Unknown pseudo-type "{reql_type}"')
def __init__(self, *args, **kwargs): super(Connection, self).__init__(ConnectionInstance, *args, **kwargs) try: self.port = int(self.port) except ValueError: raise ReqlDriverError("Could not convert port %s to an integer." % self.port)
def sendall(self, data): offset = 0 while offset < len(data): try: offset += self._socket.send(data[offset:]) except IOError as ex: if ex.errno == errno.ECONNRESET: self.close() raise ReqlDriverError("Connection is closed.") elif ex.errno != errno.EINTR: self.close() raise ReqlDriverError(('Connection interrupted ' + 'sending to %s:%s - %s') % (self.host, self.port, str(ex))) except Exception as ex: self.close() raise ReqlDriverError('Error sending to %s:%s - %s' % (self.host, self.port, str(ex)))
async def read_response(self, query): headers = await self.read_bytes(12) res_token, res_len = struct.unpack('<qL', headers) if res_token != query.token: self.close() raise ReqlDriverError('Unexpected response received.') res_buf = await self.read_bytes(res_len) return Response(res_token, res_buf, json_decoder=self.json_decoder)
def __init__(self, conn_type, host, port, db, auth_key, user, password, timeout, ssl, _handshake_version, **kwargs): self.db = db self.host = host try: self.port = int(port) except ValueError: raise ReqlDriverError("Could not convert port %r to an integer." % port) self.connect_timeout = timeout self.ssl = ssl self._conn_type = conn_type self._child_kwargs = kwargs self._instance = None self._next_token = 0 if "json_encoder" in kwargs: self._json_encoder = kwargs.pop("json_encoder") if "json_decoder" in kwargs: self._json_decoder = kwargs.pop("json_decoder") if auth_key is None and password is None: auth_key = password = "" elif auth_key is None and password is not None: auth_key = password elif auth_key is not None and password is None: password = auth_key else: # auth_key is not None and password is not None raise ReqlDriverError("`auth_key` and `password` are both set.") if _handshake_version == 4: raise NotImplementedError("The v0.4 handshake was removed.") self.handshake = HandshakeV1_0( self._json_decoder(), self._json_encoder(), self.host, self.port, user, password, )
def convert_grouped_data(obj): if "data" not in obj: raise ReqlDriverError( ("pseudo-type GROUPED_DATA object" + ' %s does not have the expected field "data".') % json.dumps(obj)) return dict([(recursively_make_hashable(k), v) for k, v in obj["data"]])
def dataReceived(self, data): try: if self._open: self._handlers[self.state](data) except Exception as e: self.transport.loseConnection() raise ReqlDriverError( 'Driver failed to handle received data.' 'Error: {exc}. Dropping the connection.'.format(exc=str(e)))
def make_error(self, query): if self.type == pResponse.CLIENT_ERROR: return ReqlDriverError(self.data[0], query.term, self.backtrace) elif self.type == pResponse.COMPILE_ERROR: return ReqlServerCompileError(self.data[0], query.term, self.backtrace) elif self.type == pResponse.RUNTIME_ERROR: return { pErrorType.INTERNAL: ReqlInternalError, pErrorType.RESOURCE_LIMIT: ReqlResourceLimitError, pErrorType.QUERY_LOGIC: ReqlQueryLogicError, pErrorType.NON_EXISTENCE: ReqlNonExistenceError, pErrorType.OP_FAILED: ReqlOpFailedError, pErrorType.OP_INDETERMINATE: ReqlOpIndeterminateError, pErrorType.USER: ReqlUserError, pErrorType.PERMISSION_ERROR: ReqlPermissionError }.get(self.error_type, ReqlRuntimeError)( self.data[0], query.term, self.backtrace) return ReqlDriverError(("Unknown Response type %d encountered" + " in a response.") % self.type)
def convert_time(self, obj): if "epoch_time" not in obj: raise ReqlDriverError( ("pseudo-type TIME object %s does not " + 'have expected field "epoch_time".') % json.dumps(obj)) if "timezone" in obj: return datetime.datetime.fromtimestamp(obj["epoch_time"], RqlTzinfo(obj["timezone"])) else: return datetime.datetime.utcfromtimestamp(obj["epoch_time"])
def convert_binary(obj: Dict[str, Any]) -> bytes: """ Convert pseudo-type BINARY object to Python bytes object. :raises: ReqlDriverError """ if "data" not in obj: raise ReqlDriverError( f"pseudo-type BINARY object {json.dumps(obj)} does not have " 'the expected field "data".' ) return RqlBinary(base64.b64decode(obj["data"].encode("utf-8")))