예제 #1
0
파일: direct.py 프로젝트: MikeyQiu/graduate
def _connect(resolved_address, **config):
    """

    :param resolved_address:
    :param config:
    :return: socket object
    """
    s = None
    try:
        if len(resolved_address) == 2:
            s = socket(AF_INET)
        elif len(resolved_address) == 4:
            s = socket(AF_INET6)
        else:
            raise ValueError("Unsupported address {!r}".format(resolved_address))
        t = s.gettimeout()
        s.settimeout(config.get("connection_timeout", DEFAULT_CONNECTION_TIMEOUT))
        log_debug("[#0000]  C: <OPEN> %s", resolved_address)
        s.connect(resolved_address)
        s.settimeout(t)
        s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if config.get("keep_alive", DEFAULT_KEEP_ALIVE) else 0)
    except SocketTimeout:
        log_debug("[#0000]  C: <TIMEOUT> %s", resolved_address)
        log_debug("[#0000]  C: <CLOSE> %s", resolved_address)
        s.close()
        raise ServiceUnavailable("Timed out trying to establish connection to {!r}".format(resolved_address))
    except (IOError, OSError) as error:  # TODO 2.0: remove IOError alias
        log_debug("[#0000]  C: <ERROR> %s %s", type(error).__name__, " ".join(map(repr, error.args)))
        log_debug("[#0000]  C: <CLOSE> %s", resolved_address)
        s.close()
        raise ServiceUnavailable("Failed to establish connection to {!r} (reason {})".format(resolved_address, error))
    else:
        return s
예제 #2
0
def connect(address, **config):
    """ Connect and perform a handshake and return a valid Connection object, assuming
    a protocol version can be agreed.
    """
    security_plan = SecurityPlan.build(**config)
    last_error = None
    # Establish a connection to the host and port specified
    # Catches refused connections see:
    # https://docs.python.org/2/library/errno.html
    log_debug("[#0000]  C: <RESOLVE> %s", address)
    resolver = Resolver(custom_resolver=config.get("resolver"))
    resolver.addresses.append(address)
    resolver.custom_resolve()
    resolver.dns_resolve()
    for resolved_address in resolver.addresses:
        try:
            s = _connect(resolved_address, **config)
            s, der_encoded_server_certificate = _secure(
                s, address[0], security_plan.ssl_context, **config)
            connection = _handshake(s, resolved_address,
                                    der_encoded_server_certificate, **config)
        except Exception as error:
            last_error = error
        else:
            return connection
    if last_error is None:
        raise ServiceUnavailable("Failed to resolve addresses for %s" %
                                 address)
    else:
        raise last_error
예제 #3
0
 def on_failure(self, metadata):
     code = metadata.get("code")
     message = metadata.get("message", "Connection initialisation failed")
     if code == "Neo.ClientError.Security.Unauthorized":
         raise AuthError(message)
     else:
         raise ServiceUnavailable(message)
예제 #4
0
    def acquire_direct(self, address):
        """ Acquire a connection to a given address from the pool.
        The address supplied should always be an IP address, not
        a host name.

        This method is thread safe.
        """
        if self.closed():
            raise ServiceUnavailable("Connection pool closed")
        with self.lock:
            try:
                connections = self.connections[address]
            except KeyError:
                connections = self.connections[address] = deque()

            connection_acquisition_start_timestamp = perf_counter()
            while True:
                # try to find a free connection in pool
                for connection in list(connections):
                    if connection.closed() or connection.defunct(
                    ) or connection.timedout():
                        connections.remove(connection)
                        continue
                    if not connection.in_use:
                        connection.in_use = True
                        return connection
                # all connections in pool are in-use
                can_create_new_connection = self._max_connection_pool_size == INFINITE or len(
                    connections) < self._max_connection_pool_size
                if can_create_new_connection:
                    try:
                        connection = self.connector(
                            address,
                            error_handler=self.connection_error_handler)
                    except ServiceUnavailable:
                        self.remove(address)
                        raise
                    else:
                        connection.pool = self
                        connection.in_use = True
                        connections.append(connection)
                        return connection

                # failed to obtain a connection from pool because the pool is full and no free connection in the pool
                span_timeout = self._connection_acquisition_timeout - (
                    perf_counter() - connection_acquisition_start_timestamp)
                if span_timeout > 0:
                    self.cond.wait(span_timeout)
                    # if timed out, then we throw error. This time computation is needed, as with python 2.7, we cannot
                    # tell if the condition is notified or timed out when we come to this line
                    if self._connection_acquisition_timeout <= (
                            perf_counter() -
                            connection_acquisition_start_timestamp):
                        raise ClientError(
                            "Failed to obtain a connection from pool within {!r}s"
                            .format(self._connection_acquisition_timeout))
                else:
                    raise ClientError(
                        "Failed to obtain a connection from pool within {!r}s".
                        format(self._connection_acquisition_timeout))
예제 #5
0
    def fetch_routing_info(self, address):
        """ Fetch raw routing info from a given router address.

        :param address: router address
        :return: list of routing records or
                 None if no connection could be established
        :raise ServiceUnavailable: if the server does not support routing or
                                   if routing support is broken
        """
        metadata = {}
        records = []

        def fail(md):
            if md.get("code") == "Neo.ClientError.Procedure.ProcedureNotFound":
                raise RoutingProtocolError("Server {!r} does not support routing".format(address))
            else:
                raise RoutingProtocolError("Routing support broken on server {!r}".format(address))

        try:
            with self.acquire_direct(address) as cx:
                _, _, server_version = (cx.server.agent or "").partition("/")
                if server_version and Version.parse(server_version) >= Version((3, 2)):
                    cx.run("CALL dbms.cluster.routing.getRoutingTable({context})",
                           {"context": self.routing_context}, on_success=metadata.update, on_failure=fail)
                else:
                    cx.run("CALL dbms.cluster.routing.getServers", {}, on_success=metadata.update, on_failure=fail)
                cx.pull_all(on_success=metadata.update, on_records=records.extend)
                cx.sync()
        except RoutingProtocolError as error:
            raise ServiceUnavailable(*error.args)
        except ServiceUnavailable:
            self.deactivate(address)
            return None
        else:
            return [dict(zip(metadata.get("fields", ()), values)) for values in records]
예제 #6
0
    def _run_transaction(self, access_mode, unit_of_work, *args, **kwargs):
        from neobolt.exceptions import ConnectionExpired, TransientError, ServiceUnavailable

        if not callable(unit_of_work):
            raise TypeError("Unit of work is not callable")

        metadata = getattr(unit_of_work, "metadata", None)
        timeout = getattr(unit_of_work, "timeout", None)

        retry_delay = retry_delay_generator(INITIAL_RETRY_DELAY,
                                            RETRY_DELAY_MULTIPLIER,
                                            RETRY_DELAY_JITTER_FACTOR)
        errors = []
        t0 = perf_counter()
        while True:
            try:
                self._open_transaction(access_mode, metadata, timeout)
                tx = self._transaction
                try:
                    result = unit_of_work(tx, *args, **kwargs)
                except Exception:
                    tx.success = False
                    raise
                else:
                    if tx.success is None:
                        tx.success = True
                finally:
                    tx.close()
            except (ServiceUnavailable, SessionExpired,
                    ConnectionExpired) as error:
                errors.append(error)
            except TransientError as error:
                if is_retriable_transient_error(error):
                    errors.append(error)
                else:
                    raise
            else:
                return result
            t1 = perf_counter()
            if t1 - t0 > self._max_retry_time:
                break
            delay = next(retry_delay)
            log.warning("Transaction failed and will be retried in {}s "
                        "({})".format(delay, "; ".join(errors[-1].args)))
            sleep(delay)
        if errors:
            raise errors[-1]
        else:
            raise ServiceUnavailable("Transaction failed")
예제 #7
0
    def _run_transaction(self, access_mode, unit_of_work, *args, **kwargs):
        from neobolt.exceptions import ConnectionExpired, TransientError, ServiceUnavailable

        if not callable(unit_of_work):
            raise TypeError("Unit of work is not callable")
        retry_delay = retry_delay_generator(INITIAL_RETRY_DELAY,
                                            RETRY_DELAY_MULTIPLIER,
                                            RETRY_DELAY_JITTER_FACTOR)
        errors = []
        t0 = perf_counter()
        while True:
            try:
                self._open_transaction(access_mode)
                tx = self._transaction
                try:
                    result = unit_of_work(tx, *args, **kwargs)
                except:
                    if tx.success is None:
                        tx.success = False
                    raise
                else:
                    if tx.success is None:
                        tx.success = True
                finally:
                    tx.close()
            except (ServiceUnavailable, SessionExpired, ConnectionExpired) as error:
                errors.append(error)
            except TransientError as error:
                if is_retriable_transient_error(error):
                    errors.append(error)
                else:
                    raise
            else:
                return result
            t1 = perf_counter()
            if t1 - t0 > self._max_retry_time:
                break
            sleep(next(retry_delay))
        if errors:
            raise errors[-1]
        else:
            raise ServiceUnavailable("Transaction failed")
예제 #8
0
    def update_routing_table(self):
        """ Update the routing table from the first router able to provide
        valid routing information.
        """
        # copied because it can be modified
        existing_routers = list(self.routing_table.routers)

        has_tried_initial_routers = False
        if self.missing_writer:
            has_tried_initial_routers = True
            if self.update_routing_table_from(self.initial_address):
                return

        if self.update_routing_table_from(*existing_routers):
            return

        if not has_tried_initial_routers and self.initial_address not in existing_routers:
            if self.update_routing_table_from(self.initial_address):
                return

        # None of the routers have been successful, so just fail
        raise ServiceUnavailable("Unable to retrieve routing information")
예제 #9
0
    def commit_transaction(self):
        """ Commit the current transaction.

        :returns: the bookmark returned from the server, if any
        :raise: :class:`.TransactionError` if no transaction is currently open
        """
        self._assert_open()
        if not self._transaction:
            raise TransactionError("No transaction to commit")
        metadata = {}
        try:
            self._connection.commit(on_success=metadata.update)
            self._connection.send_all()
            self._connection.fetch_all()
        except IncompleteCommitError:
            raise ServiceUnavailable("Connection closed during commit")
        finally:
            self._disconnect()
            self._transaction = None
        bookmark = metadata.get("bookmark")
        self._bookmarks_in = tuple([bookmark])
        self._bookmark_out = bookmark
        return bookmark
예제 #10
0
def _handshake(s, resolved_address, der_encoded_server_certificate, **config):
    """

    :param s:
    :return:
    """
    local_port = s.getsockname()[1]

    # Send details of the protocol versions supported
    supported_versions = [3, 2, 1, 0]
    handshake = [MAGIC_PREAMBLE] + supported_versions
    log_debug("[#%04X]  C: <MAGIC> 0x%08X", local_port, MAGIC_PREAMBLE)
    log_debug("[#%04X]  C: <HANDSHAKE> 0x%08X 0x%08X 0x%08X 0x%08X",
              local_port, *supported_versions)
    data = b"".join(struct_pack(">I", num) for num in handshake)
    s.sendall(data)

    # Handle the handshake response
    ready_to_read = False
    while not ready_to_read:
        ready_to_read, _, _ = select((s, ), (), (), 1)
    try:
        data = s.recv(4)
    except (IOError, OSError):  # TODO 2.0: remove IOError alias
        raise ServiceUnavailable(
            "Failed to read any data from server {!r} after connected".format(
                resolved_address))
    data_size = len(data)
    if data_size == 0:
        # If no data is returned after a successful select
        # response, the server has closed the connection
        log_debug("[#%04X]  S: <CLOSE>", local_port)
        s.close()
        raise ServiceUnavailable(
            "Connection to %r closed without handshake response" %
            (resolved_address, ))
    if data_size != 4:
        # Some garbled data has been received
        log_debug("[#%04X]  S: @*#!", local_port)
        s.close()
        raise ProtocolError(
            "Expected four byte Bolt handshake response from %r, received %r instead; "
            "check for incorrect port number" % (resolved_address, data))
    agreed_version, = struct_unpack(">I", data)
    log_debug("[#%04X]  S: <HANDSHAKE> 0x%08X", local_port, agreed_version)
    if agreed_version == 0:
        log_debug("[#%04X]  C: <CLOSE>", local_port)
        s.shutdown(SHUT_RDWR)
        s.close()
    elif agreed_version in (1, 2):
        connection = Connection(
            agreed_version,
            resolved_address,
            s,
            der_encoded_server_certificate=der_encoded_server_certificate,
            **config)
        connection.init()
        return connection
    elif agreed_version in (3, ):
        connection = Connection(
            agreed_version,
            resolved_address,
            s,
            der_encoded_server_certificate=der_encoded_server_certificate,
            **config)
        connection.hello()
        return connection
    elif agreed_version == 0x48545450:
        log_debug("[#%04X]  S: <CLOSE>", local_port)
        s.close()
        raise ServiceUnavailable("Cannot to connect to Bolt service on {!r} "
                                 "(looks like HTTP)".format(resolved_address))
    else:
        log_debug("[#%04X]  S: <CLOSE>", local_port)
        s.close()
        raise ProtocolError(
            "Unknown Bolt protocol version: {}".format(agreed_version))