Beispiel #1
0
    def on_GET(self, request):
        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
        user = yield self.store.get_user_by_id(requester.user.to_string())
        change_password = bool(user["password_hash"])

        response = {
            "capabilities": {
                "m.room_versions": {
                    "default": DEFAULT_ROOM_VERSION.identifier,
                    "available": {
                        v.identifier: v.disposition
                        for v in KNOWN_ROOM_VERSIONS.values()
                    },
                },
                "m.change_password": {"enabled": change_password},
            }
        }
        defer.returnValue((200, response))
Beispiel #2
0
def room_version_to_event_format(room_version):
    """Converts a room version string to the event format

    Args:
        room_version (str)

    Returns:
        int

    Raises:
        UnsupportedRoomVersionError if the room version is unknown
    """
    v = KNOWN_ROOM_VERSIONS.get(room_version)

    if not v:
        # this can happen if support is withdrawn for a room version
        raise UnsupportedRoomVersionError()

    return v.event_format
Beispiel #3
0
    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
        await self.auth.get_user_by_req(request, allow_guest=True)
        change_password = self.auth_handler.can_change_password()

        response: JsonDict = {
            "capabilities": {
                "m.room_versions": {
                    "default":
                    self.config.server.default_room_version.identifier,
                    "available": {
                        v.identifier: v.disposition
                        for v in KNOWN_ROOM_VERSIONS.values()
                    },
                },
                "m.change_password": {
                    "enabled": change_password
                },
                "m.set_displayname": {
                    "enabled": self.config.registration.enable_set_displayname
                },
                "m.set_avatar_url": {
                    "enabled": self.config.registration.enable_set_avatar_url
                },
                "m.3pid_changes": {
                    "enabled": self.config.registration.enable_3pid_changes
                },
            }
        }

        if self.config.experimental.msc3244_enabled:
            response["capabilities"]["m.room_versions"][
                "org.matrix.msc3244.room_capabilities"] = MSC3244_CAPABILITIES

        if self.config.experimental.msc3720_enabled:
            response["capabilities"]["org.matrix.msc3720.account_status"] = {
                "enabled": True,
            }

        return HTTPStatus.OK, response
Beispiel #4
0
    def get_room_version_txn(
        self, txn: LoggingTransaction, room_id: str
    ) -> RoomVersion:
        """Get the room_version of a given room
        Args:
            txn: Transaction object
            room_id: The room_id of the room you are trying to get the version for
        Raises:
            NotFoundError: if the room is unknown
            UnsupportedRoomVersionError: if the room uses an unknown room version.
                Typically this happens if support for the room's version has been
                removed from Synapse.
        """
        room_version_id = self.get_room_version_id_txn(txn, room_id)
        v = KNOWN_ROOM_VERSIONS.get(room_version_id)

        if not v:
            raise UnsupportedRoomVersionError(
                "Room %s uses a room version %s which is no longer supported"
                % (room_id, room_version_id)
            )

        return v
Beispiel #5
0
def check_redaction(room_version, event, auth_events):
    """Check whether the event sender is allowed to redact the target event.

    Returns:
        True if the the sender is allowed to redact the target event if the
        target event was created by them.
        False if the sender is allowed to redact the target event with no
        further checks.

    Raises:
        AuthError if the event sender is definitely not allowed to redact
        the target event.
    """
    user_level = get_user_power_level(event.user_id, auth_events)

    redact_level = _get_named_level(auth_events, "redact", 50)

    if user_level >= redact_level:
        return False

    v = KNOWN_ROOM_VERSIONS.get(room_version)
    if not v:
        raise RuntimeError("Unrecognized room version %r" % (room_version,))

    if v.event_format == EventFormatVersions.V1:
        redacter_domain = get_domain_from_id(event.event_id)
        redactee_domain = get_domain_from_id(event.redacts)
        if redacter_domain == redactee_domain:
            return True
    else:
        event.internal_metadata.recheck_redaction = True
        return True

    raise AuthError(
        403,
        "You don't have permission to redact events"
    )
Beispiel #6
0
        async def send_request(
                destination: str) -> Tuple[str, EventBase, RoomVersion]:
            ret = await self.transport_layer.make_membership_event(
                destination, room_id, user_id, membership, params)

            # Note: If not supplied, the room version may be either v1 or v2,
            # however either way the event format version will be v1.
            room_version_id = ret.get("room_version",
                                      RoomVersions.V1.identifier)
            room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
            if not room_version:
                raise UnsupportedRoomVersionError()

            pdu_dict = ret.get("event", None)
            if not isinstance(pdu_dict, dict):
                raise InvalidResponseError("Bad 'event' field in response")

            logger.debug("Got response to make_%s: %s", membership, pdu_dict)

            pdu_dict["content"].update(content)

            # The protoevent received over the JSON wire may not have all
            # the required fields. Lets just gloss over that because
            # there's some we never care about
            if "prev_state" not in pdu_dict:
                pdu_dict["prev_state"] = []

            ev = builder.create_local_event_from_event_dict(
                self._clock,
                self.hostname,
                self.signing_key,
                room_version=room_version,
                event_dict=pdu_dict,
            )

            return destination, ev, room_version
Beispiel #7
0
    async def on_POST(self, request, room_id):
        requester = await self._auth.get_user_by_req(request)

        content = parse_json_object_from_request(request)
        assert_params_in_dict(content, ("new_version", ))

        new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
        if new_version is None:
            raise SynapseError(
                400,
                "Your homeserver does not support this room version",
                Codes.UNSUPPORTED_ROOM_VERSION,
            )

        try:
            new_room_id = await self._room_creation_handler.upgrade_room(
                requester, room_id, new_version)
        except ShadowBanError:
            # Generate a random room ID.
            new_room_id = stringutils.random_string(18)

        ret = {"replacement_room": new_room_id}

        return 200, ret
Beispiel #8
0
    def read_config(self, config, **kwargs):
        self.server_name = config["server_name"]
        self.server_context = config.get("server_context", None)

        try:
            parse_and_validate_server_name(self.server_name)
        except ValueError as e:
            raise ConfigError(str(e))

        self.pid_file = self.abspath(config.get("pid_file"))
        self.web_client_location = config.get("web_client_location", None)
        self.soft_file_limit = config.get("soft_file_limit", 0)
        self.daemonize = config.get("daemonize")
        self.print_pidfile = config.get("print_pidfile")
        self.user_agent_suffix = config.get("user_agent_suffix")
        self.use_frozen_dicts = config.get("use_frozen_dicts", False)
        self.public_baseurl = config.get("public_baseurl")

        # Whether to send federation traffic out in this process. This only
        # applies to some federation traffic, and so shouldn't be used to
        # "disable" federation
        self.send_federation = config.get("send_federation", True)

        # Whether to enable user presence.
        self.use_presence = config.get("use_presence", True)

        # Whether to update the user directory or not. This should be set to
        # false only if we are updating the user directory in a worker
        self.update_user_directory = config.get("update_user_directory", True)

        # whether to enable the media repository endpoints. This should be set
        # to false if the media repository is running as a separate endpoint;
        # doing so ensures that we will not run cache cleanup jobs on the
        # master, potentially causing inconsistency.
        self.enable_media_repo = config.get("enable_media_repo", True)

        # Whether to require authentication to retrieve profile data (avatars,
        # display names) of other users through the client API.
        self.require_auth_for_profile_requests = config.get(
            "require_auth_for_profile_requests", False)

        if "restrict_public_rooms_to_local_users" in config and (
                "allow_public_rooms_without_auth" in config
                or "allow_public_rooms_over_federation" in config):
            raise ConfigError(
                "Can't use 'restrict_public_rooms_to_local_users' if"
                " 'allow_public_rooms_without_auth' and/or"
                " 'allow_public_rooms_over_federation' is set.")

        # Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This
        # flag is now obsolete but we need to check it for backward-compatibility.
        if config.get("restrict_public_rooms_to_local_users", False):
            self.allow_public_rooms_without_auth = False
            self.allow_public_rooms_over_federation = False
        else:
            # If set to 'False', requires authentication to access the server's public
            # rooms directory through the client API. Defaults to 'True'.
            self.allow_public_rooms_without_auth = config.get(
                "allow_public_rooms_without_auth", True)
            # If set to 'False', forbids any other homeserver to fetch the server's public
            # rooms directory via federation. Defaults to 'True'.
            self.allow_public_rooms_over_federation = config.get(
                "allow_public_rooms_over_federation", True)

        default_room_version = config.get("default_room_version",
                                          DEFAULT_ROOM_VERSION)

        # Ensure room version is a str
        default_room_version = str(default_room_version)

        if default_room_version not in KNOWN_ROOM_VERSIONS:
            raise ConfigError(
                "Unknown default_room_version: %s, known room versions: %s" %
                (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())))

        # Get the actual room version object rather than just the identifier
        self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]

        # whether to enable search. If disabled, new entries will not be inserted
        # into the search tables and they will not be indexed. Users will receive
        # errors when attempting to search for messages.
        self.enable_search = config.get("enable_search", True)

        self.filter_timeline_limit = config.get("filter_timeline_limit", -1)

        # Whether we should block invites sent to users on this server
        # (other than those sent by local server admins)
        self.block_non_admin_invites = config.get("block_non_admin_invites",
                                                  False)

        # Whether to enable experimental MSC1849 (aka relations) support
        self.experimental_msc1849_support_enabled = config.get(
            "experimental_msc1849_support_enabled", True)

        # Options to control access by tracking MAU
        self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
        self.max_mau_value = 0
        if self.limit_usage_by_mau:
            self.max_mau_value = config.get("max_mau_value", 0)
        self.mau_stats_only = config.get("mau_stats_only", False)

        self.mau_limits_reserved_threepids = config.get(
            "mau_limit_reserved_threepids", [])

        self.mau_trial_days = config.get("mau_trial_days", 0)

        # Options to disable HS
        self.hs_disabled = config.get("hs_disabled", False)
        self.hs_disabled_message = config.get("hs_disabled_message", "")
        self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "")

        # Admin uri to direct users at should their instance become blocked
        # due to resource constraints
        self.admin_contact = config.get("admin_contact", None)

        # FIXME: federation_domain_whitelist needs sytests
        self.federation_domain_whitelist = None
        federation_domain_whitelist = config.get("federation_domain_whitelist",
                                                 None)

        if federation_domain_whitelist is not None:
            # turn the whitelist into a hash for speed of lookup
            self.federation_domain_whitelist = {}

            for domain in federation_domain_whitelist:
                self.federation_domain_whitelist[domain] = True

        self.federation_ip_range_blacklist = config.get(
            "federation_ip_range_blacklist", [])

        # Attempt to create an IPSet from the given ranges
        try:
            self.federation_ip_range_blacklist = IPSet(
                self.federation_ip_range_blacklist)

            # Always blacklist 0.0.0.0, ::
            self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
        except Exception as e:
            raise ConfigError("Invalid range(s) provided in "
                              "federation_ip_range_blacklist: %s" % e)

        if self.public_baseurl is not None:
            if self.public_baseurl[-1] != "/":
                self.public_baseurl += "/"
        self.start_pushers = config.get("start_pushers", True)

        # (undocumented) option for torturing the worker-mode replication a bit,
        # for testing. The value defines the number of milliseconds to pause before
        # sending out any replication updates.
        self.replication_torture_level = config.get(
            "replication_torture_level")

        # Whether to require a user to be in the room to add an alias to it.
        # Defaults to True.
        self.require_membership_for_aliases = config.get(
            "require_membership_for_aliases", True)

        # Whether to allow per-room membership profiles through the send of membership
        # events with profile information that differ from the target's global profile.
        self.allow_per_room_profiles = config.get("allow_per_room_profiles",
                                                  True)

        self.listeners = []
        for listener in config.get("listeners", []):
            if not isinstance(listener.get("port", None), int):
                raise ConfigError(
                    "Listener configuration is lacking a valid 'port' option")

            if listener.setdefault("tls", False):
                # no_tls is not really supported any more, but let's grandfather it in
                # here.
                if config.get("no_tls", False):
                    logger.info(
                        "Ignoring TLS-enabled listener on port %i due to no_tls"
                    )
                    continue

            bind_address = listener.pop("bind_address", None)
            bind_addresses = listener.setdefault("bind_addresses", [])

            # if bind_address was specified, add it to the list of addresses
            if bind_address:
                bind_addresses.append(bind_address)

            # if we still have an empty list of addresses, use the default list
            if not bind_addresses:
                if listener["type"] == "metrics":
                    # the metrics listener doesn't support IPv6
                    bind_addresses.append("0.0.0.0")
                else:
                    bind_addresses.extend(DEFAULT_BIND_ADDRESSES)

            self.listeners.append(listener)

        if not self.web_client_location:
            _warn_if_webclient_configured(self.listeners)

        self.gc_thresholds = read_gc_thresholds(
            config.get("gc_thresholds", None))

        bind_port = config.get("bind_port")
        if bind_port:
            if config.get("no_tls", False):
                raise ConfigError("no_tls is incompatible with bind_port")

            self.listeners = []
            bind_host = config.get("bind_host", "")
            gzip_responses = config.get("gzip_responses", True)

            self.listeners.append({
                "port":
                bind_port,
                "bind_addresses": [bind_host],
                "tls":
                True,
                "type":
                "http",
                "resources": [
                    {
                        "names": ["client"],
                        "compress": gzip_responses
                    },
                    {
                        "names": ["federation"],
                        "compress": False
                    },
                ],
            })

            unsecure_port = config.get("unsecure_port", bind_port - 400)
            if unsecure_port:
                self.listeners.append({
                    "port":
                    unsecure_port,
                    "bind_addresses": [bind_host],
                    "tls":
                    False,
                    "type":
                    "http",
                    "resources": [
                        {
                            "names": ["client"],
                            "compress": gzip_responses
                        },
                        {
                            "names": ["federation"],
                            "compress": False
                        },
                    ],
                })

        manhole = config.get("manhole")
        if manhole:
            self.listeners.append({
                "port": manhole,
                "bind_addresses": ["127.0.0.1"],
                "type": "manhole",
                "tls": False,
            })

        metrics_port = config.get("metrics_port")
        if metrics_port:
            logger.warn((
                "The metrics_port configuration option is deprecated in Synapse 0.31 "
                "in favour of a listener. Please see "
                "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
                " on how to configure the new listener."))

            self.listeners.append({
                "port":
                metrics_port,
                "bind_addresses":
                [config.get("metrics_bind_host", "127.0.0.1")],
                "tls":
                False,
                "type":
                "http",
                "resources": [{
                    "names": ["metrics"],
                    "compress": False
                }],
            })

        _check_resource_config(self.listeners)

        # An experimental option to try and periodically clean up extremities
        # by sending dummy events.
        self.cleanup_extremities_with_dummy_events = config.get(
            "cleanup_extremities_with_dummy_events", False)
Beispiel #9
0
    def _do_send_invite(self, destination, pdu, room_version):
        """Actually sends the invite, first trying v2 API and falling back to
        v1 API if necessary.

        Args:
            destination (str): Target server
            pdu (FrozenEvent)
            room_version (str)

        Returns:
            dict: The event as a dict as returned by the remote server
        """
        time_now = self._clock.time_msec()

        try:
            content = yield self.transport_layer.send_invite_v2(
                destination=destination,
                room_id=pdu.room_id,
                event_id=pdu.event_id,
                content={
                    "event": pdu.get_pdu_json(time_now),
                    "room_version": room_version,
                    "invite_room_state": pdu.unsigned.get("invite_room_state", []),
                },
            )
            return content
        except HttpResponseException as e:
            if e.code in [400, 404]:
                err = e.to_synapse_error()

                # If we receive an error response that isn't a generic error, we
                # assume that the remote understands the v2 invite API and this
                # is a legitimate error.
                if err.errcode != Codes.UNKNOWN:
                    raise err

                # Otherwise, we assume that the remote server doesn't understand
                # the v2 invite API. That's ok provided the room uses old-style event
                # IDs.
                v = KNOWN_ROOM_VERSIONS.get(room_version)
                if v.event_format != EventFormatVersions.V1:
                    raise SynapseError(
                        400,
                        "User's homeserver does not support this room version",
                        Codes.UNSUPPORTED_ROOM_VERSION,
                    )
            elif e.code == 403:
                raise e.to_synapse_error()
            else:
                raise

        # Didn't work, try v1 API.
        # Note the v1 API returns a tuple of `(200, content)`

        _, content = yield self.transport_layer.send_invite_v1(
            destination=destination,
            room_id=pdu.room_id,
            event_id=pdu.event_id,
            content=pdu.get_pdu_json(time_now),
        )
        return content
Beispiel #10
0
    def create_room(self,
                    requester,
                    config,
                    ratelimit=True,
                    creator_join_profile=None):
        """ Creates a new room.

        Args:
            requester (synapse.types.Requester):
                The user who requested the room creation.
            config (dict) : A dict of configuration options.
            ratelimit (bool): set to False to disable the rate limiter

            creator_join_profile (dict|None):
                Set to override the displayname and avatar for the creating
                user in this room. If unset, displayname and avatar will be
                derived from the user's profile. If set, should contain the
                values to go in the body of the 'join' event (typically
                `avatar_url` and/or `displayname`.

        Returns:
            Deferred[dict]:
                a dict containing the keys `room_id` and, if an alias was
                requested, `room_alias`.
        Raises:
            SynapseError if the room ID couldn't be stored, or something went
            horribly wrong.
            ResourceLimitError if server is blocked to some resource being
            exceeded
        """
        user_id = requester.user.to_string()

        yield self.auth.check_auth_blocking(user_id)

        if (self._server_notices_mxid is not None
                and requester.user.to_string() == self._server_notices_mxid):
            # allow the server notices mxid to create rooms
            is_requester_admin = True
        else:
            is_requester_admin = yield self.auth.is_server_admin(
                requester.user)

        # Check whether the third party rules allows/changes the room create
        # request.
        yield self.third_party_event_rules.on_create_room(
            requester, config, is_requester_admin=is_requester_admin)

        if not is_requester_admin and not self.spam_checker.user_may_create_room(
                user_id):
            raise SynapseError(403, "You are not permitted to create rooms")

        if ratelimit:
            yield self.ratelimit(requester)

        room_version_id = config.get(
            "room_version", self.config.default_room_version.identifier)

        if not isinstance(room_version_id, string_types):
            raise SynapseError(400, "room_version must be a string",
                               Codes.BAD_JSON)

        room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
        if room_version is None:
            raise SynapseError(
                400,
                "Your homeserver does not support this room version",
                Codes.UNSUPPORTED_ROOM_VERSION,
            )

        if "room_alias_name" in config:
            for wchar in string.whitespace:
                if wchar in config["room_alias_name"]:
                    raise SynapseError(400, "Invalid characters in room alias")

            room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
            mapping = yield self.store.get_association_from_room_alias(
                room_alias)

            if mapping:
                raise SynapseError(400, "Room alias already taken",
                                   Codes.ROOM_IN_USE)
        else:
            room_alias = None

        invite_list = config.get("invite", [])
        for i in invite_list:
            try:
                uid = UserID.from_string(i)
                parse_and_validate_server_name(uid.domain)
            except Exception:
                raise SynapseError(400, "Invalid user_id: %s" % (i, ))

        yield self.event_creation_handler.assert_accepted_privacy_policy(
            requester)

        power_level_content_override = config.get(
            "power_level_content_override")
        if (power_level_content_override
                and "users" in power_level_content_override
                and user_id not in power_level_content_override["users"]):
            raise SynapseError(
                400,
                "Not a valid power_level_content_override: 'users' did not contain %s"
                % (user_id, ),
            )

        invite_3pid_list = config.get("invite_3pid", [])

        visibility = config.get("visibility", None)
        is_public = visibility == "public"

        room_id = yield self._generate_room_id(
            creator_id=user_id,
            is_public=is_public,
            room_version=room_version,
        )

        directory_handler = self.hs.get_handlers().directory_handler
        if room_alias:
            yield directory_handler.create_association(
                requester=requester,
                room_id=room_id,
                room_alias=room_alias,
                servers=[self.hs.hostname],
                send_event=False,
                check_membership=False,
            )

        preset_config = config.get(
            "preset",
            RoomCreationPreset.PRIVATE_CHAT
            if visibility == "private" else RoomCreationPreset.PUBLIC_CHAT,
        )

        raw_initial_state = config.get("initial_state", [])

        initial_state = OrderedDict()
        for val in raw_initial_state:
            initial_state[(val["type"], val.get("state_key",
                                                ""))] = val["content"]

        creation_content = config.get("creation_content", {})

        # override any attempt to set room versions via the creation_content
        creation_content["room_version"] = room_version.identifier

        yield self._send_events_for_new_room(
            requester,
            room_id,
            preset_config=preset_config,
            invite_list=invite_list,
            initial_state=initial_state,
            creation_content=creation_content,
            room_alias=room_alias,
            power_level_content_override=power_level_content_override,
            creator_join_profile=creator_join_profile,
        )

        if "name" in config:
            name = config["name"]
            yield self.event_creation_handler.create_and_send_nonmember_event(
                requester,
                {
                    "type": EventTypes.Name,
                    "room_id": room_id,
                    "sender": user_id,
                    "state_key": "",
                    "content": {
                        "name": name
                    },
                },
                ratelimit=False,
            )

        if "topic" in config:
            topic = config["topic"]
            yield self.event_creation_handler.create_and_send_nonmember_event(
                requester,
                {
                    "type": EventTypes.Topic,
                    "room_id": room_id,
                    "sender": user_id,
                    "state_key": "",
                    "content": {
                        "topic": topic
                    },
                },
                ratelimit=False,
            )

        for invitee in invite_list:
            content = {}
            is_direct = config.get("is_direct", None)
            if is_direct:
                content["is_direct"] = is_direct

            yield self.room_member_handler.update_membership(
                requester,
                UserID.from_string(invitee),
                room_id,
                "invite",
                ratelimit=False,
                content=content,
            )

        for invite_3pid in invite_3pid_list:
            id_server = invite_3pid["id_server"]
            id_access_token = invite_3pid.get("id_access_token")  # optional
            address = invite_3pid["address"]
            medium = invite_3pid["medium"]
            yield self.hs.get_room_member_handler().do_3pid_invite(
                room_id,
                requester.user,
                medium,
                address,
                id_server,
                requester,
                txn_id=None,
                id_access_token=id_access_token,
            )

        result = {"room_id": room_id}

        if room_alias:
            result["room_alias"] = room_alias.to_string()
            yield directory_handler.send_room_alias_update_event(
                requester, room_id)

        return result
Beispiel #11
0
    def read_config(self, config: JsonDict, **kwargs: Any) -> None:
        self.server_name = config["server_name"]
        self.server_context = config.get("server_context", None)

        try:
            parse_and_validate_server_name(self.server_name)
        except ValueError as e:
            raise ConfigError(str(e))

        self.pid_file = self.abspath(config.get("pid_file"))
        self.soft_file_limit = config.get("soft_file_limit", 0)
        self.daemonize = bool(config.get("daemonize"))
        self.print_pidfile = bool(config.get("print_pidfile"))
        self.user_agent_suffix = config.get("user_agent_suffix")
        self.use_frozen_dicts = config.get("use_frozen_dicts", False)
        self.serve_server_wellknown = config.get("serve_server_wellknown", False)

        # Whether we should serve a "client well-known":
        #  (a) at .well-known/matrix/client on our client HTTP listener
        #  (b) in the response to /login
        #
        # ... which together help ensure that clients use our public_baseurl instead of
        # whatever they were told by the user.
        #
        # For the sake of backwards compatibility with existing installations, this is
        # True if public_baseurl is specified explicitly, and otherwise False. (The
        # reasoning here is that we have no way of knowing that the default
        # public_baseurl is actually correct for existing installations - many things
        # will not work correctly, but that's (probably?) better than sending clients
        # to a completely broken URL.
        self.serve_client_wellknown = False

        public_baseurl = config.get("public_baseurl")
        if public_baseurl is None:
            public_baseurl = f"https://{self.server_name}/"
            logger.info("Using default public_baseurl %s", public_baseurl)
        else:
            self.serve_client_wellknown = True
            if public_baseurl[-1] != "/":
                public_baseurl += "/"
        self.public_baseurl = public_baseurl

        # check that public_baseurl is valid
        try:
            splits = urllib.parse.urlsplit(self.public_baseurl)
        except Exception as e:
            raise ConfigError(f"Unable to parse URL: {e}", ("public_baseurl",))
        if splits.scheme not in ("https", "http"):
            raise ConfigError(
                f"Invalid scheme '{splits.scheme}': only https and http are supported"
            )
        if splits.query or splits.fragment:
            raise ConfigError(
                "public_baseurl cannot contain query parameters or a #-fragment"
            )

        self.extra_well_known_client_content = config.get(
            "extra_well_known_client_content", {}
        )

        if not isinstance(self.extra_well_known_client_content, dict):
            raise ConfigError(
                "extra_well_known_content must be a dictionary of key-value pairs"
            )

        if "m.homeserver" in self.extra_well_known_client_content:
            raise ConfigError(
                "m.homeserver is not supported in extra_well_known_content, "
                "use public_baseurl in base config instead."
            )
        if "m.identity_server" in self.extra_well_known_client_content:
            raise ConfigError(
                "m.identity_server is not supported in extra_well_known_content, "
                "use default_identity_server in base config instead."
            )

        # Whether to enable user presence.
        presence_config = config.get("presence") or {}
        self.use_presence = presence_config.get("enabled")
        if self.use_presence is None:
            self.use_presence = config.get("use_presence", True)

        # Custom presence router module
        # This is the legacy way of configuring it (the config should now be put in the modules section)
        self.presence_router_module_class = None
        self.presence_router_config = None
        presence_router_config = presence_config.get("presence_router")
        if presence_router_config:
            (
                self.presence_router_module_class,
                self.presence_router_config,
            ) = load_module(presence_router_config, ("presence", "presence_router"))

        # whether to enable the media repository endpoints. This should be set
        # to false if the media repository is running as a separate endpoint;
        # doing so ensures that we will not run cache cleanup jobs on the
        # master, potentially causing inconsistency.
        self.enable_media_repo = config.get("enable_media_repo", True)

        # Whether to require authentication to retrieve profile data (avatars,
        # display names) of other users through the client API.
        self.require_auth_for_profile_requests = config.get(
            "require_auth_for_profile_requests", False
        )

        # Whether to require sharing a room with a user to retrieve their
        # profile data
        self.limit_profile_requests_to_users_who_share_rooms = config.get(
            "limit_profile_requests_to_users_who_share_rooms",
            False,
        )

        # Whether to retrieve and display profile data for a user when they
        # are invited to a room
        self.include_profile_data_on_invite = config.get(
            "include_profile_data_on_invite", True
        )

        if "restrict_public_rooms_to_local_users" in config and (
            "allow_public_rooms_without_auth" in config
            or "allow_public_rooms_over_federation" in config
        ):
            raise ConfigError(
                "Can't use 'restrict_public_rooms_to_local_users' if"
                " 'allow_public_rooms_without_auth' and/or"
                " 'allow_public_rooms_over_federation' is set."
            )

        # Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This
        # flag is now obsolete but we need to check it for backward-compatibility.
        if config.get("restrict_public_rooms_to_local_users", False):
            self.allow_public_rooms_without_auth = False
            self.allow_public_rooms_over_federation = False
        else:
            # If set to 'true', removes the need for authentication to access the server's
            # public rooms directory through the client API, meaning that anyone can
            # query the room directory. Defaults to 'false'.
            self.allow_public_rooms_without_auth = config.get(
                "allow_public_rooms_without_auth", False
            )
            # If set to 'true', allows any other homeserver to fetch the server's public
            # rooms directory via federation. Defaults to 'false'.
            self.allow_public_rooms_over_federation = config.get(
                "allow_public_rooms_over_federation", False
            )

        default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION)

        # Ensure room version is a str
        default_room_version = str(default_room_version)

        if default_room_version not in KNOWN_ROOM_VERSIONS:
            raise ConfigError(
                "Unknown default_room_version: %s, known room versions: %s"
                % (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
            )

        # Get the actual room version object rather than just the identifier
        self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]

        # whether to enable search. If disabled, new entries will not be inserted
        # into the search tables and they will not be indexed. Users will receive
        # errors when attempting to search for messages.
        self.enable_search = config.get("enable_search", True)

        self.filter_timeline_limit = config.get("filter_timeline_limit", 100)

        # Whether we should block invites sent to users on this server
        # (other than those sent by local server admins)
        self.block_non_admin_invites = config.get("block_non_admin_invites", False)

        # Options to control access by tracking MAU
        self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
        self.max_mau_value = 0
        if self.limit_usage_by_mau:
            self.max_mau_value = config.get("max_mau_value", 0)
        self.mau_stats_only = config.get("mau_stats_only", False)

        self.mau_limits_reserved_threepids = config.get(
            "mau_limit_reserved_threepids", []
        )

        self.mau_trial_days = config.get("mau_trial_days", 0)
        self.mau_appservice_trial_days = config.get("mau_appservice_trial_days", {})
        self.mau_limit_alerting = config.get("mau_limit_alerting", True)

        # How long to keep redacted events in the database in unredacted form
        # before redacting them.
        redaction_retention_period = config.get("redaction_retention_period", "7d")
        if redaction_retention_period is not None:
            self.redaction_retention_period: Optional[int] = self.parse_duration(
                redaction_retention_period
            )
        else:
            self.redaction_retention_period = None

        # How long to keep entries in the `users_ips` table.
        user_ips_max_age = config.get("user_ips_max_age", "28d")
        if user_ips_max_age is not None:
            self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age)
        else:
            self.user_ips_max_age = None

        # Options to disable HS
        self.hs_disabled = config.get("hs_disabled", False)
        self.hs_disabled_message = config.get("hs_disabled_message", "")

        # Admin uri to direct users at should their instance become blocked
        # due to resource constraints
        self.admin_contact = config.get("admin_contact", None)

        ip_range_blacklist = config.get(
            "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST
        )

        # Attempt to create an IPSet from the given ranges

        # Always blacklist 0.0.0.0, ::
        self.ip_range_blacklist = generate_ip_set(
            ip_range_blacklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",)
        )

        self.ip_range_whitelist = generate_ip_set(
            config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",)
        )
        # The federation_ip_range_blacklist is used for backwards-compatibility
        # and only applies to federation and identity servers.
        if "federation_ip_range_blacklist" in config:
            # Always blacklist 0.0.0.0, ::
            self.federation_ip_range_blacklist = generate_ip_set(
                config["federation_ip_range_blacklist"],
                ["0.0.0.0", "::"],
                config_path=("federation_ip_range_blacklist",),
            )
            # 'federation_ip_range_whitelist' was never a supported configuration option.
            self.federation_ip_range_whitelist = None
        else:
            # No backwards-compatiblity requrired, as federation_ip_range_blacklist
            # is not given. Default to ip_range_blacklist and ip_range_whitelist.
            self.federation_ip_range_blacklist = self.ip_range_blacklist
            self.federation_ip_range_whitelist = self.ip_range_whitelist

        # (undocumented) option for torturing the worker-mode replication a bit,
        # for testing. The value defines the number of milliseconds to pause before
        # sending out any replication updates.
        self.replication_torture_level = config.get("replication_torture_level")

        # Whether to require a user to be in the room to add an alias to it.
        # Defaults to True.
        self.require_membership_for_aliases = config.get(
            "require_membership_for_aliases", True
        )

        # Whether to allow per-room membership profiles through the send of membership
        # events with profile information that differ from the target's global profile.
        self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)

        # The maximum size an avatar can have, in bytes.
        self.max_avatar_size = config.get("max_avatar_size")
        if self.max_avatar_size is not None:
            self.max_avatar_size = self.parse_size(self.max_avatar_size)

        # The MIME types allowed for an avatar.
        self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes")
        if self.allowed_avatar_mimetypes and not isinstance(
            self.allowed_avatar_mimetypes,
            list,
        ):
            raise ConfigError("allowed_avatar_mimetypes must be a list")

        self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]

        # no_tls is not really supported any more, but let's grandfather it in
        # here.
        if config.get("no_tls", False):
            l2 = []
            for listener in self.listeners:
                if listener.tls:
                    logger.info(
                        "Ignoring TLS-enabled listener on port %i due to no_tls",
                        listener.port,
                    )
                else:
                    l2.append(listener)
            self.listeners = l2

        self.web_client_location = config.get("web_client_location", None)
        # Non-HTTP(S) web client location is not supported.
        if self.web_client_location and not (
            self.web_client_location.startswith("http://")
            or self.web_client_location.startswith("https://")
        ):
            raise ConfigError("web_client_location must point to a HTTP(S) URL.")

        self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
        self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None))

        self.limit_remote_rooms = LimitRemoteRoomsConfig(
            **(config.get("limit_remote_rooms") or {})
        )

        bind_port = config.get("bind_port")
        if bind_port:
            if config.get("no_tls", False):
                raise ConfigError("no_tls is incompatible with bind_port")

            self.listeners = []
            bind_host = config.get("bind_host", "")
            gzip_responses = config.get("gzip_responses", True)

            http_options = HttpListenerConfig(
                resources=[
                    HttpResourceConfig(names=["client"], compress=gzip_responses),
                    HttpResourceConfig(names=["federation"]),
                ],
            )

            self.listeners.append(
                ListenerConfig(
                    port=bind_port,
                    bind_addresses=[bind_host],
                    tls=True,
                    type="http",
                    http_options=http_options,
                )
            )

            unsecure_port = config.get("unsecure_port", bind_port - 400)
            if unsecure_port:
                self.listeners.append(
                    ListenerConfig(
                        port=unsecure_port,
                        bind_addresses=[bind_host],
                        tls=False,
                        type="http",
                        http_options=http_options,
                    )
                )

        manhole = config.get("manhole")
        if manhole:
            self.listeners.append(
                ListenerConfig(
                    port=manhole,
                    bind_addresses=["127.0.0.1"],
                    type="manhole",
                )
            )

        manhole_settings = config.get("manhole_settings") or {}
        validate_config(
            _MANHOLE_SETTINGS_SCHEMA, manhole_settings, ("manhole_settings",)
        )

        manhole_username = manhole_settings.get("username", "matrix")
        manhole_password = manhole_settings.get("password", "rabbithole")
        manhole_priv_key_path = manhole_settings.get("ssh_priv_key_path")
        manhole_pub_key_path = manhole_settings.get("ssh_pub_key_path")

        manhole_priv_key = None
        if manhole_priv_key_path is not None:
            try:
                manhole_priv_key = Key.fromFile(manhole_priv_key_path)
            except Exception as e:
                raise ConfigError(
                    f"Failed to read manhole private key file {manhole_priv_key_path}"
                ) from e

        manhole_pub_key = None
        if manhole_pub_key_path is not None:
            try:
                manhole_pub_key = Key.fromFile(manhole_pub_key_path)
            except Exception as e:
                raise ConfigError(
                    f"Failed to read manhole public key file {manhole_pub_key_path}"
                ) from e

        self.manhole_settings = ManholeConfig(
            username=manhole_username,
            password=manhole_password,
            priv_key=manhole_priv_key,
            pub_key=manhole_pub_key,
        )

        metrics_port = config.get("metrics_port")
        if metrics_port:
            logger.warning(METRICS_PORT_WARNING)

            self.listeners.append(
                ListenerConfig(
                    port=metrics_port,
                    bind_addresses=[config.get("metrics_bind_host", "127.0.0.1")],
                    type="http",
                    http_options=HttpListenerConfig(
                        resources=[HttpResourceConfig(names=["metrics"])]
                    ),
                )
            )

        self.cleanup_extremities_with_dummy_events = config.get(
            "cleanup_extremities_with_dummy_events", True
        )

        # The number of forward extremities in a room needed to send a dummy event.
        self.dummy_events_threshold = config.get("dummy_events_threshold", 10)

        self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)

        # Inhibits the /requestToken endpoints from returning an error that might leak
        # information about whether an e-mail address is in use or not on this
        # homeserver, and instead return a 200 with a fake sid if this kind of error is
        # met, without sending anything.
        # This is a compromise between sending an email, which could be a spam vector,
        # and letting the client know which email address is bound to an account and
        # which one isn't.
        self.request_token_inhibit_3pid_errors = config.get(
            "request_token_inhibit_3pid_errors",
            False,
        )

        # Whitelist of domain names that given next_link parameters must have
        next_link_domain_whitelist: Optional[List[str]] = config.get(
            "next_link_domain_whitelist"
        )

        self.next_link_domain_whitelist: Optional[Set[str]] = None
        if next_link_domain_whitelist is not None:
            if not isinstance(next_link_domain_whitelist, list):
                raise ConfigError("'next_link_domain_whitelist' must be a list")

            # Turn the list into a set to improve lookup speed.
            self.next_link_domain_whitelist = set(next_link_domain_whitelist)

        templates_config = config.get("templates") or {}
        if not isinstance(templates_config, dict):
            raise ConfigError("The 'templates' section must be a dictionary")

        self.custom_template_directory: Optional[str] = templates_config.get(
            "custom_template_directory"
        )
        if self.custom_template_directory is not None and not isinstance(
            self.custom_template_directory, str
        ):
            raise ConfigError("'custom_template_directory' must be a string")

        self.use_account_validity_in_account_status: bool = (
            config.get("use_account_validity_in_account_status") or False
        )

        self.rooms_to_exclude_from_sync: List[str] = (
            config.get("exclude_rooms_from_sync") or []
        )

        delete_stale_devices_after: Optional[str] = (
            config.get("delete_stale_devices_after") or None
        )

        if delete_stale_devices_after is not None:
            self.delete_stale_devices_after: Optional[int] = self.parse_duration(
                delete_stale_devices_after
            )
        else:
            self.delete_stale_devices_after = None
Beispiel #12
0
    def read_config(self, config, **kwargs):
        self.server_name = config["server_name"]
        self.server_context = config.get("server_context", None)

        try:
            parse_and_validate_server_name(self.server_name)
        except ValueError as e:
            raise ConfigError(str(e))

        self.pid_file = self.abspath(config.get("pid_file"))
        self.web_client_location = config.get("web_client_location", None)
        self.soft_file_limit = config.get("soft_file_limit", 0)
        self.daemonize = config.get("daemonize")
        self.print_pidfile = config.get("print_pidfile")
        self.user_agent_suffix = config.get("user_agent_suffix")
        self.use_frozen_dicts = config.get("use_frozen_dicts", False)
        self.public_baseurl = config.get("public_baseurl")

        # Whether to send federation traffic out in this process. This only
        # applies to some federation traffic, and so shouldn't be used to
        # "disable" federation
        self.send_federation = config.get("send_federation", True)

        # Whether to enable user presence.
        self.use_presence = config.get("use_presence", True)

        # Whether to update the user directory or not. This should be set to
        # false only if we are updating the user directory in a worker
        self.update_user_directory = config.get("update_user_directory", True)

        # whether to enable the media repository endpoints. This should be set
        # to false if the media repository is running as a separate endpoint;
        # doing so ensures that we will not run cache cleanup jobs on the
        # master, potentially causing inconsistency.
        self.enable_media_repo = config.get("enable_media_repo", True)

        # Whether to require authentication to retrieve profile data (avatars,
        # display names) of other users through the client API.
        self.require_auth_for_profile_requests = config.get(
            "require_auth_for_profile_requests", False
        )

        # Whether to require sharing a room with a user to retrieve their
        # profile data
        self.limit_profile_requests_to_users_who_share_rooms = config.get(
            "limit_profile_requests_to_users_who_share_rooms", False,
        )

        if "restrict_public_rooms_to_local_users" in config and (
            "allow_public_rooms_without_auth" in config
            or "allow_public_rooms_over_federation" in config
        ):
            raise ConfigError(
                "Can't use 'restrict_public_rooms_to_local_users' if"
                " 'allow_public_rooms_without_auth' and/or"
                " 'allow_public_rooms_over_federation' is set."
            )

        # Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This
        # flag is now obsolete but we need to check it for backward-compatibility.
        if config.get("restrict_public_rooms_to_local_users", False):
            self.allow_public_rooms_without_auth = False
            self.allow_public_rooms_over_federation = False
        else:
            # If set to 'true', removes the need for authentication to access the server's
            # public rooms directory through the client API, meaning that anyone can
            # query the room directory. Defaults to 'false'.
            self.allow_public_rooms_without_auth = config.get(
                "allow_public_rooms_without_auth", False
            )
            # If set to 'true', allows any other homeserver to fetch the server's public
            # rooms directory via federation. Defaults to 'false'.
            self.allow_public_rooms_over_federation = config.get(
                "allow_public_rooms_over_federation", False
            )

        default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION)

        # Ensure room version is a str
        default_room_version = str(default_room_version)

        if default_room_version not in KNOWN_ROOM_VERSIONS:
            raise ConfigError(
                "Unknown default_room_version: %s, known room versions: %s"
                % (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
            )

        # Get the actual room version object rather than just the identifier
        self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]

        # whether to enable search. If disabled, new entries will not be inserted
        # into the search tables and they will not be indexed. Users will receive
        # errors when attempting to search for messages.
        self.enable_search = config.get("enable_search", True)

        self.filter_timeline_limit = config.get("filter_timeline_limit", -1)

        # Whether we should block invites sent to users on this server
        # (other than those sent by local server admins)
        self.block_non_admin_invites = config.get("block_non_admin_invites", False)

        # Whether to enable experimental MSC1849 (aka relations) support
        self.experimental_msc1849_support_enabled = config.get(
            "experimental_msc1849_support_enabled", True
        )

        # Options to control access by tracking MAU
        self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
        self.max_mau_value = 0
        if self.limit_usage_by_mau:
            self.max_mau_value = config.get("max_mau_value", 0)
        self.mau_stats_only = config.get("mau_stats_only", False)

        self.mau_limits_reserved_threepids = config.get(
            "mau_limit_reserved_threepids", []
        )

        self.mau_trial_days = config.get("mau_trial_days", 0)
        self.mau_limit_alerting = config.get("mau_limit_alerting", True)

        # How long to keep redacted events in the database in unredacted form
        # before redacting them.
        redaction_retention_period = config.get("redaction_retention_period", "7d")
        if redaction_retention_period is not None:
            self.redaction_retention_period = self.parse_duration(
                redaction_retention_period
            )
        else:
            self.redaction_retention_period = None

        # How long to keep entries in the `users_ips` table.
        user_ips_max_age = config.get("user_ips_max_age", "28d")
        if user_ips_max_age is not None:
            self.user_ips_max_age = self.parse_duration(user_ips_max_age)
        else:
            self.user_ips_max_age = None

        # Options to disable HS
        self.hs_disabled = config.get("hs_disabled", False)
        self.hs_disabled_message = config.get("hs_disabled_message", "")

        # Admin uri to direct users at should their instance become blocked
        # due to resource constraints
        self.admin_contact = config.get("admin_contact", None)

        # FIXME: federation_domain_whitelist needs sytests
        self.federation_domain_whitelist = None  # type: Optional[dict]
        federation_domain_whitelist = config.get("federation_domain_whitelist", None)

        if federation_domain_whitelist is not None:
            # turn the whitelist into a hash for speed of lookup
            self.federation_domain_whitelist = {}

            for domain in federation_domain_whitelist:
                self.federation_domain_whitelist[domain] = True

        self.federation_ip_range_blacklist = config.get(
            "federation_ip_range_blacklist", []
        )

        # Attempt to create an IPSet from the given ranges
        try:
            self.federation_ip_range_blacklist = IPSet(
                self.federation_ip_range_blacklist
            )

            # Always blacklist 0.0.0.0, ::
            self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
        except Exception as e:
            raise ConfigError(
                "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
            )

        if self.public_baseurl is not None:
            if self.public_baseurl[-1] != "/":
                self.public_baseurl += "/"
        self.start_pushers = config.get("start_pushers", True)

        # (undocumented) option for torturing the worker-mode replication a bit,
        # for testing. The value defines the number of milliseconds to pause before
        # sending out any replication updates.
        self.replication_torture_level = config.get("replication_torture_level")

        # Whether to require a user to be in the room to add an alias to it.
        # Defaults to True.
        self.require_membership_for_aliases = config.get(
            "require_membership_for_aliases", True
        )

        # Whether to allow per-room membership profiles through the send of membership
        # events with profile information that differ from the target's global profile.
        self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)

        retention_config = config.get("retention")
        if retention_config is None:
            retention_config = {}

        self.retention_enabled = retention_config.get("enabled", False)

        retention_default_policy = retention_config.get("default_policy")

        if retention_default_policy is not None:
            self.retention_default_min_lifetime = retention_default_policy.get(
                "min_lifetime"
            )
            if self.retention_default_min_lifetime is not None:
                self.retention_default_min_lifetime = self.parse_duration(
                    self.retention_default_min_lifetime
                )

            self.retention_default_max_lifetime = retention_default_policy.get(
                "max_lifetime"
            )
            if self.retention_default_max_lifetime is not None:
                self.retention_default_max_lifetime = self.parse_duration(
                    self.retention_default_max_lifetime
                )

            if (
                self.retention_default_min_lifetime is not None
                and self.retention_default_max_lifetime is not None
                and (
                    self.retention_default_min_lifetime
                    > self.retention_default_max_lifetime
                )
            ):
                raise ConfigError(
                    "The default retention policy's 'min_lifetime' can not be greater"
                    " than its 'max_lifetime'"
                )
        else:
            self.retention_default_min_lifetime = None
            self.retention_default_max_lifetime = None

        if self.retention_enabled:
            logger.info(
                "Message retention policies support enabled with the following default"
                " policy: min_lifetime = %s ; max_lifetime = %s",
                self.retention_default_min_lifetime,
                self.retention_default_max_lifetime,
            )

        self.retention_allowed_lifetime_min = retention_config.get(
            "allowed_lifetime_min"
        )
        if self.retention_allowed_lifetime_min is not None:
            self.retention_allowed_lifetime_min = self.parse_duration(
                self.retention_allowed_lifetime_min
            )

        self.retention_allowed_lifetime_max = retention_config.get(
            "allowed_lifetime_max"
        )
        if self.retention_allowed_lifetime_max is not None:
            self.retention_allowed_lifetime_max = self.parse_duration(
                self.retention_allowed_lifetime_max
            )

        if (
            self.retention_allowed_lifetime_min is not None
            and self.retention_allowed_lifetime_max is not None
            and self.retention_allowed_lifetime_min
            > self.retention_allowed_lifetime_max
        ):
            raise ConfigError(
                "Invalid retention policy limits: 'allowed_lifetime_min' can not be"
                " greater than 'allowed_lifetime_max'"
            )

        self.retention_purge_jobs = []  # type: List[Dict[str, Optional[int]]]
        for purge_job_config in retention_config.get("purge_jobs", []):
            interval_config = purge_job_config.get("interval")

            if interval_config is None:
                raise ConfigError(
                    "A retention policy's purge jobs configuration must have the"
                    " 'interval' key set."
                )

            interval = self.parse_duration(interval_config)

            shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime")

            if shortest_max_lifetime is not None:
                shortest_max_lifetime = self.parse_duration(shortest_max_lifetime)

            longest_max_lifetime = purge_job_config.get("longest_max_lifetime")

            if longest_max_lifetime is not None:
                longest_max_lifetime = self.parse_duration(longest_max_lifetime)

            if (
                shortest_max_lifetime is not None
                and longest_max_lifetime is not None
                and shortest_max_lifetime > longest_max_lifetime
            ):
                raise ConfigError(
                    "A retention policy's purge jobs configuration's"
                    " 'shortest_max_lifetime' value can not be greater than its"
                    " 'longest_max_lifetime' value."
                )

            self.retention_purge_jobs.append(
                {
                    "interval": interval,
                    "shortest_max_lifetime": shortest_max_lifetime,
                    "longest_max_lifetime": longest_max_lifetime,
                }
            )

        if not self.retention_purge_jobs:
            self.retention_purge_jobs = [
                {
                    "interval": self.parse_duration("1d"),
                    "shortest_max_lifetime": None,
                    "longest_max_lifetime": None,
                }
            ]

        self.listeners = []  # type: List[dict]
        for listener in config.get("listeners", []):
            if not isinstance(listener.get("port", None), int):
                raise ConfigError(
                    "Listener configuration is lacking a valid 'port' option"
                )

            if listener.setdefault("tls", False):
                # no_tls is not really supported any more, but let's grandfather it in
                # here.
                if config.get("no_tls", False):
                    logger.info(
                        "Ignoring TLS-enabled listener on port %i due to no_tls"
                    )
                    continue

            bind_address = listener.pop("bind_address", None)
            bind_addresses = listener.setdefault("bind_addresses", [])

            # if bind_address was specified, add it to the list of addresses
            if bind_address:
                bind_addresses.append(bind_address)

            # if we still have an empty list of addresses, use the default list
            if not bind_addresses:
                if listener["type"] == "metrics":
                    # the metrics listener doesn't support IPv6
                    bind_addresses.append("0.0.0.0")
                else:
                    bind_addresses.extend(DEFAULT_BIND_ADDRESSES)

            self.listeners.append(listener)

        if not self.web_client_location:
            _warn_if_webclient_configured(self.listeners)

        self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))

        @attr.s
        class LimitRemoteRoomsConfig(object):
            enabled = attr.ib(
                validator=attr.validators.instance_of(bool), default=False
            )
            complexity = attr.ib(
                validator=attr.validators.instance_of(
                    (float, int)  # type: ignore[arg-type] # noqa
                ),
                default=1.0,
            )
            complexity_error = attr.ib(
                validator=attr.validators.instance_of(str),
                default=ROOM_COMPLEXITY_TOO_GREAT,
            )

        self.limit_remote_rooms = LimitRemoteRoomsConfig(
            **config.get("limit_remote_rooms", {})
        )

        bind_port = config.get("bind_port")
        if bind_port:
            if config.get("no_tls", False):
                raise ConfigError("no_tls is incompatible with bind_port")

            self.listeners = []
            bind_host = config.get("bind_host", "")
            gzip_responses = config.get("gzip_responses", True)

            self.listeners.append(
                {
                    "port": bind_port,
                    "bind_addresses": [bind_host],
                    "tls": True,
                    "type": "http",
                    "resources": [
                        {"names": ["client"], "compress": gzip_responses},
                        {"names": ["federation"], "compress": False},
                    ],
                }
            )

            unsecure_port = config.get("unsecure_port", bind_port - 400)
            if unsecure_port:
                self.listeners.append(
                    {
                        "port": unsecure_port,
                        "bind_addresses": [bind_host],
                        "tls": False,
                        "type": "http",
                        "resources": [
                            {"names": ["client"], "compress": gzip_responses},
                            {"names": ["federation"], "compress": False},
                        ],
                    }
                )

        manhole = config.get("manhole")
        if manhole:
            self.listeners.append(
                {
                    "port": manhole,
                    "bind_addresses": ["127.0.0.1"],
                    "type": "manhole",
                    "tls": False,
                }
            )

        metrics_port = config.get("metrics_port")
        if metrics_port:
            logger.warning(METRICS_PORT_WARNING)

            self.listeners.append(
                {
                    "port": metrics_port,
                    "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
                    "tls": False,
                    "type": "http",
                    "resources": [{"names": ["metrics"], "compress": False}],
                }
            )

        _check_resource_config(self.listeners)

        self.cleanup_extremities_with_dummy_events = config.get(
            "cleanup_extremities_with_dummy_events", True
        )

        # The number of forward extremities in a room needed to send a dummy event.
        self.dummy_events_threshold = config.get("dummy_events_threshold", 10)

        self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)

        # Inhibits the /requestToken endpoints from returning an error that might leak
        # information about whether an e-mail address is in use or not on this
        # homeserver, and instead return a 200 with a fake sid if this kind of error is
        # met, without sending anything.
        # This is a compromise between sending an email, which could be a spam vector,
        # and letting the client know which email address is bound to an account and
        # which one isn't.
        self.request_token_inhibit_3pid_errors = config.get(
            "request_token_inhibit_3pid_errors", False,
        )
Beispiel #13
0
    async def _rejected_events_metadata(self, progress: dict,
                                        batch_size: int) -> int:
        """Adds rejected events to the `state_events` and `event_auth` metadata
        tables.
        """

        last_event_id = progress.get("last_event_id", "")

        def get_rejected_events(
            txn: Cursor, ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
            # Fetch rejected event json, their room version and whether we have
            # inserted them into the state_events or auth_events tables.
            #
            # Note we can assume that events that don't have a corresponding
            # room version are V1 rooms.
            sql = """
                SELECT DISTINCT
                    event_id,
                    COALESCE(room_version, '1'),
                    json,
                    state_events.event_id IS NOT NULL,
                    event_auth.event_id IS NOT NULL
                FROM rejections
                INNER JOIN event_json USING (event_id)
                LEFT JOIN rooms USING (room_id)
                LEFT JOIN state_events USING (event_id)
                LEFT JOIN event_auth USING (event_id)
                WHERE event_id > ?
                ORDER BY event_id
                LIMIT ?
            """

            txn.execute(
                sql,
                (
                    last_event_id,
                    batch_size,
                ),
            )

            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4])
                    for row in txn]  # type: ignore

        results = await self.db_pool.runInteraction(
            desc="_rejected_events_metadata_get", func=get_rejected_events)

        if not results:
            await self.db_pool.updates._end_background_update(
                "rejected_events_metadata")
            return 0

        state_events = []
        auth_events = []
        for event_id, room_version, event_json, has_state, has_event_auth in results:
            last_event_id = event_id

            if has_state and has_event_auth:
                continue

            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
            if not room_version_obj:
                # We no longer support this room version, so we just ignore the
                # events entirely.
                logger.info(
                    "Ignoring event with unknown room version %r: %r",
                    room_version,
                    event_id,
                )
                continue

            event = make_event_from_dict(event_json, room_version_obj)

            if not event.is_state():
                continue

            if not has_state:
                state_events.append({
                    "event_id": event.event_id,
                    "room_id": event.room_id,
                    "type": event.type,
                    "state_key": event.state_key,
                })

            if not has_event_auth:
                # Old, dodgy, events may have duplicate auth events, which we
                # need to deduplicate as we have a unique constraint.
                for auth_id in set(event.auth_event_ids()):
                    auth_events.append({
                        "room_id": event.room_id,
                        "event_id": event.event_id,
                        "auth_id": auth_id,
                    })

        if state_events:
            await self.db_pool.simple_insert_many(
                table="state_events",
                values=state_events,
                desc="_rejected_events_metadata_state_events",
            )

        if auth_events:
            await self.db_pool.simple_insert_many(
                table="event_auth",
                values=auth_events,
                desc="_rejected_events_metadata_event_auth",
            )

        await self.db_pool.updates._background_update_progress(
            "rejected_events_metadata", {"last_event_id": last_event_id})

        if len(results) < batch_size:
            await self.db_pool.updates._end_background_update(
                "rejected_events_metadata")

        return len(results)
Beispiel #14
0
def _check_sigs_on_pdus(keyring, room_version, pdus):
    """Check that the given events are correctly signed

    Args:
        keyring (synapse.crypto.Keyring): keyring object to do the checks
        room_version (str): the room version of the PDUs
        pdus (Collection[EventBase]): the events to be checked

    Returns:
        List[Deferred]: a Deferred for each event in pdus, which will either succeed if
           the signatures are valid, or fail (with a SynapseError) if not.
    """

    # we want to check that the event is signed by:
    #
    # (a) the sender's server
    #
    #     - except in the case of invites created from a 3pid invite, which are exempt
    #     from this check, because the sender has to match that of the original 3pid
    #     invite, but the event may come from a different HS, for reasons that I don't
    #     entirely grok (why do the senders have to match? and if they do, why doesn't the
    #     joining server ask the inviting server to do the switcheroo with
    #     exchange_third_party_invite?).
    #
    #     That's pretty awful, since redacting such an invite will render it invalid
    #     (because it will then look like a regular invite without a valid signature),
    #     and signatures are *supposed* to be valid whether or not an event has been
    #     redacted. But this isn't the worst of the ways that 3pid invites are broken.
    #
    # (b) for V1 and V2 rooms, the server which created the event_id
    #
    # let's start by getting the domain for each pdu, and flattening the event back
    # to JSON.

    pdus_to_check = [
        PduToCheckSig(
            pdu=p,
            redacted_pdu_json=prune_event(p).get_pdu_json(),
            sender_domain=get_domain_from_id(p.sender),
            deferreds=[],
        )
        for p in pdus
    ]

    v = KNOWN_ROOM_VERSIONS.get(room_version)
    if not v:
        raise RuntimeError("Unrecognized room version %s" % (room_version,))

    # First we check that the sender event is signed by the sender's domain
    # (except if its a 3pid invite, in which case it may be sent by any server)
    pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]

    more_deferreds = keyring.verify_json_objects_for_server(
        [
            (
                p.sender_domain,
                p.redacted_pdu_json,
                p.pdu.origin_server_ts if v.enforce_key_validity else 0,
                p.pdu.event_id,
            )
            for p in pdus_to_check_sender
        ]
    )

    def sender_err(e, pdu_to_check):
        errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
            pdu_to_check.pdu.event_id,
            pdu_to_check.sender_domain,
            e.getErrorMessage(),
        )
        raise SynapseError(403, errmsg, Codes.FORBIDDEN)

    for p, d in zip(pdus_to_check_sender, more_deferreds):
        d.addErrback(sender_err, p)
        p.deferreds.append(d)

    # now let's look for events where the sender's domain is different to the
    # event id's domain (normally only the case for joins/leaves), and add additional
    # checks. Only do this if the room version has a concept of event ID domain
    # (ie, the room version uses old-style non-hash event IDs).
    if v.event_format == EventFormatVersions.V1:
        pdus_to_check_event_id = [
            p
            for p in pdus_to_check
            if p.sender_domain != get_domain_from_id(p.pdu.event_id)
        ]

        more_deferreds = keyring.verify_json_objects_for_server(
            [
                (
                    get_domain_from_id(p.pdu.event_id),
                    p.redacted_pdu_json,
                    p.pdu.origin_server_ts if v.enforce_key_validity else 0,
                    p.pdu.event_id,
                )
                for p in pdus_to_check_event_id
            ]
        )

        def event_err(e, pdu_to_check):
            errmsg = (
                "event id %s: unable to verify signature for event id domain: %s"
                % (pdu_to_check.pdu.event_id, e.getErrorMessage())
            )
            raise SynapseError(403, errmsg, Codes.FORBIDDEN)

        for p, d in zip(pdus_to_check_event_id, more_deferreds):
            d.addErrback(event_err, p)
            p.deferreds.append(d)

    # replace lists of deferreds with single Deferreds
    return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
Beispiel #15
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description="""Adds a signature to a JSON object.

Example usage:

    $ scripts-dev/sign_json.py -N test -k localhost.signing.key "{}"
    {"signatures":{"test":{"ed25519:a_ZnZh":"LmPnml6iM0iR..."}}}
""",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    parser.add_argument(
        "-N",
        "--server-name",
        help="Name to give as the local homeserver. If unspecified, will be "
        "read from the config file.",
    )

    parser.add_argument(
        "-k",
        "--signing-key-path",
        help="Path to the file containing the private ed25519 key to sign the "
        "request with.",
    )

    parser.add_argument(
        "-K",
        "--signing-key",
        help="The private ed25519 key to sign the request with.",
    )

    parser.add_argument(
        "-c",
        "--config",
        default="homeserver.yaml",
        help=
        ("Path to synapse config file, from which the server name and/or signing "
         "key path will be read. Ignored if --server-name and --signing-key(-path) "
         "are both given."),
    )

    parser.add_argument(
        "--sign-event-room-version",
        type=str,
        help=
        ("Sign the JSON as an event for the given room version, rather than raw JSON. "
         "This means that we will add a 'hashes' object, and redact the event before "
         "signing."),
    )

    input_args = parser.add_mutually_exclusive_group()

    input_args.add_argument("input_data",
                            nargs="?",
                            help="Raw JSON to be signed.")

    input_args.add_argument(
        "-i",
        "--input",
        type=argparse.FileType("r"),
        default=sys.stdin,
        help=
        ("A file from which to read the JSON to be signed. If neither --input nor "
         "input_data are given, JSON will be read from stdin."),
    )

    parser.add_argument(
        "-o",
        "--output",
        type=argparse.FileType("w"),
        default=sys.stdout,
        help="Where to write the signed JSON. Defaults to stdout.",
    )

    args = parser.parse_args()

    if not args.server_name or not (args.signing_key_path or args.signing_key):
        read_args_from_config(args)

    if args.signing_key:
        keys = read_signing_keys([args.signing_key])
    else:
        with open(args.signing_key_path) as f:
            keys = read_signing_keys(f)

    json_to_sign = args.input_data
    if json_to_sign is None:
        json_to_sign = args.input.read()

    try:
        obj = json.loads(json_to_sign)
    except JSONDecodeError as e:
        print("Unable to parse input as JSON: %s" % e, file=sys.stderr)
        sys.exit(1)

    if not isinstance(obj, dict):
        print("Input json was not an object", file=sys.stderr)
        sys.exit(1)

    if args.sign_event_room_version:
        room_version = KNOWN_ROOM_VERSIONS.get(args.sign_event_room_version)
        if not room_version:
            print(f"Unknown room version {args.sign_event_room_version}",
                  file=sys.stderr)
            sys.exit(1)
        add_hashes_and_signatures(room_version, obj, args.server_name, keys[0])
    else:
        sign_json(obj, args.server_name, keys[0])

    for c in json_encoder.iterencode(obj):
        args.output.write(c)
    args.output.write("\n")
Beispiel #16
0
    async def _get_events_from_db(self, event_ids, allow_rejected=False):
        """Fetch a bunch of events from the database.

        Returned events will be added to the cache for future lookups.

        Unknown events are omitted from the response.

        Args:
            event_ids (Iterable[str]): The event_ids of the events to fetch

            allow_rejected (bool): Whether to include rejected events. If False,
                rejected events are omitted from the response.

        Returns:
            Dict[str, _EventCacheEntry]:
                map from event id to result. May return extra events which
                weren't asked for.
        """
        fetched_events = {}
        events_to_fetch = event_ids

        while events_to_fetch:
            row_map = await self._enqueue_events(events_to_fetch)

            # we need to recursively fetch any redactions of those events
            redaction_ids = set()
            for event_id in events_to_fetch:
                row = row_map.get(event_id)
                fetched_events[event_id] = row
                if row:
                    redaction_ids.update(row["redactions"])

            events_to_fetch = redaction_ids.difference(fetched_events.keys())
            if events_to_fetch:
                logger.debug("Also fetching redaction events %s",
                             events_to_fetch)

        # build a map from event_id to EventBase
        event_map = {}
        for event_id, row in fetched_events.items():
            if not row:
                continue
            assert row["event_id"] == event_id

            rejected_reason = row["rejected_reason"]

            if not allow_rejected and rejected_reason:
                continue

            # If the event or metadata cannot be parsed, log the error and act
            # as if the event is unknown.
            try:
                d = db_to_json(row["json"])
            except ValueError:
                logger.error("Unable to parse json from event: %s", event_id)
                continue
            try:
                internal_metadata = db_to_json(row["internal_metadata"])
            except ValueError:
                logger.error(
                    "Unable to parse internal_metadata from event: %s",
                    event_id)
                continue

            format_version = row["format_version"]
            if format_version is None:
                # This means that we stored the event before we had the concept
                # of a event format version, so it must be a V1 event.
                format_version = EventFormatVersions.V1

            room_version_id = row["room_version_id"]

            if not room_version_id:
                # this should only happen for out-of-band membership events which
                # arrived before #6983 landed. For all other events, we should have
                # an entry in the 'rooms' table.
                #
                # However, the 'out_of_band_membership' flag is unreliable for older
                # invites, so just accept it for all membership events.
                #
                if d["type"] != EventTypes.Member:
                    raise Exception("Room %s for event %s is unknown" %
                                    (d["room_id"], event_id))

                # so, assuming this is an out-of-band-invite that arrived before #6983
                # landed, we know that the room version must be v5 or earlier (because
                # v6 hadn't been invented at that point, so invites from such rooms
                # would have been rejected.)
                #
                # The main reason we need to know the room version here (other than
                # choosing the right python Event class) is in case the event later has
                # to be redacted - and all the room versions up to v5 used the same
                # redaction algorithm.
                #
                # So, the following approximations should be adequate.

                if format_version == EventFormatVersions.V1:
                    # if it's event format v1 then it must be room v1 or v2
                    room_version = RoomVersions.V1
                elif format_version == EventFormatVersions.V2:
                    # if it's event format v2 then it must be room v3
                    room_version = RoomVersions.V3
                else:
                    # if it's event format v3 then it must be room v4 or v5
                    room_version = RoomVersions.V5
            else:
                room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
                if not room_version:
                    logger.warning(
                        "Event %s in room %s has unknown room version %s",
                        event_id,
                        d["room_id"],
                        room_version_id,
                    )
                    continue

                if room_version.event_format != format_version:
                    logger.error(
                        "Event %s in room %s with version %s has wrong format: "
                        "expected %s, was %s",
                        event_id,
                        d["room_id"],
                        room_version_id,
                        room_version.event_format,
                        format_version,
                    )
                    continue

            original_ev = make_event_from_dict(
                event_dict=d,
                room_version=room_version,
                internal_metadata_dict=internal_metadata,
                rejected_reason=rejected_reason,
            )

            event_map[event_id] = original_ev

        # finally, we can decide whether each one needs redacting, and build
        # the cache entries.
        result_map = {}
        for event_id, original_ev in event_map.items():
            redactions = fetched_events[event_id]["redactions"]
            redacted_event = self._maybe_redact_event_row(
                original_ev, redactions, event_map)

            cache_entry = _EventCacheEntry(event=original_ev,
                                           redacted_event=redacted_event)

            self._get_event_cache.prefill((event_id, ), cache_entry)
            result_map[event_id] = cache_entry

        return result_map
Beispiel #17
0
class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
    def prepare(self, reactor, clock, hs):
        self.store = hs.get_datastores().main

    def test_get_prev_events_for_room(self):
        room_id = "@ROOM:local"

        # add a bunch of events and hashes to act as forward extremities
        def insert_event(txn, i):
            event_id = "$event_%i:local" % i

            txn.execute(
                ("INSERT INTO events ("
                 "   room_id, event_id, type, depth, topological_ordering,"
                 "   content, processed, outlier, stream_ordering) "
                 "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"),
                (room_id, event_id, i, i, True, False, i),
            )

            txn.execute(
                ("INSERT INTO event_forward_extremities (room_id, event_id) "
                 "VALUES (?, ?)"),
                (room_id, event_id),
            )

        for i in range(0, 20):
            self.get_success(
                self.store.db_pool.runInteraction("insert", insert_event, i))

        # this should get the last ten
        r = self.get_success(self.store.get_prev_events_for_room(room_id))
        self.assertEqual(10, len(r))
        for i in range(0, 10):
            self.assertEqual("$event_%i:local" % (19 - i), r[i])

    def test_get_rooms_with_many_extremities(self):
        room1 = "#room1"
        room2 = "#room2"
        room3 = "#room3"

        def insert_event(txn, i, room_id):
            event_id = "$event_%i:local" % i
            txn.execute(
                ("INSERT INTO event_forward_extremities (room_id, event_id) "
                 "VALUES (?, ?)"),
                (room_id, event_id),
            )

        for i in range(0, 20):
            self.get_success(
                self.store.db_pool.runInteraction("insert", insert_event, i,
                                                  room1))
            self.get_success(
                self.store.db_pool.runInteraction("insert", insert_event, i,
                                                  room2))
            self.get_success(
                self.store.db_pool.runInteraction("insert", insert_event, i,
                                                  room3))

        # Test simple case
        r = self.get_success(
            self.store.get_rooms_with_many_extremities(5, 5, []))
        self.assertEqual(len(r), 3)

        # Does filter work?

        r = self.get_success(
            self.store.get_rooms_with_many_extremities(5, 5, [room1]))
        self.assertTrue(room2 in r)
        self.assertTrue(room3 in r)
        self.assertEqual(len(r), 2)

        r = self.get_success(
            self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]))
        self.assertEqual(r, [room3])

        # Does filter and limit work?

        r = self.get_success(
            self.store.get_rooms_with_many_extremities(5, 1, [room1]))
        self.assertTrue(r == [room2] or r == [room3])

    def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
        room_id = "@ROOM:local"

        # The silly auth graph we use to test the auth difference algorithm,
        # where the top are the most recent events.
        #
        #   A   B
        #    \ /
        #  D  E
        #  \  |
        #   ` F   C
        #     |  /|
        #     G ´ |
        #     | \ |
        #     H   I
        #     |   |
        #     K   J

        auth_graph = {
            "a": ["e"],
            "b": ["e"],
            "c": ["g", "i"],
            "d": ["f"],
            "e": ["f"],
            "f": ["g"],
            "g": ["h", "i"],
            "h": ["k"],
            "i": ["j"],
            "k": [],
            "j": [],
        }

        depth_map = {
            "a": 7,
            "b": 7,
            "c": 4,
            "d": 6,
            "e": 6,
            "f": 5,
            "g": 3,
            "h": 2,
            "i": 2,
            "k": 1,
            "j": 1,
        }

        # Mark the room as maybe having a cover index.

        def store_room(txn):
            self.store.db_pool.simple_insert_txn(
                txn,
                "rooms",
                {
                    "room_id": room_id,
                    "creator": "room_creator_user_id",
                    "is_public": True,
                    "room_version": "6",
                    "has_auth_chain_index": use_chain_cover_index,
                },
            )

        self.get_success(
            self.store.db_pool.runInteraction("store_room", store_room))

        # We rudely fiddle with the appropriate tables directly, as that's much
        # easier than constructing events properly.

        def insert_event(txn):
            stream_ordering = 0

            for event_id in auth_graph:
                stream_ordering += 1
                depth = depth_map[event_id]

                self.store.db_pool.simple_insert_txn(
                    txn,
                    table="events",
                    values={
                        "event_id": event_id,
                        "room_id": room_id,
                        "depth": depth,
                        "topological_ordering": depth,
                        "type": "m.test",
                        "processed": True,
                        "outlier": False,
                        "stream_ordering": stream_ordering,
                    },
                )

            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                txn,
                [
                    FakeEvent(event_id, room_id, auth_graph[event_id])
                    for event_id in auth_graph
                ],
            )

        self.get_success(
            self.store.db_pool.runInteraction(
                "insert",
                insert_event,
            ))

        return room_id

    @parameterized.expand([(True, ), (False, )])
    def test_auth_chain_ids(self, use_chain_cover_index: bool):
        room_id = self._setup_auth_chain(use_chain_cover_index)

        # a and b have the same auth chain.
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["a"]))
        self.assertCountEqual(auth_chain_ids,
                              ["e", "f", "g", "h", "i", "j", "k"])
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["b"]))
        self.assertCountEqual(auth_chain_ids,
                              ["e", "f", "g", "h", "i", "j", "k"])
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["a", "b"]))
        self.assertCountEqual(auth_chain_ids,
                              ["e", "f", "g", "h", "i", "j", "k"])

        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["c"]))
        self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])

        # d and e have the same auth chain.
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["d"]))
        self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["e"]))
        self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])

        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["f"]))
        self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])

        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["g"]))
        self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])

        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["h"]))
        self.assertEqual(auth_chain_ids, {"k"})

        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["i"]))
        self.assertEqual(auth_chain_ids, {"j"})

        # j and k have no parents.
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["j"]))
        self.assertEqual(auth_chain_ids, set())
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["k"]))
        self.assertEqual(auth_chain_ids, set())

        # More complex input sequences.
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["b", "c", "d"]))
        self.assertCountEqual(auth_chain_ids,
                              ["e", "f", "g", "h", "i", "j", "k"])

        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["h", "i"]))
        self.assertCountEqual(auth_chain_ids, ["k", "j"])

        # e gets returned even though include_given is false, but it is in the
        # auth chain of b.
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["b", "e"]))
        self.assertCountEqual(auth_chain_ids,
                              ["e", "f", "g", "h", "i", "j", "k"])

        # Test include_given.
        auth_chain_ids = self.get_success(
            self.store.get_auth_chain_ids(room_id, ["i"], include_given=True))
        self.assertCountEqual(auth_chain_ids, ["i", "j"])

    @parameterized.expand([(True, ), (False, )])
    def test_auth_difference(self, use_chain_cover_index: bool):
        room_id = self._setup_auth_chain(use_chain_cover_index)

        # Now actually test that various combinations give the right result:

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]))
        self.assertSetEqual(difference, {"a", "b"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"c"}]))
        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]))
        self.assertSetEqual(difference, {"a", "b", "c"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a", "c"}, {"b", "c"}]))
        self.assertSetEqual(difference, {"a", "b"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"d"}]))
        self.assertSetEqual(difference, {"a", "b", "d", "e"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"c"}, {"d"}]))
        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"e"}]))
        self.assertSetEqual(difference, {"a", "b"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id, [{"a"}]))
        self.assertSetEqual(difference, set())

    def test_auth_difference_partial_cover(self):
        """Test that we correctly handle rooms where not all events have a chain
        cover calculated. This can happen in some obscure edge cases, including
        during the background update that calculates the chain cover for old
        rooms.
        """

        room_id = "@ROOM:local"

        # The silly auth graph we use to test the auth difference algorithm,
        # where the top are the most recent events.
        #
        #   A   B
        #    \ /
        #  D  E
        #  \  |
        #   ` F   C
        #     |  /|
        #     G ´ |
        #     | \ |
        #     H   I
        #     |   |
        #     K   J

        auth_graph = {
            "a": ["e"],
            "b": ["e"],
            "c": ["g", "i"],
            "d": ["f"],
            "e": ["f"],
            "f": ["g"],
            "g": ["h", "i"],
            "h": ["k"],
            "i": ["j"],
            "k": [],
            "j": [],
        }

        depth_map = {
            "a": 7,
            "b": 7,
            "c": 4,
            "d": 6,
            "e": 6,
            "f": 5,
            "g": 3,
            "h": 2,
            "i": 2,
            "k": 1,
            "j": 1,
        }

        # We rudely fiddle with the appropriate tables directly, as that's much
        # easier than constructing events properly.

        def insert_event(txn):
            # First insert the room and mark it as having a chain cover.
            self.store.db_pool.simple_insert_txn(
                txn,
                "rooms",
                {
                    "room_id": room_id,
                    "creator": "room_creator_user_id",
                    "is_public": True,
                    "room_version": "6",
                    "has_auth_chain_index": True,
                },
            )

            stream_ordering = 0

            for event_id in auth_graph:
                stream_ordering += 1
                depth = depth_map[event_id]

                self.store.db_pool.simple_insert_txn(
                    txn,
                    table="events",
                    values={
                        "event_id": event_id,
                        "room_id": room_id,
                        "depth": depth,
                        "topological_ordering": depth,
                        "type": "m.test",
                        "processed": True,
                        "outlier": False,
                        "stream_ordering": stream_ordering,
                    },
                )

            # Insert all events apart from 'B'
            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                txn,
                [
                    FakeEvent(event_id, room_id, auth_graph[event_id])
                    for event_id in auth_graph if event_id != "b"
                ],
            )

            # Now we insert the event 'B' without a chain cover, by temporarily
            # pretending the room doesn't have a chain cover.

            self.store.db_pool.simple_update_txn(
                txn,
                table="rooms",
                keyvalues={"room_id": room_id},
                updatevalues={"has_auth_chain_index": False},
            )

            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                txn,
                [FakeEvent("b", room_id, auth_graph["b"])],
            )

            self.store.db_pool.simple_update_txn(
                txn,
                table="rooms",
                keyvalues={"room_id": room_id},
                updatevalues={"has_auth_chain_index": True},
            )

        self.get_success(
            self.store.db_pool.runInteraction(
                "insert",
                insert_event,
            ))

        # Now actually test that various combinations give the right result:

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]))
        self.assertSetEqual(difference, {"a", "b"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"c"}]))
        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]))
        self.assertSetEqual(difference, {"a", "b", "c"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a", "c"}, {"b", "c"}]))
        self.assertSetEqual(difference, {"a", "b"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"d"}]))
        self.assertSetEqual(difference, {"a", "b", "d", "e"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"c"}, {"d"}]))
        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id,
                                                 [{"a"}, {"b"}, {"e"}]))
        self.assertSetEqual(difference, {"a", "b"})

        difference = self.get_success(
            self.store.get_auth_chain_difference(room_id, [{"a"}]))
        self.assertSetEqual(difference, set())

    @parameterized.expand([(room_version, )
                           for room_version in KNOWN_ROOM_VERSIONS.values()])
    def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
        """Test that pruning of inbound federation queues work"""

        room_id = "some_room_id"

        def prev_event_format(
                prev_event_id: str) -> Union[Tuple[str, dict], str]:
            """Account for differences in prev_events format across room versions"""
            if room_version.event_format == EventFormatVersions.V1:
                return prev_event_id, {}

            return prev_event_id

        # Insert a bunch of events that all reference the previous one.
        self.get_success(
            self.store.db_pool.simple_insert_many(
                table="federation_inbound_events_staging",
                keys=(
                    "origin",
                    "room_id",
                    "received_ts",
                    "event_id",
                    "event_json",
                    "internal_metadata",
                ),
                values=[(
                    "some_origin",
                    room_id,
                    0,
                    f"$fake_event_id_{i + 1}",
                    json_encoder.encode({
                        "prev_events":
                        [prev_event_format(f"$fake_event_id_{i}")]
                    }),
                    "{}",
                ) for i in range(500)],
                desc="test_prune_inbound_federation_queue",
            ))

        # Calling prune once should return True, i.e. a prune happen. The second
        # time it shouldn't.
        pruned = self.get_success(
            self.store.prune_staged_events_in_room(room_id, room_version))
        self.assertTrue(pruned)

        pruned = self.get_success(
            self.store.prune_staged_events_in_room(room_id, room_version))
        self.assertFalse(pruned)

        # Assert that we only have a single event left in the queue, and that it
        # is the last one.
        count = self.get_success(
            self.store.db_pool.simple_select_one_onecol(
                table="federation_inbound_events_staging",
                keyvalues={"room_id": room_id},
                retcol="COUNT(*)",
                desc="test_prune_inbound_federation_queue",
            ))
        self.assertEqual(count, 1)

        _, event_id = self.get_success(
            self.store.get_next_staged_event_id_for_room(room_id))
        self.assertEqual(event_id, "$fake_event_id_500")
Beispiel #18
0
    def _do_send_invite(self, destination, pdu, room_version):
        """Actually sends the invite, first trying v2 API and falling back to
        v1 API if necessary.

        Args:
            destination (str): Target server
            pdu (FrozenEvent)
            room_version (str)

        Returns:
            dict: The event as a dict as returned by the remote server
        """
        time_now = self._clock.time_msec()

        try:
            content = yield self.transport_layer.send_invite_v2(
                destination=destination,
                room_id=pdu.room_id,
                event_id=pdu.event_id,
                content={
                    "event": pdu.get_pdu_json(time_now),
                    "room_version": room_version,
                    "invite_room_state": pdu.unsigned.get("invite_room_state", []),
                },
            )
            defer.returnValue(content)
        except HttpResponseException as e:
            if e.code in [400, 404]:
                err = e.to_synapse_error()

                # If we receive an error response that isn't a generic error, we
                # assume that the remote understands the v2 invite API and this
                # is a legitimate error.
                if err.errcode != Codes.UNKNOWN:
                    raise err

                # Otherwise, we assume that the remote server doesn't understand
                # the v2 invite API. That's ok provided the room uses old-style event
                # IDs.
                v = KNOWN_ROOM_VERSIONS.get(room_version)
                if v.event_format != EventFormatVersions.V1:
                    raise SynapseError(
                        400,
                        "User's homeserver does not support this room version",
                        Codes.UNSUPPORTED_ROOM_VERSION,
                    )
            elif e.code == 403:
                raise e.to_synapse_error()
            else:
                raise

        # Didn't work, try v1 API.
        # Note the v1 API returns a tuple of `(200, content)`

        _, content = yield self.transport_layer.send_invite_v1(
            destination=destination,
            room_id=pdu.room_id,
            event_id=pdu.event_id,
            content=pdu.get_pdu_json(time_now),
        )
        defer.returnValue(content)
Beispiel #19
0
    def _get_events_from_db(self, event_ids, allow_rejected=False):
        """Fetch a bunch of events from the database.

        Returned events will be added to the cache for future lookups.

        Unknown events are omitted from the response.

        Args:
            event_ids (Iterable[str]): The event_ids of the events to fetch

            allow_rejected (bool): Whether to include rejected events. If False,
                rejected events are omitted from the response.

        Returns:
            Deferred[Dict[str, _EventCacheEntry]]:
                map from event id to result. May return extra events which
                weren't asked for.
        """
        fetched_events = {}
        events_to_fetch = event_ids

        while events_to_fetch:
            row_map = yield self._enqueue_events(events_to_fetch)

            # we need to recursively fetch any redactions of those events
            redaction_ids = set()
            for event_id in events_to_fetch:
                row = row_map.get(event_id)
                fetched_events[event_id] = row
                if row:
                    redaction_ids.update(row["redactions"])

            events_to_fetch = redaction_ids.difference(fetched_events.keys())
            if events_to_fetch:
                logger.debug("Also fetching redaction events %s",
                             events_to_fetch)

        # build a map from event_id to EventBase
        event_map = {}
        for event_id, row in fetched_events.items():
            if not row:
                continue
            assert row["event_id"] == event_id

            rejected_reason = row["rejected_reason"]

            if not allow_rejected and rejected_reason:
                continue

            d = json.loads(row["json"])
            internal_metadata = json.loads(row["internal_metadata"])

            format_version = row["format_version"]
            if format_version is None:
                # This means that we stored the event before we had the concept
                # of a event format version, so it must be a V1 event.
                format_version = EventFormatVersions.V1

            room_version_id = row["room_version_id"]

            if not room_version_id:
                # this should only happen for out-of-band membership events
                if not internal_metadata.get("out_of_band_membership"):
                    logger.warning("Room %s for event %s is unknown",
                                   d["room_id"], event_id)
                    continue

                # take a wild stab at the room version based on the event format
                if format_version == EventFormatVersions.V1:
                    room_version = RoomVersions.V1
                elif format_version == EventFormatVersions.V2:
                    room_version = RoomVersions.V3
                else:
                    room_version = RoomVersions.V5
            else:
                room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
                if not room_version:
                    logger.error(
                        "Event %s in room %s has unknown room version %s",
                        event_id,
                        d["room_id"],
                        room_version_id,
                    )
                    continue

                if room_version.event_format != format_version:
                    logger.error(
                        "Event %s in room %s with version %s has wrong format: "
                        "expected %s, was %s",
                        event_id,
                        d["room_id"],
                        room_version_id,
                        room_version.event_format,
                        format_version,
                    )
                    continue

            original_ev = make_event_from_dict(
                event_dict=d,
                room_version=room_version,
                internal_metadata_dict=internal_metadata,
                rejected_reason=rejected_reason,
            )

            event_map[event_id] = original_ev

        # finally, we can decide whether each one needs redacting, and build
        # the cache entries.
        result_map = {}
        for event_id, original_ev in event_map.items():
            redactions = fetched_events[event_id]["redactions"]
            redacted_event = self._maybe_redact_event_row(
                original_ev, redactions, event_map)

            cache_entry = _EventCacheEntry(event=original_ev,
                                           redacted_event=redacted_event)

            self._get_event_cache.prefill((event_id, ), cache_entry)
            result_map[event_id] = cache_entry

        return result_map