def main(reactor, duration):
    start = time()
    count = 0
    while time() - start < duration:
        defer.ensureDeferred(_run())
        count += 1
    return defer.succeed(count)
    def tearDown(self) -> None:
        async def tearDown() -> None:
            for store in self.stores:
                await store.disconnect()

        # setUp can't return a coroutine, so convert it to a Deferred
        return ensureDeferred(tearDown())
    def whenRunning(cls, config: Configuration) -> None:
        """
        Called after the reactor has started.
        """
        async def start() -> None:
            await config.store.upgradeSchema()
            await config.store.validate()

        d = ensureDeferred(start())

        host = config.HostName
        port = config.Port

        application = Application(config=config)

        cls.log.info(
            "Setting up web service at http://{host}:{port}/",
            host=host, port=port,
        )

        patchCombinedLogFormatter()

        factory = Site(application.router.resource())
        factory.sessionFactory = IMSSession

        from twisted.internet import reactor
        reactor.listenTCP(port, factory, interface=host)

        return d
Exemple #4
0
def _call(instance, f, *args, **kwargs):
    if instance is not None or getattr(f, "__klein_bound__", False):
        args = (instance,) + args
    result = f(*args, **kwargs)
    if iscoroutine(result):
        result = ensureDeferred(result)
    return result
    def setUp(self) -> None:
        async def setUp() -> None:
            self.names: Set[str] = set()
            self.stores: List[TestDataStore] = []

            await self.mysqlService.start()

        # setUp can't return a coroutine, so convert it to a Deferred
        return ensureDeferred(setUp())
def _main():
    """
    This is a magic name for `python -m autobahn`, and specified as
    our entry_point in setup.py
    """
    react(
        lambda reactor: ensureDeferred(
            _real_main(reactor)
        )
    )
Exemple #7
0
def maybe_coroutine(obj):
    """
    If 'obj' is a coroutine and we're using Python3, wrap it in
    ensureDeferred. Otherwise return the original object.

    (This is to insert in all callback chains from user code, in case
    that user code is Python3 and used 'async def')
    """
    if six.PY3 and asyncio.iscoroutine(obj):
        return defer.ensureDeferred(obj)
    return obj
Exemple #8
0
 def lineReceived(self, line):
     # Ignore blank lines
     if not line:
         return
     keys = line.split()
     # we really do want this to be a Deferred because lineReceived
     # isn't (can't be) an async method..
     token = object()
     self.outstanding.append(token)
     d = defer.ensureDeferred(do_cmd(self.proto, tuple(keys)))
     d.addCallback(self._completed, token)
     d.addErrback(self._error, token)
Exemple #9
0
async def run(reactor, cfg, tor, if_unused, verbose, list, build, delete):
    if list:
        await list_circuits(reactor, cfg, tor, verbose)

    elif len(delete) > 0:
        deletes = []
        for d in delete:
            deletes.append(delete_circuit(reactor, cfg, tor, d, if_unused))
        results = await defer.DeferredList([defer.ensureDeferred(d) for d in deletes])
        for ok, value in results:
            if not ok:
                raise value

    elif build:
        await build_circuit(reactor, cfg, tor, build.split(','))
Exemple #10
0
 def _emit_run(self, f, args, kwargs):
     try:
         result = f(*args, **kwargs)
     except Exception:
         self.emit('failure', Failure())
     else:
         if iscoroutine and iscoroutine(result):
             d = ensureDeferred(result)
         elif isinstance(result, Deferred):
             d = result
         else:
             d = None
         if d:
             @d.addErrback
             def _errback(failure):
                 if failure:
                     self.emit('failure', failure)
Exemple #11
0
 def as_future(self, fun, *args, **kwargs):
     # Twisted doesn't automagically deal with coroutines on Py3
     if PY3_CORO and iscoroutinefunction(fun):
         return ensureDeferred(fun(*args, **kwargs))
     return maybeDeferred(fun, *args, **kwargs)
Exemple #12
0
 def get(self):
     return defer.ensureDeferred(self.async_get())
Exemple #13
0
    def request(
        self,
        method: bytes,
        uri: bytes,
        headers: Optional[Headers] = None,
        bodyProducer: Optional[IBodyProducer] = None,
    ) -> Generator[defer.Deferred, Any, defer.Deferred]:
        """
        Args:
            method: HTTP method: GET/POST/etc
            uri: Absolute URI to be retrieved
            headers:
                HTTP headers to send with the request, or None to send no extra headers.
            bodyProducer:
                An object which can generate bytes to make up the
                body of this request (for example, the properly encoded contents of
                a file for a file upload).  Or None if the request is to have
                no body.
        Returns:
            Deferred[twisted.web.iweb.IResponse]:
                fires when the header of the response has been received (regardless of the
                response status code). Fails if there is any problem which prevents that
                response from being received (including problems that prevent the request
                from being sent).
        """
        # We use urlparse as that will set `port` to None if there is no
        # explicit port.
        parsed_uri = urllib.parse.urlparse(uri)

        # There must be a valid hostname.
        assert parsed_uri.hostname

        # If this is a matrix:// URI check if the server has delegated matrix
        # traffic using well-known delegation.
        #
        # We have to do this here and not in the endpoint as we need to rewrite
        # the host header with the delegated server name.
        delegated_server = None
        if (parsed_uri.scheme == b"matrix"
                and not _is_ip_literal(parsed_uri.hostname)
                and not parsed_uri.port):
            well_known_result = yield defer.ensureDeferred(
                self._well_known_resolver.get_well_known(parsed_uri.hostname))
            delegated_server = well_known_result.delegated_server

        if delegated_server:
            # Ok, the server has delegated matrix traffic to somewhere else, so
            # lets rewrite the URL to replace the server with the delegated
            # server name.
            uri = urllib.parse.urlunparse((
                parsed_uri.scheme,
                delegated_server,
                parsed_uri.path,
                parsed_uri.params,
                parsed_uri.query,
                parsed_uri.fragment,
            ))
            parsed_uri = urllib.parse.urlparse(uri)

        # We need to make sure the host header is set to the netloc of the
        # server and that a user-agent is provided.
        if headers is None:
            request_headers = Headers()
        else:
            request_headers = headers.copy()

        if not request_headers.hasHeader(b"host"):
            request_headers.addRawHeader(b"host", parsed_uri.netloc)
        if not request_headers.hasHeader(b"user-agent"):
            request_headers.addRawHeader(b"user-agent", self.user_agent)

        res = yield make_deferred_yieldable(
            self._agent.request(method, uri, request_headers, bodyProducer))

        return res
def inline_success():
    coro = run()
    print(coro)
    yield ensureDeferred(coro)
    print('Done inline')
Exemple #15
0
async def main(reactor):
    """
    Using the 'agent' interface to talk to the echo server (run
    ../echo/server.py for the server, for example)
    """
    agent = create_client_agent(reactor)
    options = {
        "headers": {
            "x-foo": "bar",
        }
    }
    proto = await agent.open("ws://localhost:9000/ws", options)

    def got_message(*args, **kw):
        print("on_message: args={} kwargs={}".format(args, kw))
    proto.on('message', got_message)

    await proto.is_open

    proto.sendMessage(b"i am a message\n")
    await task.deferLater(reactor, 0, lambda: None)

    proto.sendClose(code=1000, reason="byebye")

    await proto.is_closed


if __name__ == "__main__":
    from twisted.internet.defer import ensureDeferred
    task.react(lambda r: ensureDeferred(main(r)))
Exemple #16
0
 def outer(*args, **kwargs):
     return ensureDeferred(wrapped(*args, **kwargs))
Exemple #17
0
    def test_send_receipts_with_backoff(self):
        """Send two receipts in quick succession; the second should be flushed, but
        only after 20ms"""
        mock_send_transaction = (
            self.hs.get_federation_transport_client().send_transaction
        )
        mock_send_transaction.return_value = make_awaitable({})

        sender = self.hs.get_federation_sender()
        receipt = ReadReceipt(
            "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
        )
        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))

        self.pump()

        # expect a call to send_transaction
        mock_send_transaction.assert_called_once()
        json_cb = mock_send_transaction.call_args[0][1]
        data = json_cb()
        self.assertEqual(
            data["edus"],
            [
                {
                    "edu_type": "m.receipt",
                    "content": {
                        "room_id": {
                            "m.read": {
                                "user_id": {
                                    "event_ids": ["event_id"],
                                    "data": {"ts": 1234},
                                }
                            }
                        }
                    },
                }
            ],
        )
        mock_send_transaction.reset_mock()

        # send the second RR
        receipt = ReadReceipt(
            "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
        )
        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
        self.pump()
        mock_send_transaction.assert_not_called()

        self.reactor.advance(19)
        mock_send_transaction.assert_not_called()

        self.reactor.advance(10)
        mock_send_transaction.assert_called_once()
        json_cb = mock_send_transaction.call_args[0][1]
        data = json_cb()
        self.assertEqual(
            data["edus"],
            [
                {
                    "edu_type": "m.receipt",
                    "content": {
                        "room_id": {
                            "m.read": {
                                "user_id": {
                                    "event_ids": ["other_id"],
                                    "data": {"ts": 1234},
                                }
                            }
                        }
                    },
                }
            ],
        )
    def test_well_known_cache_with_temp_failure(self):
        """Test that we refetch well-known before the cache expires, and that
        it ignores transient errors.
        """

        self.reactor.lookups["testserv"] = "1.2.3.4"

        fetch_d = defer.ensureDeferred(
            self.well_known_resolver.get_well_known(b"testserv"))

        # there should be an attempt to connect on port 443 for the .well-known
        clients = self.reactor.tcpClients
        self.assertEqual(len(clients), 1)
        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
        self.assertEqual(host, "1.2.3.4")
        self.assertEqual(port, 443)

        well_known_server = self._handle_well_known_connection(
            client_factory,
            expected_sni=b"testserv",
            response_headers={b"Cache-Control": b"max-age=1000"},
            content=b'{ "m.server": "target-server" }',
        )

        r = self.successResultOf(fetch_d)
        self.assertEqual(r.delegated_server, b"target-server")

        # close the tcp connection
        well_known_server.loseConnection()

        # Get close to the cache expiry, this will cause the resolver to do
        # another lookup.
        self.reactor.pump((900.0, ))

        fetch_d = defer.ensureDeferred(
            self.well_known_resolver.get_well_known(b"testserv"))

        # The resolver may retry a few times, so fonx all requests that come along
        attempts = 0
        while self.reactor.tcpClients:
            clients = self.reactor.tcpClients
            (host, port, client_factory, _timeout,
             _bindAddress) = clients.pop(0)

            attempts += 1

            # fonx the connection attempt, this will be treated as a temporary
            # failure.
            client_factory.clientConnectionFailed(None, Exception("nope"))

            # There's a few sleeps involved, so we have to pump the reactor a
            # bit.
            self.reactor.pump((1.0, 1.0))

        # We expect to see more than one attempt as there was previously a valid
        # well known.
        self.assertGreater(attempts, 1)

        # Resolver should return cached value, despite the lookup failing.
        r = self.successResultOf(fetch_d)
        self.assertEqual(r.delegated_server, b"target-server")

        # Expire both caches and repeat the request
        self.reactor.pump((10000.0, ))

        # Repeat the request, this time it should fail if the lookup fails.
        fetch_d = defer.ensureDeferred(
            self.well_known_resolver.get_well_known(b"testserv"))

        clients = self.reactor.tcpClients
        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
        client_factory.clientConnectionFailed(None, Exception("nope"))
        self.reactor.pump((0.4, ))

        r = self.successResultOf(fetch_d)
        self.assertEqual(r.delegated_server, None)
 def test_get_unread_push_actions_for_user_in_range_for_email(self):
     yield defer.ensureDeferred(
         self.store.get_unread_push_actions_for_user_in_range_for_email(
             USER_ID, 0, 1000, 20
         )
     )
Exemple #20
0
    def setUp(self):

        self.http_client = Mock()
        self.reactor = ThreadedMemoryReactorClock()
        self.hs_clock = Clock(self.reactor)
        self.homeserver = setup_test_homeserver(
            self.addCleanup,
            http_client=self.http_client,
            clock=self.hs_clock,
            reactor=self.reactor,
        )

        user_id = UserID("us", "test")
        our_user = Requester(user_id, None, False, None, None)
        room_creator = self.homeserver.get_room_creation_handler()
        room_deferred = ensureDeferred(
            room_creator.create_room(our_user,
                                     room_creator._presets_dict["public_chat"],
                                     ratelimit=False))
        self.reactor.advance(0.1)
        self.room_id = self.successResultOf(room_deferred)[0]["room_id"]

        self.store = self.homeserver.get_datastore()

        # Figure out what the most recent event is
        most_recent = self.successResultOf(
            maybeDeferred(
                self.homeserver.get_datastore().get_latest_event_ids_in_room,
                self.room_id,
            ))[0]

        join_event = make_event_from_dict({
            "room_id": self.room_id,
            "sender": "@baduser:test.serv",
            "state_key": "@baduser:test.serv",
            "event_id": "$join:test.serv",
            "depth": 1000,
            "origin_server_ts": 1,
            "type": "m.room.member",
            "origin": "test.servx",
            "content": {
                "membership": "join"
            },
            "auth_events": [],
            "prev_state": [(most_recent, {})],
            "prev_events": [(most_recent, {})],
        })

        self.handler = self.homeserver.get_handlers().federation_handler
        self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
            context)
        self.client = self.homeserver.get_federation_client()
        self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
            pdus)

        # Send the join, it should return None (which is not an error)
        d = ensureDeferred(
            self.handler.on_receive_pdu("test.serv",
                                        join_event,
                                        sent_to_us_directly=True))
        self.reactor.advance(1)
        self.assertEqual(self.successResultOf(d), None)

        # Make sure we actually joined the room
        self.assertEqual(
            self.successResultOf(
                maybeDeferred(self.store.get_latest_event_ids_in_room,
                              self.room_id))[0],
            "$join:test.serv",
        )
Exemple #21
0
    def _handle_request(self, request: Request) -> Union[int, bytes]:
        """
        Actually handle the request.
        Args:
            request: The request, corresponding to a POST request.

        Returns:
            Either a str instance or NOT_DONE_YET.

        """
        request_id = self._make_request_id()
        header_dict = {
            k.decode(): v[0].decode()
            for k, v in request.requestHeaders.getAllRawHeaders()
        }

        # extract OpenTracing scope from the HTTP headers
        span_ctx = self.sygnal.tracer.extract(Format.HTTP_HEADERS, header_dict)
        span_tags = {
            tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
            "request_id": request_id,
        }

        root_span = self.sygnal.tracer.start_span("pushgateway_v1_notify",
                                                  child_of=span_ctx,
                                                  tags=span_tags)

        # if this is True, we will not close the root_span at the end of this
        # function.
        root_span_accounted_for = False

        try:
            context = NotificationContext(request_id, root_span,
                                          time.perf_counter())

            log = NotificationLoggerAdapter(logger, {"request_id": request_id})

            try:
                body = json_decoder.decode(
                    request.content.read().decode("utf-8"))
            except Exception as exc:
                msg = "Expected JSON request body"
                log.warning(msg, exc_info=exc)
                root_span.log_kv({logs.EVENT: "error", "error.object": exc})
                request.setResponseCode(400)
                return msg.encode()

            if "notification" not in body or not isinstance(
                    body["notification"], dict):
                msg = "Invalid notification: expecting object in 'notification' key"
                log.warning(msg)
                root_span.log_kv({logs.EVENT: "error", "message": msg})
                request.setResponseCode(400)
                return msg.encode()

            try:
                notif = Notification(body["notification"])
            except InvalidNotificationException as e:
                log.exception("Invalid notification")
                request.setResponseCode(400)
                root_span.log_kv({logs.EVENT: "error", "error.object": e})
                return str(e).encode()

            if notif.event_id is not None:
                root_span.set_tag("event_id", notif.event_id)

            # track whether the notification was passed with content
            root_span.set_tag("has_content", notif.content is not None)

            NOTIFS_RECEIVED_COUNTER.inc()

            if len(notif.devices) == 0:
                msg = "No devices in notification"
                log.warning(msg)
                request.setResponseCode(400)
                return msg.encode()

            root_span_accounted_for = True

            async def cb():
                with REQUESTS_IN_FLIGHT_GUAGE.labels(
                        self.__class__.__name__).track_inprogress():
                    await self._handle_dispatch(root_span, request, log, notif,
                                                context)

            ensureDeferred(cb())

            # we have to try and send the notifications first,
            # so we can find out which ones to reject
            return NOT_DONE_YET
        except Exception as exc_val:
            root_span.set_tag(tags.ERROR, True)

            # [2] corresponds to the traceback
            trace = traceback.format_tb(sys.exc_info()[2])
            root_span.log_kv({
                logs.EVENT: tags.ERROR,
                logs.MESSAGE: str(exc_val),
                logs.ERROR_OBJECT: exc_val,
                logs.ERROR_KIND: type(exc_val),
                logs.STACK: trace,
            })
            raise
        finally:
            if not root_span_accounted_for:
                root_span.finish()
Exemple #22
0
    def test_cant_hide_direct_ancestors(self):
        """
        If you send a message, you must be able to provide the direct
        prev_events that said event references.
        """
        async def post_json(destination, path, data, headers=None, timeout=0):
            # If it asks us for new missing events, give them NOTHING
            if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                return {"events": []}

        self.http_client.post_json = post_json

        # Figure out what the most recent event is
        most_recent = self.successResultOf(
            maybeDeferred(self.store.get_latest_event_ids_in_room,
                          self.room_id))[0]

        # Now lie about an event
        lying_event = make_event_from_dict({
            "room_id":
            self.room_id,
            "sender":
            "@baduser:test.serv",
            "event_id":
            "one:test.serv",
            "depth":
            1000,
            "origin_server_ts":
            1,
            "type":
            "m.room.message",
            "origin":
            "test.serv",
            "content": {
                "body": "hewwo?"
            },
            "auth_events": [],
            "prev_events": [("two:test.serv", {}), (most_recent, {})],
        })

        with LoggingContext(request="lying_event"):
            d = ensureDeferred(
                self.handler.on_receive_pdu("test.serv",
                                            lying_event,
                                            sent_to_us_directly=True))

            # Step the reactor, so the database fetches come back
            self.reactor.advance(1)

        # on_receive_pdu should throw an error
        failure = self.failureResultOf(d)
        self.assertEqual(
            failure.value.args[0],
            ("ERROR 403: Your server isn't divulging details about prev_events "
             "referenced in this event."),
        )

        # Make sure the invalid event isn't there
        extrem = maybeDeferred(self.store.get_latest_event_ids_in_room,
                               self.room_id)
        self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
 def wrapper(*args: Any, **kwargs: Any) -> Any:
     result = f(*args, **kwargs)
     return ensureDeferred(result)
Exemple #24
0
 def post(self):
     return defer.ensureDeferred(self.async_post())
Exemple #25
0
 def test_authentication(self):
     return ensureDeferred(self.async_test_authentication())
    def test_find_first_stream_ordering_after_ts(self):
        def add_event(so, ts):
            return defer.ensureDeferred(
                self.store.db_pool.simple_insert(
                    "events",
                    {
                        "stream_ordering": so,
                        "received_ts": ts,
                        "event_id": "event%i" % so,
                        "type": "",
                        "room_id": "",
                        "content": "",
                        "processed": True,
                        "outlier": False,
                        "topological_ordering": 0,
                        "depth": 0,
                    },
                )
            )

        # start with the base case where there are no events in the table
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(11)
        )
        self.assertEqual(r, 0)

        # now with one event
        yield add_event(2, 10)
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(9)
        )
        self.assertEqual(r, 2)
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(10)
        )
        self.assertEqual(r, 2)
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(11)
        )
        self.assertEqual(r, 3)

        # add a bunch of dummy events to the events table
        for (stream_ordering, ts) in (
            (3, 110),
            (4, 120),
            (5, 120),
            (10, 130),
            (20, 140),
        ):
            yield add_event(stream_ordering, ts)

        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(110)
        )
        self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)

        # 4 and 5 are both after 120: we want 4 rather than 5
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(120)
        )
        self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)

        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(129)
        )
        self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)

        # check we can get the last event
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(140)
        )
        self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)

        # off the end
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(160)
        )
        self.assertEqual(r, 21)

        # check we can find an event at ordering zero
        yield add_event(0, 5)
        r = yield defer.ensureDeferred(
            self.store.find_first_stream_ordering_after_ts(1)
        )
        self.assertEqual(r, 0)
Exemple #27
0
    def _populate_user_directory_process_rooms(self, progress, batch_size):
        """
        Args:
            progress (dict)
            batch_size (int): Maximum number of state events to process
                per cycle.
        """
        state = self.hs.get_state_handler()

        # If we don't have progress filed, delete everything.
        if not progress:
            yield self.delete_all_from_user_dir()

        def _get_next_batch(txn):
            # Only fetch 250 rooms, so we don't fetch too many at once, even
            # if those 250 rooms have less than batch_size state events.
            sql = """
                SELECT room_id, events FROM %s
                ORDER BY events DESC
                LIMIT 250
            """ % (TEMP_TABLE + "_rooms", )
            txn.execute(sql)
            rooms_to_work_on = txn.fetchall()

            if not rooms_to_work_on:
                return None

            # Get how many are left to process, so we can give status on how
            # far we are in processing
            txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
            progress["remaining"] = txn.fetchone()[0]

            return rooms_to_work_on

        rooms_to_work_on = yield self.db.runInteraction(
            "populate_user_directory_temp_read", _get_next_batch)

        # No more rooms -- complete the transaction.
        if not rooms_to_work_on:
            yield self.db.updates._end_background_update(
                "populate_user_directory_process_rooms")
            return 1

        logger.debug("Processing the next %d rooms of %d remaining" %
                     (len(rooms_to_work_on), progress["remaining"]))

        processed_event_count = 0

        for room_id, event_count in rooms_to_work_on:
            is_in_room = yield self.is_host_joined(room_id, self.server_name)

            if is_in_room:
                is_public = yield self.is_room_world_readable_or_publicly_joinable(
                    room_id)

                users_with_profile = yield defer.ensureDeferred(
                    state.get_current_users_in_room(room_id))
                user_ids = set(users_with_profile)

                # Update each user in the user directory.
                for user_id, profile in users_with_profile.items():
                    yield self.update_profile_in_user_dir(
                        user_id, profile.display_name, profile.avatar_url)

                to_insert = set()

                if is_public:
                    for user_id in user_ids:
                        if self.get_if_app_services_interested_in_user(
                                user_id):
                            continue

                        to_insert.add(user_id)

                    if to_insert:
                        yield self.add_users_in_public_rooms(
                            room_id, to_insert)
                        to_insert.clear()
                else:
                    for user_id in user_ids:
                        if not self.hs.is_mine_id(user_id):
                            continue

                        if self.get_if_app_services_interested_in_user(
                                user_id):
                            continue

                        for other_user_id in user_ids:
                            if user_id == other_user_id:
                                continue

                            user_set = (user_id, other_user_id)
                            to_insert.add(user_set)

                            # If it gets too big, stop and write to the database
                            # to prevent storing too much in RAM.
                            if len(to_insert
                                   ) >= self.SHARE_PRIVATE_WORKING_SET:
                                yield self.add_users_who_share_private_room(
                                    room_id, to_insert)
                                to_insert.clear()

                    if to_insert:
                        yield self.add_users_who_share_private_room(
                            room_id, to_insert)
                        to_insert.clear()

            # We've finished a room. Delete it from the table.
            yield self.db.simple_delete_one(TEMP_TABLE + "_rooms",
                                            {"room_id": room_id})
            # Update the remaining counter.
            progress["remaining"] -= 1
            yield self.db.runInteraction(
                "populate_user_directory",
                self.db.updates._background_update_progress_txn,
                "populate_user_directory_process_rooms",
                progress,
            )

            processed_event_count += event_count

            if processed_event_count > batch_size:
                # Don't process any more rooms, we've hit our batch size.
                return processed_event_count

        return processed_event_count
 def _rotate(stream):
     return defer.ensureDeferred(
         self.store.db_pool.runInteraction(
             "", self.store._rotate_notifs_before_txn, stream
         )
     )
Exemple #29
0
 def wrapper(*args, **kwds):
     return ensureDeferred(func(*args, **kwds))
 def wrap(self, *args, **kwargs):
     return ensureDeferred(*args, **kwargs)
 def test_context(self):
     return ensureDeferred(self._context_test())
Exemple #32
0
    def trigger(
        self, http_method, path, content, mock_request, federation_auth_origin=None
    ):
        """ Fire an HTTP event.

        Args:
            http_method : The HTTP method
            path : The HTTP path
            content : The HTTP body
            mock_request : Mocked request to pass to the event so it can get
                           content.
            federation_auth_origin (bytes|None): domain to authenticate as, for federation
        Returns:
            A tuple of (code, response)
        Raises:
            KeyError If no event is found which will handle the path.
        """
        path = self.prefix + path

        # annoyingly we return a twisted http request which has chained calls
        # to get at the http content, hence mock it here.
        mock_content = Mock()
        config = {"read.return_value": content}
        mock_content.configure_mock(**config)
        mock_request.content = mock_content

        mock_request.method = http_method.encode("ascii")
        mock_request.uri = path.encode("ascii")

        mock_request.getClientIP.return_value = "-"

        headers = {}
        if federation_auth_origin is not None:
            headers[b"Authorization"] = [
                b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
            ]
        mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)

        # return the right path if the event requires it
        mock_request.path = path

        # add in query params to the right place
        try:
            mock_request.args = urlparse.parse_qs(path.split("?")[1])
            mock_request.path = path.split("?")[0]
            path = mock_request.path
        except Exception:
            pass

        if isinstance(path, bytes):
            path = path.decode("utf8")

        for (method, pattern, func) in self.callbacks:
            if http_method != method:
                continue

            matcher = pattern.match(path)
            if matcher:
                try:
                    args = [urlparse.unquote(u) for u in matcher.groups()]

                    (code, response) = yield defer.ensureDeferred(
                        func(mock_request, *args)
                    )
                    return code, response
                except CodeMessageException as e:
                    return (e.code, cs_error(e.msg, code=e.errcode))

        raise KeyError("No event can handle %s" % path)
Exemple #33
0
def _flattenElement(request, root, write, slotData, renderFactory,
                    dataEscaper):
    """
    Make C{root} slightly more flat by yielding all its immediate contents as
    strings, deferreds or generators that are recursive calls to itself.

    @param request: A request object which will be passed to
        L{IRenderable.render}.

    @param root: An object to be made flatter.  This may be of type C{unicode},
        L{str}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple}, L{list},
        L{types.GeneratorType}, L{Deferred}, or an object that implements
        L{IRenderable}.

    @param write: A callable which will be invoked with each L{bytes} produced
        by flattening C{root}.

    @param slotData: A L{list} of L{dict} mapping L{str} slot names to data
        with which those slots will be replaced.

    @param renderFactory: If not L{None}, an object that provides
        L{IRenderable}.

    @param dataEscaper: A 1-argument callable which takes L{bytes} or
        L{unicode} and returns L{bytes}, quoted as appropriate for the
        rendering context.  This is really only one of two values:
        L{attributeEscapingDoneOutside} or L{escapeForContent}, depending on
        whether the rendering context is within an attribute or not.  See the
        explanation in L{writeWithAttributeEscaping}.

    @return: An iterator that eventually yields L{bytes} that should be written
        to the output.  However it may also yield other iterators or
        L{Deferred}s; if it yields another iterator, the caller will iterate
        it; if it yields a L{Deferred}, the result of that L{Deferred} will
        either be L{bytes}, in which case it's written, or another generator,
        in which case it is iterated.  See L{_flattenTree} for the trampoline
        that consumes said values.
    @rtype: An iterator which yields L{bytes}, L{Deferred}, and more iterators
        of the same type.
    """
    def keepGoing(newRoot, dataEscaper=dataEscaper,
                  renderFactory=renderFactory, write=write):
        return _flattenElement(request, newRoot, write, slotData,
                               renderFactory, dataEscaper)
    if isinstance(root, (bytes, unicode)):
        write(dataEscaper(root))
    elif isinstance(root, slot):
        slotValue = _getSlotValue(root.name, slotData, root.default)
        yield keepGoing(slotValue)
    elif isinstance(root, CDATA):
        write(b'<![CDATA[')
        write(escapedCDATA(root.data))
        write(b']]>')
    elif isinstance(root, Comment):
        write(b'<!--')
        write(escapedComment(root.data))
        write(b'-->')
    elif isinstance(root, Tag):
        slotData.append(root.slotData)
        if root.render is not None:
            rendererName = root.render
            rootClone = root.clone(False)
            rootClone.render = None
            renderMethod = renderFactory.lookupRenderMethod(rendererName)
            result = renderMethod(request, rootClone)
            yield keepGoing(result)
            slotData.pop()
            return

        if not root.tagName:
            yield keepGoing(root.children)
            return

        write(b'<')
        if isinstance(root.tagName, unicode):
            tagName = root.tagName.encode('ascii')
        else:
            tagName = root.tagName
        write(tagName)
        for k, v in iteritems(root.attributes):
            if isinstance(k, unicode):
                k = k.encode('ascii')
            write(b' ' + k + b'="')
            # Serialize the contents of the attribute, wrapping the results of
            # that serialization so that _everything_ is quoted.
            yield keepGoing(
                v,
                attributeEscapingDoneOutside,
                write=writeWithAttributeEscaping(write))
            write(b'"')
        if root.children or nativeString(tagName) not in voidElements:
            write(b'>')
            # Regardless of whether we're in an attribute or not, switch back
            # to the escapeForContent dataEscaper.  The contents of a tag must
            # be quoted no matter what; in the top-level document, just so
            # they're valid, and if they're within an attribute, they have to
            # be quoted so that after applying the *un*-quoting required to re-
            # parse the tag within the attribute, all the quoting is still
            # correct.
            yield keepGoing(root.children, escapeForContent)
            write(b'</' + tagName + b'>')
        else:
            write(b' />')

    elif isinstance(root, (tuple, list, GeneratorType)):
        for element in root:
            yield keepGoing(element)
    elif isinstance(root, CharRef):
        escaped = '&#%d;' % (root.ordinal,)
        write(escaped.encode('ascii'))
    elif isinstance(root, Deferred):
        yield root.addCallback(lambda result: (result, keepGoing(result)))
    elif iscoroutine(root):
        d = ensureDeferred(root)
        yield d.addCallback(lambda result: (result, keepGoing(result)))
    elif IRenderable.providedBy(root):
        result = root.render(request)
        yield keepGoing(result, renderFactory=root)
    else:
        raise UnsupportedType(root)
Exemple #34
0
def setup(config_options):
    """
    Args:
        config_options_options: The options passed to Synapse. Usually
            `sys.argv[1:]`.

    Returns:
        HomeServer
    """
    try:
        config = HomeServerConfig.load_or_generate_config(
            "Synapse Homeserver", config_options)
    except ConfigError as e:
        sys.stderr.write("\nERROR: %s\n" % (e, ))
        sys.exit(1)

    if not config:
        # If a config isn't returned, and an exception isn't raised, we're just
        # generating config files and shouldn't try to continue.
        sys.exit(0)

    events.USE_FROZEN_DICTS = config.use_frozen_dicts

    hs = SynapseHomeServer(
        config.server_name,
        config=config,
        version_string="Synapse/" + get_version_string(synapse),
    )

    synapse.config.logger.setup_logging(hs, config, use_worker_options=False)

    logger.info("Setting up server")

    try:
        hs.setup()
    except IncorrectDatabaseSetup as e:
        quit_with_error(str(e))
    except UpgradeDatabaseException as e:
        quit_with_error("Failed to upgrade database: %s" % (e, ))

    hs.setup_master()

    async def do_acme() -> bool:
        """
        Reprovision an ACME certificate, if it's required.

        Returns:
            Whether the cert has been updated.
        """
        acme = hs.get_acme_handler()

        # Check how long the certificate is active for.
        cert_days_remaining = hs.config.is_disk_cert_valid(
            allow_self_signed=False)

        # We want to reprovision if cert_days_remaining is None (meaning no
        # certificate exists), or the days remaining number it returns
        # is less than our re-registration threshold.
        provision = False

        if (cert_days_remaining is None
                or cert_days_remaining < hs.config.acme_reprovision_threshold):
            provision = True

        if provision:
            await acme.provision_certificate()

        return provision

    async def reprovision_acme():
        """
        Provision a certificate from ACME, if required, and reload the TLS
        certificate if it's renewed.
        """
        reprovisioned = await do_acme()
        if reprovisioned:
            _base.refresh_certificate(hs)

    async def start():
        try:
            # Run the ACME provisioning code, if it's enabled.
            if hs.config.acme_enabled:
                acme = hs.get_acme_handler()
                # Start up the webservices which we will respond to ACME
                # challenges with, and then provision.
                await acme.start_listening()
                await do_acme()

                # Check if it needs to be reprovisioned every day.
                hs.get_clock().looping_call(reprovision_acme,
                                            24 * 60 * 60 * 1000)

            # Load the OIDC provider metadatas, if OIDC is enabled.
            if hs.config.oidc_enabled:
                oidc = hs.get_oidc_handler()
                # Loading the provider metadata also ensures the provider config is valid.
                await oidc.load_metadata()
                await oidc.load_jwks()

            _base.start(hs, config.listeners)

            hs.get_datastore().db_pool.updates.start_doing_background_updates()
        except Exception:
            # Print the exception and bail out.
            print("Error during startup:", file=sys.stderr)

            # this gives better tracebacks than traceback.print_exc()
            Failure().printTraceback(file=sys.stderr)

            if reactor.running:
                reactor.stop()
            sys.exit(1)

    reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))

    return hs
Exemple #35
0
def start(config_options: List[str]) -> None:
    parser = argparse.ArgumentParser(description="Synapse Admin Command")
    HomeServerConfig.add_arguments_to_parser(parser)

    subparser = parser.add_subparsers(
        title="Admin Commands",
        required=True,
        dest="command",
        metavar="<admin_command>",
        help="The admin command to perform.",
    )
    export_data_parser = subparser.add_parser(
        "export-data", help="Export all data for a user")
    export_data_parser.add_argument("user_id", help="User to extra data from")
    export_data_parser.add_argument(
        "--output-directory",
        action="store",
        metavar="DIRECTORY",
        required=False,
        help=
        "The directory to store the exported data in. Must be empty. Defaults"
        " to creating a temp directory.",
    )
    export_data_parser.set_defaults(func=export_data_command)

    try:
        config, args = HomeServerConfig.load_config_with_parser(
            parser, config_options)
    except ConfigError as e:
        sys.stderr.write("\n" + str(e) + "\n")
        sys.exit(1)

    if config.worker.worker_app is not None:
        assert config.worker.worker_app == "synapse.app.admin_cmd"

    # Update the config with some basic overrides so that don't have to specify
    # a full worker config.
    config.worker.worker_app = "synapse.app.admin_cmd"

    if not config.worker.worker_daemonize and not config.worker.worker_log_config:
        # Since we're meant to be run as a "command" let's not redirect stdio
        # unless we've actually set log config.
        config.logging.no_redirect_stdio = True

    # Explicitly disable background processes
    config.worker.should_update_user_directory = False
    config.worker.run_background_tasks = False
    config.worker.start_pushers = False
    config.worker.pusher_shard_config.instances = []
    config.worker.send_federation = False
    config.worker.federation_shard_config.instances = []

    synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts

    ss = AdminCmdServer(
        config.server.server_name,
        config=config,
        version_string="Synapse/" +
        get_distribution_version_string("matrix-synapse"),
    )

    setup_logging(ss, config, use_worker_options=True)

    ss.setup()

    # We use task.react as the basic run command as it correctly handles tearing
    # down the reactor when the deferreds resolve and setting the return value.
    # We also make sure that `_base.start` gets run before we actually run the
    # command.
    async def run() -> None:
        with LoggingContext("command"):
            await _base.start(ss)
            await args.func(ss, args)

    _base.start_worker_reactor(
        "synapse-admin-cmd",
        config,
        run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(
            run())),
    )
    def test_count_aggregation(self):
        room_id = "!foo:example.com"
        user_id = "@user1235:example.com"

        @defer.inlineCallbacks
        def _assert_counts(noitf_count, highlight_count):
            counts = yield defer.ensureDeferred(
                self.store.db_pool.runInteraction(
                    "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
                )
            )
            self.assertEquals(
                counts,
                {
                    "notify_count": noitf_count,
                    "unread_count": 0,  # Unread counts are tested in the sync tests.
                    "highlight_count": highlight_count,
                },
            )

        @defer.inlineCallbacks
        def _inject_actions(stream, action):
            event = Mock()
            event.room_id = room_id
            event.event_id = "$test:example.com"
            event.internal_metadata.stream_ordering = stream
            event.depth = stream

            yield defer.ensureDeferred(
                self.store.add_push_actions_to_staging(
                    event.event_id,
                    {user_id: action},
                    False,
                )
            )
            yield defer.ensureDeferred(
                self.store.db_pool.runInteraction(
                    "",
                    self.persist_events_store._set_push_actions_for_event_and_users_txn,
                    [(event, None)],
                    [(event, None)],
                )
            )

        def _rotate(stream):
            return defer.ensureDeferred(
                self.store.db_pool.runInteraction(
                    "", self.store._rotate_notifs_before_txn, stream
                )
            )

        def _mark_read(stream, depth):
            return defer.ensureDeferred(
                self.store.db_pool.runInteraction(
                    "",
                    self.store._remove_old_push_actions_before_txn,
                    room_id,
                    user_id,
                    stream,
                )
            )

        yield _assert_counts(0, 0)
        yield _inject_actions(1, PlAIN_NOTIF)
        yield _assert_counts(1, 0)
        yield _rotate(2)
        yield _assert_counts(1, 0)

        yield _inject_actions(3, PlAIN_NOTIF)
        yield _assert_counts(2, 0)
        yield _rotate(4)
        yield _assert_counts(2, 0)

        yield _inject_actions(5, PlAIN_NOTIF)
        yield _mark_read(3, 3)
        yield _assert_counts(1, 0)

        yield _mark_read(5, 5)
        yield _assert_counts(0, 0)

        yield _inject_actions(6, PlAIN_NOTIF)
        yield _rotate(7)

        yield defer.ensureDeferred(
            self.store.db_pool.simple_delete(
                table="event_push_actions", keyvalues={"1": 1}, desc=""
            )
        )

        yield _assert_counts(1, 0)

        yield _mark_read(7, 7)
        yield _assert_counts(0, 0)

        yield _inject_actions(8, HIGHLIGHT)
        yield _assert_counts(1, 1)
        yield _rotate(9)
        yield _assert_counts(1, 1)
        yield _rotate(10)
        yield _assert_counts(1, 1)
Exemple #37
0
 def render(self, request):
     """ This gets called by twisted every time someone sends us a request.
     """
     defer.ensureDeferred(self._async_render(request))
     return NOT_DONE_YET
def twisted_main(reactor):
    return ensureDeferred(main(TwistedBackend(reactor)))
def ensureDeferred(fun, *args, **kw):
    return defer.ensureDeferred(fun(*args, **kw))
def main():
    return react(
        lambda reactor: ensureDeferred(
            _main(reactor)
        )
    )
Exemple #41
0
 def _go():
     d = defer.ensureDeferred(_startup(reactor))
     d.addErrback(_the_bad_stuff)
     d.addBoth(lambda _: reactor.stop())
Exemple #42
0
 def lineReceived(self, line):
     """
     This method is invoked by LineOnlyReceiver for every incoming line.
     """
     self.request = line
     return ensureDeferred(self._handle_request_noblock())