def convert_pseudotype(self, obj): reql_type = obj.get('$reql_type$') if reql_type is not None: if reql_type == 'TIME': time_format = self.reql_format_opts.get('time_format') if time_format is None or time_format == 'native': # Convert to native python datetime object return self.convert_time(obj) elif time_format != 'raw': raise ReqlDriverError("Unknown time_format run option \"%s\"." % time_format) elif reql_type == 'GROUPED_DATA': group_format = self.reql_format_opts.get('group_format') if group_format is None or group_format == 'native': return self.convert_grouped_data(obj) elif group_format != 'raw': raise ReqlDriverError("Unknown group_format run option \"%s\"." % group_format) elif reql_type == 'GEOMETRY': # No special support for this. Just return the raw object return obj elif reql_type == 'BINARY': binary_format = self.reql_format_opts.get('binary_format') if binary_format is None or binary_format == 'native': return self.convert_binary(obj) elif binary_format != 'raw': raise ReqlDriverError("Unknown binary_format run option \"%s\"." % binary_format) else: raise ReqlDriverError("Unknown pseudo-type %s" % reql_type) # If there was no pseudotype, or the relevant format is raw, return # the original object return obj
def connect(self, timeout): factory = DatabaseProtoFactory(timeout, self._handleResponse, self._parent.handshake) # We connect to the server, and send the handshake payload. pConnection = None try: pConnection = yield self._connectTimeout(factory, timeout) except Exception as e: raise ReqlDriverError( 'Could not connect to {p.host}:{p.port}. Error: {exc}'.format( p=self._parent, exc=str(e))) # Now, we need to wait for the handshake. try: yield pConnection.wait_for_handshake except ReqlAuthError as e: raise except ReqlTimeoutError as e: raise ReqlTimeoutError(self._parent.host, self._parent.port) except Exception as e: raise ReqlDriverError( 'Connection interrupted during handshake with {p.host}:{p.port}. Error: {exc}' .format(p=self._parent, exc=str(e))) self._connection = pConnection returnValue(self._parent)
def recvall(self, length): res = b'' if self._read_buffer is None else self._read_buffer while len(res) < length: while True: try: chunk = self._socket.recv(length - len(res)) break except ReqlTimeoutError: raise except IOError as ex: if ex.errno == errno.ECONNRESET: self.close() raise ReqlDriverError("Connection is closed.") elif ex.errno != errno.EINTR: self.close() raise ReqlDriverError( 'Connection interrupted receiving from %s:%s - %s' % (self.host, self.port, str(ex))) except Exception as ex: self.close() raise ReqlDriverError('Error receiving from %s:%s - %s' % (self.host, self.port, str(ex))) if len(chunk) == 0: self.close() raise ReqlDriverError("Connection is closed.") res += chunk return res
def next_message(self, response): if self._state == 0: if response is not None: raise ReqlDriverError("Unexpected response") self._state = 1 return \ struct.pack("<2L", self.VERSION, len(self._auth_key)) + \ self._auth_key + \ struct.pack("<L", self.PROTOCOL) elif self._state == 1: if response is None: raise ReqlDriverError("Expected response") self._state = 2 if response != b"SUCCESS": # This is an error case, we attempt to decode `response` as UTF-8 with # with fallbacks to get something useful message = None try: message = response.decode("utf-8", errors="ignore").strip() except TypeError: try: message = response.decode("utf-8").strip() except UnicodeError: message = repr(response).strip() if message == "ERROR: Incorrect authorization key.": raise ReqlAuthError("Incorrect authentication key.", self._host, self._port) else: raise ReqlDriverError( "Server dropped connection with message: \"%s\"" % (message, )) return None else: raise ReqlDriverError("Unexpected handshake state")
def run(self, c=None, **global_optargs): if c is None: c = Repl.get() if c is None: if Repl.repl_active: raise ReqlDriverError( "RqlQuery.run must be given a connection to run on. A default connection has been set with " "`repl()` on another thread, but not this one." ) else: raise ReqlDriverError("RqlQuery.run must be given a connection to run on.") return c._start(self, **global_optargs)
def close(self, noreply_wait=False, token=None, exception=None): d = defer.succeed(None) self._closing = True error_message = "Connection is closed" if exception is not None: error_message = "Connection is closed (reason: {exc})".format( exc=str(exception)) for cursor in list(self._cursor_cache.values()): cursor._error(error_message) for query, deferred in iter(self._user_queries.values()): if not deferred.called: deferred.errback(fail=ReqlDriverError(error_message)) self._user_queries = {} self._cursor_cache = {} if noreply_wait: noreply = Query(pQuery.NOREPLY_WAIT, token, None, None) d = self.run_query(noreply, False) def closeConnection(res): self._connection.transport.loseConnection() return res return d.addBoth(closeConnection)
def _handleResponse(self, token, data): try: cursor = self._cursor_cache.get(token) if cursor is not None: cursor._extend(data) elif token in self._user_queries: query, deferred = self._user_queries[token] res = Response(token, data, self._parent._get_json_decoder(query)) if res.type == pResponse.SUCCESS_ATOM: deferred.callback(maybe_profile(res.data[0], res)) elif res.type in (pResponse.SUCCESS_SEQUENCE, pResponse.SUCCESS_PARTIAL): cursor = TwistedCursor(self, query, res) deferred.callback(maybe_profile(cursor, res)) elif res.type == pResponse.WAIT_COMPLETE: deferred.callback(None) elif res.type == pResponse.SERVER_INFO: deferred.callback(res.data[0]) else: deferred.errback(res.make_error(query)) del self._user_queries[token] elif not self._closing: raise ReqlDriverError("Unexpected response received.") except Exception as e: if not self._closing: self.close(exception=e)
def close(self, noreply_wait=False, token=None, exception=None): self._closing = True if exception is not None: err_message = "Connection is closed (%s)." % str(exception) else: err_message = "Connection is closed." # Cursors may remove themselves when errored, so copy a list of them for cursor in list(self._cursor_cache.values()): cursor._error(err_message) for query, future in iter(self._user_queries.values()): future.set_exception(ReqlDriverError(err_message)) self._user_queries = {} self._cursor_cache = {} if noreply_wait: noreply = Query(pQuery.NOREPLY_WAIT, token, None, None) yield self.run_query(noreply, False) try: self._stream.close() except iostream.StreamClosedError: pass raise gen.Return(None)
def _wait_to_timeout(wait): if isinstance(wait, bool): return None if wait else 0 elif isinstance(wait, numbers.Real) and wait >= 0: return wait else: raise ReqlDriverError("Invalid wait timeout '%s'" % str(wait))
def close(self, noreply_wait=False, token=None, exception=None): self._closing = True if exception is not None: err_message = "Connection is closed (%s)." % str(exception) else: err_message = "Connection is closed." # Cursors may remove themselves when errored, so copy a list of them for cursor in list(self._cursor_cache.values()): cursor._error(err_message) for query, future in iter(self._user_queries.values()): if not future.done(): future.set_exception(ReqlDriverError(err_message)) self._user_queries = {} self._cursor_cache = {} if noreply_wait: noreply = Query(pQuery.NOREPLY_WAIT, token, None, None) yield from self.run_query(noreply, False) self._streamwriter.close() # We must not wait for the _reader_task if we got an exception, because that # means that we were called from it. Waiting would lead to a deadlock. if self._reader_task and exception is None: yield from self._reader_task return None
def __init__(self, *args, **kwargs): super(Connection, self).__init__(ConnectionInstance, *args, **kwargs) try: self.port = int(self.port) except ValueError: raise ReqlDriverError("Could not convert port %s to an integer." % self.port)
def dataReceived(self, data): try: if self._open: self._handlers[self.state](data) except Exception as e: self.transport.loseConnection() raise ReqlDriverError('Driver failed to handle received data.' 'Error: {exc}. Dropping the connection.'.format(exc=str(e)))
def convert_grouped_data(self, obj): if 'data' not in obj: raise ReqlDriverError( ('pseudo-type GROUPED_DATA object' + ' %s does not have the expected field "data".') % json.dumps(obj)) return dict([(recursively_make_hashable(k), v) for k, v in obj['data']])
def sendall(self, data): offset = 0 while offset < len(data): try: offset += self._socket.send(data[offset:]) except IOError as ex: if ex.errno == errno.ECONNRESET: self.close() raise ReqlDriverError("Connection is closed.") elif ex.errno != errno.EINTR: self.close() raise ReqlDriverError( ('Connection interrupted ' + 'sending to %s:%s - %s') % (self.host, self.port, str(ex))) except Exception as ex: self.close() raise ReqlDriverError('Error sending to %s:%s - %s' % (self.host, self.port, str(ex)))
def make_error(self, query): if self.type == pResponse.CLIENT_ERROR: return ReqlDriverError(self.data[0], query.term, self.backtrace) elif self.type == pResponse.COMPILE_ERROR: return ReqlServerCompileError(self.data[0], query.term, self.backtrace) elif self.type == pResponse.RUNTIME_ERROR: return { pErrorType.INTERNAL: ReqlInternalError, pErrorType.RESOURCE_LIMIT: ReqlResourceLimitError, pErrorType.QUERY_LOGIC: ReqlQueryLogicError, pErrorType.NON_EXISTENCE: ReqlNonExistenceError, pErrorType.OP_FAILED: ReqlOpFailedError, pErrorType.OP_INDETERMINATE: ReqlOpIndeterminateError, pErrorType.USER: ReqlUserError, pErrorType.PERMISSION_ERROR: ReqlPermissionError }.get(self.error_type, ReqlRuntimeError)( self.data[0], query.term, self.backtrace) return ReqlDriverError(("Unknown Response type %d encountered" + " in a response.") % self.type)
def convert_time(self, obj): if 'epoch_time' not in obj: raise ReqlDriverError(('pseudo-type TIME object %s does not ' + 'have expected field "epoch_time".') % json.dumps(obj)) if 'timezone' in obj: return datetime.datetime.fromtimestamp(obj['epoch_time'], RqlTzinfo(obj['timezone'])) else: return datetime.datetime.utcfromtimestamp(obj['epoch_time'])
def __init__(self, conn_type, host, port, db, auth_key, user, password, timeout, ssl, _handshake_version, **kwargs): self.db = db self.host = host try: self.port = int(port) except ValueError: raise ReqlDriverError("Could not convert port %r to an integer." % port) self.connect_timeout = timeout self.ssl = ssl self._conn_type = conn_type self._child_kwargs = kwargs self._instance = None self._next_token = 0 if 'json_encoder' in kwargs: self._json_encoder = kwargs.pop('json_encoder') if 'json_decoder' in kwargs: self._json_decoder = kwargs.pop('json_decoder') if auth_key is None and password is None: auth_key = password = '' elif auth_key is None and password is not None: auth_key = password elif auth_key is not None and password is None: password = auth_key else: # auth_key is not None and password is not None raise ReqlDriverError("`auth_key` and `password` are both set.") if _handshake_version == 4: self.handshake = HandshakeV0_4(self.host, self.port, auth_key) else: self.handshake = HandshakeV1_0( self._json_decoder(), self._json_encoder(), self.host, self.port, user, password)
def recvall(self, length, deadline): res = b'' if self._read_buffer is None else self._read_buffer timeout = None if deadline is None else max(0, deadline - time.time()) self._socket.settimeout(timeout) while len(res) < length: while True: try: chunk = self._socket.recv(length - len(res)) self._socket.settimeout(None) break except socket.timeout: self._read_buffer = res self._socket.settimeout(None) raise ReqlTimeoutError(self.host, self.port) except IOError as ex: if ex.errno == errno.ECONNRESET: self.close() raise ReqlDriverError("Connection is closed.") elif ex.errno == errno.EWOULDBLOCK: self.close() # This should only happen with a timeout of 0 raise ReqlTimeoutError(self.host, self.port) elif ex.errno != errno.EINTR: raise ReqlDriverError(('Connection interrupted ' + 'receiving from %s:%s - %s') % (self.host, self.port, str(ex))) except Exception as ex: self.close() raise ReqlDriverError('Error receiving from %s:%s - %s' % (self.host, self.port, str(ex))) if len(chunk) == 0: self.close() raise ReqlDriverError("Connection is closed.") res += chunk return res
def _read_response(self, query, deadline=None): token = query.token # We may get an async continue result, in which case we save # it and read the next response while True: try: # The first 8 bytes give the corresponding query token # of this response. The next 4 bytes give the # expected length of this response. if self._header_in_progress is None: self._header_in_progress \ = self._socket.recvall(12, deadline) (res_token, res_len,) \ = struct.unpack("<qL", self._header_in_progress) res_buf = self._socket.recvall(res_len, deadline) self._header_in_progress = None except KeyboardInterrupt as ex: # Cancel outstanding queries by dropping this connection, # then create a new connection for the user's convenience. self._parent.reconnect(noreply_wait=False) raise ex res = None cursor = self._cursor_cache.get(res_token) if cursor is not None: # Construct response cursor._extend(res_buf) if res_token == token: return res elif res_token == token: return Response( res_token, res_buf, self._parent._get_json_decoder(query)) elif not self._closing: # This response is corrupted or not intended for us self.close() raise ReqlDriverError("Unexpected response received.")
def _reader(self): try: while True: buf = yield self._stream.read_bytes(12) ( token, length, ) = struct.unpack("<qL", buf) buf = yield self._stream.read_bytes(length) cursor = self._cursor_cache.get(token) if cursor is not None: cursor._extend(buf) elif token in self._user_queries: # Do not pop the query from the dict until later, so # we don't lose track of it in case of an exception query, future = self._user_queries[token] res = Response(token, buf, self._parent._get_json_decoder(query)) if res.type == pResponse.SUCCESS_ATOM: future.set_result(maybe_profile(res.data[0], res)) elif res.type in (pResponse.SUCCESS_SEQUENCE, pResponse.SUCCESS_PARTIAL): cursor = TornadoCursor(self, query, res) future.set_result(maybe_profile(cursor, res)) elif res.type == pResponse.WAIT_COMPLETE: future.set_result(None) elif res.type == pResponse.SERVER_INFO: future.set_result(res.data[0]) else: future.set_exception(res.make_error(query)) del self._user_queries[token] elif not self._closing: raise ReqlDriverError("Unexpected response received.") except Exception as ex: if not self._closing: yield self.close(exception=ex)
def check_open(self): if self._instance is None or not self._instance.is_open(): raise ReqlDriverError('Connection is closed.')
def connect(self, timeout): deadline = None if timeout is None else self._io_loop.time() + timeout try: if len(self._parent.ssl) > 0: ssl_options = {} if self._parent.ssl["ca_certs"]: ssl_options['ca_certs'] = self._parent.ssl["ca_certs"] ssl_options['cert_reqs'] = 2 # ssl.CERT_REQUIRED stream_future = TCPClient().connect(self._parent.host, self._parent.port, ssl_options=ssl_options) else: stream_future = TCPClient().connect(self._parent.host, self._parent.port) self._stream = yield with_absolute_timeout( deadline, stream_future, io_loop=self._io_loop, quiet_exceptions=(iostream.StreamClosedError)) except Exception as err: raise ReqlDriverError( 'Could not connect to %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) self._stream.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._stream.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) try: self._parent.handshake.reset() response = None while True: request = self._parent.handshake.next_message(response) if request is None: break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately if request is not "": self._stream.write(request) response = yield with_absolute_timeout( deadline, self._stream.read_until(b'\0'), io_loop=self._io_loop, quiet_exceptions=(iostream.StreamClosedError)) response = response[:-1] except ReqlAuthError: try: self._stream.close() except iostream.StreamClosedError: pass raise except ReqlTimeoutError: try: self._stream.close() except iostream.StreamClosedError: pass raise ReqlTimeoutError(self._parent.host, self._parent.port) except Exception as err: try: self._stream.close() except iostream.StreamClosedError: pass raise ReqlDriverError( 'Connection interrupted during handshake with %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) # Start a parallel function to perform reads self._io_loop.add_callback(self._reader) raise gen.Return(self._parent)
def __init__(self, parent): self.host = parent._parent.host self.port = parent._parent.port self._read_buffer = None self._socket = None self.ssl = parent._parent.ssl try: self._socket = socket.create_connection((self.host, self.port)) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if len(self.ssl) > 0: try: if hasattr(ssl, 'SSLContext' ): # Python2.7 and 3.2+, or backports.ssl ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) if hasattr(ssl_context, "options"): ssl_context.options |= getattr( ssl, "OP_NO_SSLv2", 0) ssl_context.options |= getattr( ssl, "OP_NO_SSLv3", 0) self.ssl_context.verify_mode = ssl.CERT_REQUIRED self.ssl_context.check_hostname = True # redundant with match_hostname self.ssl_context.load_verify_locations( self.ssl["ca_certs"]) self._socket = ssl_context.wrap_socket( self._socket, server_hostname=self.host) else: # this does not disable SSLv2 or SSLv3 self._socket = ssl.wrap_socket( self._socket, cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=self.ssl["ca_certs"]) except IOError as exc: self._socket.close() raise ReqlDriverError( "SSL handshake failed (see server log for more information): %s" % str(exc)) try: ssl.match_hostname(self._socket.getpeercert(), hostname=self.host) except ssl.CertificateError: self._socket.close() raise parent._parent.handshake.reset() response = None while True: request = parent._parent.handshake.next_message(response) if request is None: break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately if request is not "": self.sendall(request) # The response from the server is a null-terminated string response = b'' while True: char = self.recvall(1) if char == b'\0': break response += char except (ReqlAuthError, ReqlTimeoutError): self.close() raise except ReqlDriverError as ex: self.close() error = str(ex) \ .replace('receiving from', 'during handshake with') \ .replace('sending to', 'during handshake with') raise ReqlDriverError(error) except Exception as ex: self.close() raise ReqlDriverError("Could not connect to %s:%s. Error: %s" % (self.host, self.port, ex))
def __init__(self, parent, timeout): self.host = parent._parent.host self.port = parent._parent.port self._read_buffer = None self._socket = None self.ssl = parent._parent.ssl deadline = time.time() + timeout try: self._socket = socket.create_connection((self.host, self.port), timeout) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if len(self.ssl) > 0: try: if hasattr(ssl, 'SSLContext'): # Python2.7 and 3.2+, or backports.ssl ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) if hasattr(ssl_context, "options"): ssl_context.options |= getattr(ssl, "OP_NO_SSLv2", 0) ssl_context.options |= getattr(ssl, "OP_NO_SSLv3", 0) ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.check_hostname = True # redundant with match_hostname ssl_context.load_verify_locations(self.ssl["ca_certs"]) self._socket = ssl_context.wrap_socket(self._socket, server_hostname=self.host) else: # this does not disable SSLv2 or SSLv3 self._socket = ssl.wrap_socket( self._socket, cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=self.ssl["ca_certs"]) except IOError as err: self._socket.close() if 'EOF occurred in violation of protocol' in str( err) or 'sslv3 alert handshake failure' in str(err): # probably on an older version of OpenSSL raise ReqlDriverError( "SSL handshake failed, likely because Python is linked against an old version of OpenSSL " "that does not support either TLSv1.2 or any of the allowed ciphers. This can be worked " "around by lowering the security setting on the server with the options " "`--tls-min-protocol TLSv1 --tls-ciphers " "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:AES256-SHA` (see server log for more " "information): %s" % str(err) ) else: raise ReqlDriverError( "SSL handshake failed (see server log for more information): %s" % str(err)) try: match_hostname(self._socket.getpeercert(), hostname=self.host) except CertificateError: self._socket.close() raise parent._parent.handshake.reset() response = None while True: request = parent._parent.handshake.next_message(response) if request is None: break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately if request is not "": self.sendall(request) # The response from the server is a null-terminated string response = b'' while True: char = self.recvall(1, deadline) if char == b'\0': break response += char except (ReqlAuthError, ReqlTimeoutError): self.close() raise except ReqlDriverError as ex: self.close() error = str(ex)\ .replace('receiving from', 'during handshake with')\ .replace('sending to', 'during handshake with') raise ReqlDriverError(error) except socket.timeout as ex: self.close() raise ReqlTimeoutError(self.host, self.port) except Exception as ex: self.close() raise ReqlDriverError("Could not connect to %s:%s. Error: %s" % (self.host, self.port, str(ex)))
def convert_binary(obj): if 'data' not in obj: raise ReqlDriverError(('pseudo-type BINARY object %s does not have ' + 'the expected field "data".') % json.dumps(obj)) return RqlBinary(base64.b64decode(obj['data'].encode('utf-8')))
def connect(self, timeout): try: ssl_context = None if len(self._parent.ssl) > 0: ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) if hasattr(ssl_context, "options"): ssl_context.options |= getattr(ssl, "OP_NO_SSLv2", 0) ssl_context.options |= getattr(ssl, "OP_NO_SSLv3", 0) ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.check_hostname = True # redundant with match_hostname ssl_context.load_verify_locations(self._parent.ssl["ca_certs"]) self._streamreader, self._streamwriter = yield from asyncio.open_connection( self._parent.host, self._parent.port, loop=self._io_loop, ssl=ssl_context) self._streamwriter.get_extra_info('socket').setsockopt( socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._streamwriter.get_extra_info('socket').setsockopt( socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) except Exception as err: raise ReqlDriverError( 'Could not connect to %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) try: self._parent.handshake.reset() response = None with translate_timeout_errors(): while True: request = self._parent.handshake.next_message(response) if request is None: break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately if request is not "": self._streamwriter.write(request) response = yield from asyncio.wait_for( _read_until(self._streamreader, b'\0'), timeout, loop=self._io_loop, ) response = response[:-1] except ReqlAuthError: yield from self.close() raise except ReqlTimeoutError as err: yield from self.close() raise ReqlDriverError( 'Connection interrupted during handshake with %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) except Exception as err: yield from self.close() raise ReqlDriverError( 'Could not connect to %s:%s. Error: %s' % (self._parent.host, self._parent.port, str(err))) # Start a parallel function to perform reads # store a reference to it so it doesn't get destroyed self._reader_task = asyncio.ensure_future(self._reader(), loop=self._io_loop) return self._parent
def __iter__(*args, **kwargs): raise ReqlDriverError( "__iter__ called on an RqlQuery object.\n" "To iterate over the results of a query, call run first.\n" "To iterate inside a query, use map or for_each.")
def next_message(self, response): if self._state == 0: if response is not None: raise ReqlDriverError("Unexpected response") # Using base64 encoding for printable characters self._r = base64.standard_b64encode( bytes(bytearray( self._random.getrandbits(8) for i in range(18)))) self._client_first_message_bare = b"n=" + self._username + b",r=" + self._r # Here we send the version as well as the initial JSON as an optimization self._state = 1 return struct.pack("<L", self.VERSION) + \ self._json_encoder.encode({ "protocol_version": self._protocol_version, "authentication_method": "SCRAM-SHA-256", "authentication": (b"n,," + self._client_first_message_bare).decode("ascii") }).encode("utf-8") + \ b'\0' elif self._state == 1: response = response.decode("utf-8") if response.startswith("ERROR"): raise ReqlDriverError( "Received an unexpected reply. You may be attempting to connect to a RethinkDB server that is too " "old for this driver. The minimum supported server version is 2.3.0." ) json = self._json_decoder.decode(response) try: if json["success"] is False: if 10 <= json["error_code"] <= 20: raise ReqlAuthError(json["error"], self._host, self._port) else: raise ReqlDriverError(json["error"]) min = json["min_protocol_version"] max = json["max_protocol_version"] if not min <= self._protocol_version <= max: raise ReqlDriverError( "Unsupported protocol version %d, expected between %d and %d" % (self._protocol_version, min, max)) except KeyError as key_error: raise ReqlDriverError("Missing key: %s" % (key_error, )) # We've already sent the initial JSON above, and only support a single # protocol version at the moment thus we simply read the next response self._state = 2 return "" elif self._state == 2: json = self._json_decoder.decode(response.decode("utf-8")) server_first_message = r = salt = i = None try: if json["success"] is False: if 10 <= json["error_code"] <= 20: raise ReqlAuthError(json["error"], self._host, self._port) else: raise ReqlDriverError(json["error"]) server_first_message = json["authentication"].encode("ascii") authentication = dict( x.split(b"=", 1) for x in server_first_message.split(b",")) r = authentication[b"r"] if not r.startswith(self._r): raise ReqlAuthError("Invalid nonce from server", self._host, self._port) salt = base64.standard_b64decode(authentication[b"s"]) i = int(authentication[b"i"]) except KeyError as key_error: raise ReqlDriverError("Missing key: %s" % (key_error, )) client_final_message_without_proof = b"c=biws,r=" + r # SaltedPassword := Hi(Normalize(password), salt, i) salted_password = self._pbkdf2_hmac("sha256", self._password, salt, i) # ClientKey := HMAC(SaltedPassword, "Client Key") client_key = hmac.new(salted_password, b"Client Key", hashlib.sha256).digest() # StoredKey := H(ClientKey) stored_key = hashlib.sha256(client_key).digest() # AuthMessage := client-first-message-bare + "," + # server-first-message + "," + # client-final-message-without-proof auth_message = b",".join( (self._client_first_message_bare, server_first_message, client_final_message_without_proof)) # ClientSignature := HMAC(StoredKey, AuthMessage) client_signature = hmac.new(stored_key, auth_message, hashlib.sha256).digest() # ClientProof := ClientKey XOR ClientSignature client_proof = struct.pack( "32B", *(l ^ r for l, r in zip(struct.unpack("32B", client_key), struct.unpack("32B", client_signature)))) # ServerKey := HMAC(SaltedPassword, "Server Key") server_key = hmac.new(salted_password, b"Server Key", hashlib.sha256).digest() # ServerSignature := HMAC(ServerKey, AuthMessage) self._server_signature = hmac.new(server_key, auth_message, hashlib.sha256).digest() self._state = 3 return self._json_encoder.encode({ "authentication": ( client_final_message_without_proof + b",p=" + base64.standard_b64encode(client_proof) ).decode("ascii") }).encode("utf-8") + \ b'\0' elif self._state == 3: json = self._json_decoder.decode(response.decode("utf-8")) v = None try: if json["success"] is False: if 10 <= json["error_code"] <= 20: raise ReqlAuthError(json["error"], self._host, self._port) else: raise ReqlDriverError(json["error"]) authentication = dict( x.split(b"=", 1) for x in json["authentication"].encode( "ascii").split(b",")) v = base64.standard_b64decode(authentication[b"v"]) except KeyError as key_error: raise ReqlDriverError("Missing key: %s" % (key_error, )) if not self._compare_digest(v, self._server_signature): raise ReqlAuthError("Invalid server signature", self._host, self._port) self._state = 4 return None else: raise ReqlDriverError("Unexpected handshake state")