class ThreadedUrllib2TestMixin(object):

    def setUp(self):
        self._semaphore = DeferredSemaphore(2)

    def tearDown(self):
        pass


    def getPages(self, count, url):
        return gatherResults([self.getPage(url) for i in xrange(0, count)])

    @inlineCallbacks
    def getPage(self, url):
        yield self._semaphore.acquire()
        page = yield deferToThread(self._openPage, url)
        self._semaphore.release()
        returnValue(page)

    def _openPage(self, url):
        log.msg("Opening url: %r" % url)
        return urlopen(url).read()

    @inlineCallbacks
    def getPageLength(self, url):
        response = yield self.getPage(url)
        returnValue(len(response))
Beispiel #2
0
class PlotlyStreamProducer(object):
    """Implements a producer that copies from a buffer to a plot.ly
    connection.
    """
    implements(IBodyProducer)
    length = UNKNOWN_LENGTH

    def __init__(self, buffer, start_callback=None):
        self.buffer = buffer
        self._done = False
        self._flush = DeferredSemaphore(1)
        self._waiter = DeferredSemaphore(1)
        self._flush.acquire()
        self._started = start_callback
        self._keepalive = LoopingCall(self._send_keepalive)

    @inlineCallbacks
    def startProducing(self, consumer):
        self._keepalive.start(60)
        self._started.callback(None)
        while True:
            # if paused, this will block
            yield self._waiter.acquire()
            while len(self.buffer):
                v = self.buffer.pop(0)
                if v is not None:
                    consumer.write(json.dumps(v))
                consumer.write("\n")
            yield self._waiter.release()

            if self._done: 
                return
            yield self._flush.acquire()

    def pauseProducing(self):
        return self._waiter.acquire()

    def resumeProducing(self):
        return self._waiter.release()

    def stopProducing(self):
        self._done = True
        if self._keepalive.running:
            self._keepalive.stop()

    def _send_keepalive(self):
        self.buffer.append(None)
        self.flush()

    def flush(self):
        if self._flush.tokens == 0:
            self._flush.release()
class TwistedWebTestMixin(object):

    def setUp(self):
        self._semaphore = DeferredSemaphore(2)

    def tearDown(self):
        pass


    @inlineCallbacks
    def getPages(self, count, url):
        return gatherResults([self.getPage(url) for i in xrange(0, count)])

    @inlineCallbacks
    def getPage(self, url):
        yield self._semaphore.acquire()
        page = yield tx_getPage(url)
        self._semaphore.release()
        returnValue(page)

    @inlineCallbacks
    def getPageLength(self, url):
        response = yield self.getPage(url)
        returnValue(len(response))
Beispiel #4
0
class GcmPushkin(ConcurrencyLimitedPushkin):
    """
    Pushkin that relays notifications to Google/Firebase Cloud Messaging.
    """

    UNDERSTOOD_CONFIG_FIELDS = {
        "type",
        "api_key",
        "fcm_options",
        "max_connections",
    } | ConcurrencyLimitedPushkin.UNDERSTOOD_CONFIG_FIELDS

    def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str,
                                                                 Any]) -> None:
        super().__init__(name, sygnal, config)

        nonunderstood = set(self.cfg.keys()).difference(
            self.UNDERSTOOD_CONFIG_FIELDS)
        if len(nonunderstood) > 0:
            logger.warning(
                "The following configuration fields are not understood: %s",
                nonunderstood,
            )

        self.http_pool = HTTPConnectionPool(reactor=sygnal.reactor)
        self.max_connections = self.get_config("max_connections", int,
                                               DEFAULT_MAX_CONNECTIONS)

        self.connection_semaphore = DeferredSemaphore(self.max_connections)
        self.http_pool.maxPersistentPerHost = self.max_connections

        tls_client_options_factory = ClientTLSOptionsFactory()

        # use the Sygnal global proxy configuration
        proxy_url = sygnal.config.get("proxy")

        self.http_agent = ProxyAgent(
            reactor=sygnal.reactor,
            pool=self.http_pool,
            contextFactory=tls_client_options_factory,
            proxy_url_str=proxy_url,
        )

        self.api_key = self.get_config("api_key", str)
        if not self.api_key:
            raise PushkinSetupException("No API key set in config")

        # Use the fcm_options config dictionary as a foundation for the body;
        # this lets the Sygnal admin choose custom FCM options
        # (e.g. content_available).
        self.base_request_body = self.get_config("fcm_options", dict, {})
        if not isinstance(self.base_request_body, dict):
            raise PushkinSetupException(
                "Config field fcm_options, if set, must be a dictionary of options"
            )

    @classmethod
    async def create(cls, name: str, sygnal: "Sygnal",
                     config: Dict[str, Any]) -> "GcmPushkin":
        """
        Override this if your pushkin needs to call async code in order to
        be constructed. Otherwise, it defaults to just invoking the Python-standard
        __init__ constructor.

        Returns:
            an instance of this Pushkin
        """
        return cls(name, sygnal, config)

    async def _perform_http_request(
            self, body: Dict,
            headers: Dict[AnyStr, List[AnyStr]]) -> Tuple[IResponse, str]:
        """
        Perform an HTTP request to the FCM server with the body and headers
        specified.
        Args:
            body: Body. Will be JSON-encoded.
            headers: HTTP Headers.

        Returns:

        """
        body_producer = FileBodyProducer(BytesIO(json.dumps(body).encode()))

        # we use the semaphore to actually limit the number of concurrent
        # requests, since the HTTPConnectionPool will actually just lead to more
        # requests being created but not pooled – it does not perform limiting.
        with QUEUE_TIME_HISTOGRAM.time():
            with PENDING_REQUESTS_GAUGE.track_inprogress():
                await self.connection_semaphore.acquire()

        try:
            with SEND_TIME_HISTOGRAM.time():
                with ACTIVE_REQUESTS_GAUGE.track_inprogress():
                    response = await self.http_agent.request(
                        b"POST",
                        GCM_URL,
                        headers=Headers(headers),
                        bodyProducer=body_producer,
                    )
                    response_text = (await readBody(response)).decode()
        except Exception as exception:
            raise TemporaryNotificationDispatchException(
                "GCM request failure") from exception
        finally:
            self.connection_semaphore.release()
        return response, response_text

    async def _request_dispatch(
        self,
        n: Notification,
        log: NotificationLoggerAdapter,
        body: dict,
        headers: Dict[AnyStr, List[AnyStr]],
        pushkeys: List[str],
        span: Span,
    ) -> Tuple[List[str], List[str]]:
        poke_start_time = time.time()

        failed = []

        response, response_text = await self._perform_http_request(
            body, headers)

        RESPONSE_STATUS_CODES_COUNTER.labels(pushkin=self.name,
                                             code=response.code).inc()

        log.debug("GCM request took %f seconds", time.time() - poke_start_time)

        span.set_tag(tags.HTTP_STATUS_CODE, response.code)

        if 500 <= response.code < 600:
            log.debug("%d from server, waiting to try again", response.code)

            retry_after = None

            for header_value in response.headers.getRawHeaders(b"retry-after",
                                                               default=[]):
                retry_after = int(header_value)
                span.log_kv({
                    "event": "gcm_retry_after",
                    "retry_after": retry_after
                })

            raise TemporaryNotificationDispatchException(
                "GCM server error, hopefully temporary.",
                custom_retry_delay=retry_after)
        elif response.code == 400:
            log.error(
                "%d from server, we have sent something invalid! Error: %r",
                response.code,
                response_text,
            )
            # permanent failure: give up
            raise NotificationDispatchException("Invalid request")
        elif response.code == 401:
            log.error("401 from server! Our API key is invalid? Error: %r",
                      response_text)
            # permanent failure: give up
            raise NotificationDispatchException("Not authorised to push")
        elif response.code == 404:
            # assume they're all failed
            log.info("Reg IDs %r get 404 response; assuming unregistered",
                     pushkeys)
            return pushkeys, []
        elif 200 <= response.code < 300:
            try:
                resp_object = json_decoder.decode(response_text)
            except ValueError:
                raise NotificationDispatchException(
                    "Invalid JSON response from GCM.")
            if "results" not in resp_object:
                log.error(
                    "%d from server but response contained no 'results' key: %r",
                    response.code,
                    response_text,
                )
            if len(resp_object["results"]) < len(pushkeys):
                log.error(
                    "Sent %d notifications but only got %d responses!",
                    len(n.devices),
                    len(resp_object["results"]),
                )
                span.log_kv({
                    logs.EVENT: "gcm_response_mismatch",
                    "num_devices": len(n.devices),
                    "num_results": len(resp_object["results"]),
                })

            # determine which pushkeys to retry or forget about
            new_pushkeys = []
            for i, result in enumerate(resp_object["results"]):
                if "error" in result:
                    log.warning("Error for pushkey %s: %s", pushkeys[i],
                                result["error"])
                    span.set_tag("gcm_error", result["error"])
                    if result["error"] in BAD_PUSHKEY_FAILURE_CODES:
                        log.info(
                            "Reg ID %r has permanently failed with code %r: "
                            "rejecting upstream",
                            pushkeys[i],
                            result["error"],
                        )
                        failed.append(pushkeys[i])
                    elif result["error"] in BAD_MESSAGE_FAILURE_CODES:
                        log.info(
                            "Message for reg ID %r has permanently failed with code %r",
                            pushkeys[i],
                            result["error"],
                        )
                    else:
                        log.info(
                            "Reg ID %r has temporarily failed with code %r",
                            pushkeys[i],
                            result["error"],
                        )
                        new_pushkeys.append(pushkeys[i])
            return failed, new_pushkeys
        else:
            raise NotificationDispatchException(
                f"Unknown GCM response code {response.code}")

    async def _dispatch_notification_unlimited(
            self, n: Notification, device: Device,
            context: NotificationContext) -> List[str]:
        log = NotificationLoggerAdapter(logger,
                                        {"request_id": context.request_id})

        # `_dispatch_notification_unlimited` gets called once for each device in the
        # `Notification` with a matching app ID. We do something a little dirty and
        # perform all of our dispatches the first time we get called for a
        # `Notification` and do nothing for the rest of the times we get called.
        pushkeys = [
            device.pushkey for device in n.devices
            if self.handles_appid(device.app_id)
        ]
        # `pushkeys` ought to never be empty here. At the very least it should contain
        # `device`'s pushkey.

        if pushkeys[0] != device.pushkey:
            # We've already been asked to dispatch for this `Notification` and have
            # previously sent out the notification to all devices.
            return []

        # The pushkey is kind of secret because you can use it to send push
        # to someone.
        # span_tags = {"pushkeys": pushkeys}
        span_tags = {"gcm_num_devices": len(pushkeys)}

        with self.sygnal.tracer.start_span(
                "gcm_dispatch", tags=span_tags,
                child_of=context.opentracing_span) as span_parent:
            # TODO: Implement collapse_key to queue only one message per room.
            failed: List[str] = []

            data = GcmPushkin._build_data(n, device)

            # Reject pushkey if default_payload is misconfigured
            if data is None:
                failed.append(device.pushkey)

            headers = {
                "User-Agent": ["sygnal"],
                "Content-Type": ["application/json"],
                "Authorization": ["key=%s" % (self.api_key, )],
            }

            body = self.base_request_body.copy()
            body["data"] = data
            body["priority"] = "normal" if n.prio == "low" else "high"

            for retry_number in range(0, MAX_TRIES):
                if len(pushkeys) == 1:
                    body["to"] = pushkeys[0]
                else:
                    body["registration_ids"] = pushkeys

                log.info("Sending (attempt %i) => %r", retry_number, pushkeys)

                try:
                    span_tags = {"retry_num": retry_number}

                    with self.sygnal.tracer.start_span(
                            "gcm_dispatch_try",
                            tags=span_tags,
                            child_of=span_parent) as span:
                        new_failed, new_pushkeys = await self._request_dispatch(
                            n, log, body, headers, pushkeys, span)
                    pushkeys = new_pushkeys
                    failed += new_failed

                    if len(pushkeys) == 0:
                        break
                except TemporaryNotificationDispatchException as exc:
                    retry_delay = RETRY_DELAY_BASE * (2**retry_number)
                    if exc.custom_retry_delay is not None:
                        retry_delay = exc.custom_retry_delay

                    log.warning(
                        "Temporary failure, will retry in %d seconds",
                        retry_delay,
                        exc_info=True,
                    )

                    span_parent.log_kv({
                        "event": "temporary_fail",
                        "retrying_in": retry_delay
                    })

                    await twisted_sleep(retry_delay,
                                        twisted_reactor=self.sygnal.reactor)

            if len(pushkeys) > 0:
                log.info("Gave up retrying reg IDs: %r", pushkeys)
            # Count the number of failed devices.
            span_parent.set_tag("gcm_num_failed", len(failed))
            return failed

    @staticmethod
    def _build_data(n: Notification,
                    device: Device) -> Optional[Dict[str, Any]]:
        """
        Build the payload data to be sent.
        Args:
            n: Notification to build the payload for.
            device: Device information to which the constructed payload
            will be sent.

        Returns:
            JSON-compatible dict or None if the default_payload is misconfigured
        """
        data = {}

        if device.data:
            default_payload = device.data.get("default_payload", {})
            if isinstance(default_payload, dict):
                data.update(default_payload)
            else:
                logger.error(
                    "default_payload was misconfigured, this value must be a dict."
                )
                return None

        for attr in [
                "event_id",
                "type",
                "sender",
                "room_name",
                "room_alias",
                "membership",
                "sender_display_name",
                "content",
                "room_id",
        ]:
            if hasattr(n, attr):
                data[attr] = getattr(n, attr)
                # Truncate fields to a sensible maximum length. If the whole
                # body is too long, GCM will reject it.
                if data[attr] is not None and len(
                        data[attr]) > MAX_BYTES_PER_FIELD:
                    data[attr] = data[attr][0:MAX_BYTES_PER_FIELD]

        data["prio"] = "high"
        if n.prio == "low":
            data["prio"] = "normal"

        if getattr(n, "counts", None):
            data["unread"] = n.counts.unread
            data["missed_calls"] = n.counts.missed_calls

        return data
Beispiel #5
0
class WebpushPushkin(ConcurrencyLimitedPushkin):
    """
    Pushkin that relays notifications to Google/Firebase Cloud Messaging.
    """

    UNDERSTOOD_CONFIG_FIELDS = {
        "type",
        "max_connections",
        "vapid_private_key",
        "vapid_contact_email",
        "allowed_endpoints",
        "ttl",
    } | ConcurrencyLimitedPushkin.UNDERSTOOD_CONFIG_FIELDS

    def __init__(self, name, sygnal, config):
        super(WebpushPushkin, self).__init__(name, sygnal, config)

        nonunderstood = self.cfg.keys() - self.UNDERSTOOD_CONFIG_FIELDS
        if nonunderstood:
            logger.warning(
                "The following configuration fields are not understood: %s",
                nonunderstood,
            )

        self.http_pool = HTTPConnectionPool(reactor=sygnal.reactor)
        self.max_connections = self.get_config("max_connections",
                                               DEFAULT_MAX_CONNECTIONS)
        self.connection_semaphore = DeferredSemaphore(self.max_connections)
        self.http_pool.maxPersistentPerHost = self.max_connections

        tls_client_options_factory = ClientTLSOptionsFactory()

        # use the Sygnal global proxy configuration
        proxy_url = sygnal.config.get("proxy")

        self.http_agent = ProxyAgent(
            reactor=sygnal.reactor,
            pool=self.http_pool,
            contextFactory=tls_client_options_factory,
            proxy_url_str=proxy_url,
        )
        self.http_agent_wrapper = HttpAgentWrapper(self.http_agent)

        self.allowed_endpoints = None  # type: Optional[List[Pattern]]
        allowed_endpoints = self.get_config("allowed_endpoints")
        if allowed_endpoints:
            if not isinstance(allowed_endpoints, list):
                raise PushkinSetupException(
                    "'allowed_endpoints' should be a list or not set")
            self.allowed_endpoints = list(map(glob_to_regex,
                                              allowed_endpoints))
        privkey_filename = self.get_config("vapid_private_key")
        if not privkey_filename:
            raise PushkinSetupException(
                "'vapid_private_key' not set in config")
        if not os.path.exists(privkey_filename):
            raise PushkinSetupException(
                "path in 'vapid_private_key' does not exist")
        try:
            self.vapid_private_key = Vapid.from_file(
                private_key_file=privkey_filename)
        except VapidException as e:
            raise PushkinSetupException(
                "invalid 'vapid_private_key' file") from e
        self.vapid_contact_email = self.get_config("vapid_contact_email")
        if not self.vapid_contact_email:
            raise PushkinSetupException(
                "'vapid_contact_email' not set in config")
        self.ttl = self.get_config("ttl", DEFAULT_TTL)
        if not isinstance(self.ttl, int):
            raise PushkinSetupException("'ttl' must be an int if set")

    async def _dispatch_notification_unlimited(self, n, device, context):
        p256dh = device.pushkey
        if not isinstance(device.data, dict):
            logger.warn("Rejecting pushkey %s; device.data is not a dict",
                        device.pushkey)
            return [device.pushkey]

        endpoint = device.data.get("endpoint")
        auth = device.data.get("auth")
        endpoint_domain = urlparse(endpoint).netloc
        if self.allowed_endpoints:
            allowed = any(
                regex.fullmatch(endpoint_domain)
                for regex in self.allowed_endpoints)
            if not allowed:
                logger.error(
                    "push gateway %s is not in allowed_endpoints, blocking request",
                    endpoint_domain,
                )
                # abort, but don't reject push key
                return []

        if not p256dh or not endpoint or not auth:
            logger.warn(
                "Rejecting pushkey; subscription info incomplete " +
                "(p256dh: %s, endpoint: %s, auth: %s)",
                p256dh,
                endpoint,
                auth,
            )
            return [device.pushkey]

        subscription_info = {
            "endpoint": endpoint,
            "keys": {
                "p256dh": p256dh,
                "auth": auth
            },
        }
        payload = WebpushPushkin._build_payload(n, device)
        data = json.dumps(payload)

        # note that webpush modifies vapid_claims, so make sure it's only used once
        vapid_claims = {
            "sub": "mailto:{}".format(self.vapid_contact_email),
        }
        # we use the semaphore to actually limit the number of concurrent
        # requests, since the HTTPConnectionPool will actually just lead to more
        # requests being created but not pooled – it does not perform limiting.
        with QUEUE_TIME_HISTOGRAM.time():
            with PENDING_REQUESTS_GAUGE.track_inprogress():
                await self.connection_semaphore.acquire()

        try:
            with SEND_TIME_HISTOGRAM.time():
                with ACTIVE_REQUESTS_GAUGE.track_inprogress():
                    response_wrapper = webpush(
                        subscription_info=subscription_info,
                        data=data,
                        ttl=self.ttl,
                        vapid_private_key=self.vapid_private_key,
                        vapid_claims=vapid_claims,
                        requests_session=self.http_agent_wrapper,
                    )
                    response = await response_wrapper.deferred
                    response_text = (await readBody(response)).decode()
        finally:
            self.connection_semaphore.release()

        # assume 4xx is permanent and 5xx is temporary
        if 400 <= response.code < 500:
            logger.warn(
                "Rejecting pushkey %s; gateway %s failed with %d: %s",
                device.pushkey,
                endpoint_domain,
                response.code,
                response_text,
            )
            return [device.pushkey]
        return []

    @staticmethod
    def _build_payload(n, device):
        """
        Build the payload data to be sent.

        Args:
            n: Notification to build the payload for.
            device (Device): Device information to which the constructed payload
            will be sent.

        Returns:
            JSON-compatible dict
        """
        payload = {}

        default_payload = device.data.get("default_payload")
        if isinstance(default_payload, dict):
            payload.update(default_payload)

        for attr in [
                "room_id",
                "room_name",
                "room_alias",
                "membership",
                "event_id",
                "sender",
                "sender_display_name",
                "user_is_target",
                "type",
                "content",
        ]:
            value = getattr(n, attr, None)
            if value:
                payload[attr] = value

        counts = getattr(n, "counts", None)
        if counts is not None:
            for attr in ["unread", "missed_calls"]:
                count_value = getattr(counts, attr, None)
                if count_value is not None:
                    payload[attr] = count_value

        return payload
Beispiel #6
0
class IndxConnectionPool:
    """ A wrapper for txpostgres connection pools, which auto-reconnects. """
    def __init__(self):
        logging.debug("IndxConnectionPool starting. ")
        self.connections = {}  # by connection string
        self.conn_strs = {}  # by db_name
        self.semaphore = DeferredSemaphore(1)
        self.subscribers = {}  # by db name

    def removeAll(self, db_name):
        """ Remove all connections for a named database - used before deleting that database. """
        logging.debug("IndxConnectionPool removeAll {0}".format(db_name))
        d_list = []
        if db_name in self.conn_strs:
            for conn_str in self.conn_strs[db_name]:
                for conn in self.connections[conn_str].getInuse():
                    d_list.append(conn.close())
                for conn in self.connections[conn_str].getFree():
                    d_list.append(conn.close())

                del self.connections[conn_str]
            del self.conn_strs[db_name]

        dl = DeferredList(d_list)
        return dl

    def connect(self, db_name, db_user, db_pass, db_host, db_port):
        """ Returns an IndxConnection (Actual connection and pool made when query is made). """

        return_d = Deferred()
        log_conn_str = "dbname='{0}' user='******' password='******' host='{3}' port='{4}' application_name='{5}'".format(
            db_name, db_user, "XXXX", db_host, db_port,
            indx_pg2.APPLICATION_NAME)
        conn_str = "dbname='{0}' user='******' password='******' host='{3}' port='{4}' application_name='{5}'".format(
            db_name, db_user, db_pass, db_host, db_port,
            indx_pg2.APPLICATION_NAME)
        logging.debug("IndxConnectionPool connect: {0}".format(log_conn_str))

        if db_name not in self.conn_strs:
            self.conn_strs[db_name] = []
        self.conn_strs[db_name].append(conn_str)

        def free_cb(conn):
            """ Called back when this IndxConnection has finished querying, so
                we put the real connection back into the pool. """
            logging.debug("IndxConnectionPool free_cb, conn: {0}".format(conn))

            self.connections[conn_str].freeConnection(
                conn)  # no dealing with callbacks, just carry on

        def alloc_cb(conn_str):
            # a query was called - allocate a connection now and pass it back
            return self._connect(conn_str)

        indx_connection = IndxConnection(conn_str, alloc_cb, free_cb)
        return_d.callback(indx_connection)
        return return_d

    def _connect(self, conn_str):
        """ Connect and return a free Connection.
            Figures out whether to make new connections, use the pool, or wait in a queue.
        """
        logging.debug("IndxConnectionPool _connect ({0})".format(conn_str))
        return_d = Deferred()

        def err_cb(failure):
            logging.error(
                "IndxConnectionPool _connect err_cb: {0}".format(failure))
            self.semaphore.release()
            return_d.errback(failure)

        def succeed_cb(empty):
            logging.debug("IndxConnectionPool _connect succeed_cb")
            # TODO pass a Connection back

            if len(self.connections[conn_str].getFree()) > 0:
                # free connection, use it straight away
                conn = self.connections[conn_str].getFree().pop()

                self.connections[conn_str].getInuse().append(conn)
                self.semaphore.release()
                return_d.callback(conn)
                return

            if len(self.connections[conn_str].getInuse()) < MAX_CONNS:
                # not at max connections for this conn_str

                # create a new one
                d = self._newConnection(conn_str)

                def connected_cb(indx_conn):
                    logging.debug(
                        "IndxConnectionPool _connect connected_cb ({0})".
                        format(indx_conn))
                    self.connections[conn_str].getFree().remove(indx_conn)
                    self.connections[conn_str].getInuse().append(indx_conn)
                    self.semaphore.release()
                    return_d.callback(indx_conn)
                    return

                d.addCallbacks(connected_cb, err_cb)
                return

            # wait for a connection
            def wait_cb(conn):
                logging.debug(
                    "IndxConnectionPool _connect wait_cb ({0})".format(conn))
                # already put in 'inuse'
                return_d.callback(conn)
                return

            self.semaphore.release()
            self.connections[conn_str].getWaiting().append(wait_cb)
            return

        def locked_cb(empty):
            logging.debug("IndxConnectionPool _connect locked_cb")
            if conn_str not in self.connections:
                self._newConnections(conn_str).addCallbacks(succeed_cb, err_cb)
            else:
                threads.deferToThread(succeed_cb, None)


#                succeed_cb(None)

        self.semaphore.acquire().addCallbacks(locked_cb, err_cb)
        return return_d

    def _closeOldConnection(self):
        """ Close the oldest connection, so we can open a new one up. """
        # is already in a semaphore lock, from _newConnection
        logging.debug("IndxConnectionPool _closeOldConnection")

        ### we could force quite them through postgresql like this - but instead we kill them from inside
        #query = "SELECT * FROM pg_stat_activity WHERE state = 'idle' AND application_name = %s AND query != 'LISTEN wb_new_version' ORDER BY state_change LIMIT 1;"
        #params = [indx_pg2.APPLICATION_NAME]

        return_d = Deferred()

        def err_cb(failure):
            return_d.errback(failure)

        ages = {}
        for conn_str, dbpool in self.connections.items():
            lastused = dbpool.getTime()
            if lastused not in ages:
                ages[lastused] = []
            ages[lastused].append(dbpool)

        times = ages.keys()
        times.sort()

        pool_queue = []
        for timekey in times:
            pools = ages[timekey]
            pool_queue.extend(pools)

        def removed_cb(count):

            if count < REMOVE_AT_ONCE and len(pool_queue) > 0:
                pool = pool_queue.pop(0)
                pool.getFree()
                pool.removeAll(count).addCallbacks(removed_cb, err_cb)
            else:
                return_d.callback(None)

        removed_cb(0)
        return return_d

    def _newConnection(self, conn_str):
        """ Makes a new connection to the DB
            and then puts it in the 'free' pool of this conn_str.
        """
        logging.debug("IndxConnectionPool _newConnection")
        # lock with the semaphore before calling this
        return_d = Deferred()

        def close_old_cb(failure):
            failure.trap(psycopg2.OperationalError, Exception)
            # couldn't connect, so close an old connection first
            logging.error(
                "IndxConnectionPool error close_old_cb: {0} - state of conns is: {1}"
                .format(failure.value, self.connections))

            logging.error("IndxConnectionPool connections: {0}".format(
                "\n".join(
                    map(lambda name: self.connections[name].__str__(),
                        self.connections))))

            def closed_cb(empty):
                # closed, so try connecting again
                self._newConnection(conn_str).addCallbacks(
                    return_d.callback, return_d.errback)

            closed_d = self._closeOldConnection()
            closed_d.addCallbacks(closed_cb, return_d.errback)

        try:
            # try to connect
            def connected_cb(connection):
                logging.debug(
                    "IndxConnectionPool _newConnection connected_cb, connection: {0}"
                    .format(connection))
                self.connections[conn_str].getFree().append(connection)
                return_d.callback(connection)

            conn = txpostgres.Connection()
            connection_d = conn.connect(conn_str)
            connection_d.addCallbacks(connected_cb, close_old_cb)
        except Exception as e:
            # close an old connection first
            logging.debug(
                "IndxConnectionPool Exception, going to call close_old_cb: ({0})"
                .format(e))
            close_old_cb(Failure(e))

        return return_d

    def _newConnections(self, conn_str):
        """ Make a pool of new connections. """
        # lock with the semaphore before calling this
        logging.debug("IndxConnectionPool _newConnections")
        return_d = Deferred()

        self.connections[conn_str] = DBConnectionPool(conn_str)

        try:
            d_list = []
            for i in range(MIN_CONNS):
                connection_d = self._newConnection(conn_str)
                d_list.append(connection_d)

            dl = DeferredList(d_list)
            dl.addCallbacks(return_d.callback, return_d.errback)

        except Exception as e:
            logging.error(
                "IndxConnectionPool error in _newConnections: {0}".format(e))
            return_d.errback(Failure(e))

        return return_d
Beispiel #7
0
class DBConnectionPool():
    """ A pool of DB connections for a specific connection string / DB. """
    def __init__(self, conn_str):
        self.waiting = []
        self.inuse = []
        self.free = []

        self.semaphore = DeferredSemaphore(1)
        self.updateTime()

    def __unicode__(self):
        return self.__str__()

    def __str__(self):
        return "waiting: {0}, inuse: {1}, free: {2}, semaphore: {3}, lastused: {4}".format(
            self.waiting, self.inuse, self.free, self.semaphore, self.lastused)

    def updateTime(self):
        self.lastused = time.mktime(time.gmtime())  # epoch time

    def getTime(self):
        return self.lastused

    def getWaiting(self):
        self.updateTime()
        return self.waiting

    def getInuse(self):
        self.updateTime()
        return self.inuse

    def getFree(self):
        self.updateTime()
        return self.free

    def freeConnection(self, conn):
        """ Free a connection from this DBPool. """
        def locked_cb(empty):
            logging.debug("DBConnectionPool locked_cb")
            self.getInuse().remove(conn)

            if len(self.getWaiting()) > 0:
                callback = self.getWaiting().pop()
                self.getInuse().append(conn)
                self.semaphore.release()
                callback(conn)
            else:
                self.getFree().append(conn)
                self.semaphore.release()

        def err_cb(failure):
            failure.trap(Exception)
            logging.error("DBConnectionPool free, err_cb: {0}".format(
                failure.value))
            self.semaphore.release()

        self.semaphore.acquire().addCallbacks(locked_cb, err_cb)

    def removeAll(self, count):
        """ Remove all free connections (usually because they're old and we're in
            a freeing up period.
        """
        logging.debug(
            "DBConnectionPool removeAll called, count: {0}".format(count))
        return_d = Deferred()
        self.updateTime()

        def err_cb(failure):
            self.semaphore.release()
            return_d.errback(failure)

        def locked_cb(count):
            # immediately close the free connections
            while len(self.free) > 0:
                conn = self.free.pop(0)
                conn.close()
                count += 1

            self.semaphore.release()
            return_d.callback(count)

        self.semaphore.acquire().addCallbacks(lambda s: locked_cb(count),
                                              err_cb)
        return return_d
Beispiel #8
0
class RateLimitedClient(object):
    """A Web client with per-second request limit.
    """

    # Max number of requests per second (can be < 1.0)
    rate_limit = None
    # Grace delay (seconds) when the server throttles us
    grace_delay = 30
    # Max number of parallel requests
    max_concurrency = 5

    def __init__(self, time=None):
        self.sem = DeferredSemaphore(self.max_concurrency)
        self.grace_deferred = None
        self.logger = logging.getLogger("webclient")
        self.time = time or reactor
        self.last_request = 0.0

    def _enable_grace_delay(self, delay):
        if self.grace_deferred:
            # Already enabled by an earlier concurrent request
            return
        self.grace_deferred = Deferred()

        def expire():
            g = self.grace_deferred
            self.grace_deferred = None
            g.callback(None)

        reactor.callLater(self.grace_delay, expire)

    def _delay_if_necessary(self, func, *args, **kwargs):
        d = Deferred()
        d.addCallback(lambda _: func(*args, **kwargs))
        trigger = None
        if self.grace_deferred:
            trigger = self.grace_deferred
        elif self.rate_limit:
            delay = (self.last_request + 1.0 / self.rate_limit) - self.time.seconds()
            if delay > 0:
                self.logger.debug("inserting rate limit delay of %.1f", delay)
                trigger = Deferred()
                self.time.callLater(delay, trigger.callback, None)
        (trigger or maybeDeferred(lambda: None)).chainDeferred(d)
        return d

    def get_page(self, url, *args, **kwargs):
        if isinstance(url, unicode):
            url = url.encode("utf8")

        def schedule_request(_):
            return self._delay_if_necessary(issue_request, None)

        def issue_request(_):
            self.last_request = self.time.seconds()
            self.logger.debug("fetching %r", url)
            return getPage(url, *args, **kwargs)

        def handle_success(value):
            self.sem.release()
            self.logger.debug("got %d bytes for %r", len(value), url)
            return value

        def handle_error(failure):
            self.sem.release()
            failure.trap(HTTPError)
            self.logger.debug("got HTTP error %s", failure.value)
            self.trap_throttling(failure)
            delay = self.grace_delay
            self.logger.warning("we are throttled, delaying by %.1f seconds", delay)
            self._enable_grace_delay(delay)
            # auto-retry
            return do_get_page()

        def do_get_page():
            # We acquire the semaphore *before* seeing if we should delay
            # the request, so that we avoid pounding on the server when
            # the grace period is entered.
            d = self.sem.acquire()
            d.addCallback(schedule_request)
            d.addCallbacks(handle_success, handle_error)
            return d

        return do_get_page()

    def trap_throttling(self, failure):
        """Trap HTTP failures and return if we are
        throttled by the distant site, else re-raise.
        """
        e = failure.value
        if e.status in ("400", "420", "500", "503"):
            return
        failure.raiseException()
Beispiel #9
0
class GcmPushkin(Pushkin):
    """
    Pushkin that relays notifications to Google/Firebase Cloud Messaging.
    """

    UNDERSTOOD_CONFIG_FIELDS = {"type", "api_key"}

    def __init__(self, name, sygnal, config, canonical_reg_id_store):
        super(GcmPushkin, self).__init__(name, sygnal, config)

        nonunderstood = set(self.cfg.keys()).difference(
            self.UNDERSTOOD_CONFIG_FIELDS)
        if len(nonunderstood) > 0:
            logger.warning(
                "The following configuration fields are not understood: %s",
                nonunderstood,
            )

        self.http_pool = HTTPConnectionPool(reactor=sygnal.reactor)
        self.max_connections = self.get_config("max_connections",
                                               DEFAULT_MAX_CONNECTIONS)
        self.connection_semaphore = DeferredSemaphore(self.max_connections)
        self.http_pool.maxPersistentPerHost = self.max_connections

        tls_client_options_factory = ClientTLSOptionsFactory()

        self.http_agent = Agent(
            reactor=sygnal.reactor,
            pool=self.http_pool,
            contextFactory=tls_client_options_factory,
        )

        self.db = sygnal.database
        self.canonical_reg_id_store = canonical_reg_id_store

        self.api_key = self.get_config("api_key")
        if not self.api_key:
            raise PushkinSetupException("No API key set in config")

    @classmethod
    async def create(cls, name, sygnal, config):
        """
        Override this if your pushkin needs to call async code in order to
        be constructed. Otherwise, it defaults to just invoking the Python-standard
        __init__ constructor.

        Returns:
            an instance of this Pushkin
        """
        logger.debug("About to set up CanonicalRegId Store")
        canonical_reg_id_store = CanonicalRegIdStore()
        await canonical_reg_id_store.setup(sygnal.database,
                                           sygnal.database_engine)
        logger.debug("Finished setting up CanonicalRegId Store")

        return cls(name, sygnal, config, canonical_reg_id_store)

    async def _perform_http_request(self, body, headers):
        """
        Perform an HTTP request to the FCM server with the body and headers
        specified.
        Args:
            body (nested dict): Body. Will be JSON-encoded.
            headers (Headers): HTTP Headers.

        Returns:

        """
        body_producer = FileBodyProducer(BytesIO(json.dumps(body).encode()))

        # we use the semaphore to actually limit the number of concurrent
        # requests, since the HTTPConnectionPool will actually just lead to more
        # requests being created but not pooled – it does not perform limiting.
        with QUEUE_TIME_HISTOGRAM.time():
            with PENDING_REQUESTS_GAUGE.track_inprogress():
                await self.connection_semaphore.acquire()

        try:
            with SEND_TIME_HISTOGRAM.time():
                with ACTIVE_REQUESTS_GAUGE.track_inprogress():
                    response = await self.http_agent.request(
                        b"POST",
                        GCM_URL,
                        headers=Headers(headers),
                        bodyProducer=body_producer,
                    )
                    response_text = (await readBody(response)).decode()
        except Exception as exception:
            raise TemporaryNotificationDispatchException(
                "GCM request failure") from exception
        finally:
            self.connection_semaphore.release()
        return response, response_text

    async def _request_dispatch(self, n, log, body, headers, pushkeys, span):
        poke_start_time = time.time()

        failed = []

        response, response_text = await self._perform_http_request(
            body, headers)

        RESPONSE_STATUS_CODES_COUNTER.labels(pushkin=self.name,
                                             code=response.code).inc()

        log.debug("GCM request took %f seconds", time.time() - poke_start_time)

        span.set_tag(tags.HTTP_STATUS_CODE, response.code)

        if 500 <= response.code < 600:
            log.debug("%d from server, waiting to try again", response.code)

            retry_after = None

            for header_value in response.headers.getRawHeaders(b"retry-after",
                                                               default=[]):
                retry_after = int(header_value)
                span.log_kv({
                    "event": "gcm_retry_after",
                    "retry_after": retry_after
                })

            raise TemporaryNotificationDispatchException(
                "GCM server error, hopefully temporary.",
                custom_retry_delay=retry_after)
        elif response.code == 400:
            log.error(
                "%d from server, we have sent something invalid! Error: %r",
                response.code,
                response_text,
            )
            # permanent failure: give up
            raise NotificationDispatchException("Invalid request")
        elif response.code == 401:
            log.error("401 from server! Our API key is invalid? Error: %r",
                      response_text)
            # permanent failure: give up
            raise NotificationDispatchException("Not authorised to push")
        elif response.code == 404:
            # assume they're all failed
            log.info("Reg IDs %r get 404 response; assuming unregistered",
                     pushkeys)
            return pushkeys, []
        elif 200 <= response.code < 300:
            try:
                resp_object = json.loads(response_text)
            except JSONDecodeError:
                raise NotificationDispatchException(
                    "Invalid JSON response from GCM.")
            if "results" not in resp_object:
                log.error(
                    "%d from server but response contained no 'results' key: %r",
                    response.code,
                    response_text,
                )
            if len(resp_object["results"]) < len(pushkeys):
                log.error(
                    "Sent %d notifications but only got %d responses!",
                    len(n.devices),
                    len(resp_object["results"]),
                )
                span.log_kv({
                    logs.EVENT: "gcm_response_mismatch",
                    "num_devices": len(n.devices),
                    "num_results": len(resp_object["results"]),
                })

            # determine which pushkeys to retry or forget about
            new_pushkeys = []
            for i, result in enumerate(resp_object["results"]):
                span.set_tag("gcm_regid_updated", "registration_id" in result)
                if "registration_id" in result:
                    await self.canonical_reg_id_store.set_canonical_id(
                        pushkeys[i], result["registration_id"])
                if "error" in result:
                    log.warning("Error for pushkey %s: %s", pushkeys[i],
                                result["error"])
                    span.set_tag("gcm_error", result["error"])
                    if result["error"] in BAD_PUSHKEY_FAILURE_CODES:
                        log.info(
                            "Reg ID %r has permanently failed with code %r: "
                            "rejecting upstream",
                            pushkeys[i],
                            result["error"],
                        )
                        failed.append(pushkeys[i])
                    elif result["error"] in BAD_MESSAGE_FAILURE_CODES:
                        log.info(
                            "Message for reg ID %r has permanently failed with code %r",
                            pushkeys[i],
                            result["error"],
                        )
                    else:
                        log.info(
                            "Reg ID %r has temporarily failed with code %r",
                            pushkeys[i],
                            result["error"],
                        )
                        new_pushkeys.append(pushkeys[i])
            return failed, new_pushkeys
        else:
            raise NotificationDispatchException(
                f"Unknown GCM response code {response.code}")

    async def dispatch_notification(self, n, device, context):
        log = NotificationLoggerAdapter(logger,
                                        {"request_id": context.request_id})

        pushkeys = [
            device.pushkey for device in n.devices
            if device.app_id == self.name
        ]
        # Resolve canonical IDs for all pushkeys

        if pushkeys[0] != device.pushkey:
            # Only send notifications once, to all devices at once.
            return []

        # The pushkey is kind of secret because you can use it to send push
        # to someone.
        # span_tags = {"pushkeys": pushkeys}
        span_tags = {"gcm_num_devices": len(pushkeys)}

        with self.sygnal.tracer.start_span(
                "gcm_dispatch", tags=span_tags,
                child_of=context.opentracing_span) as span_parent:
            reg_id_mappings = await self.canonical_reg_id_store.get_canonical_ids(
                pushkeys)

            reg_id_mappings = {
                reg_id: canonical_reg_id or reg_id
                for (reg_id, canonical_reg_id) in reg_id_mappings.items()
            }

            inverse_reg_id_mappings = {
                v: k
                for (k, v) in reg_id_mappings.items()
            }

            data = GcmPushkin._build_data(n)
            headers = {
                b"User-Agent": ["sygnal"],
                b"Content-Type": ["application/json"],
                b"Authorization": ["key=%s" % (self.api_key, )],
            }

            # count the number of remapped registration IDs in the request
            span_parent.set_tag(
                "gcm_num_remapped_reg_ids_used",
                [k != v for (k, v) in reg_id_mappings.items()].count(True),
            )

            # TODO: Implement collapse_key to queue only one message per room.
            failed = []

            body = {
                "data": data,
                "priority": "normal" if n.prio == "low" else "high"
            }

            for retry_number in range(0, MAX_TRIES):
                mapped_pushkeys = [reg_id_mappings[pk] for pk in pushkeys]

                if len(pushkeys) == 1:
                    body["to"] = mapped_pushkeys[0]
                else:
                    body["registration_ids"] = mapped_pushkeys

                log.info("Sending (attempt %i) => %r", retry_number,
                         mapped_pushkeys)

                try:
                    span_tags = {"retry_num": retry_number}

                    with self.sygnal.tracer.start_span(
                            "gcm_dispatch_try",
                            tags=span_tags,
                            child_of=span_parent) as span:
                        new_failed, new_pushkeys = await self._request_dispatch(
                            n, log, body, headers, mapped_pushkeys, span)
                    pushkeys = new_pushkeys
                    failed += [
                        inverse_reg_id_mappings[canonical_pk]
                        for canonical_pk in new_failed
                    ]
                    if len(pushkeys) == 0:
                        break
                except TemporaryNotificationDispatchException as exc:
                    retry_delay = RETRY_DELAY_BASE * (2**retry_number)
                    if exc.custom_retry_delay is not None:
                        retry_delay = exc.custom_retry_delay

                    log.warning(
                        "Temporary failure, will retry in %d seconds",
                        retry_delay,
                        exc_info=True,
                    )

                    span_parent.log_kv({
                        "event": "temporary_fail",
                        "retrying_in": retry_delay
                    })

                    await twisted_sleep(retry_delay,
                                        twisted_reactor=self.sygnal.reactor)

            if len(pushkeys) > 0:
                log.info("Gave up retrying reg IDs: %r", pushkeys)
            # Count the number of failed devices.
            span_parent.set_tag("gcm_num_failed", len(failed))
            return failed

    @staticmethod
    def _build_data(n):
        """
        Build the payload data to be sent.
        Args:
            n: Notification to build the payload for.

        Returns:
            JSON-compatible dict
        """
        data = {}
        for attr in [
                "event_id",
                "type",
                "sender",
                "room_name",
                "room_alias",
                "membership",
                "sender_display_name",
                "content",
                "room_id",
        ]:
            if hasattr(n, attr):
                data[attr] = getattr(n, attr)
                # Truncate fields to a sensible maximum length. If the whole
                # body is too long, GCM will reject it.
                if data[attr] is not None and len(
                        data[attr]) > MAX_BYTES_PER_FIELD:
                    data[attr] = data[attr][0:MAX_BYTES_PER_FIELD]

        data["prio"] = "high"
        if n.prio == "low":
            data["prio"] = "normal"

        if getattr(n, "counts", None):
            data["unread"] = n.counts.unread
            data["missed_calls"] = n.counts.missed_calls

        return data
class AggregationResponseCache(object):
    '''
    This holds all the responses being aggregated for a single destination.
    
    One of the main challenges here is to make sure while we're sending the responses,
    we don't get a new response in and not send it.
    '''


    def __init__(self, numSecondsToWait, numMessagesToWaitFor, chordNode):
        '''
        Constructor
        '''
        self.numSecondsToWait = numSecondsToWait
        self.numMessagesToWaitFor = numMessagesToWaitFor
        self.numSecondsToWait = numSecondsToWait
        self.chordNode = chordNode
        self.semaphore = DeferredSemaphore(1)
        self.messageList = [] # Holds tuples of (message, envelope)
        
        # Construct a timer to wait
        self.timerID = None
        
    def addResponse(self, message, envelope):
        '''We use a semaphore to ensure we don't modify the list while sending.'''
        d = self.semaphore.acquire()
        d.addCallback(self._addResponse, message, envelope)
        
    def _addResponse(self, dummy_defResult, message, envelope):
        '''This is called only once we have the semaphore.'''         
        self.messageList.append ( (message, envelope) )
        
        print("DEBUG: AggRespCache: %s  adding message %s " % (self.chordNode.nodeLocation.port, message))
        
        if len(self.messageList) >= self.numMessagesToWaitFor:
            # Send it!
            self._sendResponse()
        else:
            # Make sure a timer is running
            if self.timerID is None or not self.timerID.active():
                self.timerID = reactor.callLater(self.numSecondsToWait, self.sendResponse)
            
            # We're done.
            self.semaphore.release()    
        
            
    def sendResponse(self):
        '''Only call sendResponse when you have the lock.'''
        d = self.semaphore.acquire()
        d.addCallback(self._sendResponse)
        
    
    def _sendResponse(self, dummy_deferResult=None):
        '''Send the response but only after acquiring the semaphore
        '''
        # Copy the list
        messagesListCopy = self.messageList
        self.messageList = []
        
        # Release the semaphore
        self.semaphore.release()
        
        # Stop the timer if it's still going
        if self.timerID is not None and self.timerID.active():
            self.timerID.cancel()
            self.timerID = None
        
        print("DEBUG: AggResponseCache-Sending %d Messages %s" % (len(messagesListCopy), self.chordNode.nodeLocation.port))
        
        # Send a P2P message to the dest with all the responses
        d = self.chordNode.sendSyncMultipleMessage(messagesListCopy, 'p2p') # Will this break message authentication?
        d.addCallback(self.sendAcks, messagesListCopy)
        d.addErrback(self.sendResponseFailed)

#     def emptyMessageList(self, _):
#         self.messageList = []
        
    def sendAcks(self, resultsDict, messageList):
        # Send ACK messages to the nodes for which we aggregated
        
        for (_message, envelope) in messageList:
            # Get the status to return
            msgID = envelope['msgID']
            if msgID not in resultsDict:
                status = False
            else:
                status = resultsDict[msgID]

            d = self.chordNode.sendSingleAck(msgID, envelope['source'], status)
            d.addErrback(self.sendAckFailed, envelope['source'])
                        
            
    def sendAckFailed(self, fail, sourceNode):
        log.err("We failed to SendAck for source %s" % sourceNode, fail)
            
        
    def sendResponseFailed(self, theFailure):
        log.err(theFailure)    
        
        
        

        
Beispiel #11
0
class IndxConnectionPool:
    """ A wrapper for txpostgres connection pools, which auto-reconnects. """

    def __init__(self):
        logging.debug("IndxConnectionPool starting. ")
        self.connections = {} # by connection string
        self.conn_strs = {} # by db_name
        self.semaphore = DeferredSemaphore(1)
        self.subscribers = {} # by db name

    def removeAll(self, db_name):
        """ Remove all connections for a named database - used before deleting that database. """
        logging.debug("IndxConnectionPool removeAll {0}".format(db_name))
        d_list = []
        if db_name in self.conn_strs:
            for conn_str in self.conn_strs[db_name]:
                for conn in self.connections[conn_str].getInuse():
                    d_list.append(conn.close())
                for conn in self.connections[conn_str].getFree():
                    d_list.append(conn.close())

                del self.connections[conn_str]
            del self.conn_strs[db_name]

        dl = DeferredList(d_list)
        return dl

    def connect(self, db_name, db_user, db_pass, db_host, db_port):
        """ Returns an IndxConnection (Actual connection and pool made when query is made). """

        return_d = Deferred()
        log_conn_str = "dbname='{0}' user='******' password='******' host='{3}' port='{4}' application_name='{5}'".format(db_name, db_user, "XXXX", db_host, db_port, indx_pg2.APPLICATION_NAME)
        conn_str = "dbname='{0}' user='******' password='******' host='{3}' port='{4}' application_name='{5}'".format(db_name, db_user, db_pass, db_host, db_port, indx_pg2.APPLICATION_NAME)
        logging.debug("IndxConnectionPool connect: {0}".format(log_conn_str))

        if db_name not in self.conn_strs:
            self.conn_strs[db_name] = []
        self.conn_strs[db_name].append(conn_str)

        def free_cb(conn):
            """ Called back when this IndxConnection has finished querying, so
                we put the real connection back into the pool. """
            logging.debug("IndxConnectionPool free_cb, conn: {0}".format(conn))

            self.connections[conn_str].freeConnection(conn) # no dealing with callbacks, just carry on


        def alloc_cb(conn_str):
            # a query was called - allocate a connection now and pass it back
            return self._connect(conn_str)
 
        indx_connection = IndxConnection(conn_str, alloc_cb, free_cb)
        return_d.callback(indx_connection)
        return return_d


    def _connect(self, conn_str):
        """ Connect and return a free Connection.
            Figures out whether to make new connections, use the pool, or wait in a queue.
        """
        logging.debug("IndxConnectionPool _connect ({0})".format(conn_str))
        return_d = Deferred()

        def err_cb(failure):
            logging.error("IndxConnectionPool _connect err_cb: {0}".format(failure))
            self.semaphore.release()
            return_d.errback(failure)

        def succeed_cb(empty):
            logging.debug("IndxConnectionPool _connect succeed_cb")
            # TODO pass a Connection back
            
            if len(self.connections[conn_str].getFree()) > 0:
                # free connection, use it straight away
                conn = self.connections[conn_str].getFree().pop()

                self.connections[conn_str].getInuse().append(conn)
                self.semaphore.release()
                return_d.callback(conn)
                return

            if len(self.connections[conn_str].getInuse()) < MAX_CONNS:
                # not at max connections for this conn_str
                
                # create a new one
                d = self._newConnection(conn_str)

                def connected_cb(indx_conn):
                    logging.debug("IndxConnectionPool _connect connected_cb ({0})".format(indx_conn))
                    self.connections[conn_str].getFree().remove(indx_conn)
                    self.connections[conn_str].getInuse().append(indx_conn)
                    self.semaphore.release()
                    return_d.callback(indx_conn)
                    return

                d.addCallbacks(connected_cb, err_cb)
                return

            # wait for a connection
            def wait_cb(conn):
                logging.debug("IndxConnectionPool _connect wait_cb ({0})".format(conn))
                # already put in 'inuse'
                return_d.callback(conn)
                return

            self.semaphore.release()
            self.connections[conn_str].getWaiting().append(wait_cb)
            return

        def locked_cb(empty):
            logging.debug("IndxConnectionPool _connect locked_cb")
            if conn_str not in self.connections:
                self._newConnections(conn_str).addCallbacks(succeed_cb, err_cb)
            else:
                threads.deferToThread(succeed_cb, None)
#                succeed_cb(None)

        self.semaphore.acquire().addCallbacks(locked_cb, err_cb)
        return return_d

    def _closeOldConnection(self):
        """ Close the oldest connection, so we can open a new one up. """
        # is already in a semaphore lock, from _newConnection
        logging.debug("IndxConnectionPool _closeOldConnection")

        ### we could force quite them through postgresql like this - but instead we kill them from inside
        #query = "SELECT * FROM pg_stat_activity WHERE state = 'idle' AND application_name = %s AND query != 'LISTEN wb_new_version' ORDER BY state_change LIMIT 1;"
        #params = [indx_pg2.APPLICATION_NAME]

        return_d = Deferred()

        def err_cb(failure):
            return_d.errback(failure)

        ages = {}
        for conn_str, dbpool in self.connections.items():
            lastused = dbpool.getTime()
            if lastused not in ages:
                ages[lastused] = []
            ages[lastused].append(dbpool)

        times = ages.keys()
        times.sort()

        pool_queue = []
        for timekey in times:
            pools = ages[timekey]
            pool_queue.extend(pools)

        def removed_cb(count):

            if count < REMOVE_AT_ONCE and len(pool_queue) > 0:
                pool = pool_queue.pop(0)
                pool.getFree()
                pool.removeAll(count).addCallbacks(removed_cb, err_cb)
            else:
                return_d.callback(None)
        
        removed_cb(0)
        return return_d

    def _newConnection(self, conn_str):
        """ Makes a new connection to the DB
            and then puts it in the 'free' pool of this conn_str.
        """
        logging.debug("IndxConnectionPool _newConnection")
        # lock with the semaphore before calling this
        return_d = Deferred()

        def close_old_cb(failure):
            failure.trap(psycopg2.OperationalError, Exception)
            # couldn't connect, so close an old connection first
            logging.error("IndxConnectionPool error close_old_cb: {0} - state of conns is: {1}".format(failure.value, self.connections))

            logging.error("IndxConnectionPool connections: {0}".format("\n".join(map(lambda name: self.connections[name].__str__(), self.connections))))

            def closed_cb(empty):
                # closed, so try connecting again
                self._newConnection(conn_str).addCallbacks(return_d.callback, return_d.errback)

            closed_d = self._closeOldConnection()
            closed_d.addCallbacks(closed_cb, return_d.errback)

        try:
            # try to connect
            def connected_cb(connection):
                logging.debug("IndxConnectionPool _newConnection connected_cb, connection: {0}".format(connection))
                self.connections[conn_str].getFree().append(connection)
                return_d.callback(connection)

            conn = txpostgres.Connection()
            connection_d = conn.connect(conn_str)
            connection_d.addCallbacks(connected_cb, close_old_cb)
        except Exception as e:
            # close an old connection first
            logging.debug("IndxConnectionPool Exception, going to call close_old_cb: ({0})".format(e))
            close_old_cb(Failure(e))

        return return_d

    def _newConnections(self, conn_str):
        """ Make a pool of new connections. """
        # lock with the semaphore before calling this
        logging.debug("IndxConnectionPool _newConnections")
        return_d = Deferred()

        self.connections[conn_str] = DBConnectionPool(conn_str)

        try:
            d_list = []
            for i in range(MIN_CONNS):
                connection_d = self._newConnection(conn_str) 
                d_list.append(connection_d)

            dl = DeferredList(d_list)
            dl.addCallbacks(return_d.callback, return_d.errback)

        except Exception as e:
            logging.error("IndxConnectionPool error in _newConnections: {0}".format(e))
            return_d.errback(Failure(e))

        return return_d
Beispiel #12
0
class DBConnectionPool():
    """ A pool of DB connections for a specific connection string / DB. """

    def __init__(self, conn_str):
        self.waiting = []
        self.inuse = []
        self.free = []

        self.semaphore = DeferredSemaphore(1)
        self.updateTime()

    def __unicode__(self):
        return self.__str__()

    def __str__(self):
        return "waiting: {0}, inuse: {1}, free: {2}, semaphore: {3}, lastused: {4}".format(self.waiting, self.inuse, self.free, self.semaphore, self.lastused)

    def updateTime(self):
        self.lastused = time.mktime(time.gmtime()) # epoch time

    def getTime(self):
        return self.lastused

    def getWaiting(self):
        self.updateTime()
        return self.waiting

    def getInuse(self):
        self.updateTime()
        return self.inuse

    def getFree(self):
        self.updateTime()
        return self.free

    def freeConnection(self, conn):
        """ Free a connection from this DBPool. """

        def locked_cb(empty):
            logging.debug("DBConnectionPool locked_cb")
            self.getInuse().remove(conn)
            
            if len(self.getWaiting()) > 0:
                callback = self.getWaiting().pop()
                self.getInuse().append(conn)
                self.semaphore.release()
                callback(conn)
            else: 
                self.getFree().append(conn)
                self.semaphore.release()

        def err_cb(failure):
            failure.trap(Exception)
            logging.error("DBConnectionPool free, err_cb: {0}".format(failure.value))
            self.semaphore.release()

        self.semaphore.acquire().addCallbacks(locked_cb, err_cb)


    def removeAll(self, count):
        """ Remove all free connections (usually because they're old and we're in
            a freeing up period.
        """
        logging.debug("DBConnectionPool removeAll called, count: {0}".format(count))
        return_d = Deferred()
        self.updateTime()

        def err_cb(failure):
            self.semaphore.release()
            return_d.errback(failure)

        def locked_cb(count):
            # immediately close the free connections
            while len(self.free) > 0:
                conn = self.free.pop(0)
                conn.close()
                count += 1

            self.semaphore.release()
            return_d.callback(count)

        self.semaphore.acquire().addCallbacks(lambda s: locked_cb(count), err_cb)
        return return_d
Beispiel #13
0
class WebpushPushkin(ConcurrencyLimitedPushkin):
    """
    Pushkin that relays notifications to Google/Firebase Cloud Messaging.
    """

    UNDERSTOOD_CONFIG_FIELDS = {
        "type",
        "max_connections",
        "vapid_private_key",
        "vapid_contact_email",
        "allowed_endpoints",
        "ttl",
    } | ConcurrencyLimitedPushkin.UNDERSTOOD_CONFIG_FIELDS

    def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]):
        super().__init__(name, sygnal, config)

        nonunderstood = self.cfg.keys() - self.UNDERSTOOD_CONFIG_FIELDS
        if nonunderstood:
            logger.warning(
                "The following configuration fields are not understood: %s",
                nonunderstood,
            )

        self.http_pool = HTTPConnectionPool(reactor=sygnal.reactor)
        self.max_connections = self.get_config(
            "max_connections", int, DEFAULT_MAX_CONNECTIONS
        )
        self.connection_semaphore = DeferredSemaphore(self.max_connections)
        self.http_pool.maxPersistentPerHost = self.max_connections

        tls_client_options_factory = ClientTLSOptionsFactory()

        # use the Sygnal global proxy configuration
        proxy_url = sygnal.config.get("proxy")

        self.http_agent = ProxyAgent(
            reactor=sygnal.reactor,
            pool=self.http_pool,
            contextFactory=tls_client_options_factory,
            proxy_url_str=proxy_url,
        )
        self.http_request_factory = HttpRequestFactory()

        self.allowed_endpoints: Optional[List[Pattern[str]]] = None
        allowed_endpoints = self.get_config("allowed_endpoints", list)
        if allowed_endpoints:
            self.allowed_endpoints = list(map(glob_to_regex, allowed_endpoints))

        privkey_filename = self.get_config("vapid_private_key", str)
        if not privkey_filename:
            raise PushkinSetupException("'vapid_private_key' not set in config")
        if not os.path.exists(privkey_filename):
            raise PushkinSetupException("path in 'vapid_private_key' does not exist")
        try:
            self.vapid_private_key = Vapid.from_file(private_key_file=privkey_filename)
        except VapidException as e:
            raise PushkinSetupException("invalid 'vapid_private_key' file") from e
        self.vapid_contact_email = self.get_config("vapid_contact_email", str)
        if not self.vapid_contact_email:
            raise PushkinSetupException("'vapid_contact_email' not set in config")
        self.ttl = self.get_config("ttl", int, DEFAULT_TTL)

    async def _dispatch_notification_unlimited(
        self, n: Notification, device: Device, context: NotificationContext
    ) -> List[str]:
        p256dh = device.pushkey
        if not isinstance(device.data, dict):
            logger.warn(
                "Rejecting pushkey %s; device.data is not a dict", device.pushkey
            )
            return [device.pushkey]

        # drop notifications without an event id if requested,
        # see https://github.com/matrix-org/sygnal/issues/186
        if device.data.get("events_only") is True and not n.event_id:
            return []

        endpoint = device.data.get("endpoint")
        auth = device.data.get("auth")

        if not p256dh or not isinstance(endpoint, str) or not isinstance(auth, str):
            logger.warn(
                "Rejecting pushkey; subscription info incomplete or invalid "
                + "(p256dh: %s, endpoint: %r, auth: %r)",
                p256dh,
                endpoint,
                auth,
            )
            return [device.pushkey]

        endpoint_domain = urlparse(endpoint).netloc
        if self.allowed_endpoints:
            allowed = any(
                regex.fullmatch(endpoint_domain) for regex in self.allowed_endpoints
            )
            if not allowed:
                logger.error(
                    "push gateway %s is not in allowed_endpoints, blocking request",
                    endpoint_domain,
                )
                # abort, but don't reject push key
                return []

        subscription_info = {
            "endpoint": endpoint,
            "keys": {"p256dh": p256dh, "auth": auth},
        }
        payload = WebpushPushkin._build_payload(n, device)
        data = json.dumps(payload)

        # web push only supports normal and low priority, so assume normal if absent
        low_priority = n.prio == "low"
        # allow dropping earlier notifications in the same room if requested
        topic = None
        if n.room_id and device.data.get("only_last_per_room") is True:
            # ask for a 22 byte hash, so the base64 of it is 32,
            # the limit webpush allows for the topic
            topic = urlsafe_b64encode(
                blake2s(n.room_id.encode(), digest_size=22).digest()
            )

        # note that webpush modifies vapid_claims, so make sure it's only used once
        vapid_claims = {
            "sub": "mailto:{}".format(self.vapid_contact_email),
        }
        # we use the semaphore to actually limit the number of concurrent
        # requests, since the HTTPConnectionPool will actually just lead to more
        # requests being created but not pooled – it does not perform limiting.
        with QUEUE_TIME_HISTOGRAM.time():
            with PENDING_REQUESTS_GAUGE.track_inprogress():
                await self.connection_semaphore.acquire()
        try:
            with SEND_TIME_HISTOGRAM.time():
                with ACTIVE_REQUESTS_GAUGE.track_inprogress():
                    request = webpush(
                        subscription_info=subscription_info,
                        data=data,
                        ttl=self.ttl,
                        vapid_private_key=self.vapid_private_key,
                        vapid_claims=vapid_claims,
                        requests_session=self.http_request_factory,
                    )
                    response = await request.execute(
                        self.http_agent, low_priority, topic
                    )
                    response_text = (await readBody(response)).decode()
        finally:
            self.connection_semaphore.release()

        reject_pushkey = self._handle_response(
            response, response_text, device.pushkey, endpoint_domain
        )
        if reject_pushkey:
            return [device.pushkey]
        return []

    @staticmethod
    def _build_payload(n: Notification, device: Device) -> Dict[str, Any]:
        """
        Build the payload data to be sent.

        Args:
            n: Notification to build the payload for.
            device: Device information to which the constructed payload
            will be sent.

        Returns:
            JSON-compatible dict
        """
        payload = {}

        if device.data:
            default_payload = device.data.get("default_payload")
            if isinstance(default_payload, dict):
                payload.update(default_payload)

        for attr in [
            "room_id",
            "room_name",
            "room_alias",
            "membership",
            "event_id",
            "sender",
            "sender_display_name",
            "user_is_target",
            "type",
        ]:
            value = getattr(n, attr, None)
            if value:
                payload[attr] = value

        counts = getattr(n, "counts", None)
        if counts is not None:
            for attr in ["unread", "missed_calls"]:
                count_value = getattr(counts, attr, None)
                if count_value is not None:
                    payload[attr] = count_value

        if n.content and isinstance(n.content, dict):
            content = n.content.copy()
            # we can't show formatted_body in a notification anyway on web
            # so remove it
            content.pop("formatted_body", None)
            body = content.get("body")
            # make some attempts to not go over the max payload length
            if isinstance(body, str) and len(body) > MAX_BODY_LENGTH:
                content["body"] = body[0 : MAX_BODY_LENGTH - 1] + "…"
            ciphertext = content.get("ciphertext")
            if isinstance(ciphertext, str) and len(ciphertext) > MAX_CIPHERTEXT_LENGTH:
                content.pop("ciphertext", None)
            payload["content"] = content

        return payload

    def _handle_response(
        self,
        response: IResponse,
        response_text: str,
        pushkey: str,
        endpoint_domain: str,
    ) -> bool:
        """
        Logs and determines the outcome of the response

        Returns:
            Boolean whether the puskey should be rejected
        """
        ttl_response_headers = response.headers.getRawHeaders(b"TTL")
        if ttl_response_headers:
            try:
                ttl_given = int(ttl_response_headers[0])
                if ttl_given != self.ttl:
                    logger.info(
                        "requested TTL of %d to endpoint %s but got %d",
                        self.ttl,
                        endpoint_domain,
                        ttl_given,
                    )
            except ValueError:
                pass
        # permanent errors
        if response.code == 404 or response.code == 410:
            logger.warn(
                "Rejecting pushkey %s; subscription is invalid on %s: %d: %s",
                pushkey,
                endpoint_domain,
                response.code,
                response_text,
            )
            return True
        # and temporary ones
        if response.code >= 400:
            logger.warn(
                "webpush request failed for pushkey %s; %s responded with %d: %s",
                pushkey,
                endpoint_domain,
                response.code,
                response_text,
            )
        elif response.code != 201:
            logger.info(
                "webpush request for pushkey %s didn't respond with 201; "
                + "%s responded with %d: %s",
                pushkey,
                endpoint_domain,
                response.code,
                response_text,
            )
        return False
Beispiel #14
0
class BaseQtWebKitMiddleware(object):
    nam_cls = ScrapyNetworkAccessManager

    @classmethod
    def from_crawler(cls, crawler):
        settings = crawler.settings

        if crawler.settings.getbool('QTWEBKIT_COOKIES_ENABLED', False):
            cookies_middleware = CookiesMiddleware(
                crawler.settings.getbool('COOKIES_DEBUG')
            )
        else:
            cookies_middleware = None

        qt_platform = settings.get("QTWEBKIT_QT_PLATFORM", "minimal")
        if qt_platform == "default":
            qt_platform = None

        ext = cls(
            crawler,
            show_window=settings.getbool("QTWEBKIT_SHOW_WINDOW", False),
            qt_platform=qt_platform,
            enable_webkit_dev_tools=settings.get("QTWEBKIT_ENABLE_DEV_TOOLS",
                                                 False),
            page_limit=settings.getint("QTWEBKIT_PAGE_LIMIT", 4),
            cookies_middleware=cookies_middleware
        )

        return ext

    @staticmethod
    def engine_stopped():
        if QApplication.instance():
            QApplication.instance().quit()

    def __init__(self, crawler, show_window=False, qt_platform="minimal",
                 enable_webkit_dev_tools=False, page_limit=4,
                 cookies_middleware=None):
        super(BaseQtWebKitMiddleware, self).__init__()
        self._crawler = crawler
        self.show_window = show_window
        self.qt_platform = qt_platform
        self.enable_webkit_dev_tools = enable_webkit_dev_tools
        if page_limit != 1:
            if QWebSettings is not None:
                QWebSettings.setObjectCacheCapacities(0, 0, 0)
        if page_limit is None:
            self.semaphore = DummySemaphore()
        else:
            self.semaphore = DeferredSemaphore(page_limit)
        self.cookies_middleware = cookies_middleware
        self._references = set()

    @staticmethod
    def _schedule_qt_event_loop(app):
        """

        Schedule a QApplication's event loop within Twisted. Should be called
        at most once per QApplication.

        """
        # XXX: This is ugly but I don't know another way to do it.
        call = LoopingCall(app.processEvents)
        call.start(0.02, False)
        app.aboutToQuit.connect(call.stop)

    def _setup_page(self, page, extra_settings):
        settings = page.settings()
        settings.setAttribute(QWebSettings.JavaEnabled, False)
        settings.setAttribute(QWebSettings.PluginsEnabled, False)
        settings.setAttribute(QWebSettings.PrivateBrowsingEnabled, True)
        settings.setAttribute(QWebSettings.LocalStorageEnabled, True)
        settings.setAttribute(QWebSettings.LocalContentCanAccessRemoteUrls,
                              True)
        settings.setAttribute(QWebSettings.LocalContentCanAccessFileUrls,
                              True)
        settings.setAttribute(QWebSettings.NotificationsEnabled, False)

        settings.setAttribute(QWebSettings.DeveloperExtrasEnabled,
                              self.enable_webkit_dev_tools)

        for setting, value in extra_settings.items():
            settings.setAttribute(setting, value)

    @staticmethod
    def _make_qt_request(scrapy_request):
        """Build a QNetworkRequest from a Scrapy request."""

        qt_request = QNetworkRequest(QUrl(scrapy_request.url))
        for header, values in scrapy_request.headers.items():
            qt_request.setRawHeader(header, b', '.join(values))

        try:
            operation = HTTP_METHOD_TO_QT_OPERATION[scrapy_request.method]
        except KeyError:
            operation = QNetworkAccessManager.CustomOperation
            qt_request.setAttribute(QNetworkRequest.CustomVerbAttribute,
                                    scrapy_request.method)

        qt_request.setAttribute(QNetworkRequest.CacheSaveControlAttribute,
                                False)

        req_body = QByteArray(scrapy_request.body)

        return qt_request, operation, req_body

    @inlineCallbacks
    def process_request(self, request, spider):
        if self.cookies_middleware:
            yield self.cookies_middleware.process_request(request, spider)

        if isinstance(request, QtWebKitRequest):
            if request.webpage:
                # Request is to continue processing with an existing webpage
                # object.
                webpage = request.webpage
                request = request.replace(webpage=None)
                webpage.networkAccessManager().request = request
                returnValue(self._handle_page_request(spider, request,
                                                      webpage))
            else:
                yield self.semaphore.acquire()
                response = yield self.create_page(request, spider)
                returnValue(response)

    def process_response(self, request, response, spider):
        if self.cookies_middleware:
            return self.cookies_middleware.process_response(request, response,
                                                            spider)
        else:
            return response

    def ensure_qapplication(self):
        """Create and setup a QApplication if one does not already exist."""
        if not QApplication.instance():
            args = ["scrapy"]
            if self.qt_platform is not None:
                args.extend(["-platform", self.qt_platform])
            app = QApplication(args)
            self._schedule_qt_event_loop(app)
            _QApplicationStopper(self._crawler.signals, app)

    def create_page(self, request, spider):
        """

        Create a webpage object, load a request on it, return a deferred that
        fires with a response on page load.

        """

        self.ensure_qapplication()

        webpage = WebPage()
        self._setup_page(webpage,
                         request.meta.get('qwebsettings_settings', {}))
        self._references.add(webpage)

        if self.show_window:
            webview = QWebView()
            webview.setPage(webpage)
            webpage.webview = webview
            self._add_webview_to_window(webview, spider.name)

        if request.meta.get('qtwebkit_user_agent', False):
            request.headers['User-Agent'] = webpage.userAgentForUrl(
                QUrl(request.url)
            )

        nam = self.nam_cls(spider, request, request.headers.get('User-Agent'),
                           parent=webpage)
        if ((self.cookies_middleware and
             'dont_merge_cookies' not in request.meta)):
            cookiejarkey = request.meta.get("cookiejar")
            cookiejar = ScrapyAwareCookieJar(self.cookies_middleware,
                                             cookiejarkey, parent=nam)
            nam.setCookieJar(cookiejar)
        webpage.setNetworkAccessManager(nam)

        d = deferred_for_signal(webpage.load_finished_with_error)
        d.addCallback(partial(self._handle_page_request, spider, request,
                              webpage))
        webpage.mainFrame().load(*self._make_qt_request(request))
        return d

    def _add_webview_to_window(self, webview, title=""):
        pass

    def _remove_webview_from_window(self, webview):
        pass

    def _handle_page_request(self, spider, request, webpage,
                             load_result=(True, None)):
        """

        Handle a request for a web page, either a page load or a request to
        continue using an existing page object.

        """

        try:
            ok, error = load_result

            if ok:
                # The error object is not available if a page load was not
                # requested.
                if error and error.domain == QWebPage.Http:
                    status = error.error
                else:
                    status = 200
                if error:
                    url = error.url
                else:
                    url = webpage.mainFrame().url()

                qwebpage_response = request.meta.get('qwebpage_response', False)
                if qwebpage_response:
                    respcls = QtWebKitResponse
                else:
                    respcls = HtmlResponse

                response = respcls(status=status,
                                   url=url.toString(),
                                   headers=error.headers,
                                   body=webpage.mainFrame().toHtml(),
                                   encoding='utf-8',
                                   request=request)

                if qwebpage_response:
                    response.webpage = webpage
                    request.callback = partial(self._request_callback, spider,
                                               request.callback or 'parse')
                else:
                    self._close_page(webpage)

            else:
                raise self._exception_from_errorpageextensionoption(error)

        except Exception as err:
            response = Failure(err)

        return response

    @inlineCallbacks
    def _request_callback(self, spider, original_callback, response):
        """

        Close the page (lose the reference to it so it is garbage collected)
        when the callback returns.

        The original callback may prevent page closing by setting the
        should_close_webpage attribute in responses. This is useful for
        example if the page is stored somewhere else (e.g. request meta) to be
        used later. The page then needs to be closed manually at some point by
        calling its close_page() function, which is created here.

        """

        if isinstance(original_callback, basestring):
            original_callback = getattr(spider, original_callback)

        webpage = response.webpage
        response.should_close_webpage = True
        try:
            returnValue(arg_to_iter((yield maybeDeferred(original_callback,
                                                         response))))
        finally:
            # FIXME: sometimes this section is reached before the wrapped
            # callback finishes, when it returns a Deferred.
            if response.should_close_webpage:
                self._close_page(webpage)
            else:
                webpage.close_page = partial(self._close_page, webpage)
                webpage.close_page.__doc__ = ("Lose the reference to the "
                                              "webpage object and allow it "
                                              "to be garbage collected.")

    def _close_page(self, webpage):
        self._references.remove(webpage)
        # Resetting the main frame URL prevents it from making any more
        # requests, which would cause Qt errors after the webpage is deleted.
        webpage.mainFrame().setUrl(QUrl())
        if webpage.webview is not None:
            self._remove_webview_from_window(webpage.webview)
        self.semaphore.release()

    _qt_error_exc_mapping = {
        QNetworkReply.ConnectionRefusedError: ConnectionRefusedError,
        QNetworkReply.RemoteHostClosedError: ConnectionLost,
        QNetworkReply.HostNotFoundError: DNSLookupError,
        QNetworkReply.TimeoutError: TimeoutError,
        QNetworkReply.OperationCanceledError: ConnectingCancelledError,
        QNetworkReply.SslHandshakeFailedError: SSLError,
        QNetworkReply.ProtocolUnknownError: NotSupported
    }

    def _exception_from_errorpageextensionoption(self, option):
        if option.domain == QWebPage.QtNetwork:
            exc_cls = self._qt_error_exc_mapping.get(option.error,
                                                     ConnectError)
        # elif option.domain == QWebPage.WebKit:
        #     exc_cls = Exception
        else:
            exc_cls = Exception

        return exc_cls(option.errorString)