Пример #1
0
    def test_get_set(self):
        clock = MockClock()
        cache = ExpiringCache("test", clock, max_len=1)

        cache["key"] = "value"
        self.assertEquals(cache.get("key"), "value")
        self.assertEquals(cache["key"], "value")
Пример #2
0
    def test_eviction(self):
        clock = MockClock()
        cache = ExpiringCache("test", clock, max_len=2)

        cache["key"] = "value"
        cache["key2"] = "value2"
        self.assertEquals(cache.get("key"), "value")
        self.assertEquals(cache.get("key2"), "value2")

        cache["key3"] = "value3"
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), "value2")
        self.assertEquals(cache.get("key3"), "value3")
Пример #3
0
    def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates: Dict[str, List[Tuple[str, str, Iterable[str],
                                                    JsonDict]]] = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

        # Attempt to resync out of sync device lists every 30s.
        self._resync_retry_in_progress = False
        self.clock.looping_call(
            run_as_background_process,
            30 * 1000,
            func=self._maybe_retry_device_resync,
            desc="_maybe_retry_device_resync",
        )
    def __init__(self, hs, media_repo, media_storage):
        Resource.__init__(self)

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SpiderHttpClient(hs)
        self.media_repo = media_repo
        self.primary_base_path = media_repo.primary_base_path
        self.media_storage = media_storage

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # memory cache mapping urls to an ObservableDeferred returning
        # JSON-encoded OG metadata
        self._cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )

        self._cleaner_loop = self.clock.looping_call(
            self._start_expire_url_cache_data, 10 * 1000,
        )
Пример #5
0
    def __init__(self, hs, media_repo):
        Resource.__init__(self)

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.version_string = hs.version_string
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SpiderHttpClient(hs)
        self.media_repo = media_repo

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # simple memory cache mapping urls to OG metadata
        self.cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )
        self.cache.start()

        self.downloads = {}
Пример #6
0
    def __init__(self, hs: "HomeServer"):
        self.clock = hs.get_clock()

        self.resolve_linearizer = Linearizer(name="state_resolve_lock")

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache: ExpiringCache[FrozenSet[int],
                                         _StateCacheEntry] = ExpiringCache(
                                             cache_name="state_cache",
                                             clock=self.clock,
                                             max_len=100000,
                                             expiry_ms=EVICTION_TIMEOUT_SECONDS
                                             * 1000,
                                             iterable=True,
                                             reset_expiry_on_get=True,
                                         )

        #
        # stuff for tracking time spent on state-res by room
        #

        # tracks the amount of work done on state res per room
        self._state_res_metrics: DefaultDict[
            str, _StateResMetrics] = defaultdict(_StateResMetrics)

        self.clock.looping_call(self._report_metrics, 120 * 1000)
Пример #7
0
    def __init__(self, hs, media_repo, media_storage):
        super().__init__()

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SimpleHttpClient(
            hs,
            treq_args={"browser_like_redirects": True},
            ip_whitelist=hs.config.url_preview_ip_range_whitelist,
            ip_blacklist=hs.config.url_preview_ip_range_blacklist,
            http_proxy=os.getenvb(b"http_proxy"),
            https_proxy=os.getenvb(b"HTTPS_PROXY"),
        )
        self.media_repo = media_repo
        self.primary_base_path = media_repo.primary_base_path
        self.media_storage = media_storage

        # We run the background jobs if we're the instance specified (or no
        # instance is specified, where we assume there is only one instance
        # serving media).
        instance_running_jobs = hs.config.media.media_instance_running_background_jobs
        self._worker_run_media_background_jobs = (
            instance_running_jobs is None
            or instance_running_jobs == hs.get_instance_name()
        )

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
        self.url_preview_accept_language = hs.config.url_preview_accept_language

        # memory cache mapping urls to an ObservableDeferred returning
        # JSON-encoded OG metadata
        self._cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=ONE_HOUR,
        )

        if self._worker_run_media_background_jobs:
            self._cleaner_loop = self.clock.looping_call(
                self._start_expire_url_cache_data, 10 * 1000
            )
Пример #8
0
    def __init__(self, hs: "HomeServer"):
        super().__init__(hs)

        self.pdu_destination_tried = {}  # type: Dict[str, Dict[str, int]]
        self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
        self.state = hs.get_state_handler()
        self.transport_layer = hs.get_federation_transport_client()

        self.hostname = hs.hostname
        self.signing_key = hs.signing_key

        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )  # type: ExpiringCache[str, EventBase]
Пример #9
0
    def __init__(self, hs):
        super(FederationClient, self).__init__(hs)

        self.pdu_destination_tried = {}
        self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
        self.state = hs.get_state_handler()
        self.transport_layer = hs.get_federation_transport_client()

        self.hostname = hs.hostname
        self.signing_key = hs.config.signing_key[0]

        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )
Пример #10
0
 def get_eachchat_cache_for_openid(self) -> ExpiringCache:
     clock = self.clock
     expiry_ms = 4 * 60 * 1000
     cache = ExpiringCache("ec_code_openid",
                           clock,
                           max_len=0,
                           expiry_ms=expiry_ms,
                           iterable=True)
     return cache
Пример #11
0
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)

        self.checkers = {}  # type: dict[str, UserInteractiveAuthChecker]
        for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
            inst = auth_checker_class(hs)
            if inst.is_enabled():
                self.checkers[inst.AUTH_TYPE] = inst

        self.bcrypt_rounds = hs.config.bcrypt_rounds

        # This is not a cache per se, but a store of all current sessions that
        # expire after N hours
        self.sessions = ExpiringCache(
            cache_name="register_sessions",
            clock=hs.get_clock(),
            expiry_ms=self.SESSION_EXPIRE_MS,
            reset_expiry_on_get=True,
        )

        account_handler = ModuleApi(hs, self)
        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.macaroon_gen = hs.get_macaroon_generator()
        self._password_enabled = hs.config.password_enabled

        # we keep this as a list despite the O(N^2) implication so that we can
        # keep PASSWORD first and avoid confusing clients which pick the first
        # type in the list. (NB that the spec doesn't require us to do so and
        # clients which favour types that they don't understand over those that
        # they do are technically broken)
        login_types = []
        if self._password_enabled:
            login_types.append(LoginType.PASSWORD)
        for provider in self.password_providers:
            if hasattr(provider, "get_supported_login_types"):
                for t in provider.get_supported_login_types().keys():
                    if t not in login_types:
                        login_types.append(t)
        self._supported_login_types = login_types

        # Ratelimiter for failed auth during UIA. Uses same ratelimit config
        # as per `rc_login.failed_attempts`.
        self._failed_uia_attempts_ratelimiter = Ratelimiter()

        self._clock = self.hs.get_clock()
Пример #12
0
    def start_get_pdu_cache(self):
        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )

        self._get_pdu_cache.start()
Пример #13
0
    def test_time_eviction(self):
        clock = MockClock()
        cache = ExpiringCache("test", clock, expiry_ms=1000)
        cache.start()

        cache["key"] = 1
        clock.advance_time(0.5)
        cache["key2"] = 2

        self.assertEquals(cache.get("key"), 1)
        self.assertEquals(cache.get("key2"), 2)

        clock.advance_time(0.9)
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), 2)

        clock.advance_time(1)
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), None)
Пример #14
0
    def __init__(self, db_conn, hs):
        super(TransactionStore, self).__init__(db_conn, hs)

        self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)

        self._destination_retry_cache = ExpiringCache(
            cache_name="get_destination_retry_timings",
            clock=self._clock,
            expiry_ms=5 * 60 * 1000,
        )
Пример #15
0
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)
        self.checkers = {
            LoginType.RECAPTCHA: self._check_recaptcha,
            LoginType.EMAIL_IDENTITY: self._check_email_identity,
            LoginType.MSISDN: self._check_msisdn,
            LoginType.DUMMY: self._check_dummy_auth,
            LoginType.TERMS: self._check_terms_auth,
        }
        self.bcrypt_rounds = hs.config.bcrypt_rounds

        # This is not a cache per se, but a store of all current sessions that
        # expire after N hours
        self.sessions = ExpiringCache(
            cache_name="register_sessions",
            clock=hs.get_clock(),
            expiry_ms=self.SESSION_EXPIRE_MS,
            reset_expiry_on_get=True,
        )

        account_handler = ModuleApi(hs, self)
        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.macaroon_gen = hs.get_macaroon_generator()
        self._password_enabled = hs.config.password_enabled

        # we keep this as a list despite the O(N^2) implication so that we can
        # keep PASSWORD first and avoid confusing clients which pick the first
        # type in the list. (NB that the spec doesn't require us to do so and
        # clients which favour types that they don't understand over those that
        # they do are technically broken)
        login_types = []
        if self._password_enabled:
            login_types.append(LoginType.PASSWORD)
        for provider in self.password_providers:
            if hasattr(provider, "get_supported_login_types"):
                for t in provider.get_supported_login_types().keys():
                    if t not in login_types:
                        login_types.append(t)
        self._supported_login_types = login_types

        self._account_ratelimiter = Ratelimiter()
        self._failed_attempts_ratelimiter = Ratelimiter()

        self._clock = self.hs.get_clock()
Пример #16
0
    def __init__(self, db_conn, hs):
        super(DeviceInboxStore, self).__init__(db_conn, hs)

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )
Пример #17
0
    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )
Пример #18
0
    def start_caching(self):
        logger.debug("start_caching")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=SIZE_OF_CACHE,
            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
            reset_expiry_on_get=True,
        )

        self._state_cache.start()
Пример #19
0
    def test_time_eviction(self):
        clock = MockClock()
        cache = ExpiringCache("test", clock, expiry_ms=1000)
        cache.start()

        cache["key"] = 1
        clock.advance_time(0.5)
        cache["key2"] = 2

        self.assertEquals(cache.get("key"), 1)
        self.assertEquals(cache.get("key2"), 2)

        clock.advance_time(0.9)
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), 2)

        clock.advance_time(1)
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), None)
    def __init__(self, db_conn, hs):
        super(DeviceInboxStore, self).__init__(db_conn, hs)

        self.register_background_index_update(
            "device_inbox_stream_index",
            index_name="device_inbox_stream_id_user_id",
            table="device_inbox",
            columns=["stream_id", "user_id"],
        )

        self.register_background_update_handler(
            self.DEVICE_INBOX_STREAM_ID,
            self._background_drop_index_device_inbox,
        )

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )
Пример #21
0
    def __init__(self, hs, media_repo, media_storage):
        super().__init__()

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SimpleHttpClient(
            hs,
            treq_args={"browser_like_redirects": True},
            ip_whitelist=hs.config.url_preview_ip_range_whitelist,
            ip_blacklist=hs.config.url_preview_ip_range_blacklist,
            http_proxy=os.getenvb(b"http_proxy"),
            https_proxy=os.getenvb(b"HTTPS_PROXY"),
        )
        self.media_repo = media_repo
        self.primary_base_path = media_repo.primary_base_path
        self.media_storage = media_storage

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
        self.url_preview_accept_language = hs.config.url_preview_accept_language

        # memory cache mapping urls to an ObservableDeferred returning
        # JSON-encoded OG metadata
        self._cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )

        self._cleaner_loop = self.clock.looping_call(
            self._start_expire_url_cache_data, 10 * 1000
        )
Пример #22
0
    def __init__(self, hs):
        self.clock = hs.get_clock()

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache = None
        self.resolve_linearizer = Linearizer(name="state_resolve_lock")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=SIZE_OF_CACHE,
            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
            iterable=True,
            reset_expiry_on_get=True,
        )
Пример #23
0
    def __init__(self, hs):
        super(FederationClient, self).__init__(hs)

        self.pdu_destination_tried = {}
        self._clock.looping_call(
            self._clear_tried_cache, 60 * 1000,
        )
        self.state = hs.get_state_handler()
        self.transport_layer = hs.get_federation_transport_client()

        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )
Пример #24
0
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)
        self.checkers = {
            LoginType.PASSWORD: self._check_password_auth,
            LoginType.RECAPTCHA: self._check_recaptcha,
            LoginType.EMAIL_IDENTITY: self._check_email_identity,
            LoginType.MSISDN: self._check_msisdn,
            LoginType.DUMMY: self._check_dummy_auth,
        }
        self.bcrypt_rounds = hs.config.bcrypt_rounds

        # This is not a cache per se, but a store of all current sessions that
        # expire after N hours
        self.sessions = ExpiringCache(
            cache_name="register_sessions",
            clock=hs.get_clock(),
            expiry_ms=self.SESSION_EXPIRE_MS,
            reset_expiry_on_get=True,
        )

        account_handler = ModuleApi(hs, self)
        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.macaroon_gen = hs.get_macaroon_generator()
        self._password_enabled = hs.config.password_enabled

        login_types = set()
        if self._password_enabled:
            login_types.add(LoginType.PASSWORD)
        for provider in self.password_providers:
            if hasattr(provider, "get_supported_login_types"):
                login_types.update(
                    provider.get_supported_login_types().keys()
                )
        self._supported_login_types = frozenset(login_types)
Пример #25
0
    def test_eviction(self):
        clock = MockClock()
        cache = ExpiringCache("test", clock, max_len=2)

        cache["key"] = "value"
        cache["key2"] = "value2"
        self.assertEquals(cache.get("key"), "value")
        self.assertEquals(cache.get("key2"), "value2")

        cache["key3"] = "value3"
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), "value2")
        self.assertEquals(cache.get("key3"), "value3")
Пример #26
0
    def test_iterable_eviction(self):
        clock = MockClock()
        cache = ExpiringCache("test", clock, max_len=5, iterable=True)

        cache["key"] = [1]
        cache["key2"] = [2, 3]
        cache["key3"] = [4, 5]

        self.assertEquals(cache.get("key"), [1])
        self.assertEquals(cache.get("key2"), [2, 3])
        self.assertEquals(cache.get("key3"), [4, 5])

        cache["key4"] = [6, 7]
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), None)
        self.assertEquals(cache.get("key3"), [4, 5])
        self.assertEquals(cache.get("key4"), [6, 7])
Пример #27
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
        self._device_inbox_id_gen = SlavedIdTracker(db_conn, "device_inbox",
                                                    "stream_id")
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token(),
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token(),
        )

        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )
Пример #28
0
    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )
Пример #29
0
    def test_iterable_eviction(self):
        clock = MockClock()
        cache = ExpiringCache("test", clock, max_len=5, iterable=True)

        cache["key"] = [1]
        cache["key2"] = [2, 3]
        cache["key3"] = [4, 5]

        self.assertEquals(cache.get("key"), [1])
        self.assertEquals(cache.get("key2"), [2, 3])
        self.assertEquals(cache.get("key3"), [4, 5])

        cache["key4"] = [6, 7]
        self.assertEquals(cache.get("key"), None)
        self.assertEquals(cache.get("key2"), None)
        self.assertEquals(cache.get("key3"), [4, 5])
        self.assertEquals(cache.get("key4"), [6, 7])
Пример #30
0
    def __init__(self, hs):
        """
        Args:
            hs (synapse.server.HomeServer):
        """
        super(AuthHandler, self).__init__(hs)
        self.checkers = {
            LoginType.PASSWORD: self._check_password_auth,
            LoginType.RECAPTCHA: self._check_recaptcha,
            LoginType.EMAIL_IDENTITY: self._check_email_identity,
            LoginType.MSISDN: self._check_msisdn,
            LoginType.DUMMY: self._check_dummy_auth,
        }
        self.bcrypt_rounds = hs.config.bcrypt_rounds

        # This is not a cache per se, but a store of all current sessions that
        # expire after N hours
        self.sessions = ExpiringCache(
            cache_name="register_sessions",
            clock=hs.get_clock(),
            expiry_ms=self.SESSION_EXPIRE_MS,
            reset_expiry_on_get=True,
        )

        account_handler = _AccountHandler(
            hs, check_user_exists=self.check_user_exists
        )

        self.password_providers = [
            module(config=config, account_handler=account_handler)
            for module, config in hs.config.password_providers
        ]

        logger.info("Extra password_providers: %r", self.password_providers)

        self.hs = hs  # FIXME better possibility to access registrationHandler later?
        self.device_handler = hs.get_device_handler()
        self.macaroon_gen = hs.get_macaroon_generator()
Пример #31
0
    def __init__(self, db_conn, hs):
        super(DeviceInboxStore, self).__init__(db_conn, hs)

        self.register_background_index_update(
            "device_inbox_stream_index",
            index_name="device_inbox_stream_id_user_id",
            table="device_inbox",
            columns=["stream_id", "user_id"],
        )

        self.register_background_update_handler(
            self.DEVICE_INBOX_STREAM_ID,
            self._background_drop_index_device_inbox,
        )

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )
Пример #32
0
class DeviceListEduUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]   # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning("Got device list update edu for %r from %r", user_id, origin)
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            return

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            if resync:
                # Fetch all devices for the user.
                origin = get_domain_from_id(user_id)
                try:
                    result = yield self.federation.query_user_devices(origin, user_id)
                except NotRetryingDestination:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.warn(
                        "Failed to handle device list update for %s,"
                        " we're not retrying the remote",
                        user_id,
                    )
                    # We abort on exceptions rather than accepting the update
                    # as otherwise synapse will 'forget' that its device list
                    # is out of date. If we bail then we will retry the resync
                    # next time we get a device list update for this user_id.
                    # This makes it more likely that the device lists will
                    # eventually become consistent.
                    return
                except FederationDeniedError as e:
                    logger.info(e)
                    return
                except Exception:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.exception(
                        "Failed to handle device list update for %s", user_id
                    )
                    return

                stream_id = result["stream_id"]
                devices = result["devices"]
                yield self.store.update_remote_device_list_cache(
                    user_id, devices, stream_id,
                )
                device_ids = [device["device_id"] for device in devices]
                yield self.device_handler.notify_device_update(user_id, device_ids)
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id,
                    )

                yield self.device_handler.notify_device_update(
                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                )

            self._seen_updates.setdefault(user_id, set()).update(
                stream_id for _, stream_id, _, _ in pending_updates
            )

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(
            user_id
        )

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                defer.returnValue(True)

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    defer.returnValue(True)

            stream_id_in_updates.add(stream_id)

        defer.returnValue(False)
Пример #33
0
class FederationClient(FederationBase):
    def __init__(self, hs):
        super(FederationClient, self).__init__(hs)

        self.pdu_destination_tried = {}
        self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
        self.state = hs.get_state_handler()
        self.transport_layer = hs.get_federation_transport_client()

        self.hostname = hs.hostname
        self.signing_key = hs.config.signing_key[0]

        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )

    def _clear_tried_cache(self):
        """Clear pdu_destination_tried cache"""
        now = self._clock.time_msec()

        old_dict = self.pdu_destination_tried
        self.pdu_destination_tried = {}

        for event_id, destination_dict in old_dict.items():
            destination_dict = {
                dest: time
                for dest, time in destination_dict.items()
                if time + PDU_RETRY_TIME_MS > now
            }
            if destination_dict:
                self.pdu_destination_tried[event_id] = destination_dict

    @log_function
    def make_query(
        self,
        destination,
        query_type,
        args,
        retry_on_dns_fail=False,
        ignore_backoff=False,
    ):
        """Sends a federation Query to a remote homeserver of the given type
        and arguments.

        Args:
            destination (str): Domain name of the remote homeserver
            query_type (str): Category of the query type; should match the
                handler name used in register_query_handler().
            args (dict): Mapping of strings to strings containing the details
                of the query request.
            ignore_backoff (bool): true to ignore the historical backoff data
                and try the request anyway.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.labels(query_type).inc()

        return self.transport_layer.make_query(
            destination,
            query_type,
            args,
            retry_on_dns_fail=retry_on_dns_fail,
            ignore_backoff=ignore_backoff,
        )

    @log_function
    def query_client_keys(self, destination, content, timeout):
        """Query device keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.labels("client_device_keys").inc()
        return self.transport_layer.query_client_keys(destination, content, timeout)

    @log_function
    def query_user_devices(self, destination, user_id, timeout=30000):
        """Query the device keys for a list of user ids hosted on a remote
        server.
        """
        sent_queries_counter.labels("user_devices").inc()
        return self.transport_layer.query_user_devices(destination, user_id, timeout)

    @log_function
    def claim_client_keys(self, destination, content, timeout):
        """Claims one-time keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.labels("client_one_time_keys").inc()
        return self.transport_layer.claim_client_keys(destination, content, timeout)

    @defer.inlineCallbacks
    @log_function
    def backfill(self, dest, room_id, limit, extremities):
        """Requests some more historic PDUs for the given context from the
        given destination server.

        Args:
            dest (str): The remote home server to ask.
            room_id (str): The room_id to backfill.
            limit (int): The maximum number of PDUs to return.
            extremities (list): List of PDU id and origins of the first pdus
                we have seen from the context

        Returns:
            Deferred: Results in the received PDUs.
        """
        logger.debug("backfill extrem=%s", extremities)

        # If there are no extremeties then we've (probably) reached the start.
        if not extremities:
            return

        transaction_data = yield self.transport_layer.backfill(
            dest, room_id, extremities, limit
        )

        logger.debug("backfill transaction_data=%s", repr(transaction_data))

        room_version = yield self.store.get_room_version(room_id)
        format_ver = room_version_to_event_format(room_version)

        pdus = [
            event_from_pdu_json(p, format_ver, outlier=False)
            for p in transaction_data["pdus"]
        ]

        # FIXME: We should handle signature failures more gracefully.
        pdus[:] = yield make_deferred_yieldable(
            defer.gatherResults(
                self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True
            ).addErrback(unwrapFirstError)
        )

        return pdus

    @defer.inlineCallbacks
    @log_function
    def get_pdu(
        self, destinations, event_id, room_version, outlier=False, timeout=None
    ):
        """Requests the PDU with given origin and ID from the remote home
        servers.

        Will attempt to get the PDU from each destination in the list until
        one succeeds.

        Args:
            destinations (list): Which home servers to query
            event_id (str): event to fetch
            room_version (str): version of the room
            outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                it's from an arbitary point in the context as opposed to part
                of the current block of PDUs. Defaults to `False`
            timeout (int): How long to try (in ms) each destination for before
                moving to the next destination. None indicates no timeout.

        Returns:
            Deferred: Results in the requested PDU, or None if we were unable to find
               it.
        """

        # TODO: Rate limit the number of times we try and get the same event.

        ev = self._get_pdu_cache.get(event_id)
        if ev:
            return ev

        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})

        format_ver = room_version_to_event_format(room_version)

        signed_pdu = None
        for destination in destinations:
            now = self._clock.time_msec()
            last_attempt = pdu_attempts.get(destination, 0)
            if last_attempt + PDU_RETRY_TIME_MS > now:
                continue

            try:
                transaction_data = yield self.transport_layer.get_event(
                    destination, event_id, timeout=timeout
                )

                logger.debug(
                    "retrieved event id %s from %s: %r",
                    event_id,
                    destination,
                    transaction_data,
                )

                pdu_list = [
                    event_from_pdu_json(p, format_ver, outlier=outlier)
                    for p in transaction_data["pdus"]
                ]

                if pdu_list and pdu_list[0]:
                    pdu = pdu_list[0]

                    # Check signatures are correct.
                    signed_pdu = yield self._check_sigs_and_hash(room_version, pdu)

                    break

                pdu_attempts[destination] = now

            except SynapseError as e:
                logger.info(
                    "Failed to get PDU %s from %s because %s", event_id, destination, e
                )
                continue
            except NotRetryingDestination as e:
                logger.info(str(e))
                continue
            except FederationDeniedError as e:
                logger.info(str(e))
                continue
            except Exception as e:
                pdu_attempts[destination] = now

                logger.info(
                    "Failed to get PDU %s from %s because %s", event_id, destination, e
                )
                continue

        if signed_pdu:
            self._get_pdu_cache[event_id] = signed_pdu

        return signed_pdu

    @defer.inlineCallbacks
    @log_function
    def get_state_for_room(self, destination, room_id, event_id):
        """Requests all of the room state at a given event from a remote home server.

        Args:
            destination (str): The remote homeserver to query for the state.
            room_id (str): The id of the room we're interested in.
            event_id (str): The id of the event we want the state at.

        Returns:
            Deferred[Tuple[List[EventBase], List[EventBase]]]:
                A list of events in the state, and a list of events in the auth chain
                for the given event.
        """
        try:
            # First we try and ask for just the IDs, as thats far quicker if
            # we have most of the state and auth_chain already.
            # However, this may 404 if the other side has an old synapse.
            result = yield self.transport_layer.get_room_state_ids(
                destination, room_id, event_id=event_id
            )

            state_event_ids = result["pdu_ids"]
            auth_event_ids = result.get("auth_chain_ids", [])

            fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
                destination, room_id, set(state_event_ids + auth_event_ids)
            )

            if failed_to_fetch:
                logger.warning(
                    "Failed to fetch missing state/auth events for %s: %s",
                    room_id,
                    failed_to_fetch,
                )

            event_map = {ev.event_id: ev for ev in fetched_events}

            pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
            auth_chain = [
                event_map[e_id] for e_id in auth_event_ids if e_id in event_map
            ]

            auth_chain.sort(key=lambda e: e.depth)

            return pdus, auth_chain
        except HttpResponseException as e:
            if e.code == 400 or e.code == 404:
                logger.info("Failed to use get_room_state_ids API, falling back")
            else:
                raise e

        result = yield self.transport_layer.get_room_state(
            destination, room_id, event_id=event_id
        )

        room_version = yield self.store.get_room_version(room_id)
        format_ver = room_version_to_event_format(room_version)

        pdus = [
            event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"]
        ]

        auth_chain = [
            event_from_pdu_json(p, format_ver, outlier=True)
            for p in result.get("auth_chain", [])
        ]

        seen_events = yield self.store.get_events(
            [ev.event_id for ev in itertools.chain(pdus, auth_chain)]
        )

        signed_pdus = yield self._check_sigs_and_hash_and_fetch(
            destination,
            [p for p in pdus if p.event_id not in seen_events],
            outlier=True,
            room_version=room_version,
        )
        signed_pdus.extend(
            seen_events[p.event_id] for p in pdus if p.event_id in seen_events
        )

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination,
            [p for p in auth_chain if p.event_id not in seen_events],
            outlier=True,
            room_version=room_version,
        )
        signed_auth.extend(
            seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
        )

        signed_auth.sort(key=lambda e: e.depth)

        return signed_pdus, signed_auth

    @defer.inlineCallbacks
    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
        """Fetch events from a remote destination, checking if we already have them.

        Args:
            destination (str)
            room_id (str)
            event_ids (list)

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
        signed_events = list(seen_events.values())

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            return signed_events, failed_to_fetch

        logger.debug(
            "Fetching unknown state/auth events %s for room %s",
            missing_events,
            event_ids,
        )

        room_version = yield self.store.get_room_version(room_id)

        batch_size = 20
        missing_events = list(missing_events)
        for i in range(0, len(missing_events), batch_size):
            batch = set(missing_events[i : i + batch_size])

            deferreds = [
                run_in_background(
                    self.get_pdu,
                    destinations=[destination],
                    event_id=e_id,
                    room_version=room_version,
                )
                for e_id in batch
            ]

            res = yield make_deferred_yieldable(
                defer.DeferredList(deferreds, consumeErrors=True)
            )
            for success, result in res:
                if success and result:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        return signed_events, failed_to_fetch

    @defer.inlineCallbacks
    @log_function
    def get_event_auth(self, destination, room_id, event_id):
        res = yield self.transport_layer.get_event_auth(destination, room_id, event_id)

        room_version = yield self.store.get_room_version(room_id)
        format_ver = room_version_to_event_format(room_version)

        auth_chain = [
            event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"]
        ]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination, auth_chain, outlier=True, room_version=room_version
        )

        signed_auth.sort(key=lambda e: e.depth)

        return signed_auth

    @defer.inlineCallbacks
    def _try_destination_list(self, description, destinations, callback):
        """Try an operation on a series of servers, until it succeeds

        Args:
            description (unicode): description of the operation we're doing, for logging

            destinations (Iterable[unicode]): list of server_names to try

            callback (callable):  Function to run for each server. Passed a single
                argument: the server_name to try. May return a deferred.

                If the callback raises a CodeMessageException with a 300/400 code,
                attempts to perform the operation stop immediately and the exception is
                reraised.

                Otherwise, if the callback raises an Exception the error is logged and the
                next server tried. Normally the stacktrace is logged but this is
                suppressed if the exception is an InvalidResponseError.

        Returns:
            The [Deferred] result of callback, if it succeeds

        Raises:
            SynapseError if the chosen remote server returns a 300/400 code, or
            no servers were reachable.
        """
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                res = yield callback(destination)
                return res
            except InvalidResponseError as e:
                logger.warn("Failed to %s via %s: %s", description, destination, e)
            except HttpResponseException as e:
                if not 500 <= e.code < 600:
                    raise e.to_synapse_error()
                else:
                    logger.warn(
                        "Failed to %s via %s: %i %s",
                        description,
                        destination,
                        e.code,
                        e.args[0],
                    )
            except Exception:
                logger.warn("Failed to %s via %s", description, destination, exc_info=1)

        raise SynapseError(502, "Failed to %s via any server" % (description,))

    def make_membership_event(
        self, destinations, room_id, user_id, membership, content, params
    ):
        """
        Creates an m.room.member event, with context, without participating in the room.

        Does so by asking one of the already participating servers to create an
        event with proper context.

        Returns a fully signed and hashed event.

        Note that this does not append any events to any graphs.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            room_id (str): The room in which the event will happen.
            user_id (str): The user whose membership is being evented.
            membership (str): The "membership" property of the event. Must be
                one of "join" or "leave".
            content (dict): Any additional data to put into the content field
                of the event.
            params (dict[str, str|Iterable[str]]): Query parameters to include in the
                request.
        Return:
            Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
            `(origin, event, event_format)` where origin is the remote
            homeserver which generated the event, and event_format is one of
            `synapse.api.room_versions.EventFormatVersions`.

            Fails with a ``SynapseError`` if the chosen remote server
            returns a 300/400 code.

            Fails with a ``RuntimeError`` if no servers were reachable.
        """
        valid_memberships = {Membership.JOIN, Membership.LEAVE}
        if membership not in valid_memberships:
            raise RuntimeError(
                "make_membership_event called with membership='%s', must be one of %s"
                % (membership, ",".join(valid_memberships))
            )

        @defer.inlineCallbacks
        def send_request(destination):
            ret = yield self.transport_layer.make_membership_event(
                destination, room_id, user_id, membership, params
            )

            # Note: If not supplied, the room version may be either v1 or v2,
            # however either way the event format version will be v1.
            room_version = ret.get("room_version", RoomVersions.V1.identifier)
            event_format = room_version_to_event_format(room_version)

            pdu_dict = ret.get("event", None)
            if not isinstance(pdu_dict, dict):
                raise InvalidResponseError("Bad 'event' field in response")

            logger.debug("Got response to make_%s: %s", membership, pdu_dict)

            pdu_dict["content"].update(content)

            # The protoevent received over the JSON wire may not have all
            # the required fields. Lets just gloss over that because
            # there's some we never care about
            if "prev_state" not in pdu_dict:
                pdu_dict["prev_state"] = []

            ev = builder.create_local_event_from_event_dict(
                self._clock,
                self.hostname,
                self.signing_key,
                format_version=event_format,
                event_dict=pdu_dict,
            )

            return (destination, ev, event_format)

        return self._try_destination_list(
            "make_" + membership, destinations, send_request
        )

    def send_join(self, destinations, pdu, event_format_version):
        """Sends a join event to one of a list of homeservers.

        Doing so will cause the remote server to add the event to the graph,
        and send the event out to the rest of the federation.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            pdu (BaseEvent): event to be sent
            event_format_version (int): The event format version

        Return:
            Deferred: resolves to a dict with members ``origin`` (a string
            giving the serer the event was sent to, ``state`` (?) and
            ``auth_chain``.

            Fails with a ``SynapseError`` if the chosen remote server
            returns a 300/400 code.

            Fails with a ``RuntimeError`` if no servers were reachable.
        """

        def check_authchain_validity(signed_auth_chain):
            for e in signed_auth_chain:
                if e.type == EventTypes.Create:
                    create_event = e
                    break
            else:
                raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,))

            # the room version should be sane.
            room_version = create_event.content.get("room_version", "1")
            if room_version not in KNOWN_ROOM_VERSIONS:
                # This shouldn't be possible, because the remote server should have
                # rejected the join attempt during make_join.
                raise InvalidResponseError(
                    "room appears to have unsupported version %s" % (room_version,)
                )

        @defer.inlineCallbacks
        def send_request(destination):
            time_now = self._clock.time_msec()
            _, content = yield self.transport_layer.send_join(
                destination=destination,
                room_id=pdu.room_id,
                event_id=pdu.event_id,
                content=pdu.get_pdu_json(time_now),
            )

            logger.debug("Got content: %s", content)

            state = [
                event_from_pdu_json(p, event_format_version, outlier=True)
                for p in content.get("state", [])
            ]

            auth_chain = [
                event_from_pdu_json(p, event_format_version, outlier=True)
                for p in content.get("auth_chain", [])
            ]

            pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}

            room_version = None
            for e in state:
                if (e.type, e.state_key) == (EventTypes.Create, ""):
                    room_version = e.content.get(
                        "room_version", RoomVersions.V1.identifier
                    )
                    break

            if room_version is None:
                # If the state doesn't have a create event then the room is
                # invalid, and it would fail auth checks anyway.
                raise SynapseError(400, "No create event in state")

            valid_pdus = yield self._check_sigs_and_hash_and_fetch(
                destination,
                list(pdus.values()),
                outlier=True,
                room_version=room_version,
            )

            valid_pdus_map = {p.event_id: p for p in valid_pdus}

            # NB: We *need* to copy to ensure that we don't have multiple
            # references being passed on, as that causes... issues.
            signed_state = [
                copy.copy(valid_pdus_map[p.event_id])
                for p in state
                if p.event_id in valid_pdus_map
            ]

            signed_auth = [
                valid_pdus_map[p.event_id]
                for p in auth_chain
                if p.event_id in valid_pdus_map
            ]

            # NB: We *need* to copy to ensure that we don't have multiple
            # references being passed on, as that causes... issues.
            for s in signed_state:
                s.internal_metadata = copy.deepcopy(s.internal_metadata)

            check_authchain_validity(signed_auth)

            return {
                "state": signed_state,
                "auth_chain": signed_auth,
                "origin": destination,
            }

        return self._try_destination_list("send_join", destinations, send_request)

    @defer.inlineCallbacks
    def send_invite(self, destination, room_id, event_id, pdu):
        room_version = yield self.store.get_room_version(room_id)

        content = yield self._do_send_invite(destination, pdu, room_version)

        pdu_dict = content["event"]

        logger.debug("Got response to send_invite: %s", pdu_dict)

        room_version = yield self.store.get_room_version(room_id)
        format_ver = room_version_to_event_format(room_version)

        pdu = event_from_pdu_json(pdu_dict, format_ver)

        # Check signatures are correct.
        pdu = yield self._check_sigs_and_hash(room_version, pdu)

        # FIXME: We should handle signature failures more gracefully.

        return pdu

    @defer.inlineCallbacks
    def _do_send_invite(self, destination, pdu, room_version):
        """Actually sends the invite, first trying v2 API and falling back to
        v1 API if necessary.

        Args:
            destination (str): Target server
            pdu (FrozenEvent)
            room_version (str)

        Returns:
            dict: The event as a dict as returned by the remote server
        """
        time_now = self._clock.time_msec()

        try:
            content = yield self.transport_layer.send_invite_v2(
                destination=destination,
                room_id=pdu.room_id,
                event_id=pdu.event_id,
                content={
                    "event": pdu.get_pdu_json(time_now),
                    "room_version": room_version,
                    "invite_room_state": pdu.unsigned.get("invite_room_state", []),
                },
            )
            return content
        except HttpResponseException as e:
            if e.code in [400, 404]:
                err = e.to_synapse_error()

                # If we receive an error response that isn't a generic error, we
                # assume that the remote understands the v2 invite API and this
                # is a legitimate error.
                if err.errcode != Codes.UNKNOWN:
                    raise err

                # Otherwise, we assume that the remote server doesn't understand
                # the v2 invite API. That's ok provided the room uses old-style event
                # IDs.
                v = KNOWN_ROOM_VERSIONS.get(room_version)
                if v.event_format != EventFormatVersions.V1:
                    raise SynapseError(
                        400,
                        "User's homeserver does not support this room version",
                        Codes.UNSUPPORTED_ROOM_VERSION,
                    )
            elif e.code == 403:
                raise e.to_synapse_error()
            else:
                raise

        # Didn't work, try v1 API.
        # Note the v1 API returns a tuple of `(200, content)`

        _, content = yield self.transport_layer.send_invite_v1(
            destination=destination,
            room_id=pdu.room_id,
            event_id=pdu.event_id,
            content=pdu.get_pdu_json(time_now),
        )
        return content

    def send_leave(self, destinations, pdu):
        """Sends a leave event to one of a list of homeservers.

        Doing so will cause the remote server to add the event to the graph,
        and send the event out to the rest of the federation.

        This is mostly useful to reject received invites.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            pdu (BaseEvent): event to be sent

        Return:
            Deferred: resolves to None.

            Fails with a ``SynapseError`` if the chosen remote server
            returns a 300/400 code.

            Fails with a ``RuntimeError`` if no servers were reachable.
        """

        @defer.inlineCallbacks
        def send_request(destination):
            time_now = self._clock.time_msec()
            _, content = yield self.transport_layer.send_leave(
                destination=destination,
                room_id=pdu.room_id,
                event_id=pdu.event_id,
                content=pdu.get_pdu_json(time_now),
            )

            logger.debug("Got content: %s", content)
            return None

        return self._try_destination_list("send_leave", destinations, send_request)

    def get_public_rooms(
        self,
        destination,
        limit=None,
        since_token=None,
        search_filter=None,
        include_all_networks=False,
        third_party_instance_id=None,
    ):
        if destination == self.server_name:
            return

        return self.transport_layer.get_public_rooms(
            destination,
            limit,
            since_token,
            search_filter,
            include_all_networks=include_all_networks,
            third_party_instance_id=third_party_instance_id,
        )

    @defer.inlineCallbacks
    def query_auth(self, destination, room_id, event_id, local_auth):
        """
        Params:
            destination (str)
            event_it (str)
            local_auth (list)
        """
        time_now = self._clock.time_msec()

        send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]}

        code, content = yield self.transport_layer.send_query_auth(
            destination=destination,
            room_id=room_id,
            event_id=event_id,
            content=send_content,
        )

        room_version = yield self.store.get_room_version(room_id)
        format_ver = room_version_to_event_format(room_version)

        auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination, auth_chain, outlier=True, room_version=room_version
        )

        signed_auth.sort(key=lambda e: e.depth)

        ret = {
            "auth_chain": signed_auth,
            "rejects": content.get("rejects", []),
            "missing": content.get("missing", []),
        }

        return ret

    @defer.inlineCallbacks
    def get_missing_events(
        self,
        destination,
        room_id,
        earliest_events_ids,
        latest_events,
        limit,
        min_depth,
        timeout,
    ):
        """Tries to fetch events we are missing. This is called when we receive
        an event without having received all of its ancestors.

        Args:
            destination (str)
            room_id (str)
            earliest_events_ids (list): List of event ids. Effectively the
                events we expected to receive, but haven't. `get_missing_events`
                should only return events that didn't happen before these.
            latest_events (list): List of events we have received that we don't
                have all previous events for.
            limit (int): Maximum number of events to return.
            min_depth (int): Minimum depth of events tor return.
            timeout (int): Max time to wait in ms
        """
        try:
            content = yield self.transport_layer.get_missing_events(
                destination=destination,
                room_id=room_id,
                earliest_events=earliest_events_ids,
                latest_events=[e.event_id for e in latest_events],
                limit=limit,
                min_depth=min_depth,
                timeout=timeout,
            )

            room_version = yield self.store.get_room_version(room_id)
            format_ver = room_version_to_event_format(room_version)

            events = [
                event_from_pdu_json(e, format_ver) for e in content.get("events", [])
            ]

            signed_events = yield self._check_sigs_and_hash_and_fetch(
                destination, events, outlier=False, room_version=room_version
            )
        except HttpResponseException as e:
            if not e.code == 400:
                raise

            # We are probably hitting an old server that doesn't support
            # get_missing_events
            signed_events = []

        return signed_events

    @defer.inlineCallbacks
    def forward_third_party_invite(self, destinations, room_id, event_dict):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                yield self.transport_layer.exchange_third_party_invite(
                    destination=destination, room_id=room_id, event_dict=event_dict
                )
                return None
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception(
                    "Failed to send_third_party_invite via %s: %s", destination, str(e)
                )

        raise RuntimeError("Failed to send to any server.")

    @defer.inlineCallbacks
    def get_room_complexity(self, destination, room_id):
        """
        Fetch the complexity of a remote room from another server.

        Args:
            destination (str): The remote server
            room_id (str): The room ID to ask about.

        Returns:
            Deferred[dict] or Deferred[None]: Dict contains the complexity
            metric versions, while None means we could not fetch the complexity.
        """
        try:
            complexity = yield self.transport_layer.get_room_complexity(
                destination=destination, room_id=room_id
            )
            defer.returnValue(complexity)
        except CodeMessageException as e:
            # We didn't manage to get it -- probably a 404. We are okay if other
            # servers don't give it to us.
            logger.debug(
                "Failed to fetch room complexity via %s for %s, got a %d",
                destination,
                room_id,
                e.code,
            )
        except Exception:
            logger.exception(
                "Failed to fetch room complexity via %s for %s", destination, room_id
            )

        # If we don't manage to find it, return None. It's not an error if a
        # server doesn't give it to us.
        defer.returnValue(None)
Пример #34
0
class PreviewUrlResource(Resource):
    isLeaf = True

    def __init__(self, hs, media_repo):
        Resource.__init__(self)

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.version_string = hs.version_string
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SpiderHttpClient(hs)
        self.media_repo = media_repo

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # simple memory cache mapping urls to OG metadata
        self.cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )
        self.cache.start()

        self.downloads = {}

    def render_GET(self, request):
        self._async_render_GET(request)
        return NOT_DONE_YET

    @request_handler()
    @defer.inlineCallbacks
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug((
                    "Matching attrib '%s' with value '%s' against"
                    " pattern '%s'"
                ) % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warn(
                    "URL %s blocked by url_blacklist entry %s", url, entry
                )
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN
                )

        # first check the memory cache - good to handle all the clients on this
        # HS thundering away to preview the same URL at the same time.
        og = self.cache.get(url)
        if og:
            respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
            return

        # then check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (
            cache_result and
            cache_result["download_ts"] + cache_result["expires"] > ts and
            cache_result["response_code"] / 100 == 2
        ):
            respond_with_json_bytes(
                request, 200, cache_result["og"].encode('utf-8'),
                send_cors=True
            )
            return

        # Ensure only one download for a given URL is active at a time
        download = self.downloads.get(url)
        if download is None:
            download = self._download_url(url, requester.user)
            download = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self.downloads[url] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[url]
                return media_info
        media_info = yield download.observe()

        # FIXME: we should probably update our cache now anyway, so that
        # even if the OG calculation raises, we don't keep hammering on the
        # remote server.  For now, leave it uncached to aid debugging OG
        # calculation problems

        logger.debug("got media_info of '%s'" % media_info)

        if self._is_media(media_info['media_type']):
            dims = yield self.media_repo._generate_local_thumbnails(
                media_info['filesystem_id'], media_info
            )

            og = {
                "og:description": media_info['download_name'],
                "og:image": "mxc://%s/%s" % (
                    self.server_name, media_info['filesystem_id']
                ),
                "og:image:type": media_info['media_type'],
                "matrix:image:size": media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif self._is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            from lxml import etree

            file = open(media_info['filename'])
            body = file.read()
            file.close()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
            encoding = match.group(1) if match else "utf-8"

            try:
                parser = etree.HTMLParser(recover=True, encoding=encoding)
                tree = etree.fromstring(body, parser)
                og = yield self._calc_og(tree, media_info, requester)
            except UnicodeDecodeError:
                # blindly try decoding the body as utf-8, which seems to fix
                # the charset mismatches on https://google.com
                parser = etree.HTMLParser(recover=True, encoding=encoding)
                tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
                og = yield self._calc_og(tree, media_info, requester)

        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        # store OG in ephemeral in-memory cache
        self.cache[url] = og

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"],
            json.dumps(og),
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)

    @defer.inlineCallbacks
    def _calc_og(self, tree, media_info, requester):
        # suck our tree into lxml and define our OG response.

        # if we see any image URLs in the OG response, then spider them
        # (although the client could choose to do this by asking for previews of those
        # URLs to avoid DoSing the server)

        # "og:type"         : "video",
        # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
        # "og:site_name"    : "YouTube",
        # "og:video:type"   : "application/x-shockwave-flash",
        # "og:description"  : "Fun stuff happening here",
        # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
        # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
        # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
        # "og:video:width"  : "1280"
        # "og:video:height" : "720",
        # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",

        og = {}
        for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
            if 'content' in tag.attrib:
                og[tag.attrib['property']] = tag.attrib['content']

        # TODO: grab article: meta tags too, e.g.:

        # "article:publisher" : "https://www.facebook.com/thethudonline" />
        # "article:author" content="https://www.facebook.com/thethudonline" />
        # "article:tag" content="baby" />
        # "article:section" content="Breaking News" />
        # "article:published_time" content="2016-03-31T19:58:24+00:00" />
        # "article:modified_time" content="2016-04-01T18:31:53+00:00" />

        if 'og:title' not in og:
            # do some basic spidering of the HTML
            title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
            og['og:title'] = title[0].text.strip() if title else None

        if 'og:image' not in og:
            # TODO: extract a favicon failing all else
            meta_image = tree.xpath(
                "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
            )
            if meta_image:
                og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
            else:
                # TODO: consider inlined CSS styles as well as width & height attribs
                images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
                images = sorted(images, key=lambda i: (
                    -1 * float(i.attrib['width']) * float(i.attrib['height'])
                ))
                if not images:
                    images = tree.xpath("//img[@src]")
                if images:
                    og['og:image'] = images[0].attrib['src']

        # pre-cache the image for posterity
        # FIXME: it might be cleaner to use the same flow as the main /preview_url
        # request itself and benefit from the same caching etc.  But for now we
        # just rely on the caching on the master request to speed things up.
        if 'og:image' in og and og['og:image']:
            image_info = yield self._download_url(
                self._rebase_url(og['og:image'], media_info['uri']), requester.user
            )

            if self._is_media(image_info['media_type']):
                # TODO: make sure we don't choke on white-on-transparent images
                dims = yield self.media_repo._generate_local_thumbnails(
                    image_info['filesystem_id'], image_info
                )
                if dims:
                    og["og:image:width"] = dims['width']
                    og["og:image:height"] = dims['height']
                else:
                    logger.warn("Couldn't get dims for %s" % og["og:image"])

                og["og:image"] = "mxc://%s/%s" % (
                    self.server_name, image_info['filesystem_id']
                )
                og["og:image:type"] = image_info['media_type']
                og["matrix:image:size"] = image_info['media_length']
            else:
                del og["og:image"]

        if 'og:description' not in og:
            meta_description = tree.xpath(
                "//*/meta"
                "[translate(@name, 'DESCRIPTION', 'description')='description']"
                "/@content")
            if meta_description:
                og['og:description'] = meta_description[0]
            else:
                # grab any text nodes which are inside the <body/> tag...
                # unless they are within an HTML5 semantic markup tag...
                # <header/>, <nav/>, <aside/>, <footer/>
                # ...or if they are within a <script/> or <style/> tag.
                # This is a very very very coarse approximation to a plain text
                # render of the page.
                text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
                                        "ancestor::aside | ancestor::footer | "
                                        "ancestor::script | ancestor::style)]" +
                                        "[ancestor::body]")
                text = ''
                for text_node in text_nodes:
                    if len(text) < 500:
                        text += text_node + ' '
                    else:
                        break
                text = re.sub(r'[\t ]+', ' ', text)
                text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
                text = text.strip()[:500]
                og['og:description'] = text if text else None

        # TODO: delete the url downloads to stop diskfilling,
        # as we only ever cared about its OG
        defer.returnValue(og)

    def _rebase_url(self, url, base):
        base = list(urlparse.urlparse(base))
        url = list(urlparse.urlparse(url))
        if not url[0]:  # fix up schema
            url[0] = base[0] or "http"
        if not url[1]:  # fix up hostname
            url[1] = base[1]
            if not url[2].startswith('/'):
                url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
        return urlparse.urlunparse(url)

    @defer.inlineCallbacks
    def _download_url(self, url, user):
        # TODO: we should probably honour robots.txt... except in practice
        # we're most likely being explicitly triggered by a human rather than a
        # bot, so are we really a robot?

        # XXX: horrible duplication with base_resource's _download_remote_file()
        file_id = random_string(24)

        fname = self.filepaths.local_media_filepath(file_id)
        self.media_repo._makedirs(fname)

        try:
            with open(fname, "wb") as f:
                logger.debug("Trying to get url '%s'" % url)
                length, headers, uri, code = yield self.client.get_file(
                    url, output_stream=f, max_size=self.max_spider_size,
                )
                # FIXME: pass through 404s and other error messages nicely

            media_type = headers["Content-Type"][0]
            time_now_ms = self.clock.time_msec()

            content_disposition = headers.get("Content-Disposition", None)
            if content_disposition:
                _, params = cgi.parse_header(content_disposition[0],)
                download_name = None

                # First check if there is a valid UTF-8 filename
                download_name_utf8 = params.get("filename*", None)
                if download_name_utf8:
                    if download_name_utf8.lower().startswith("utf-8''"):
                        download_name = download_name_utf8[7:]

                # If there isn't check for an ascii name.
                if not download_name:
                    download_name_ascii = params.get("filename", None)
                    if download_name_ascii and is_ascii(download_name_ascii):
                        download_name = download_name_ascii

                if download_name:
                    download_name = urlparse.unquote(download_name)
                    try:
                        download_name = download_name.decode("utf-8")
                    except UnicodeDecodeError:
                        download_name = None
            else:
                download_name = None

            yield self.store.store_local_media(
                media_id=file_id,
                media_type=media_type,
                time_now_ms=self.clock.time_msec(),
                upload_name=download_name,
                media_length=length,
                user_id=user,
            )

        except Exception as e:
            os.remove(fname)
            raise SynapseError(
                500, ("Failed to download content: %s" % e),
                Codes.UNKNOWN
            )

        defer.returnValue({
            "media_type": media_type,
            "media_length": length,
            "download_name": download_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
            "filename": fname,
            "uri": uri,
            "response_code": code,
            # FIXME: we should calculate a proper expiration based on the
            # Cache-Control and Expire headers.  But for now, assume 1 hour.
            "expires": 60 * 60 * 1000,
            "etag": headers["ETag"][0] if "ETag" in headers else None,
        })

    def _is_media(self, content_type):
        if content_type.lower().startswith("image/"):
            return True

    def _is_html(self, content_type):
        content_type = content_type.lower()
        if (
            content_type.startswith("text/html") or
            content_type.startswith("application/xhtml")
        ):
            return True
Пример #35
0
class StateResolutionHandler(object):
    """Responsible for doing state conflict resolution.

    Note that the storage layer depends on this handler, so all functions must
    be storage-independent.
    """
    def __init__(self, hs):
        self.clock = hs.get_clock()

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache = None
        self.resolve_linearizer = Linearizer(name="state_resolve_lock")

    def start_caching(self):
        logger.debug("start_caching")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=SIZE_OF_CACHE,
            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
            iterable=True,
            reset_expiry_on_get=True,
        )

        self._state_cache.start()

    @defer.inlineCallbacks
    @log_function
    def resolve_state_groups(
        self, room_id, state_groups_ids, event_map, state_map_factory,
    ):
        """Resolves conflicts between a set of state groups

        Always generates a new state group (unless we hit the cache), so should
        not be called for a single state group

        Args:
            room_id (str): room we are resolving for (used for logging)
            state_groups_ids (dict[int, dict[(str, str), str]]):
                 map from state group id to the state in that state group
                (where 'state' is a map from state key to event id)

            event_map(dict[str,FrozenEvent]|None):
                a dict from event_id to event, for any events that we happen to
                have in flight (eg, those currently being persisted). This will be
                used as a starting point fof finding the state we need; any missing
                events will be requested via state_map_factory.

                If None, all events will be fetched via state_map_factory.

        Returns:
            Deferred[_StateCacheEntry]: resolved state
        """
        logger.debug(
            "resolve_state_groups state_groups %s",
            state_groups_ids.keys()
        )

        group_names = frozenset(state_groups_ids.keys())

        with (yield self.resolve_linearizer.queue(group_names)):
            if self._state_cache is not None:
                cache = self._state_cache.get(group_names, None)
                if cache:
                    defer.returnValue(cache)

            logger.info(
                "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
            )

            # build a map from state key to the event_ids which set that state.
            # dict[(str, str), set[str])
            state = {}
            for st in state_groups_ids.itervalues():
                for key, e_id in st.iteritems():
                    state.setdefault(key, set()).add(e_id)

            # build a map from state key to the event_ids which set that state,
            # including only those where there are state keys in conflict.
            conflicted_state = {
                k: list(v)
                for k, v in state.iteritems()
                if len(v) > 1
            }

            if conflicted_state:
                logger.info("Resolving conflicted state for %r", room_id)
                with Measure(self.clock, "state._resolve_events"):
                    new_state = yield resolve_events_with_factory(
                        state_groups_ids.values(),
                        event_map=event_map,
                        state_map_factory=state_map_factory,
                    )
            else:
                new_state = {
                    key: e_ids.pop() for key, e_ids in state.iteritems()
                }

            with Measure(self.clock, "state.create_group_ids"):
                # if the new state matches any of the input state groups, we can
                # use that state group again. Otherwise we will generate a state_id
                # which will be used as a cache key for future resolutions, but
                # not get persisted.
                state_group = None
                new_state_event_ids = frozenset(new_state.itervalues())
                for sg, events in state_groups_ids.iteritems():
                    if new_state_event_ids == frozenset(e_id for e_id in events):
                        state_group = sg
                        break

                # TODO: We want to create a state group for this set of events, to
                # increase cache hits, but we need to make sure that it doesn't
                # end up as a prev_group without being added to the database

                prev_group = None
                delta_ids = None
                for old_group, old_ids in state_groups_ids.iteritems():
                    if not set(new_state) - set(old_ids):
                        n_delta_ids = {
                            k: v
                            for k, v in new_state.iteritems()
                            if old_ids.get(k) != v
                        }
                        if not delta_ids or len(n_delta_ids) < len(delta_ids):
                            prev_group = old_group
                            delta_ids = n_delta_ids

            cache = _StateCacheEntry(
                state=new_state,
                state_group=state_group,
                prev_group=prev_group,
                delta_ids=delta_ids,
            )

            if self._state_cache is not None:
                self._state_cache[group_names] = cache

            defer.returnValue(cache)
Пример #36
0
class PreviewUrlResource(Resource):
    isLeaf = True

    def __init__(self, hs, media_repo):
        Resource.__init__(self)

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.version_string = hs.version_string
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SpiderHttpClient(hs)
        self.media_repo = media_repo

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # simple memory cache mapping urls to OG metadata
        self.cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )
        self.cache.start()

        self.downloads = {}

    def render_GET(self, request):
        self._async_render_GET(request)
        return NOT_DONE_YET

    @request_handler()
    @defer.inlineCallbacks
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug((
                    "Matching attrib '%s' with value '%s' against"
                    " pattern '%s'"
                ) % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warn(
                    "URL %s blocked by url_blacklist entry %s", url, entry
                )
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN
                )

        # first check the memory cache - good to handle all the clients on this
        # HS thundering away to preview the same URL at the same time.
        og = self.cache.get(url)
        if og:
            respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
            return

        # then check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (
            cache_result and
            cache_result["download_ts"] + cache_result["expires"] > ts and
            cache_result["response_code"] / 100 == 2
        ):
            respond_with_json_bytes(
                request, 200, cache_result["og"].encode('utf-8'),
                send_cors=True
            )
            return

        # Ensure only one download for a given URL is active at a time
        download = self.downloads.get(url)
        if download is None:
            download = self._download_url(url, requester.user)
            download = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self.downloads[url] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[url]
                return media_info
        media_info = yield download.observe()

        # FIXME: we should probably update our cache now anyway, so that
        # even if the OG calculation raises, we don't keep hammering on the
        # remote server.  For now, leave it uncached to aid debugging OG
        # calculation problems

        logger.debug("got media_info of '%s'" % media_info)

        if self._is_media(media_info['media_type']):
            dims = yield self.media_repo._generate_local_thumbnails(
                media_info['filesystem_id'], media_info
            )

            og = {
                "og:description": media_info['download_name'],
                "og:image": "mxc://%s/%s" % (
                    self.server_name, media_info['filesystem_id']
                ),
                "og:image:type": media_info['media_type'],
                "matrix:image:size": media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif self._is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            from lxml import etree

            file = open(media_info['filename'])
            body = file.read()
            file.close()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
            encoding = match.group(1) if match else "utf-8"

            try:
                parser = etree.HTMLParser(recover=True, encoding=encoding)
                tree = etree.fromstring(body, parser)
                og = yield self._calc_og(tree, media_info, requester)
            except UnicodeDecodeError:
                # blindly try decoding the body as utf-8, which seems to fix
                # the charset mismatches on https://google.com
                parser = etree.HTMLParser(recover=True, encoding=encoding)
                tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
                og = yield self._calc_og(tree, media_info, requester)

        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        # store OG in ephemeral in-memory cache
        self.cache[url] = og

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"],
            json.dumps(og),
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)

    @defer.inlineCallbacks
    def _calc_og(self, tree, media_info, requester):
        # suck our tree into lxml and define our OG response.

        # if we see any image URLs in the OG response, then spider them
        # (although the client could choose to do this by asking for previews of those
        # URLs to avoid DoSing the server)

        # "og:type"         : "video",
        # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
        # "og:site_name"    : "YouTube",
        # "og:video:type"   : "application/x-shockwave-flash",
        # "og:description"  : "Fun stuff happening here",
        # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
        # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
        # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
        # "og:video:width"  : "1280"
        # "og:video:height" : "720",
        # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",

        og = {}
        for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
            og[tag.attrib['property']] = tag.attrib['content']

        # TODO: grab article: meta tags too, e.g.:

        # "article:publisher" : "https://www.facebook.com/thethudonline" />
        # "article:author" content="https://www.facebook.com/thethudonline" />
        # "article:tag" content="baby" />
        # "article:section" content="Breaking News" />
        # "article:published_time" content="2016-03-31T19:58:24+00:00" />
        # "article:modified_time" content="2016-04-01T18:31:53+00:00" />

        if 'og:title' not in og:
            # do some basic spidering of the HTML
            title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
            og['og:title'] = title[0].text.strip() if title else None

        if 'og:image' not in og:
            # TODO: extract a favicon failing all else
            meta_image = tree.xpath(
                "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
            )
            if meta_image:
                og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
            else:
                # TODO: consider inlined CSS styles as well as width & height attribs
                images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
                images = sorted(images, key=lambda i: (
                    -1 * int(i.attrib['width']) * int(i.attrib['height'])
                ))
                if not images:
                    images = tree.xpath("//img[@src]")
                if images:
                    og['og:image'] = images[0].attrib['src']

        # pre-cache the image for posterity
        # FIXME: it might be cleaner to use the same flow as the main /preview_url request
        # itself and benefit from the same caching etc.  But for now we just rely on the
        # caching on the master request to speed things up.
        if 'og:image' in og and og['og:image']:
            image_info = yield self._download_url(
                self._rebase_url(og['og:image'], media_info['uri']), requester.user
            )

            if self._is_media(image_info['media_type']):
                # TODO: make sure we don't choke on white-on-transparent images
                dims = yield self.media_repo._generate_local_thumbnails(
                    image_info['filesystem_id'], image_info
                )
                if dims:
                    og["og:image:width"] = dims['width']
                    og["og:image:height"] = dims['height']
                else:
                    logger.warn("Couldn't get dims for %s" % og["og:image"])

                og["og:image"] = "mxc://%s/%s" % (
                    self.server_name, image_info['filesystem_id']
                )
                og["og:image:type"] = image_info['media_type']
                og["matrix:image:size"] = image_info['media_length']
            else:
                del og["og:image"]

        if 'og:description' not in og:
            meta_description = tree.xpath(
                "//*/meta"
                "[translate(@name, 'DESCRIPTION', 'description')='description']"
                "/@content")
            if meta_description:
                og['og:description'] = meta_description[0]
            else:
                # grab any text nodes which are inside the <body/> tag...
                # unless they are within an HTML5 semantic markup tag...
                # <header/>, <nav/>, <aside/>, <footer/>
                # ...or if they are within a <script/> or <style/> tag.
                # This is a very very very coarse approximation to a plain text
                # render of the page.
                text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
                                        "ancestor::aside | ancestor::footer | "
                                        "ancestor::script | ancestor::style)]" +
                                        "[ancestor::body]")
                text = ''
                for text_node in text_nodes:
                    if len(text) < 500:
                        text += text_node + ' '
                    else:
                        break
                text = re.sub(r'[\t ]+', ' ', text)
                text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
                text = text.strip()[:500]
                og['og:description'] = text if text else None

        # TODO: delete the url downloads to stop diskfilling,
        # as we only ever cared about its OG
        defer.returnValue(og)

    def _rebase_url(self, url, base):
        base = list(urlparse.urlparse(base))
        url = list(urlparse.urlparse(url))
        if not url[0]:  # fix up schema
            url[0] = base[0] or "http"
        if not url[1]:  # fix up hostname
            url[1] = base[1]
            if not url[2].startswith('/'):
                url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
        return urlparse.urlunparse(url)

    @defer.inlineCallbacks
    def _download_url(self, url, user):
        # TODO: we should probably honour robots.txt... except in practice
        # we're most likely being explicitly triggered by a human rather than a
        # bot, so are we really a robot?

        # XXX: horrible duplication with base_resource's _download_remote_file()
        file_id = random_string(24)

        fname = self.filepaths.local_media_filepath(file_id)
        self.media_repo._makedirs(fname)

        try:
            with open(fname, "wb") as f:
                logger.debug("Trying to get url '%s'" % url)
                length, headers, uri, code = yield self.client.get_file(
                    url, output_stream=f, max_size=self.max_spider_size,
                )
                # FIXME: pass through 404s and other error messages nicely

            media_type = headers["Content-Type"][0]
            time_now_ms = self.clock.time_msec()

            content_disposition = headers.get("Content-Disposition", None)
            if content_disposition:
                _, params = cgi.parse_header(content_disposition[0],)
                download_name = None

                # First check if there is a valid UTF-8 filename
                download_name_utf8 = params.get("filename*", None)
                if download_name_utf8:
                    if download_name_utf8.lower().startswith("utf-8''"):
                        download_name = download_name_utf8[7:]

                # If there isn't check for an ascii name.
                if not download_name:
                    download_name_ascii = params.get("filename", None)
                    if download_name_ascii and is_ascii(download_name_ascii):
                        download_name = download_name_ascii

                if download_name:
                    download_name = urlparse.unquote(download_name)
                    try:
                        download_name = download_name.decode("utf-8")
                    except UnicodeDecodeError:
                        download_name = None
            else:
                download_name = None

            yield self.store.store_local_media(
                media_id=file_id,
                media_type=media_type,
                time_now_ms=self.clock.time_msec(),
                upload_name=download_name,
                media_length=length,
                user_id=user,
            )

        except Exception as e:
            os.remove(fname)
            raise SynapseError(
                500, ("Failed to download content: %s" % e),
                Codes.UNKNOWN
            )

        defer.returnValue({
            "media_type": media_type,
            "media_length": length,
            "download_name": download_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
            "filename": fname,
            "uri": uri,
            "response_code": code,
            # FIXME: we should calculate a proper expiration based on the
            # Cache-Control and Expire headers.  But for now, assume 1 hour.
            "expires": 60 * 60 * 1000,
            "etag": headers["ETag"][0] if "ETag" in headers else None,
        })

    def _is_media(self, content_type):
        if content_type.lower().startswith("image/"):
            return True

    def _is_html(self, content_type):
        content_type = content_type.lower()
        if (
            content_type.startswith("text/html") or
            content_type.startswith("application/xhtml")
        ):
            return True
Пример #37
0
class StateResolutionHandler(object):
    """Responsible for doing state conflict resolution.

    Note that the storage layer depends on this handler, so all functions must
    be storage-independent.
    """
    def __init__(self, hs):
        self.clock = hs.get_clock()

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache = None
        self.resolve_linearizer = Linearizer(name="state_resolve_lock")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=SIZE_OF_CACHE,
            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
            iterable=True,
            reset_expiry_on_get=True,
        )

    @defer.inlineCallbacks
    @log_function
    def resolve_state_groups(
        self, room_id, room_version, state_groups_ids, event_map, state_res_store,
    ):
        """Resolves conflicts between a set of state groups

        Always generates a new state group (unless we hit the cache), so should
        not be called for a single state group

        Args:
            room_id (str): room we are resolving for (used for logging)
            room_version (str): version of the room
            state_groups_ids (dict[int, dict[(str, str), str]]):
                 map from state group id to the state in that state group
                (where 'state' is a map from state key to event id)

            event_map(dict[str,FrozenEvent]|None):
                a dict from event_id to event, for any events that we happen to
                have in flight (eg, those currently being persisted). This will be
                used as a starting point fof finding the state we need; any missing
                events will be requested via state_res_store.

                If None, all events will be fetched via state_res_store.

            state_res_store (StateResolutionStore)

        Returns:
            Deferred[_StateCacheEntry]: resolved state
        """
        logger.debug(
            "resolve_state_groups state_groups %s",
            state_groups_ids.keys()
        )

        group_names = frozenset(state_groups_ids.keys())

        with (yield self.resolve_linearizer.queue(group_names)):
            if self._state_cache is not None:
                cache = self._state_cache.get(group_names, None)
                if cache:
                    defer.returnValue(cache)

            logger.info(
                "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
            )

            # start by assuming we won't have any conflicted state, and build up the new
            # state map by iterating through the state groups. If we discover a conflict,
            # we give up and instead use `resolve_events_with_store`.
            #
            # XXX: is this actually worthwhile, or should we just let
            # resolve_events_with_store do it?
            new_state = {}
            conflicted_state = False
            for st in itervalues(state_groups_ids):
                for key, e_id in iteritems(st):
                    if key in new_state:
                        conflicted_state = True
                        break
                    new_state[key] = e_id
                if conflicted_state:
                    break

            if conflicted_state:
                logger.info("Resolving conflicted state for %r", room_id)
                with Measure(self.clock, "state._resolve_events"):
                    new_state = yield resolve_events_with_store(
                        room_version,
                        list(itervalues(state_groups_ids)),
                        event_map=event_map,
                        state_res_store=state_res_store,
                    )

            # if the new state matches any of the input state groups, we can
            # use that state group again. Otherwise we will generate a state_id
            # which will be used as a cache key for future resolutions, but
            # not get persisted.

            with Measure(self.clock, "state.create_group_ids"):
                cache = _make_state_cache_entry(new_state, state_groups_ids)

            if self._state_cache is not None:
                self._state_cache[group_names] = cache

            defer.returnValue(cache)
Пример #38
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="to_device",
                instance_name=self._instance_name,
                tables=[("device_inbox", "instance_name", "stream_id")],
                sequence_name="device_inbox_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )
Пример #39
0
class PreviewUrlResource(DirectServeResource):
    isLeaf = True

    def __init__(self, hs, media_repo, media_storage):
        super().__init__()

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SimpleHttpClient(
            hs,
            treq_args={"browser_like_redirects": True},
            ip_whitelist=hs.config.url_preview_ip_range_whitelist,
            ip_blacklist=hs.config.url_preview_ip_range_blacklist,
        )
        self.media_repo = media_repo
        self.primary_base_path = media_repo.primary_base_path
        self.media_storage = media_storage

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # memory cache mapping urls to an ObservableDeferred returning
        # JSON-encoded OG metadata
        self._cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )

        self._cleaner_loop = self.clock.looping_call(
            self._start_expire_url_cache_data, 10 * 1000)

    def render_OPTIONS(self, request):
        request.setHeader(b"Allow", b"OPTIONS, GET")
        return respond_with_json(request, 200, {}, send_cors=True)

    @wrap_json_request_handler
    async def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = await self.auth.get_user_by_req(request)
        url = parse_string(request, "url")
        if b"ts" in request.args:
            ts = parse_integer(request, "ts")
        else:
            ts = self.clock.time_msec()

        # XXX: we could move this into _do_preview if we wanted.
        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug(("Matching attrib '%s' with value '%s' against"
                              " pattern '%s'") % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith("^"):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib),
                                           pattern):
                        match = False
                        continue
            if match:
                logger.warn("URL %s blocked by url_blacklist entry %s", url,
                            entry)
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN)

        # the in-memory cache:
        # * ensures that only one request is active at a time
        # * takes load off the DB for the thundering herds
        # * also caches any failures (unlike the DB) so we don't keep
        #    requesting the same endpoint

        observable = self._cache.get(url)

        if not observable:
            download = run_in_background(self._do_preview, url, requester.user,
                                         ts)
            observable = ObservableDeferred(download, consumeErrors=True)
            self._cache[url] = observable
        else:
            logger.info("Returning cached response")

        og = await make_deferred_yieldable(
            defer.maybeDeferred(observable.observe))
        respond_with_json_bytes(request, 200, og, send_cors=True)

    @defer.inlineCallbacks
    def _do_preview(self, url, user, ts):
        """Check the db, and download the URL and build a preview

        Args:
            url (str):
            user (str):
            ts (int):

        Returns:
            Deferred[str]: json-encoded og data
        """
        # check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (cache_result and cache_result["expires_ts"] > ts
                and cache_result["response_code"] / 100 == 2):
            # It may be stored as text in the database, not as bytes (such as
            # PostgreSQL). If so, encode it back before handing it on.
            og = cache_result["og"]
            if isinstance(og, six.text_type):
                og = og.encode("utf8")
            defer.returnValue(og)
            return

        media_info = yield self._download_url(url, user)

        logger.debug("got media_info of '%s'" % media_info)

        if _is_media(media_info["media_type"]):
            file_id = media_info["filesystem_id"]
            dims = yield self.media_repo._generate_thumbnails(
                None,
                file_id,
                file_id,
                media_info["media_type"],
                url_cache=True)

            og = {
                "og:description":
                media_info["download_name"],
                "og:image":
                "mxc://%s/%s" %
                (self.server_name, media_info["filesystem_id"]),
                "og:image:type":
                media_info["media_type"],
                "matrix:image:size":
                media_info["media_length"],
            }

            if dims:
                og["og:image:width"] = dims["width"]
                og["og:image:height"] = dims["height"]
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif _is_html(media_info["media_type"]):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            with open(media_info["filename"], "rb") as file:
                body = file.read()

            encoding = None

            # Let's try and figure out if it has an encoding set in a meta tag.
            # Limit it to the first 1kb, since it ought to be in the meta tags
            # at the top.
            match = _charset_match.search(body[:1000])

            # If we find a match, it should take precedence over the
            # Content-Type header, so set it here.
            if match:
                encoding = match.group(1).decode("ascii")

            # If we don't find a match, we'll look at the HTTP Content-Type, and
            # if that doesn't exist, we'll fall back to UTF-8.
            if not encoding:
                match = _content_type_match.match(media_info["media_type"])
                encoding = match.group(1) if match else "utf-8"

            og = decode_and_calc_og(body, media_info["uri"], encoding)

            # pre-cache the image for posterity
            # FIXME: it might be cleaner to use the same flow as the main /preview_url
            # request itself and benefit from the same caching etc.  But for now we
            # just rely on the caching on the master request to speed things up.
            if "og:image" in og and og["og:image"]:
                image_info = yield self._download_url(
                    _rebase_url(og["og:image"], media_info["uri"]), user)

                if _is_media(image_info["media_type"]):
                    # TODO: make sure we don't choke on white-on-transparent images
                    file_id = image_info["filesystem_id"]
                    dims = yield self.media_repo._generate_thumbnails(
                        None,
                        file_id,
                        file_id,
                        image_info["media_type"],
                        url_cache=True)
                    if dims:
                        og["og:image:width"] = dims["width"]
                        og["og:image:height"] = dims["height"]
                    else:
                        logger.warn("Couldn't get dims for %s" %
                                    og["og:image"])

                    og["og:image"] = "mxc://%s/%s" % (
                        self.server_name,
                        image_info["filesystem_id"],
                    )
                    og["og:image:type"] = image_info["media_type"]
                    og["matrix:image:size"] = image_info["media_length"]
                else:
                    del og["og:image"]
        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        jsonog = json.dumps(og).encode("utf8")

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"] + media_info["created_ts"],
            jsonog,
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        defer.returnValue(jsonog)

    @defer.inlineCallbacks
    def _download_url(self, url, user):
        # TODO: we should probably honour robots.txt... except in practice
        # we're most likely being explicitly triggered by a human rather than a
        # bot, so are we really a robot?

        file_id = datetime.date.today().isoformat() + "_" + random_string(16)

        file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)

        with self.media_storage.store_into_file(file_info) as (f, fname,
                                                               finish):
            try:
                logger.debug("Trying to get url '%s'" % url)
                length, headers, uri, code = yield self.client.get_file(
                    url, output_stream=f, max_size=self.max_spider_size)
            except SynapseError:
                # Pass SynapseErrors through directly, so that the servlet
                # handler will return a SynapseError to the client instead of
                # blank data or a 500.
                raise
            except DNSLookupError:
                # DNS lookup returned no results
                # Note: This will also be the case if one of the resolved IP
                # addresses is blacklisted
                raise SynapseError(
                    502,
                    "DNS resolution failure during URL preview generation",
                    Codes.UNKNOWN,
                )
            except Exception as e:
                # FIXME: pass through 404s and other error messages nicely
                logger.warn("Error downloading %s: %r", url, e)

                raise SynapseError(
                    500,
                    "Failed to download content: %s" %
                    (traceback.format_exception_only(sys.exc_info()[0], e), ),
                    Codes.UNKNOWN,
                )
            yield finish()

        try:
            if b"Content-Type" in headers:
                media_type = headers[b"Content-Type"][0].decode("ascii")
            else:
                media_type = "application/octet-stream"
            time_now_ms = self.clock.time_msec()

            download_name = get_filename_from_headers(headers)

            yield self.store.store_local_media(
                media_id=file_id,
                media_type=media_type,
                time_now_ms=self.clock.time_msec(),
                upload_name=download_name,
                media_length=length,
                user_id=user,
                url_cache=url,
            )

        except Exception as e:
            logger.error("Error handling downloaded %s: %r", url, e)
            # TODO: we really ought to delete the downloaded file in this
            # case, since we won't have recorded it in the db, and will
            # therefore not expire it.
            raise

        defer.returnValue({
            "media_type":
            media_type,
            "media_length":
            length,
            "download_name":
            download_name,
            "created_ts":
            time_now_ms,
            "filesystem_id":
            file_id,
            "filename":
            fname,
            "uri":
            uri,
            "response_code":
            code,
            # FIXME: we should calculate a proper expiration based on the
            # Cache-Control and Expire headers.  But for now, assume 1 hour.
            "expires":
            60 * 60 * 1000,
            "etag":
            headers["ETag"][0] if "ETag" in headers else None,
        })

    def _start_expire_url_cache_data(self):
        return run_as_background_process("expire_url_cache_data",
                                         self._expire_url_cache_data)

    @defer.inlineCallbacks
    def _expire_url_cache_data(self):
        """Clean up expired url cache content, media and thumbnails.
        """
        # TODO: Delete from backup media store

        now = self.clock.time_msec()

        logger.info("Running url preview cache expiry")

        if not (yield self.store.has_completed_background_updates()):
            logger.info("Still running DB updates; skipping expiry")
            return

        # First we delete expired url cache entries
        media_ids = yield self.store.get_expired_url_cache(now)

        removed_media = []
        for media_id in media_ids:
            fname = self.filepaths.url_cache_filepath(media_id)
            try:
                os.remove(fname)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warn("Failed to remove media: %r: %s", media_id, e)
                    continue

            removed_media.append(media_id)

            try:
                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(
                    media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

        yield self.store.delete_url_cache(removed_media)

        if removed_media:
            logger.info("Deleted %d entries from url cache",
                        len(removed_media))

        # Now we delete old images associated with the url cache.
        # These may be cached for a bit on the client (i.e., they
        # may have a room open with a preview url thing open).
        # So we wait a couple of days before deleting, just in case.
        expire_before = now - 2 * 24 * 60 * 60 * 1000
        media_ids = yield self.store.get_url_cache_media_before(expire_before)

        removed_media = []
        for media_id in media_ids:
            fname = self.filepaths.url_cache_filepath(media_id)
            try:
                os.remove(fname)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warn("Failed to remove media: %r: %s", media_id, e)
                    continue

            try:
                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(
                    media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

            thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(
                media_id)
            try:
                shutil.rmtree(thumbnail_dir)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warn("Failed to remove media: %r: %s", media_id, e)
                    continue

            removed_media.append(media_id)

            try:
                dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(
                    media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

        yield self.store.delete_url_cache_media(removed_media)

        logger.info("Deleted %d media from url cache", len(removed_media))
Пример #40
0
class PreviewUrlResource(Resource):
    isLeaf = True

    def __init__(self, hs, media_repo):
        Resource.__init__(self)

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.version_string = hs.version_string
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SpiderHttpClient(hs)
        self.media_repo = media_repo

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # simple memory cache mapping urls to OG metadata
        self.cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )
        self.cache.start()

        self.downloads = {}

    def render_GET(self, request):
        self._async_render_GET(request)
        return NOT_DONE_YET

    @request_handler()
    @defer.inlineCallbacks
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug((
                    "Matching attrib '%s' with value '%s' against"
                    " pattern '%s'"
                ) % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warn(
                    "URL %s blocked by url_blacklist entry %s", url, entry
                )
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN
                )

        # first check the memory cache - good to handle all the clients on this
        # HS thundering away to preview the same URL at the same time.
        og = self.cache.get(url)
        if og:
            respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
            return

        # then check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (
            cache_result and
            cache_result["download_ts"] + cache_result["expires"] > ts and
            cache_result["response_code"] / 100 == 2
        ):
            respond_with_json_bytes(
                request, 200, cache_result["og"].encode('utf-8'),
                send_cors=True
            )
            return

        # Ensure only one download for a given URL is active at a time
        download = self.downloads.get(url)
        if download is None:
            download = self._download_url(url, requester.user)
            download = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self.downloads[url] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[url]
                return media_info
        media_info = yield download.observe()

        # FIXME: we should probably update our cache now anyway, so that
        # even if the OG calculation raises, we don't keep hammering on the
        # remote server.  For now, leave it uncached to aid debugging OG
        # calculation problems

        logger.debug("got media_info of '%s'" % media_info)

        if _is_media(media_info['media_type']):
            dims = yield self.media_repo._generate_local_thumbnails(
                media_info['filesystem_id'], media_info
            )

            og = {
                "og:description": media_info['download_name'],
                "og:image": "mxc://%s/%s" % (
                    self.server_name, media_info['filesystem_id']
                ),
                "og:image:type": media_info['media_type'],
                "matrix:image:size": media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif _is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            file = open(media_info['filename'])
            body = file.read()
            file.close()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
            encoding = match.group(1) if match else "utf-8"

            og = decode_and_calc_og(body, media_info['uri'], encoding)

            # pre-cache the image for posterity
            # FIXME: it might be cleaner to use the same flow as the main /preview_url
            # request itself and benefit from the same caching etc.  But for now we
            # just rely on the caching on the master request to speed things up.
            if 'og:image' in og and og['og:image']:
                image_info = yield self._download_url(
                    _rebase_url(og['og:image'], media_info['uri']), requester.user
                )

                if _is_media(image_info['media_type']):
                    # TODO: make sure we don't choke on white-on-transparent images
                    dims = yield self.media_repo._generate_local_thumbnails(
                        image_info['filesystem_id'], image_info
                    )
                    if dims:
                        og["og:image:width"] = dims['width']
                        og["og:image:height"] = dims['height']
                    else:
                        logger.warn("Couldn't get dims for %s" % og["og:image"])

                    og["og:image"] = "mxc://%s/%s" % (
                        self.server_name, image_info['filesystem_id']
                    )
                    og["og:image:type"] = image_info['media_type']
                    og["matrix:image:size"] = image_info['media_length']
                else:
                    del og["og:image"]
        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        # store OG in ephemeral in-memory cache
        self.cache[url] = og

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"],
            json.dumps(og),
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)

    @defer.inlineCallbacks
    def _download_url(self, url, user):
        # TODO: we should probably honour robots.txt... except in practice
        # we're most likely being explicitly triggered by a human rather than a
        # bot, so are we really a robot?

        # XXX: horrible duplication with base_resource's _download_remote_file()
        file_id = random_string(24)

        fname = self.filepaths.local_media_filepath(file_id)
        self.media_repo._makedirs(fname)

        try:
            with open(fname, "wb") as f:
                logger.debug("Trying to get url '%s'" % url)
                length, headers, uri, code = yield self.client.get_file(
                    url, output_stream=f, max_size=self.max_spider_size,
                )
                # FIXME: pass through 404s and other error messages nicely

            media_type = headers["Content-Type"][0]
            time_now_ms = self.clock.time_msec()

            content_disposition = headers.get("Content-Disposition", None)
            if content_disposition:
                _, params = cgi.parse_header(content_disposition[0],)
                download_name = None

                # First check if there is a valid UTF-8 filename
                download_name_utf8 = params.get("filename*", None)
                if download_name_utf8:
                    if download_name_utf8.lower().startswith("utf-8''"):
                        download_name = download_name_utf8[7:]

                # If there isn't check for an ascii name.
                if not download_name:
                    download_name_ascii = params.get("filename", None)
                    if download_name_ascii and is_ascii(download_name_ascii):
                        download_name = download_name_ascii

                if download_name:
                    download_name = urlparse.unquote(download_name)
                    try:
                        download_name = download_name.decode("utf-8")
                    except UnicodeDecodeError:
                        download_name = None
            else:
                download_name = None

            yield self.store.store_local_media(
                media_id=file_id,
                media_type=media_type,
                time_now_ms=self.clock.time_msec(),
                upload_name=download_name,
                media_length=length,
                user_id=user,
            )

        except Exception as e:
            os.remove(fname)
            raise SynapseError(
                500, ("Failed to download content: %s" % e),
                Codes.UNKNOWN
            )

        defer.returnValue({
            "media_type": media_type,
            "media_length": length,
            "download_name": download_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
            "filename": fname,
            "uri": uri,
            "response_code": code,
            # FIXME: we should calculate a proper expiration based on the
            # Cache-Control and Expire headers.  But for now, assume 1 hour.
            "expires": 60 * 60 * 1000,
            "etag": headers["ETag"][0] if "ETag" in headers else None,
        })
class PreviewUrlResource(Resource):
    isLeaf = True

    def __init__(self, hs, media_repo, media_storage):
        Resource.__init__(self)

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SpiderHttpClient(hs)
        self.media_repo = media_repo
        self.primary_base_path = media_repo.primary_base_path
        self.media_storage = media_storage

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # memory cache mapping urls to an ObservableDeferred returning
        # JSON-encoded OG metadata
        self._cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )

        self._cleaner_loop = self.clock.looping_call(
            self._start_expire_url_cache_data, 10 * 1000,
        )

    def render_OPTIONS(self, request):
        return respond_with_json(request, 200, {}, send_cors=True)

    def render_GET(self, request):
        self._async_render_GET(request)
        return NOT_DONE_YET

    @wrap_json_request_handler
    @defer.inlineCallbacks
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = parse_string(request, "url")
        if b"ts" in request.args:
            ts = parse_integer(request, "ts")
        else:
            ts = self.clock.time_msec()

        # XXX: we could move this into _do_preview if we wanted.
        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug((
                    "Matching attrib '%s' with value '%s' against"
                    " pattern '%s'"
                ) % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warn(
                    "URL %s blocked by url_blacklist entry %s", url, entry
                )
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN
                )

        # the in-memory cache:
        # * ensures that only one request is active at a time
        # * takes load off the DB for the thundering herds
        # * also caches any failures (unlike the DB) so we don't keep
        #    requesting the same endpoint

        observable = self._cache.get(url)

        if not observable:
            download = run_in_background(
                self._do_preview,
                url, requester.user, ts,
            )
            observable = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self._cache[url] = observable
        else:
            logger.info("Returning cached response")

        og = yield make_deferred_yieldable(observable.observe())
        respond_with_json_bytes(request, 200, og, send_cors=True)

    @defer.inlineCallbacks
    def _do_preview(self, url, user, ts):
        """Check the db, and download the URL and build a preview

        Args:
            url (str):
            user (str):
            ts (int):

        Returns:
            Deferred[str]: json-encoded og data
        """
        # check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (
            cache_result and
            cache_result["expires_ts"] > ts and
            cache_result["response_code"] / 100 == 2
        ):
            # It may be stored as text in the database, not as bytes (such as
            # PostgreSQL). If so, encode it back before handing it on.
            og = cache_result["og"]
            if isinstance(og, six.text_type):
                og = og.encode('utf8')
            defer.returnValue(og)
            return

        media_info = yield self._download_url(url, user)

        logger.debug("got media_info of '%s'" % media_info)

        if _is_media(media_info['media_type']):
            file_id = media_info['filesystem_id']
            dims = yield self.media_repo._generate_thumbnails(
                None, file_id, file_id, media_info["media_type"],
                url_cache=True,
            )

            og = {
                "og:description": media_info['download_name'],
                "og:image": "mxc://%s/%s" % (
                    self.server_name, media_info['filesystem_id']
                ),
                "og:image:type": media_info['media_type'],
                "matrix:image:size": media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif _is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            with open(media_info['filename'], 'rb') as file:
                body = file.read()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(
                r'.*; *charset="?(.*?)"?(;|$)',
                media_info['media_type'],
                re.I
            )
            encoding = match.group(1) if match else "utf-8"

            og = decode_and_calc_og(body, media_info['uri'], encoding)

            # pre-cache the image for posterity
            # FIXME: it might be cleaner to use the same flow as the main /preview_url
            # request itself and benefit from the same caching etc.  But for now we
            # just rely on the caching on the master request to speed things up.
            if 'og:image' in og and og['og:image']:
                image_info = yield self._download_url(
                    _rebase_url(og['og:image'], media_info['uri']), user
                )

                if _is_media(image_info['media_type']):
                    # TODO: make sure we don't choke on white-on-transparent images
                    file_id = image_info['filesystem_id']
                    dims = yield self.media_repo._generate_thumbnails(
                        None, file_id, file_id, image_info["media_type"],
                        url_cache=True,
                    )
                    if dims:
                        og["og:image:width"] = dims['width']
                        og["og:image:height"] = dims['height']
                    else:
                        logger.warn("Couldn't get dims for %s" % og["og:image"])

                    og["og:image"] = "mxc://%s/%s" % (
                        self.server_name, image_info['filesystem_id']
                    )
                    og["og:image:type"] = image_info['media_type']
                    og["matrix:image:size"] = image_info['media_length']
                else:
                    del og["og:image"]
        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        jsonog = json.dumps(og).encode('utf8')

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"] + media_info["created_ts"],
            jsonog,
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        defer.returnValue(jsonog)

    @defer.inlineCallbacks
    def _download_url(self, url, user):
        # TODO: we should probably honour robots.txt... except in practice
        # we're most likely being explicitly triggered by a human rather than a
        # bot, so are we really a robot?

        file_id = datetime.date.today().isoformat() + '_' + random_string(16)

        file_info = FileInfo(
            server_name=None,
            file_id=file_id,
            url_cache=True,
        )

        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
            try:
                logger.debug("Trying to get url '%s'" % url)
                length, headers, uri, code = yield self.client.get_file(
                    url, output_stream=f, max_size=self.max_spider_size,
                )
            except Exception as e:
                # FIXME: pass through 404s and other error messages nicely
                logger.warn("Error downloading %s: %r", url, e)
                raise SynapseError(
                    500, "Failed to download content: %s" % (
                        traceback.format_exception_only(sys.exc_info()[0], e),
                    ),
                    Codes.UNKNOWN,
                )
            yield finish()

        try:
            if b"Content-Type" in headers:
                media_type = headers[b"Content-Type"][0].decode('ascii')
            else:
                media_type = "application/octet-stream"
            time_now_ms = self.clock.time_msec()

            content_disposition = headers.get(b"Content-Disposition", None)
            if content_disposition:
                _, params = cgi.parse_header(content_disposition[0],)
                download_name = None

                # First check if there is a valid UTF-8 filename
                download_name_utf8 = params.get("filename*", None)
                if download_name_utf8:
                    if download_name_utf8.lower().startswith("utf-8''"):
                        download_name = download_name_utf8[7:]

                # If there isn't check for an ascii name.
                if not download_name:
                    download_name_ascii = params.get("filename", None)
                    if download_name_ascii and is_ascii(download_name_ascii):
                        download_name = download_name_ascii

                if download_name:
                    download_name = urlparse.unquote(download_name)
                    try:
                        download_name = download_name.decode("utf-8")
                    except UnicodeDecodeError:
                        download_name = None
            else:
                download_name = None

            yield self.store.store_local_media(
                media_id=file_id,
                media_type=media_type,
                time_now_ms=self.clock.time_msec(),
                upload_name=download_name,
                media_length=length,
                user_id=user,
                url_cache=url,
            )

        except Exception as e:
            logger.error("Error handling downloaded %s: %r", url, e)
            # TODO: we really ought to delete the downloaded file in this
            # case, since we won't have recorded it in the db, and will
            # therefore not expire it.
            raise

        defer.returnValue({
            "media_type": media_type,
            "media_length": length,
            "download_name": download_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
            "filename": fname,
            "uri": uri,
            "response_code": code,
            # FIXME: we should calculate a proper expiration based on the
            # Cache-Control and Expire headers.  But for now, assume 1 hour.
            "expires": 60 * 60 * 1000,
            "etag": headers["ETag"][0] if "ETag" in headers else None,
        })

    def _start_expire_url_cache_data(self):
        return run_as_background_process(
            "expire_url_cache_data", self._expire_url_cache_data,
        )

    @defer.inlineCallbacks
    def _expire_url_cache_data(self):
        """Clean up expired url cache content, media and thumbnails.
        """
        # TODO: Delete from backup media store

        now = self.clock.time_msec()

        logger.info("Running url preview cache expiry")

        if not (yield self.store.has_completed_background_updates()):
            logger.info("Still running DB updates; skipping expiry")
            return

        # First we delete expired url cache entries
        media_ids = yield self.store.get_expired_url_cache(now)

        removed_media = []
        for media_id in media_ids:
            fname = self.filepaths.url_cache_filepath(media_id)
            try:
                os.remove(fname)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warn("Failed to remove media: %r: %s", media_id, e)
                    continue

            removed_media.append(media_id)

            try:
                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

        yield self.store.delete_url_cache(removed_media)

        if removed_media:
            logger.info("Deleted %d entries from url cache", len(removed_media))

        # Now we delete old images associated with the url cache.
        # These may be cached for a bit on the client (i.e., they
        # may have a room open with a preview url thing open).
        # So we wait a couple of days before deleting, just in case.
        expire_before = now - 2 * 24 * 60 * 60 * 1000
        media_ids = yield self.store.get_url_cache_media_before(expire_before)

        removed_media = []
        for media_id in media_ids:
            fname = self.filepaths.url_cache_filepath(media_id)
            try:
                os.remove(fname)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warn("Failed to remove media: %r: %s", media_id, e)
                    continue

            try:
                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

            thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
            try:
                shutil.rmtree(thumbnail_dir)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warn("Failed to remove media: %r: %s", media_id, e)
                    continue

            removed_media.append(media_id)

            try:
                dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

        yield self.store.delete_url_cache_media(removed_media)

        logger.info("Deleted %d media from url cache", len(removed_media))
Пример #42
0
class StateHandler(object):
    """ Responsible for doing state conflict resolution.
    """

    def __init__(self, hs):
        self.clock = hs.get_clock()
        self.store = hs.get_datastore()
        self.hs = hs

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache = None

    def start_caching(self):
        logger.debug("start_caching")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=SIZE_OF_CACHE,
            expiry_ms=EVICTION_TIMEOUT_SECONDS*1000,
            reset_expiry_on_get=True,
        )

        self._state_cache.start()

    @defer.inlineCallbacks
    def get_current_state(self, room_id, event_type=None, state_key=""):
        """ Returns the current state for the room as a list. This is done by
        calling `get_latest_events_in_room` to get the leading edges of the
        event graph and then resolving any of the state conflicts.

        This is equivalent to getting the state of an event that were to send
        next before receiving any new events.

        If `event_type` is specified, then the method returns only the one
        event (or None) with that `event_type` and `state_key`.
        """
        event_ids = yield self.store.get_latest_event_ids_in_room(room_id)

        cache = None
        if self._state_cache is not None:
            cache = self._state_cache.get(frozenset(event_ids), None)

        if cache:
            cache.ts = self.clock.time_msec()
            state = cache.state
        else:
            res = yield self.resolve_state_groups(room_id, event_ids)
            state = res[1]

        if event_type:
            defer.returnValue(state.get((event_type, state_key)))
            return

        defer.returnValue(state)

    @defer.inlineCallbacks
    def compute_event_context(self, event, old_state=None, outlier=False):
        """ Fills out the context with the `current state` of the graph. The
        `current state` here is defined to be the state of the event graph
        just before the event - i.e. it never includes `event`

        If `event` has `auth_events` then this will also fill out the
        `auth_events` field on `context` from the `current_state`.

        Args:
            event (EventBase)
        Returns:
            an EventContext
        """
        yield run_on_reactor()

        context = EventContext()

        if outlier:
            # If this is an outlier, then we know it shouldn't have any current
            # state. Certainly store.get_current_state won't return any, and
            # persisting the event won't store the state group.
            if old_state:
                context.current_state = {
                    (s.type, s.state_key): s for s in old_state
                }
            else:
                context.current_state = {}
            context.prev_state_events = []
            context.state_group = None
            defer.returnValue(context)

        if old_state:
            context.current_state = {
                (s.type, s.state_key): s for s in old_state
            }
            context.state_group = None

            if event.is_state():
                key = (event.type, event.state_key)
                if key in context.current_state:
                    replaces = context.current_state[key]
                    if replaces.event_id != event.event_id:  # Paranoia check
                        event.unsigned["replaces_state"] = replaces.event_id

            context.prev_state_events = []
            defer.returnValue(context)

        if event.is_state():
            ret = yield self.resolve_state_groups(
                event.room_id, [e for e, _ in event.prev_events],
                event_type=event.type,
                state_key=event.state_key,
            )
        else:
            ret = yield self.resolve_state_groups(
                event.room_id, [e for e, _ in event.prev_events],
            )

        group, curr_state, prev_state = ret

        context.current_state = curr_state
        context.state_group = group if not event.is_state() else None

        if event.is_state():
            key = (event.type, event.state_key)
            if key in context.current_state:
                replaces = context.current_state[key]
                event.unsigned["replaces_state"] = replaces.event_id

        context.prev_state_events = prev_state
        defer.returnValue(context)

    @defer.inlineCallbacks
    @log_function
    def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
        """ Given a list of event_ids this method fetches the state at each
        event, resolves conflicts between them and returns them.

        Return format is a tuple: (`state_group`, `state_events`), where the
        first is the name of a state group if one and only one is involved,
        otherwise `None`.
        """
        logger.debug("resolve_state_groups event_ids %s", event_ids)

        if self._state_cache is not None:
            cache = self._state_cache.get(frozenset(event_ids), None)
            if cache and cache.state_group:
                cache.ts = self.clock.time_msec()
                prev_state = cache.state.get((event_type, state_key), None)
                if prev_state:
                    prev_state = prev_state.event_id
                    prev_states = [prev_state]
                else:
                    prev_states = []
                defer.returnValue(
                    (cache.state_group, cache.state, prev_states)
                )

        state_groups = yield self.store.get_state_groups(
            room_id, event_ids
        )

        logger.debug(
            "resolve_state_groups state_groups %s",
            state_groups.keys()
        )

        group_names = set(state_groups.keys())
        if len(group_names) == 1:
            name, state_list = state_groups.items().pop()
            state = {
                (e.type, e.state_key): e
                for e in state_list
            }
            prev_state = state.get((event_type, state_key), None)
            if prev_state:
                prev_state = prev_state.event_id
                prev_states = [prev_state]
            else:
                prev_states = []

            if self._state_cache is not None:
                cache = _StateCacheEntry(
                    state=state,
                    state_group=name,
                    ts=self.clock.time_msec()
                )

                self._state_cache[frozenset(event_ids)] = cache

            defer.returnValue((name, state, prev_states))

        new_state, prev_states = self._resolve_events(
            state_groups.values(), event_type, state_key
        )

        if self._state_cache is not None:
            cache = _StateCacheEntry(
                state=new_state,
                state_group=None,
                ts=self.clock.time_msec()
            )

            self._state_cache[frozenset(event_ids)] = cache

        defer.returnValue((None, new_state, prev_states))

    def resolve_events(self, state_sets, event):
        if event.is_state():
            return self._resolve_events(
                state_sets, event.type, event.state_key
            )
        else:
            return self._resolve_events(state_sets)

    def _resolve_events(self, state_sets, event_type=None, state_key=""):
        state = {}
        for st in state_sets:
            for e in st:
                state.setdefault(
                    (e.type, e.state_key),
                    {}
                )[e.event_id] = e

        unconflicted_state = {
            k: v.values()[0] for k, v in state.items()
            if len(v.values()) == 1
        }

        conflicted_state = {
            k: v.values()
            for k, v in state.items()
            if len(v.values()) > 1
        }

        if event_type:
            prev_states_events = conflicted_state.get(
                (event_type, state_key), []
            )
            prev_states = [s.event_id for s in prev_states_events]
        else:
            prev_states = []

        auth_events = {
            k: e for k, e in unconflicted_state.items()
            if k[0] in AuthEventTypes
        }

        try:
            resolved_state = self._resolve_state_events(
                conflicted_state, auth_events
            )
        except:
            logger.exception("Failed to resolve state")
            raise

        new_state = unconflicted_state
        new_state.update(resolved_state)

        return new_state, prev_states

    @log_function
    def _resolve_state_events(self, conflicted_state, auth_events):
        """ This is where we actually decide which of the conflicted state to
        use.

        We resolve conflicts in the following order:
            1. power levels
            2. memberships
            3. other events.
        """
        resolved_state = {}
        power_key = (EventTypes.PowerLevels, "")
        if power_key in conflicted_state.items():
            power_levels = conflicted_state[power_key]
            resolved_state[power_key] = self._resolve_auth_events(power_levels)

        auth_events.update(resolved_state)

        for key, events in conflicted_state.items():
            if key[0] == EventTypes.JoinRules:
                resolved_state[key] = self._resolve_auth_events(
                    events,
                    auth_events
                )

        auth_events.update(resolved_state)

        for key, events in conflicted_state.items():
            if key[0] == EventTypes.Member:
                resolved_state[key] = self._resolve_auth_events(
                    events,
                    auth_events
                )

        auth_events.update(resolved_state)

        for key, events in conflicted_state.items():
            if key not in resolved_state:
                resolved_state[key] = self._resolve_normal_events(
                    events, auth_events
                )

        return resolved_state

    def _resolve_auth_events(self, events, auth_events):
        reverse = [i for i in reversed(self._ordered_events(events))]

        auth_events = dict(auth_events)

        prev_event = reverse[0]
        for event in reverse[1:]:
            auth_events[(prev_event.type, prev_event.state_key)] = prev_event
            try:
                # FIXME: hs.get_auth() is bad style, but we need to do it to
                # get around circular deps.
                self.hs.get_auth().check(event, auth_events)
                prev_event = event
            except AuthError:
                return prev_event

        return event

    def _resolve_normal_events(self, events, auth_events):
        for event in self._ordered_events(events):
            try:
                # FIXME: hs.get_auth() is bad style, but we need to do it to
                # get around circular deps.
                self.hs.get_auth().check(event, auth_events)
                return event
            except AuthError:
                pass

        # Use the last event (the one with the least depth) if they all fail
        # the auth check.
        return event

    def _ordered_events(self, events):
        def key_func(e):
            return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()

        return sorted(events, key=key_func)
Пример #43
0
class PreviewUrlResource(DirectServeJsonResource):
    isLeaf = True

    def __init__(self, hs, media_repo, media_storage):
        super().__init__()

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SimpleHttpClient(
            hs,
            treq_args={"browser_like_redirects": True},
            ip_whitelist=hs.config.url_preview_ip_range_whitelist,
            ip_blacklist=hs.config.url_preview_ip_range_blacklist,
            http_proxy=os.getenvb(b"http_proxy"),
            https_proxy=os.getenvb(b"HTTPS_PROXY"),
        )
        self.media_repo = media_repo
        self.primary_base_path = media_repo.primary_base_path
        self.media_storage = media_storage

        # We run the background jobs if we're the instance specified (or no
        # instance is specified, where we assume there is only one instance
        # serving media).
        instance_running_jobs = hs.config.media.media_instance_running_background_jobs
        self._worker_run_media_background_jobs = (
            instance_running_jobs is None
            or instance_running_jobs == hs.get_instance_name()
        )

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
        self.url_preview_accept_language = hs.config.url_preview_accept_language

        # memory cache mapping urls to an ObservableDeferred returning
        # JSON-encoded OG metadata
        self._cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=ONE_HOUR,
        )

        if self._worker_run_media_background_jobs:
            self._cleaner_loop = self.clock.looping_call(
                self._start_expire_url_cache_data, 10 * 1000
            )

    async def _async_render_OPTIONS(self, request):
        request.setHeader(b"Allow", b"OPTIONS, GET")
        respond_with_json(request, 200, {}, send_cors=True)

    async def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = await self.auth.get_user_by_req(request)
        url = parse_string(request, "url")
        if b"ts" in request.args:
            ts = parse_integer(request, "ts")
        else:
            ts = self.clock.time_msec()

        # XXX: we could move this into _do_preview if we wanted.
        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug(
                    "Matching attrib '%s' with value '%s' against pattern '%s'",
                    attrib,
                    value,
                    pattern,
                )

                if value is None:
                    match = False
                    continue

                if pattern.startswith("^"):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warning("URL %s blocked by url_blacklist entry %s", url, entry)
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
                )

        # the in-memory cache:
        # * ensures that only one request is active at a time
        # * takes load off the DB for the thundering herds
        # * also caches any failures (unlike the DB) so we don't keep
        #    requesting the same endpoint

        observable = self._cache.get(url)

        if not observable:
            download = run_in_background(self._do_preview, url, requester.user, ts)
            observable = ObservableDeferred(download, consumeErrors=True)
            self._cache[url] = observable
        else:
            logger.info("Returning cached response")

        og = await make_deferred_yieldable(observable.observe())
        respond_with_json_bytes(request, 200, og, send_cors=True)

    async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
        """Check the db, and download the URL and build a preview

        Args:
            url: The URL to preview.
            user: The user requesting the preview.
            ts: The timestamp requested for the preview.

        Returns:
            json-encoded og data
        """
        # check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = await self.store.get_url_cache(url, ts)
        if (
            cache_result
            and cache_result["expires_ts"] > ts
            and cache_result["response_code"] / 100 == 2
        ):
            # It may be stored as text in the database, not as bytes (such as
            # PostgreSQL). If so, encode it back before handing it on.
            og = cache_result["og"]
            if isinstance(og, str):
                og = og.encode("utf8")
            return og

        media_info = await self._download_url(url, user)

        logger.debug("got media_info of '%s'", media_info)

        if _is_media(media_info["media_type"]):
            file_id = media_info["filesystem_id"]
            dims = await self.media_repo._generate_thumbnails(
                None, file_id, file_id, media_info["media_type"], url_cache=True
            )

            og = {
                "og:description": media_info["download_name"],
                "og:image": "mxc://%s/%s"
                % (self.server_name, media_info["filesystem_id"]),
                "og:image:type": media_info["media_type"],
                "matrix:image:size": media_info["media_length"],
            }

            if dims:
                og["og:image:width"] = dims["width"]
                og["og:image:height"] = dims["height"]
            else:
                logger.warning("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif _is_html(media_info["media_type"]):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            with open(media_info["filename"], "rb") as file:
                body = file.read()

            encoding = None

            # Let's try and figure out if it has an encoding set in a meta tag.
            # Limit it to the first 1kb, since it ought to be in the meta tags
            # at the top.
            match = _charset_match.search(body[:1000])

            # If we find a match, it should take precedence over the
            # Content-Type header, so set it here.
            if match:
                encoding = match.group(1).decode("ascii")

            # If we don't find a match, we'll look at the HTTP Content-Type, and
            # if that doesn't exist, we'll fall back to UTF-8.
            if not encoding:
                content_match = _content_type_match.match(media_info["media_type"])
                encoding = content_match.group(1) if content_match else "utf-8"

            og = decode_and_calc_og(body, media_info["uri"], encoding)

            # pre-cache the image for posterity
            # FIXME: it might be cleaner to use the same flow as the main /preview_url
            # request itself and benefit from the same caching etc.  But for now we
            # just rely on the caching on the master request to speed things up.
            if "og:image" in og and og["og:image"]:
                image_info = await self._download_url(
                    _rebase_url(og["og:image"], media_info["uri"]), user
                )

                if _is_media(image_info["media_type"]):
                    # TODO: make sure we don't choke on white-on-transparent images
                    file_id = image_info["filesystem_id"]
                    dims = await self.media_repo._generate_thumbnails(
                        None, file_id, file_id, image_info["media_type"], url_cache=True
                    )
                    if dims:
                        og["og:image:width"] = dims["width"]
                        og["og:image:height"] = dims["height"]
                    else:
                        logger.warning("Couldn't get dims for %s", og["og:image"])

                    og["og:image"] = "mxc://%s/%s" % (
                        self.server_name,
                        image_info["filesystem_id"],
                    )
                    og["og:image:type"] = image_info["media_type"]
                    og["matrix:image:size"] = image_info["media_length"]
                else:
                    del og["og:image"]
        else:
            logger.warning("Failed to find any OG data in %s", url)
            og = {}

        # filter out any stupidly long values
        keys_to_remove = []
        for k, v in og.items():
            # values can be numeric as well as strings, hence the cast to str
            if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN:
                logger.warning(
                    "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN]
                )
                keys_to_remove.append(k)
        for k in keys_to_remove:
            del og[k]

        logger.debug("Calculated OG for %s as %s", url, og)

        jsonog = json_encoder.encode(og)

        # store OG in history-aware DB cache
        await self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"] + media_info["created_ts"],
            jsonog,
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        return jsonog.encode("utf8")

    def _get_oembed_url(self, url: str) -> Optional[str]:
        """
        Check whether the URL should be downloaded as oEmbed content instead.

        Params:
            url: The URL to check.

        Returns:
            A URL to use instead or None if the original URL should be used.
        """
        for url_pattern, endpoint in _oembed_patterns.items():
            if url_pattern.fullmatch(url):
                return endpoint

        # No match.
        return None

    async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
        """
        Request content from an oEmbed endpoint.

        Params:
            endpoint: The oEmbed API endpoint.
            url: The URL to pass to the API.

        Returns:
            An object representing the metadata returned.

        Raises:
            OEmbedError if fetching or parsing of the oEmbed information fails.
        """
        try:
            logger.debug("Trying to get oEmbed content for url '%s'", url)
            result = await self.client.get_json(
                endpoint,
                # TODO Specify max height / width.
                # Note that only the JSON format is supported.
                args={"url": url},
            )

            # Ensure there's a version of 1.0.
            if result.get("version") != "1.0":
                raise OEmbedError("Invalid version: %s" % (result.get("version"),))

            oembed_type = result.get("type")

            # Ensure the cache age is None or an int.
            cache_age = result.get("cache_age")
            if cache_age:
                cache_age = int(cache_age)

            oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)

            # HTML content.
            if oembed_type == "rich":
                oembed_result.html = result.get("html")
                return oembed_result

            if oembed_type == "photo":
                oembed_result.url = result.get("url")
                return oembed_result

            # TODO Handle link and video types.

            if "thumbnail_url" in result:
                oembed_result.url = result.get("thumbnail_url")
                return oembed_result

            raise OEmbedError("Incompatible oEmbed information.")

        except OEmbedError as e:
            # Trap OEmbedErrors first so we can directly re-raise them.
            logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
            raise

        except Exception as e:
            # Trap any exception and let the code follow as usual.
            # FIXME: pass through 404s and other error messages nicely
            logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
            raise OEmbedError() from e

    async def _download_url(self, url: str, user):
        # TODO: we should probably honour robots.txt... except in practice
        # we're most likely being explicitly triggered by a human rather than a
        # bot, so are we really a robot?

        file_id = datetime.date.today().isoformat() + "_" + random_string(16)

        file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)

        # If this URL can be accessed via oEmbed, use that instead.
        url_to_download = url  # type: Optional[str]
        oembed_url = self._get_oembed_url(url)
        if oembed_url:
            # The result might be a new URL to download, or it might be HTML content.
            try:
                oembed_result = await self._get_oembed_content(oembed_url, url)
                if oembed_result.url:
                    url_to_download = oembed_result.url
                elif oembed_result.html:
                    url_to_download = None
            except OEmbedError:
                # If an error occurs, try doing a normal preview.
                pass

        if url_to_download:
            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
                try:
                    logger.debug("Trying to get preview for url '%s'", url_to_download)
                    length, headers, uri, code = await self.client.get_file(
                        url_to_download,
                        output_stream=f,
                        max_size=self.max_spider_size,
                        headers={"Accept-Language": self.url_preview_accept_language},
                    )
                except SynapseError:
                    # Pass SynapseErrors through directly, so that the servlet
                    # handler will return a SynapseError to the client instead of
                    # blank data or a 500.
                    raise
                except DNSLookupError:
                    # DNS lookup returned no results
                    # Note: This will also be the case if one of the resolved IP
                    # addresses is blacklisted
                    raise SynapseError(
                        502,
                        "DNS resolution failure during URL preview generation",
                        Codes.UNKNOWN,
                    )
                except Exception as e:
                    # FIXME: pass through 404s and other error messages nicely
                    logger.warning("Error downloading %s: %r", url_to_download, e)

                    raise SynapseError(
                        500,
                        "Failed to download content: %s"
                        % (traceback.format_exception_only(sys.exc_info()[0], e),),
                        Codes.UNKNOWN,
                    )
                await finish()

                if b"Content-Type" in headers:
                    media_type = headers[b"Content-Type"][0].decode("ascii")
                else:
                    media_type = "application/octet-stream"

                download_name = get_filename_from_headers(headers)

                # FIXME: we should calculate a proper expiration based on the
                # Cache-Control and Expire headers.  But for now, assume 1 hour.
                expires = ONE_HOUR
                etag = (
                    headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
                )
        else:
            # we can only get here if we did an oembed request and have an oembed_result.html
            assert oembed_result.html is not None
            assert oembed_url is not None

            html_bytes = oembed_result.html.encode("utf-8")
            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
                f.write(html_bytes)
                await finish()

            media_type = "text/html"
            download_name = oembed_result.title
            length = len(html_bytes)
            # If a specific cache age was not given, assume 1 hour.
            expires = oembed_result.cache_age or ONE_HOUR
            uri = oembed_url
            code = 200
            etag = None

        try:
            time_now_ms = self.clock.time_msec()

            await self.store.store_local_media(
                media_id=file_id,
                media_type=media_type,
                time_now_ms=time_now_ms,
                upload_name=download_name,
                media_length=length,
                user_id=user,
                url_cache=url,
            )

        except Exception as e:
            logger.error("Error handling downloaded %s: %r", url, e)
            # TODO: we really ought to delete the downloaded file in this
            # case, since we won't have recorded it in the db, and will
            # therefore not expire it.
            raise

        return {
            "media_type": media_type,
            "media_length": length,
            "download_name": download_name,
            "created_ts": time_now_ms,
            "filesystem_id": file_id,
            "filename": fname,
            "uri": uri,
            "response_code": code,
            "expires": expires,
            "etag": etag,
        }

    def _start_expire_url_cache_data(self):
        return run_as_background_process(
            "expire_url_cache_data", self._expire_url_cache_data
        )

    async def _expire_url_cache_data(self):
        """Clean up expired url cache content, media and thumbnails.
        """
        # TODO: Delete from backup media store

        assert self._worker_run_media_background_jobs

        now = self.clock.time_msec()

        logger.debug("Running url preview cache expiry")

        if not (await self.store.db_pool.updates.has_completed_background_updates()):
            logger.info("Still running DB updates; skipping expiry")
            return

        # First we delete expired url cache entries
        media_ids = await self.store.get_expired_url_cache(now)

        removed_media = []
        for media_id in media_ids:
            fname = self.filepaths.url_cache_filepath(media_id)
            try:
                os.remove(fname)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warning("Failed to remove media: %r: %s", media_id, e)
                    continue

            removed_media.append(media_id)

            try:
                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

        await self.store.delete_url_cache(removed_media)

        if removed_media:
            logger.info("Deleted %d entries from url cache", len(removed_media))
        else:
            logger.debug("No entries removed from url cache")

        # Now we delete old images associated with the url cache.
        # These may be cached for a bit on the client (i.e., they
        # may have a room open with a preview url thing open).
        # So we wait a couple of days before deleting, just in case.
        expire_before = now - 2 * 24 * ONE_HOUR
        media_ids = await self.store.get_url_cache_media_before(expire_before)

        removed_media = []
        for media_id in media_ids:
            fname = self.filepaths.url_cache_filepath(media_id)
            try:
                os.remove(fname)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warning("Failed to remove media: %r: %s", media_id, e)
                    continue

            try:
                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

            thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
            try:
                shutil.rmtree(thumbnail_dir)
            except OSError as e:
                # If the path doesn't exist, meh
                if e.errno != errno.ENOENT:
                    logger.warning("Failed to remove media: %r: %s", media_id, e)
                    continue

            removed_media.append(media_id)

            try:
                dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
                for dir in dirs:
                    os.rmdir(dir)
            except Exception:
                pass

        await self.store.delete_url_cache_media(removed_media)

        if removed_media:
            logger.info("Deleted %d media from url cache", len(removed_media))
        else:
            logger.debug("No media removed from url cache")
Пример #44
0
class TransactionStore(SQLBaseStore):
    """A collection of queries for handling PDUs.
    """

    def __init__(self, db_conn, hs):
        super(TransactionStore, self).__init__(db_conn, hs)

        self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)

        self._destination_retry_cache = ExpiringCache(
            cache_name="get_destination_retry_timings",
            clock=self._clock,
            expiry_ms=5 * 60 * 1000,
        )

    def get_received_txn_response(self, transaction_id, origin):
        """For an incoming transaction from a given origin, check if we have
        already responded to it. If so, return the response code and response
        body (as a dict).

        Args:
            transaction_id (str)
            origin(str)

        Returns:
            tuple: None if we have not previously responded to
            this transaction or a 2-tuple of (int, dict)
        """

        return self.runInteraction(
            "get_received_txn_response",
            self._get_received_txn_response,
            transaction_id,
            origin,
        )

    def _get_received_txn_response(self, txn, transaction_id, origin):
        result = self._simple_select_one_txn(
            txn,
            table="received_transactions",
            keyvalues={"transaction_id": transaction_id, "origin": origin},
            retcols=(
                "transaction_id",
                "origin",
                "ts",
                "response_code",
                "response_json",
                "has_been_referenced",
            ),
            allow_none=True,
        )

        if result and result["response_code"]:
            return result["response_code"], db_to_json(result["response_json"])

        else:
            return None

    def set_received_txn_response(self, transaction_id, origin, code, response_dict):
        """Persist the response we returened for an incoming transaction, and
        should return for subsequent transactions with the same transaction_id
        and origin.

        Args:
            txn
            transaction_id (str)
            origin (str)
            code (int)
            response_json (str)
        """

        return self._simple_insert(
            table="received_transactions",
            values={
                "transaction_id": transaction_id,
                "origin": origin,
                "response_code": code,
                "response_json": db_binary_type(encode_canonical_json(response_dict)),
                "ts": self._clock.time_msec(),
            },
            or_ignore=True,
            desc="set_received_txn_response",
        )

    def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
        """Persists an outgoing transaction and calculates the values for the
        previous transaction id list.

        This should be called before sending the transaction so that it has the
        correct value for the `prev_ids` key.

        Args:
            transaction_id (str)
            destination (str)
            origin_server_ts (int)

        Returns:
            list: A list of previous transaction ids.
        """
        return defer.succeed([])

    def delivered_txn(self, transaction_id, destination, code, response_dict):
        """Persists the response for an outgoing transaction.

        Args:
            transaction_id (str)
            destination (str)
            code (int)
            response_json (str)
        """
        pass

    @defer.inlineCallbacks
    def get_destination_retry_timings(self, destination):
        """Gets the current retry timings (if any) for a given destination.

        Args:
            destination (str)

        Returns:
            None if not retrying
            Otherwise a dict for the retry scheme
        """

        result = self._destination_retry_cache.get(destination, SENTINEL)
        if result is not SENTINEL:
            defer.returnValue(result)

        result = yield self.runInteraction(
            "get_destination_retry_timings",
            self._get_destination_retry_timings,
            destination,
        )

        # We don't hugely care about race conditions between getting and
        # invalidating the cache, since we time out fairly quickly anyway.
        self._destination_retry_cache[destination] = result
        defer.returnValue(result)

    def _get_destination_retry_timings(self, txn, destination):
        result = self._simple_select_one_txn(
            txn,
            table="destinations",
            keyvalues={"destination": destination},
            retcols=("destination", "retry_last_ts", "retry_interval"),
            allow_none=True,
        )

        if result and result["retry_last_ts"] > 0:
            return result
        else:
            return None

    def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
        """Sets the current retry timings for a given destination.
        Both timings should be zero if retrying is no longer occuring.

        Args:
            destination (str)
            retry_last_ts (int) - time of last retry attempt in unix epoch ms
            retry_interval (int) - how long until next retry in ms
        """

        self._destination_retry_cache.pop(destination, None)
        return self.runInteraction(
            "set_destination_retry_timings",
            self._set_destination_retry_timings,
            destination,
            retry_last_ts,
            retry_interval,
        )

    def _set_destination_retry_timings(
        self, txn, destination, retry_last_ts, retry_interval
    ):
        self.database_engine.lock_table(txn, "destinations")

        # We need to be careful here as the data may have changed from under us
        # due to a worker setting the timings.

        prev_row = self._simple_select_one_txn(
            txn,
            table="destinations",
            keyvalues={"destination": destination},
            retcols=("retry_last_ts", "retry_interval"),
            allow_none=True,
        )

        if not prev_row:
            self._simple_insert_txn(
                txn,
                table="destinations",
                values={
                    "destination": destination,
                    "retry_last_ts": retry_last_ts,
                    "retry_interval": retry_interval,
                },
            )
        elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
            self._simple_update_one_txn(
                txn,
                "destinations",
                keyvalues={"destination": destination},
                updatevalues={
                    "retry_last_ts": retry_last_ts,
                    "retry_interval": retry_interval,
                },
            )

    def get_destinations_needing_retry(self):
        """Get all destinations which are due a retry for sending a transaction.

        Returns:
            list: A list of dicts
        """

        return self.runInteraction(
            "get_destinations_needing_retry", self._get_destinations_needing_retry
        )

    def _get_destinations_needing_retry(self, txn):
        query = (
            "SELECT * FROM destinations"
            " WHERE retry_last_ts > 0 and retry_next_ts < ?"
        )

        txn.execute(query, (self._clock.time_msec(),))
        return self.cursor_to_dict(txn)

    def _start_cleanup_transactions(self):
        return run_as_background_process(
            "cleanup_transactions", self._cleanup_transactions
        )

    def _cleanup_transactions(self):
        now = self._clock.time_msec()
        month_ago = now - 30 * 24 * 60 * 60 * 1000

        def _cleanup_transactions_txn(txn):
            txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))

        return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
class PreviewUrlResource(Resource):
    isLeaf = True

    def __init__(self, hs, media_repo):
        Resource.__init__(self)

        self.auth = hs.get_auth()
        self.clock = hs.get_clock()
        self.version_string = hs.version_string
        self.filepaths = media_repo.filepaths
        self.max_spider_size = hs.config.max_spider_size
        self.server_name = hs.hostname
        self.store = hs.get_datastore()
        self.client = SpiderHttpClient(hs)
        self.media_repo = media_repo

        self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist

        # simple memory cache mapping urls to OG metadata
        self.cache = ExpiringCache(
            cache_name="url_previews",
            clock=self.clock,
            # don't spider URLs more often than once an hour
            expiry_ms=60 * 60 * 1000,
        )
        self.cache.start()

        self.downloads = {}

    def render_GET(self, request):
        self._async_render_GET(request)
        return NOT_DONE_YET

    @request_handler()
    @defer.inlineCallbacks
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = request.args.get("url")[0]
        if "ts" in request.args:
            ts = int(request.args.get("ts")[0])
        else:
            ts = self.clock.time_msec()

        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug(("Matching attrib '%s' with value '%s' against"
                              " pattern '%s'") % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib),
                                           pattern):
                        match = False
                        continue
            if match:
                logger.warn("URL %s blocked by url_blacklist entry %s", url,
                            entry)
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN)

        # first check the memory cache - good to handle all the clients on this
        # HS thundering away to preview the same URL at the same time.
        og = self.cache.get(url)
        if og:
            respond_with_json_bytes(request,
                                    200,
                                    json.dumps(og),
                                    send_cors=True)
            return

        # then check the URL cache in the DB (which will also provide us with
        # historical previews, if we have any)
        cache_result = yield self.store.get_url_cache(url, ts)
        if (cache_result
                and cache_result["download_ts"] + cache_result["expires"] > ts
                and cache_result["response_code"] / 100 == 2):
            respond_with_json_bytes(request,
                                    200,
                                    cache_result["og"].encode('utf-8'),
                                    send_cors=True)
            return

        # Ensure only one download for a given URL is active at a time
        download = self.downloads.get(url)
        if download is None:
            download = self._download_url(url, requester.user)
            download = ObservableDeferred(download, consumeErrors=True)
            self.downloads[url] = download

            @download.addBoth
            def callback(media_info):
                del self.downloads[url]
                return media_info

        media_info = yield download.observe()

        # FIXME: we should probably update our cache now anyway, so that
        # even if the OG calculation raises, we don't keep hammering on the
        # remote server.  For now, leave it uncached to aid debugging OG
        # calculation problems

        logger.debug("got media_info of '%s'" % media_info)

        if _is_media(media_info['media_type']):
            dims = yield self.media_repo._generate_local_thumbnails(
                media_info['filesystem_id'],
                media_info,
                url_cache=True,
            )

            og = {
                "og:description":
                media_info['download_name'],
                "og:image":
                "mxc://%s/%s" %
                (self.server_name, media_info['filesystem_id']),
                "og:image:type":
                media_info['media_type'],
                "matrix:image:size":
                media_info['media_length'],
            }

            if dims:
                og["og:image:width"] = dims['width']
                og["og:image:height"] = dims['height']
            else:
                logger.warn("Couldn't get dims for %s" % url)

            # define our OG response for this media
        elif _is_html(media_info['media_type']):
            # TODO: somehow stop a big HTML tree from exploding synapse's RAM

            file = open(media_info['filename'])
            body = file.read()
            file.close()

            # clobber the encoding from the content-type, or default to utf-8
            # XXX: this overrides any <meta/> or XML charset headers in the body
            # which may pose problems, but so far seems to work okay.
            match = re.match(r'.*; *charset=(.*?)(;|$)',
                             media_info['media_type'], re.I)
            encoding = match.group(1) if match else "utf-8"

            og = decode_and_calc_og(body, media_info['uri'], encoding)

            # pre-cache the image for posterity
            # FIXME: it might be cleaner to use the same flow as the main /preview_url
            # request itself and benefit from the same caching etc.  But for now we
            # just rely on the caching on the master request to speed things up.
            if 'og:image' in og and og['og:image']:
                image_info = yield self._download_url(
                    _rebase_url(og['og:image'], media_info['uri']),
                    requester.user)

                if _is_media(image_info['media_type']):
                    # TODO: make sure we don't choke on white-on-transparent images
                    dims = yield self.media_repo._generate_local_thumbnails(
                        image_info['filesystem_id'],
                        image_info,
                        url_cache=True,
                    )
                    if dims:
                        og["og:image:width"] = dims['width']
                        og["og:image:height"] = dims['height']
                    else:
                        logger.warn("Couldn't get dims for %s" %
                                    og["og:image"])

                    og["og:image"] = "mxc://%s/%s" % (
                        self.server_name, image_info['filesystem_id'])
                    og["og:image:type"] = image_info['media_type']
                    og["matrix:image:size"] = image_info['media_length']
                else:
                    del og["og:image"]
        else:
            logger.warn("Failed to find any OG data in %s", url)
            og = {}

        logger.debug("Calculated OG for %s as %s" % (url, og))

        # store OG in ephemeral in-memory cache
        self.cache[url] = og

        # store OG in history-aware DB cache
        yield self.store.store_url_cache(
            url,
            media_info["response_code"],
            media_info["etag"],
            media_info["expires"],
            json.dumps(og),
            media_info["filesystem_id"],
            media_info["created_ts"],
        )

        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)

    @defer.inlineCallbacks
    def _download_url(self, url, user):
        # TODO: we should probably honour robots.txt... except in practice
        # we're most likely being explicitly triggered by a human rather than a
        # bot, so are we really a robot?

        # XXX: horrible duplication with base_resource's _download_remote_file()
        file_id = random_string(24)

        fname = self.filepaths.url_cache_filepath(file_id)
        self.media_repo._makedirs(fname)

        try:
            with open(fname, "wb") as f:
                logger.debug("Trying to get url '%s'" % url)
                length, headers, uri, code = yield self.client.get_file(
                    url,
                    output_stream=f,
                    max_size=self.max_spider_size,
                )
                # FIXME: pass through 404s and other error messages nicely

            media_type = headers["Content-Type"][0]
            time_now_ms = self.clock.time_msec()

            content_disposition = headers.get("Content-Disposition", None)
            if content_disposition:
                _, params = cgi.parse_header(content_disposition[0], )
                download_name = None

                # First check if there is a valid UTF-8 filename
                download_name_utf8 = params.get("filename*", None)
                if download_name_utf8:
                    if download_name_utf8.lower().startswith("utf-8''"):
                        download_name = download_name_utf8[7:]

                # If there isn't check for an ascii name.
                if not download_name:
                    download_name_ascii = params.get("filename", None)
                    if download_name_ascii and is_ascii(download_name_ascii):
                        download_name = download_name_ascii

                if download_name:
                    download_name = urlparse.unquote(download_name)
                    try:
                        download_name = download_name.decode("utf-8")
                    except UnicodeDecodeError:
                        download_name = None
            else:
                download_name = None

            yield self.store.store_local_media(
                media_id=file_id,
                media_type=media_type,
                time_now_ms=self.clock.time_msec(),
                upload_name=download_name,
                media_length=length,
                user_id=user,
                url_cache=url,
            )

        except Exception as e:
            os.remove(fname)
            raise SynapseError(500, ("Failed to download content: %s" % e),
                               Codes.UNKNOWN)

        defer.returnValue({
            "media_type":
            media_type,
            "media_length":
            length,
            "download_name":
            download_name,
            "created_ts":
            time_now_ms,
            "filesystem_id":
            file_id,
            "filename":
            fname,
            "uri":
            uri,
            "response_code":
            code,
            # FIXME: we should calculate a proper expiration based on the
            # Cache-Control and Expire headers.  But for now, assume 1 hour.
            "expires":
            60 * 60 * 1000,
            "etag":
            headers["ETag"][0] if "ETag" in headers else None,
        })
Пример #46
0
class DeviceInboxStore(BackgroundUpdateStore):
    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"

    def __init__(self, db_conn, hs):
        super(DeviceInboxStore, self).__init__(db_conn, hs)

        self.register_background_index_update(
            "device_inbox_stream_index",
            index_name="device_inbox_stream_id_user_id",
            table="device_inbox",
            columns=["stream_id", "user_id"],
        )

        self.register_background_update_handler(
            self.DEVICE_INBOX_STREAM_ID,
            self._background_drop_index_device_inbox,
        )

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

    @defer.inlineCallbacks
    def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
                                     remote_messages_by_destination):
        """Used to send messages from this server.

        Args:
            sender_user_id(str): The ID of the user sending these messages.
            local_messages_by_user_and_device(dict):
                Dictionary of user_id to device_id to message.
            remote_messages_by_destination(dict):
                Dictionary of destination server_name to the EDU JSON to send.
        Returns:
            A deferred stream_id that resolves when the messages have been
            inserted.
        """

        def add_messages_txn(txn, now_ms, stream_id):
            # Add the local messages directly to the local inbox.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device
            )

            # Add the remote messages to the federation outbox.
            # We'll send them to a remote server when we next send a
            # federation transaction to that destination.
            sql = (
                "INSERT INTO device_federation_outbox"
                " (destination, stream_id, queued_ts, messages_json)"
                " VALUES (?,?,?,?)"
            )
            rows = []
            for destination, edu in remote_messages_by_destination.items():
                edu_json = json.dumps(edu)
                rows.append((destination, stream_id, now_ms, edu_json))
            txn.executemany(sql, rows)

        with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            yield self.runInteraction(
                "add_messages_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id
                )
            for destination in remote_messages_by_destination.keys():
                self._device_federation_outbox_stream_cache.entity_has_changed(
                    destination, stream_id
                )

        defer.returnValue(self._device_inbox_id_gen.get_current_token())

    @defer.inlineCallbacks
    def add_messages_from_remote_to_device_inbox(
        self, origin, message_id, local_messages_by_user_then_device
    ):
        def add_messages_txn(txn, now_ms, stream_id):
            # Check if we've already inserted a matching message_id for that
            # origin. This can happen if the origin doesn't receive our
            # acknowledgement from the first time we received the message.
            already_inserted = self._simple_select_one_txn(
                txn, table="device_federation_inbox",
                keyvalues={"origin": origin, "message_id": message_id},
                retcols=("message_id",),
                allow_none=True,
            )
            if already_inserted is not None:
                return

            # Add an entry for this message_id so that we know we've processed
            # it.
            self._simple_insert_txn(
                txn, table="device_federation_inbox",
                values={
                    "origin": origin,
                    "message_id": message_id,
                    "received_ts": now_ms,
                },
            )

            # Add the messages to the approriate local device inboxes so that
            # they'll be sent to the devices when they next sync.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device
            )

        with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            yield self.runInteraction(
                "add_messages_from_remote_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id
                )

        defer.returnValue(stream_id)

    def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                messages_by_user_then_device):
        sql = (
            "UPDATE device_max_stream_id"
            " SET stream_id = ?"
            " WHERE stream_id < ?"
        )
        txn.execute(sql, (stream_id, stream_id))

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items():
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                sql = (
                    "SELECT device_id FROM devices"
                    " WHERE user_id = ?"
                )
                txn.execute(sql, (user_id,))
                message_json = json.dumps(messages_by_device["*"])
                for row in txn:
                    # Add the message for all devices for this user on this
                    # server.
                    device = row[0]
                    messages_json_for_user[device] = message_json
            else:
                if not devices:
                    continue
                sql = (
                    "SELECT device_id FROM devices"
                    " WHERE user_id = ? AND device_id IN ("
                    + ",".join("?" * len(devices))
                    + ")"
                )
                # TODO: Maybe this needs to be done in batches if there are
                # too many local devices for a given user.
                txn.execute(sql, [user_id] + devices)
                for row in txn:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device = row[0]
                    message_json = json.dumps(messages_by_device[device])
                    messages_json_for_user[device] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        sql = (
            "INSERT INTO device_inbox"
            " (user_id, device_id, stream_id, message_json)"
            " VALUES (?,?,?,?)"
        )
        rows = []
        for user_id, messages_by_device in local_by_user_then_device.items():
            for device_id, message_json in messages_by_device.items():
                rows.append((user_id, device_id, stream_id, message_json))

        txn.executemany(sql, rows)

    def get_new_messages_for_device(
        self, user_id, device_id, last_stream_id, current_stream_id, limit=100
    ):
        """
        Args:
            user_id(str): The recipient user_id.
            device_id(str): The recipient device_id.
            current_stream_id(int): The current position of the to device
                message stream.
        Returns:
            Deferred ([dict], int): List of messages for the device and where
                in the stream the messages got to.
        """
        has_changed = self._device_inbox_stream_cache.has_entity_changed(
            user_id, last_stream_id
        )
        if not has_changed:
            return defer.succeed(([], current_stream_id))

        def get_new_messages_for_device_txn(txn):
            sql = (
                "SELECT stream_id, message_json FROM device_inbox"
                " WHERE user_id = ? AND device_id = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?"
            )
            txn.execute(sql, (
                user_id, device_id, last_stream_id, current_stream_id, limit
            ))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(json.loads(row[1]))
            if len(messages) < limit:
                stream_pos = current_stream_id
            return (messages, stream_pos)

        return self.runInteraction(
            "get_new_messages_for_device", get_new_messages_for_device_txn,
        )

    @defer.inlineCallbacks
    def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
        """
        Args:
            user_id(str): The recipient user_id.
            device_id(str): The recipient device_id.
            up_to_stream_id(int): Where to delete messages up to.
        Returns:
            A deferred that resolves to the number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None
        )
        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id
            )
            if not has_changed:
                defer.returnValue(0)

        def delete_messages_for_device_txn(txn):
            sql = (
                "DELETE FROM device_inbox"
                " WHERE user_id = ? AND device_id = ?"
                " AND stream_id <= ?"
            )
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = yield self.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn
        )

        # Update the cache, ensuring that we only ever increase the value
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0
        )
        self._last_device_delete_cache[(user_id, device_id)] = max(
            last_deleted_stream_id, up_to_stream_id
        )

        defer.returnValue(count)

    def get_all_new_device_messages(self, last_pos, current_pos, limit):
        """
        Args:
            last_pos(int):
            current_pos(int):
            limit(int):
        Returns:
            A deferred list of rows from the device inbox
        """
        if last_pos == current_pos:
            return defer.succeed([])

        def get_all_new_device_messages_txn(txn):
            # We limit like this as we might have multiple rows per stream_id, and
            # we want to make sure we always get all entries for any stream_id
            # we return.
            upper_pos = min(current_pos, last_pos + limit)
            sql = (
                "SELECT max(stream_id), user_id"
                " FROM device_inbox"
                " WHERE ? < stream_id AND stream_id <= ?"
                " GROUP BY user_id"
            )
            txn.execute(sql, (last_pos, upper_pos))
            rows = txn.fetchall()

            sql = (
                "SELECT max(stream_id), destination"
                " FROM device_federation_outbox"
                " WHERE ? < stream_id AND stream_id <= ?"
                " GROUP BY destination"
            )
            txn.execute(sql, (last_pos, upper_pos))
            rows.extend(txn)

            # Order by ascending stream ordering
            rows.sort()

            return rows

        return self.runInteraction(
            "get_all_new_device_messages", get_all_new_device_messages_txn
        )

    def get_to_device_stream_token(self):
        return self._device_inbox_id_gen.get_current_token()

    def get_new_device_msgs_for_remote(
        self, destination, last_stream_id, current_stream_id, limit=100
    ):
        """
        Args:
            destination(str): The name of the remote server.
            last_stream_id(int|long): The last position of the device message stream
                that the server sent up to.
            current_stream_id(int|long): The current position of the device
                message stream.
        Returns:
            Deferred ([dict], int|long): List of messages for the device and where
                in the stream the messages got to.
        """

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id
        )
        if not has_changed or last_stream_id == current_stream_id:
            return defer.succeed(([], current_stream_id))

        def get_new_messages_for_remote_destination_txn(txn):
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?"
            )
            txn.execute(sql, (
                destination, last_stream_id, current_stream_id, limit
            ))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(json.loads(row[1]))
            if len(messages) < limit:
                stream_pos = current_stream_id
            return (messages, stream_pos)

        return self.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )

    def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
        """Used to delete messages when the remote destination acknowledges
        their receipt.

        Args:
            destination(str): The destination server_name
            up_to_stream_id(int): Where to delete messages up to.
        Returns:
            A deferred that resolves when the messages have been deleted.
        """
        def delete_messages_for_remote_destination_txn(txn):
            sql = (
                "DELETE FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND stream_id <= ?"
            )
            txn.execute(sql, (destination, up_to_stream_id))

        return self.runInteraction(
            "delete_device_msgs_for_remote",
            delete_messages_for_remote_destination_txn
        )

    @defer.inlineCallbacks
    def _background_drop_index_device_inbox(self, progress, batch_size):
        def reindex_txn(conn):
            txn = conn.cursor()
            txn.execute(
                "DROP INDEX IF EXISTS device_inbox_stream_id"
            )
            txn.close()

        yield self.runWithConnection(reindex_txn)

        yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)

        defer.returnValue(1)
Пример #47
0
class DeviceListUpdater:
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

        # Attempt to resync out of sync device lists every 30s.
        self._resync_retry_in_progress = False
        self.clock.looping_call(
            run_as_background_process,
            30 * 1000,
            func=self._maybe_retry_device_resync,
            desc="_maybe_retry_device_resync",
        )

    @trace
    async def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        set_tag("origin", origin)
        set_tag("edu_content", edu_content)
        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]  # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning(
                "Got device list update edu for %r/%r from %r",
                user_id,
                device_id,
                origin,
            )

            set_tag("error", True)
            log_kv({
                "message": "Got a device list update edu from a user and "
                "device which does not match the origin of the request.",
                "user_id": user_id,
                "device_id": device_id,
            })
            return

        room_ids = await self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            set_tag("error", True)
            log_kv({
                "message": "Got an update from a user for which "
                "we don't share any rooms",
                "other user_id": user_id,
            })
            logger.warning(
                "Got device list update edu for %r/%r, but don't share a room",
                user_id,
                device_id,
            )
            return

        logger.debug("Received device list update for %r/%r", user_id,
                     device_id)

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content))

        await self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    async def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (await self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            for device_id, stream_id, prev_ids, content in pending_updates:
                logger.debug(
                    "Handling update %r/%r, ID: %r, prev: %r ",
                    user_id,
                    device_id,
                    stream_id,
                    prev_ids,
                )

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = await self._need_to_do_resync(user_id, pending_updates)

            if logger.isEnabledFor(logging.INFO):
                logger.info(
                    "Received device list update for %s, requiring resync: %s. Devices: %s",
                    user_id,
                    resync,
                    ", ".join(u[0] for u in pending_updates),
                )

            if resync:
                await self.user_device_resync(user_id)
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    await self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id)

                await self.device_handler.notify_device_update(
                    user_id,
                    [device_id for device_id, _, _, _ in pending_updates])

                self._seen_updates.setdefault(user_id, set()).update(
                    stream_id for _, stream_id, _, _ in pending_updates)

    async def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = await self.store.get_device_list_last_stream_id_for_remote(
            user_id)

        logger.debug("Current extremity for %r: %r", user_id, extremity)

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                return True

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    return True

            stream_id_in_updates.add(stream_id)

        return False

    @trace
    async def _maybe_retry_device_resync(self):
        """Retry to resync device lists that are out of sync, except if another retry is
        in progress.
        """
        if self._resync_retry_in_progress:
            return

        try:
            # Prevent another call of this function to retry resyncing device lists so
            # we don't send too many requests.
            self._resync_retry_in_progress = True
            # Get all of the users that need resyncing.
            need_resync = await self.store.get_user_ids_requiring_device_list_resync(
            )
            # Iterate over the set of user IDs.
            for user_id in need_resync:
                try:
                    # Try to resync the current user's devices list.
                    result = await self.user_device_resync(
                        user_id=user_id,
                        mark_failed_as_stale=False,
                    )

                    # user_device_resync only returns a result if it managed to
                    # successfully resync and update the database. Updating the table
                    # of users requiring resync isn't necessary here as
                    # user_device_resync already does it (through
                    # self.store.update_remote_device_list_cache).
                    if result:
                        logger.debug(
                            "Successfully resynced the device list for %s",
                            user_id,
                        )
                except Exception as e:
                    # If there was an issue resyncing this user, e.g. if the remote
                    # server sent a malformed result, just log the error instead of
                    # aborting all the subsequent resyncs.
                    logger.debug(
                        "Could not resync the device list for %s: %s",
                        user_id,
                        e,
                    )
        finally:
            # Allow future calls to retry resyncinc out of sync device lists.
            self._resync_retry_in_progress = False

    async def user_device_resync(
            self,
            user_id: str,
            mark_failed_as_stale: bool = True) -> Optional[dict]:
        """Fetches all devices for a user and updates the device cache with them.

        Args:
            user_id: The user's id whose device_list will be updated.
            mark_failed_as_stale: Whether to mark the user's device list as stale
                if the attempt to resync failed.
        Returns:
            A dict with device info as under the "devices" in the result of this
            request:
            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
        """
        logger.debug("Attempting to resync the device list for %s", user_id)
        log_kv({"message": "Doing resync to update device list."})
        # Fetch all devices for the user.
        origin = get_domain_from_id(user_id)
        try:
            result = await self.federation.query_user_devices(origin, user_id)
        except NotRetryingDestination:
            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                await self.store.mark_remote_user_device_cache_as_stale(user_id
                                                                        )

            return
        except (RequestSendFailed, HttpResponseException) as e:
            logger.warning(
                "Failed to handle device list update for %s: %s",
                user_id,
                e,
            )

            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                await self.store.mark_remote_user_device_cache_as_stale(user_id
                                                                        )

            # We abort on exceptions rather than accepting the update
            # as otherwise synapse will 'forget' that its device list
            # is out of date. If we bail then we will retry the resync
            # next time we get a device list update for this user_id.
            # This makes it more likely that the device lists will
            # eventually become consistent.
            return
        except FederationDeniedError as e:
            set_tag("error", True)
            log_kv({"reason": "FederationDeniedError"})
            logger.info(e)
            return
        except Exception as e:
            set_tag("error", True)
            log_kv({
                "message": "Exception raised by federation request",
                "exception": e
            })
            logger.exception("Failed to handle device list update for %s",
                             user_id)

            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                await self.store.mark_remote_user_device_cache_as_stale(user_id
                                                                        )

            return
        log_kv({"result": result})
        stream_id = result["stream_id"]
        devices = result["devices"]

        # Get the master key and the self-signing key for this user if provided in the
        # response (None if not in the response).
        # The response will not contain the user signing key, as this key is only used by
        # its owner, thus it doesn't make sense to send it over federation.
        master_key = result.get("master_key")
        self_signing_key = result.get("self_signing_key")

        # If the remote server has more than ~1000 devices for this user
        # we assume that something is going horribly wrong (e.g. a bot
        # that logs in and creates a new device every time it tries to
        # send a message).  Maintaining lots of devices per user in the
        # cache can cause serious performance issues as if this request
        # takes more than 60s to complete, internal replication from the
        # inbound federation worker to the synapse master may time out
        # causing the inbound federation to fail and causing the remote
        # server to retry, causing a DoS.  So in this scenario we give
        # up on storing the total list of devices and only handle the
        # delta instead.
        if len(devices) > 1000:
            logger.warning(
                "Ignoring device list snapshot for %s as it has >1K devs (%d)",
                user_id,
                len(devices),
            )
            devices = []

        for device in devices:
            logger.debug(
                "Handling resync update %r/%r, ID: %r",
                user_id,
                device["device_id"],
                stream_id,
            )

        await self.store.update_remote_device_list_cache(
            user_id, devices, stream_id)
        device_ids = [device["device_id"] for device in devices]

        # Handle cross-signing keys.
        cross_signing_device_ids = await self.process_cross_signing_key_update(
            user_id,
            master_key,
            self_signing_key,
        )
        device_ids = device_ids + cross_signing_device_ids

        await self.device_handler.notify_device_update(user_id, device_ids)

        # We clobber the seen updates since we've re-synced from a given
        # point.
        self._seen_updates[user_id] = {stream_id}

        return result

    async def process_cross_signing_key_update(
        self,
        user_id: str,
        master_key: Optional[Dict[str, Any]],
        self_signing_key: Optional[Dict[str, Any]],
    ) -> list:
        """Process the given new master and self-signing key for the given remote user.

        Args:
            user_id: The ID of the user these keys are for.
            master_key: The dict of the cross-signing master key as returned by the
                remote server.
            self_signing_key: The dict of the cross-signing self-signing key as returned
                by the remote server.

        Return:
            The device IDs for the given keys.
        """
        device_ids = []

        if master_key:
            await self.store.set_e2e_cross_signing_key(user_id, "master",
                                                       master_key)
            _, verify_key = get_verify_key_from_cross_signing_key(master_key)
            # verify_key is a VerifyKey from signedjson, which uses
            # .version to denote the portion of the key ID after the
            # algorithm and colon, which is the device ID
            device_ids.append(verify_key.version)
        if self_signing_key:
            await self.store.set_e2e_cross_signing_key(user_id, "self_signing",
                                                       self_signing_key)
            _, verify_key = get_verify_key_from_cross_signing_key(
                self_signing_key)
            device_ids.append(verify_key.version)

        return device_ids
Пример #48
0
class FederationClient(FederationBase):
    def __init__(self, hs):
        super(FederationClient, self).__init__(hs)

        self.pdu_destination_tried = {}
        self._clock.looping_call(
            self._clear_tried_cache, 60 * 1000,
        )
        self.state = hs.get_state_handler()
        self.transport_layer = hs.get_federation_transport_client()

        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )

    def _clear_tried_cache(self):
        """Clear pdu_destination_tried cache"""
        now = self._clock.time_msec()

        old_dict = self.pdu_destination_tried
        self.pdu_destination_tried = {}

        for event_id, destination_dict in old_dict.items():
            destination_dict = {
                dest: time
                for dest, time in destination_dict.items()
                if time + PDU_RETRY_TIME_MS > now
            }
            if destination_dict:
                self.pdu_destination_tried[event_id] = destination_dict

    @log_function
    def make_query(self, destination, query_type, args,
                   retry_on_dns_fail=False, ignore_backoff=False):
        """Sends a federation Query to a remote homeserver of the given type
        and arguments.

        Args:
            destination (str): Domain name of the remote homeserver
            query_type (str): Category of the query type; should match the
                handler name used in register_query_handler().
            args (dict): Mapping of strings to strings containing the details
                of the query request.
            ignore_backoff (bool): true to ignore the historical backoff data
                and try the request anyway.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.labels(query_type).inc()

        return self.transport_layer.make_query(
            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
            ignore_backoff=ignore_backoff,
        )

    @log_function
    def query_client_keys(self, destination, content, timeout):
        """Query device keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.labels("client_device_keys").inc()
        return self.transport_layer.query_client_keys(
            destination, content, timeout
        )

    @log_function
    def query_user_devices(self, destination, user_id, timeout=30000):
        """Query the device keys for a list of user ids hosted on a remote
        server.
        """
        sent_queries_counter.labels("user_devices").inc()
        return self.transport_layer.query_user_devices(
            destination, user_id, timeout
        )

    @log_function
    def claim_client_keys(self, destination, content, timeout):
        """Claims one-time keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.labels("client_one_time_keys").inc()
        return self.transport_layer.claim_client_keys(
            destination, content, timeout
        )

    @defer.inlineCallbacks
    @log_function
    def backfill(self, dest, context, limit, extremities):
        """Requests some more historic PDUs for the given context from the
        given destination server.

        Args:
            dest (str): The remote home server to ask.
            context (str): The context to backfill.
            limit (int): The maximum number of PDUs to return.
            extremities (list): List of PDU id and origins of the first pdus
                we have seen from the context

        Returns:
            Deferred: Results in the received PDUs.
        """
        logger.debug("backfill extrem=%s", extremities)

        # If there are no extremeties then we've (probably) reached the start.
        if not extremities:
            return

        transaction_data = yield self.transport_layer.backfill(
            dest, context, extremities, limit)

        logger.debug("backfill transaction_data=%s", repr(transaction_data))

        pdus = [
            event_from_pdu_json(p, outlier=False)
            for p in transaction_data["pdus"]
        ]

        # FIXME: We should handle signature failures more gracefully.
        pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
            self._check_sigs_and_hashes(pdus),
            consumeErrors=True,
        ).addErrback(unwrapFirstError))

        defer.returnValue(pdus)

    @defer.inlineCallbacks
    @log_function
    def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
        """Requests the PDU with given origin and ID from the remote home
        servers.

        Will attempt to get the PDU from each destination in the list until
        one succeeds.

        Args:
            destinations (list): Which home servers to query
            event_id (str): event to fetch
            outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                it's from an arbitary point in the context as opposed to part
                of the current block of PDUs. Defaults to `False`
            timeout (int): How long to try (in ms) each destination for before
                moving to the next destination. None indicates no timeout.

        Returns:
            Deferred: Results in the requested PDU.
        """

        # TODO: Rate limit the number of times we try and get the same event.

        ev = self._get_pdu_cache.get(event_id)
        if ev:
            defer.returnValue(ev)

        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})

        signed_pdu = None
        for destination in destinations:
            now = self._clock.time_msec()
            last_attempt = pdu_attempts.get(destination, 0)
            if last_attempt + PDU_RETRY_TIME_MS > now:
                continue

            try:
                transaction_data = yield self.transport_layer.get_event(
                    destination, event_id, timeout=timeout,
                )

                logger.debug("transaction_data %r", transaction_data)

                pdu_list = [
                    event_from_pdu_json(p, outlier=outlier)
                    for p in transaction_data["pdus"]
                ]

                if pdu_list and pdu_list[0]:
                    pdu = pdu_list[0]

                    # Check signatures are correct.
                    signed_pdu = yield self._check_sigs_and_hash(pdu)

                    break

                pdu_attempts[destination] = now

            except SynapseError as e:
                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id, destination, e,
                )
            except NotRetryingDestination as e:
                logger.info(str(e))
                continue
            except FederationDeniedError as e:
                logger.info(str(e))
                continue
            except Exception as e:
                pdu_attempts[destination] = now

                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id, destination, e,
                )
                continue

        if signed_pdu:
            self._get_pdu_cache[event_id] = signed_pdu

        defer.returnValue(signed_pdu)

    @defer.inlineCallbacks
    @log_function
    def get_state_for_room(self, destination, room_id, event_id):
        """Requests all of the room state at a given event from a remote home server.

        Args:
            destination (str): The remote homeserver to query for the state.
            room_id (str): The id of the room we're interested in.
            event_id (str): The id of the event we want the state at.

        Returns:
            Deferred[Tuple[List[EventBase], List[EventBase]]]:
                A list of events in the state, and a list of events in the auth chain
                for the given event.
        """
        try:
            # First we try and ask for just the IDs, as thats far quicker if
            # we have most of the state and auth_chain already.
            # However, this may 404 if the other side has an old synapse.
            result = yield self.transport_layer.get_room_state_ids(
                destination, room_id, event_id=event_id,
            )

            state_event_ids = result["pdu_ids"]
            auth_event_ids = result.get("auth_chain_ids", [])

            fetched_events, failed_to_fetch = yield self.get_events(
                [destination], room_id, set(state_event_ids + auth_event_ids)
            )

            if failed_to_fetch:
                logger.warn("Failed to get %r", failed_to_fetch)

            event_map = {
                ev.event_id: ev for ev in fetched_events
            }

            pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
            auth_chain = [
                event_map[e_id] for e_id in auth_event_ids if e_id in event_map
            ]

            auth_chain.sort(key=lambda e: e.depth)

            defer.returnValue((pdus, auth_chain))
        except HttpResponseException as e:
            if e.code == 400 or e.code == 404:
                logger.info("Failed to use get_room_state_ids API, falling back")
            else:
                raise e

        result = yield self.transport_layer.get_room_state(
            destination, room_id, event_id=event_id,
        )

        pdus = [
            event_from_pdu_json(p, outlier=True) for p in result["pdus"]
        ]

        auth_chain = [
            event_from_pdu_json(p, outlier=True)
            for p in result.get("auth_chain", [])
        ]

        seen_events = yield self.store.get_events([
            ev.event_id for ev in itertools.chain(pdus, auth_chain)
        ])

        signed_pdus = yield self._check_sigs_and_hash_and_fetch(
            destination,
            [p for p in pdus if p.event_id not in seen_events],
            outlier=True
        )
        signed_pdus.extend(
            seen_events[p.event_id] for p in pdus if p.event_id in seen_events
        )

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination,
            [p for p in auth_chain if p.event_id not in seen_events],
            outlier=True
        )
        signed_auth.extend(
            seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
        )

        signed_auth.sort(key=lambda e: e.depth)

        defer.returnValue((signed_pdus, signed_auth))

    @defer.inlineCallbacks
    def get_events(self, destinations, room_id, event_ids, return_local=True):
        """Fetch events from some remote destinations, checking if we already
        have them.

        Args:
            destinations (list)
            room_id (str)
            event_ids (list)
            return_local (bool): Whether to include events we already have in
                the DB in the returned list of events

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        if return_local:
            seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
            signed_events = list(seen_events.values())
        else:
            seen_events = yield self.store.have_seen_events(event_ids)
            signed_events = []

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            defer.returnValue((signed_events, failed_to_fetch))

        def random_server_list():
            srvs = list(destinations)
            random.shuffle(srvs)
            return srvs

        batch_size = 20
        missing_events = list(missing_events)
        for i in range(0, len(missing_events), batch_size):
            batch = set(missing_events[i:i + batch_size])

            deferreds = [
                run_in_background(
                    self.get_pdu,
                    destinations=random_server_list(),
                    event_id=e_id,
                )
                for e_id in batch
            ]

            res = yield make_deferred_yieldable(
                defer.DeferredList(deferreds, consumeErrors=True)
            )
            for success, result in res:
                if success and result:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        defer.returnValue((signed_events, failed_to_fetch))

    @defer.inlineCallbacks
    @log_function
    def get_event_auth(self, destination, room_id, event_id):
        res = yield self.transport_layer.get_event_auth(
            destination, room_id, event_id,
        )

        auth_chain = [
            event_from_pdu_json(p, outlier=True)
            for p in res["auth_chain"]
        ]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination, auth_chain, outlier=True
        )

        signed_auth.sort(key=lambda e: e.depth)

        defer.returnValue(signed_auth)

    @defer.inlineCallbacks
    def _try_destination_list(self, description, destinations, callback):
        """Try an operation on a series of servers, until it succeeds

        Args:
            description (unicode): description of the operation we're doing, for logging

            destinations (Iterable[unicode]): list of server_names to try

            callback (callable):  Function to run for each server. Passed a single
                argument: the server_name to try. May return a deferred.

                If the callback raises a CodeMessageException with a 300/400 code,
                attempts to perform the operation stop immediately and the exception is
                reraised.

                Otherwise, if the callback raises an Exception the error is logged and the
                next server tried. Normally the stacktrace is logged but this is
                suppressed if the exception is an InvalidResponseError.

        Returns:
            The [Deferred] result of callback, if it succeeds

        Raises:
            SynapseError if the chosen remote server returns a 300/400 code.

            RuntimeError if no servers were reachable.
        """
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                res = yield callback(destination)
                defer.returnValue(res)
            except InvalidResponseError as e:
                logger.warn(
                    "Failed to %s via %s: %s",
                    description, destination, e,
                )
            except HttpResponseException as e:
                if not 500 <= e.code < 600:
                    raise e.to_synapse_error()
                else:
                    logger.warn(
                        "Failed to %s via %s: %i %s",
                        description, destination, e.code, e.args[0],
                    )
            except Exception:
                logger.warn(
                    "Failed to %s via %s",
                    description, destination, exc_info=1,
                )

        raise RuntimeError("Failed to %s via any server" % (description, ))

    def make_membership_event(self, destinations, room_id, user_id, membership,
                              content, params):
        """
        Creates an m.room.member event, with context, without participating in the room.

        Does so by asking one of the already participating servers to create an
        event with proper context.

        Note that this does not append any events to any graphs.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            room_id (str): The room in which the event will happen.
            user_id (str): The user whose membership is being evented.
            membership (str): The "membership" property of the event. Must be
                one of "join" or "leave".
            content (dict): Any additional data to put into the content field
                of the event.
            params (dict[str, str|Iterable[str]]): Query parameters to include in the
                request.
        Return:
            Deferred: resolves to a tuple of (origin (str), event (object))
            where origin is the remote homeserver which generated the event.

            Fails with a ``SynapseError`` if the chosen remote server
            returns a 300/400 code.

            Fails with a ``RuntimeError`` if no servers were reachable.
        """
        valid_memberships = {Membership.JOIN, Membership.LEAVE}
        if membership not in valid_memberships:
            raise RuntimeError(
                "make_membership_event called with membership='%s', must be one of %s" %
                (membership, ",".join(valid_memberships))
            )

        @defer.inlineCallbacks
        def send_request(destination):
            ret = yield self.transport_layer.make_membership_event(
                destination, room_id, user_id, membership, params,
            )

            pdu_dict = ret.get("event", None)
            if not isinstance(pdu_dict, dict):
                raise InvalidResponseError("Bad 'event' field in response")

            logger.debug("Got response to make_%s: %s", membership, pdu_dict)

            pdu_dict["content"].update(content)

            # The protoevent received over the JSON wire may not have all
            # the required fields. Lets just gloss over that because
            # there's some we never care about
            if "prev_state" not in pdu_dict:
                pdu_dict["prev_state"] = []

            ev = builder.EventBuilder(pdu_dict)

            defer.returnValue(
                (destination, ev)
            )

        return self._try_destination_list(
            "make_" + membership, destinations, send_request,
        )

    def send_join(self, destinations, pdu):
        """Sends a join event to one of a list of homeservers.

        Doing so will cause the remote server to add the event to the graph,
        and send the event out to the rest of the federation.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            pdu (BaseEvent): event to be sent

        Return:
            Deferred: resolves to a dict with members ``origin`` (a string
            giving the serer the event was sent to, ``state`` (?) and
            ``auth_chain``.

            Fails with a ``SynapseError`` if the chosen remote server
            returns a 300/400 code.

            Fails with a ``RuntimeError`` if no servers were reachable.
        """

        def check_authchain_validity(signed_auth_chain):
            for e in signed_auth_chain:
                if e.type == EventTypes.Create:
                    create_event = e
                    break
            else:
                raise InvalidResponseError(
                    "no %s in auth chain" % (EventTypes.Create,),
                )

            # the room version should be sane.
            room_version = create_event.content.get("room_version", "1")
            if room_version not in KNOWN_ROOM_VERSIONS:
                # This shouldn't be possible, because the remote server should have
                # rejected the join attempt during make_join.
                raise InvalidResponseError(
                    "room appears to have unsupported version %s" % (
                        room_version,
                    ))

        @defer.inlineCallbacks
        def send_request(destination):
            time_now = self._clock.time_msec()
            _, content = yield self.transport_layer.send_join(
                destination=destination,
                room_id=pdu.room_id,
                event_id=pdu.event_id,
                content=pdu.get_pdu_json(time_now),
            )

            logger.debug("Got content: %s", content)

            state = [
                event_from_pdu_json(p, outlier=True)
                for p in content.get("state", [])
            ]

            auth_chain = [
                event_from_pdu_json(p, outlier=True)
                for p in content.get("auth_chain", [])
            ]

            pdus = {
                p.event_id: p
                for p in itertools.chain(state, auth_chain)
            }

            valid_pdus = yield self._check_sigs_and_hash_and_fetch(
                destination, list(pdus.values()),
                outlier=True,
            )

            valid_pdus_map = {
                p.event_id: p
                for p in valid_pdus
            }

            # NB: We *need* to copy to ensure that we don't have multiple
            # references being passed on, as that causes... issues.
            signed_state = [
                copy.copy(valid_pdus_map[p.event_id])
                for p in state
                if p.event_id in valid_pdus_map
            ]

            signed_auth = [
                valid_pdus_map[p.event_id]
                for p in auth_chain
                if p.event_id in valid_pdus_map
            ]

            # NB: We *need* to copy to ensure that we don't have multiple
            # references being passed on, as that causes... issues.
            for s in signed_state:
                s.internal_metadata = copy.deepcopy(s.internal_metadata)

            check_authchain_validity(signed_auth)

            defer.returnValue({
                "state": signed_state,
                "auth_chain": signed_auth,
                "origin": destination,
            })
        return self._try_destination_list("send_join", destinations, send_request)

    @defer.inlineCallbacks
    def send_invite(self, destination, room_id, event_id, pdu):
        time_now = self._clock.time_msec()
        try:
            code, content = yield self.transport_layer.send_invite(
                destination=destination,
                room_id=room_id,
                event_id=event_id,
                content=pdu.get_pdu_json(time_now),
            )
        except HttpResponseException as e:
            if e.code == 403:
                raise e.to_synapse_error()
            raise

        pdu_dict = content["event"]

        logger.debug("Got response to send_invite: %s", pdu_dict)

        pdu = event_from_pdu_json(pdu_dict)

        # Check signatures are correct.
        pdu = yield self._check_sigs_and_hash(pdu)

        # FIXME: We should handle signature failures more gracefully.

        defer.returnValue(pdu)

    def send_leave(self, destinations, pdu):
        """Sends a leave event to one of a list of homeservers.

        Doing so will cause the remote server to add the event to the graph,
        and send the event out to the rest of the federation.

        This is mostly useful to reject received invites.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            pdu (BaseEvent): event to be sent

        Return:
            Deferred: resolves to None.

            Fails with a ``SynapseError`` if the chosen remote server
            returns a 300/400 code.

            Fails with a ``RuntimeError`` if no servers were reachable.
        """
        @defer.inlineCallbacks
        def send_request(destination):
            time_now = self._clock.time_msec()
            _, content = yield self.transport_layer.send_leave(
                destination=destination,
                room_id=pdu.room_id,
                event_id=pdu.event_id,
                content=pdu.get_pdu_json(time_now),
            )

            logger.debug("Got content: %s", content)
            defer.returnValue(None)

        return self._try_destination_list("send_leave", destinations, send_request)

    def get_public_rooms(self, destination, limit=None, since_token=None,
                         search_filter=None, include_all_networks=False,
                         third_party_instance_id=None):
        if destination == self.server_name:
            return

        return self.transport_layer.get_public_rooms(
            destination, limit, since_token, search_filter,
            include_all_networks=include_all_networks,
            third_party_instance_id=third_party_instance_id,
        )

    @defer.inlineCallbacks
    def query_auth(self, destination, room_id, event_id, local_auth):
        """
        Params:
            destination (str)
            event_it (str)
            local_auth (list)
        """
        time_now = self._clock.time_msec()

        send_content = {
            "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
        }

        code, content = yield self.transport_layer.send_query_auth(
            destination=destination,
            room_id=room_id,
            event_id=event_id,
            content=send_content,
        )

        auth_chain = [
            event_from_pdu_json(e)
            for e in content["auth_chain"]
        ]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination, auth_chain, outlier=True
        )

        signed_auth.sort(key=lambda e: e.depth)

        ret = {
            "auth_chain": signed_auth,
            "rejects": content.get("rejects", []),
            "missing": content.get("missing", []),
        }

        defer.returnValue(ret)

    @defer.inlineCallbacks
    def get_missing_events(self, destination, room_id, earliest_events_ids,
                           latest_events, limit, min_depth, timeout):
        """Tries to fetch events we are missing. This is called when we receive
        an event without having received all of its ancestors.

        Args:
            destination (str)
            room_id (str)
            earliest_events_ids (list): List of event ids. Effectively the
                events we expected to receive, but haven't. `get_missing_events`
                should only return events that didn't happen before these.
            latest_events (list): List of events we have received that we don't
                have all previous events for.
            limit (int): Maximum number of events to return.
            min_depth (int): Minimum depth of events tor return.
            timeout (int): Max time to wait in ms
        """
        try:
            content = yield self.transport_layer.get_missing_events(
                destination=destination,
                room_id=room_id,
                earliest_events=earliest_events_ids,
                latest_events=[e.event_id for e in latest_events],
                limit=limit,
                min_depth=min_depth,
                timeout=timeout,
            )

            events = [
                event_from_pdu_json(e)
                for e in content.get("events", [])
            ]

            signed_events = yield self._check_sigs_and_hash_and_fetch(
                destination, events, outlier=False
            )
        except HttpResponseException as e:
            if not e.code == 400:
                raise

            # We are probably hitting an old server that doesn't support
            # get_missing_events
            signed_events = []

        defer.returnValue(signed_events)

    @defer.inlineCallbacks
    def forward_third_party_invite(self, destinations, room_id, event_dict):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                yield self.transport_layer.exchange_third_party_invite(
                    destination=destination,
                    room_id=room_id,
                    event_dict=event_dict,
                )
                defer.returnValue(None)
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception(
                    "Failed to send_third_party_invite via %s: %s",
                    destination, str(e)
                )

        raise RuntimeError("Failed to send to any server.")
Пример #49
0
class DeviceInboxWorkerStore(SQLBaseStore):
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="to_device",
                instance_name=self._instance_name,
                tables=[("device_inbox", "instance_name", "stream_id")],
                sequence_name="device_inbox_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

    def process_replication_rows(self, stream_name, instance_name, token,
                                 rows):
        if stream_name == ToDeviceStream.NAME:
            self._device_inbox_id_gen.advance(instance_name, token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token)
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token)
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)

    def get_to_device_stream_token(self):
        return self._device_inbox_id_gen.get_current_token()

    async def get_new_messages_for_device(
        self,
        user_id: str,
        device_id: str,
        last_stream_id: int,
        current_stream_id: int,
        limit: int = 100,
    ) -> Tuple[List[dict], int]:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            last_stream_id: The last stream ID checked.
            current_stream_id: The current position of the to device
                message stream.
            limit: The maximum number of messages to retrieve.

        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """
        has_changed = self._device_inbox_stream_cache.has_entity_changed(
            user_id, last_stream_id)
        if not has_changed:
            return ([], current_stream_id)

        def get_new_messages_for_device_txn(txn):
            sql = ("SELECT stream_id, message_json FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND ? < stream_id AND stream_id <= ?"
                   " ORDER BY stream_id ASC"
                   " LIMIT ?")
            txn.execute(
                sql,
                (user_id, device_id, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_messages_for_device", get_new_messages_for_device_txn)

    @trace
    async def delete_messages_for_device(self, user_id: str, device_id: str,
                                         up_to_stream_id: int) -> int:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            up_to_stream_id: Where to delete messages up to.

        Returns:
            The number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None)

        set_tag("last_deleted_stream_id", last_deleted_stream_id)

        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id)
            if not has_changed:
                log_kv({"message": "No changes in cache since last check"})
                return 0

        def delete_messages_for_device_txn(txn):
            sql = ("DELETE FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = await self.db_pool.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn)

        log_kv({
            "message": "deleted {} messages for device".format(count),
            "count": count
        })

        # Update the cache, ensuring that we only ever increase the value
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0)
        self._last_device_delete_cache[(user_id, device_id)] = max(
            last_deleted_stream_id, up_to_stream_id)

        return count

    @trace
    async def get_new_device_msgs_for_remote(self, destination, last_stream_id,
                                             current_stream_id,
                                             limit) -> Tuple[List[dict], int]:
        """
        Args:
            destination(str): The name of the remote server.
            last_stream_id(int|long): The last position of the device message stream
                that the server sent up to.
            current_stream_id(int|long): The current position of the device
                message stream.
        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return ([], current_stream_id)

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return ([], last_stream_id)

        @trace
        def get_new_messages_for_remote_destination_txn(txn):
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )

    @trace
    async def delete_device_msgs_for_remote(self, destination: str,
                                            up_to_stream_id: int) -> None:
        """Used to delete messages when the remote destination acknowledges
        their receipt.

        Args:
            destination: The destination server_name
            up_to_stream_id: Where to delete messages up to.
        """
        def delete_messages_for_remote_destination_txn(txn):
            sql = ("DELETE FROM device_federation_outbox"
                   " WHERE destination = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (destination, up_to_stream_id))

        await self.db_pool.runInteraction(
            "delete_device_msgs_for_remote",
            delete_messages_for_remote_destination_txn)

    async def get_all_new_device_messages(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]:
        """Get updates for to device replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_new_device_messages_txn(txn):
            # We limit like this as we might have multiple rows per stream_id, and
            # we want to make sure we always get all entries for any stream_id
            # we return.
            upper_pos = min(current_id, last_id + limit)
            sql = ("SELECT max(stream_id), user_id"
                   " FROM device_inbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY user_id")
            txn.execute(sql, (last_id, upper_pos))
            updates = [(row[0], row[1:]) for row in txn]

            sql = ("SELECT max(stream_id), destination"
                   " FROM device_federation_outbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY destination")
            txn.execute(sql, (last_id, upper_pos))
            updates.extend((row[0], row[1:]) for row in txn)

            # Order by ascending stream ordering
            updates.sort()

            limited = False
            upto_token = current_id
            if len(updates) >= limit:
                upto_token = updates[-1][0]
                limited = True

            return updates, upto_token, limited

        return await self.db_pool.runInteraction(
            "get_all_new_device_messages", get_all_new_device_messages_txn)

    @trace
    async def add_messages_to_device_inbox(
        self,
        local_messages_by_user_then_device: dict,
        remote_messages_by_destination: dict,
    ) -> int:
        """Used to send messages from this server.

        Args:
            local_messages_by_user_and_device:
                Dictionary of user_id to device_id to message.
            remote_messages_by_destination:
                Dictionary of destination server_name to the EDU JSON to send.

        Returns:
            The new stream_id.
        """

        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Add the local messages directly to the local inbox.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

            # Add the remote messages to the federation outbox.
            # We'll send them to a remote server when we next send a
            # federation transaction to that destination.
            self.db_pool.simple_insert_many_txn(
                txn,
                table="device_federation_outbox",
                values=[{
                    "destination": destination,
                    "stream_id": stream_id,
                    "queued_ts": now_ms,
                    "messages_json": json_encoder.encode(edu),
                    "instance_name": self._instance_name,
                } for destination, edu in
                        remote_messages_by_destination.items()],
            )

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction("add_messages_to_device_inbox",
                                              add_messages_txn, now_ms,
                                              stream_id)
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)
            for destination in remote_messages_by_destination.keys():
                self._device_federation_outbox_stream_cache.entity_has_changed(
                    destination, stream_id)

        return self._device_inbox_id_gen.get_current_token()

    async def add_messages_from_remote_to_device_inbox(
            self, origin: str, message_id: str,
            local_messages_by_user_then_device: dict) -> int:
        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Check if we've already inserted a matching message_id for that
            # origin. This can happen if the origin doesn't receive our
            # acknowledgement from the first time we received the message.
            already_inserted = self.db_pool.simple_select_one_txn(
                txn,
                table="device_federation_inbox",
                keyvalues={
                    "origin": origin,
                    "message_id": message_id
                },
                retcols=("message_id", ),
                allow_none=True,
            )
            if already_inserted is not None:
                return

            # Add an entry for this message_id so that we know we've processed
            # it.
            self.db_pool.simple_insert_txn(
                txn,
                table="device_federation_inbox",
                values={
                    "origin": origin,
                    "message_id": message_id,
                    "received_ts": now_ms,
                },
            )

            # Add the messages to the approriate local device inboxes so that
            # they'll be sent to the devices when they next sync.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction(
                "add_messages_from_remote_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)

        return stream_id

    def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                messages_by_user_then_device):
        assert self._can_write_to_device

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items(
        ):
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                devices = self.db_pool.simple_select_onecol_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    retcol="device_id",
                )

                message_json = json_encoder.encode(messages_by_device["*"])
                for device_id in devices:
                    # Add the message for all devices for this user on this
                    # server.
                    messages_json_for_user[device_id] = message_json
            else:
                if not devices:
                    continue

                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    column="device_id",
                    iterable=devices,
                    retcols=("device_id", ),
                )

                for row in rows:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device_id = row["device_id"]
                    message_json = json_encoder.encode(
                        messages_by_device[device_id])
                    messages_json_for_user[device_id] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_inbox",
            values=[{
                "user_id": user_id,
                "device_id": device_id,
                "stream_id": stream_id,
                "message_json": message_json,
                "instance_name": self._instance_name,
            } for user_id, messages_by_device in
                    local_by_user_then_device.items()
                    for device_id, message_json in messages_by_device.items()],
        )
Пример #50
0
class TransactionStore(SQLBaseStore):
    """A collection of queries for handling PDUs.
    """
    def __init__(self, db_conn, hs):
        super(TransactionStore, self).__init__(db_conn, hs)

        self._clock.looping_call(self._start_cleanup_transactions,
                                 30 * 60 * 1000)

        self._destination_retry_cache = ExpiringCache(
            cache_name="get_destination_retry_timings",
            clock=self._clock,
            expiry_ms=5 * 60 * 1000,
        )

    def get_received_txn_response(self, transaction_id, origin):
        """For an incoming transaction from a given origin, check if we have
        already responded to it. If so, return the response code and response
        body (as a dict).

        Args:
            transaction_id (str)
            origin(str)

        Returns:
            tuple: None if we have not previously responded to
            this transaction or a 2-tuple of (int, dict)
        """

        return self.runInteraction("get_received_txn_response",
                                   self._get_received_txn_response,
                                   transaction_id, origin)

    def _get_received_txn_response(self, txn, transaction_id, origin):
        result = self._simple_select_one_txn(
            txn,
            table="received_transactions",
            keyvalues={
                "transaction_id": transaction_id,
                "origin": origin,
            },
            retcols=(
                "transaction_id",
                "origin",
                "ts",
                "response_code",
                "response_json",
                "has_been_referenced",
            ),
            allow_none=True,
        )

        if result and result["response_code"]:
            return result["response_code"], db_to_json(result["response_json"])

        else:
            return None

    def set_received_txn_response(self, transaction_id, origin, code,
                                  response_dict):
        """Persist the response we returened for an incoming transaction, and
        should return for subsequent transactions with the same transaction_id
        and origin.

        Args:
            txn
            transaction_id (str)
            origin (str)
            code (int)
            response_json (str)
        """

        return self._simple_insert(
            table="received_transactions",
            values={
                "transaction_id":
                transaction_id,
                "origin":
                origin,
                "response_code":
                code,
                "response_json":
                db_binary_type(encode_canonical_json(response_dict)),
                "ts":
                self._clock.time_msec(),
            },
            or_ignore=True,
            desc="set_received_txn_response",
        )

    def prep_send_transaction(self, transaction_id, destination,
                              origin_server_ts):
        """Persists an outgoing transaction and calculates the values for the
        previous transaction id list.

        This should be called before sending the transaction so that it has the
        correct value for the `prev_ids` key.

        Args:
            transaction_id (str)
            destination (str)
            origin_server_ts (int)

        Returns:
            list: A list of previous transaction ids.
        """
        return defer.succeed([])

    def delivered_txn(self, transaction_id, destination, code, response_dict):
        """Persists the response for an outgoing transaction.

        Args:
            transaction_id (str)
            destination (str)
            code (int)
            response_json (str)
        """
        pass

    @defer.inlineCallbacks
    def get_destination_retry_timings(self, destination):
        """Gets the current retry timings (if any) for a given destination.

        Args:
            destination (str)

        Returns:
            None if not retrying
            Otherwise a dict for the retry scheme
        """

        result = self._destination_retry_cache.get(destination, SENTINEL)
        if result is not SENTINEL:
            defer.returnValue(result)

        result = yield self.runInteraction("get_destination_retry_timings",
                                           self._get_destination_retry_timings,
                                           destination)

        # We don't hugely care about race conditions between getting and
        # invalidating the cache, since we time out fairly quickly anyway.
        self._destination_retry_cache[destination] = result
        defer.returnValue(result)

    def _get_destination_retry_timings(self, txn, destination):
        result = self._simple_select_one_txn(
            txn,
            table="destinations",
            keyvalues={
                "destination": destination,
            },
            retcols=("destination", "retry_last_ts", "retry_interval"),
            allow_none=True,
        )

        if result and result["retry_last_ts"] > 0:
            return result
        else:
            return None

    def set_destination_retry_timings(self, destination, retry_last_ts,
                                      retry_interval):
        """Sets the current retry timings for a given destination.
        Both timings should be zero if retrying is no longer occuring.

        Args:
            destination (str)
            retry_last_ts (int) - time of last retry attempt in unix epoch ms
            retry_interval (int) - how long until next retry in ms
        """

        self._destination_retry_cache.pop(destination, None)
        return self.runInteraction(
            "set_destination_retry_timings",
            self._set_destination_retry_timings,
            destination,
            retry_last_ts,
            retry_interval,
        )

    def _set_destination_retry_timings(self, txn, destination, retry_last_ts,
                                       retry_interval):
        self.database_engine.lock_table(txn, "destinations")

        # We need to be careful here as the data may have changed from under us
        # due to a worker setting the timings.

        prev_row = self._simple_select_one_txn(
            txn,
            table="destinations",
            keyvalues={
                "destination": destination,
            },
            retcols=("retry_last_ts", "retry_interval"),
            allow_none=True,
        )

        if not prev_row:
            self._simple_insert_txn(txn,
                                    table="destinations",
                                    values={
                                        "destination": destination,
                                        "retry_last_ts": retry_last_ts,
                                        "retry_interval": retry_interval,
                                    })
        elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
            self._simple_update_one_txn(
                txn,
                "destinations",
                keyvalues={
                    "destination": destination,
                },
                updatevalues={
                    "retry_last_ts": retry_last_ts,
                    "retry_interval": retry_interval,
                },
            )

    def get_destinations_needing_retry(self):
        """Get all destinations which are due a retry for sending a transaction.

        Returns:
            list: A list of dicts
        """

        return self.runInteraction("get_destinations_needing_retry",
                                   self._get_destinations_needing_retry)

    def _get_destinations_needing_retry(self, txn):
        query = ("SELECT * FROM destinations"
                 " WHERE retry_last_ts > 0 and retry_next_ts < ?")

        txn.execute(query, (self._clock.time_msec(), ))
        return self.cursor_to_dict(txn)

    def _start_cleanup_transactions(self):
        return run_as_background_process(
            "cleanup_transactions",
            self._cleanup_transactions,
        )

    def _cleanup_transactions(self):
        now = self._clock.time_msec()
        month_ago = now - 30 * 24 * 60 * 60 * 1000

        def _cleanup_transactions_txn(txn):
            txn.execute("DELETE FROM received_transactions WHERE ts < ?",
                        (month_ago, ))

        return self.runInteraction("_cleanup_transactions",
                                   _cleanup_transactions_txn)
Пример #51
0
class DeviceListEduUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]  # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning("Got device list update edu for %r from %r",
                           user_id, origin)
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            return

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content))

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            if resync:
                # Fetch all devices for the user.
                origin = get_domain_from_id(user_id)
                try:
                    result = yield self.federation.query_user_devices(
                        origin, user_id)
                except NotRetryingDestination:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.warn(
                        "Failed to handle device list update for %s,"
                        " we're not retrying the remote",
                        user_id,
                    )
                    # We abort on exceptions rather than accepting the update
                    # as otherwise synapse will 'forget' that its device list
                    # is out of date. If we bail then we will retry the resync
                    # next time we get a device list update for this user_id.
                    # This makes it more likely that the device lists will
                    # eventually become consistent.
                    return
                except FederationDeniedError as e:
                    logger.info(e)
                    return
                except Exception:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.exception(
                        "Failed to handle device list update for %s", user_id)
                    return

                stream_id = result["stream_id"]
                devices = result["devices"]
                yield self.store.update_remote_device_list_cache(
                    user_id,
                    devices,
                    stream_id,
                )
                device_ids = [device["device_id"] for device in devices]
                yield self.device_handler.notify_device_update(
                    user_id, device_ids)
            else:
                # Simply update the single device, since we know that is the only
                # change (becuase of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id,
                        device_id,
                        content,
                        stream_id,
                    )

                yield self.device_handler.notify_device_update(
                    user_id,
                    [device_id for device_id, _, _, _ in pending_updates])

            self._seen_updates.setdefault(user_id, set()).update(
                stream_id for _, stream_id, _, _ in pending_updates)

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(
            user_id)

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                defer.returnValue(True)

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    defer.returnValue(True)

            stream_id_in_updates.add(stream_id)

        defer.returnValue(False)
Пример #52
0
class DeviceListUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @trace
    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        set_tag("origin", origin)
        set_tag("edu_content", edu_content)
        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]  # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning(
                "Got device list update edu for %r/%r from %r",
                user_id,
                device_id,
                origin,
            )

            set_tag("error", True)
            log_kv(
                {
                    "message": "Got a device list update edu from a user and "
                    "device which does not match the origin of the request.",
                    "user_id": user_id,
                    "device_id": device_id,
                }
            )
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            set_tag("error", True)
            log_kv(
                {
                    "message": "Got an update from a user for which "
                    "we don't share any rooms",
                    "other user_id": user_id,
                }
            )
            logger.warning(
                "Got device list update edu for %r/%r, but don't share a room",
                user_id,
                device_id,
            )
            return

        logger.debug("Received device list update for %r/%r", user_id, device_id)

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            for device_id, stream_id, prev_ids, content in pending_updates:
                logger.debug(
                    "Handling update %r/%r, ID: %r, prev: %r ",
                    user_id,
                    device_id,
                    stream_id,
                    prev_ids,
                )

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            if logger.isEnabledFor(logging.INFO):
                logger.info(
                    "Received device list update for %s, requiring resync: %s. Devices: %s",
                    user_id,
                    resync,
                    ", ".join(u[0] for u in pending_updates),
                )

            if resync:
                yield self.user_device_resync(user_id)
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id
                    )

                yield self.device_handler.notify_device_update(
                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                )

                self._seen_updates.setdefault(user_id, set()).update(
                    stream_id for _, stream_id, _, _ in pending_updates
                )

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)

        logger.debug("Current extremity for %r: %r", user_id, extremity)

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                return True

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    return True

            stream_id_in_updates.add(stream_id)

        return False

    @defer.inlineCallbacks
    def user_device_resync(self, user_id):
        """Fetches all devices for a user and updates the device cache with them.

        Args:
            user_id (str): The user's id whose device_list will be updated.
        Returns:
            Deferred[dict]: a dict with device info as under the "devices" in the result of this
            request:
            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
        """
        log_kv({"message": "Doing resync to update device list."})
        # Fetch all devices for the user.
        origin = get_domain_from_id(user_id)
        try:
            result = yield self.federation.query_user_devices(origin, user_id)
        except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
            # TODO: Remember that we are now out of sync and try again
            # later
            logger.warning("Failed to handle device list update for %s", user_id)
            # We abort on exceptions rather than accepting the update
            # as otherwise synapse will 'forget' that its device list
            # is out of date. If we bail then we will retry the resync
            # next time we get a device list update for this user_id.
            # This makes it more likely that the device lists will
            # eventually become consistent.
            return
        except FederationDeniedError as e:
            set_tag("error", True)
            log_kv({"reason": "FederationDeniedError"})
            logger.info(e)
            return
        except Exception as e:
            # TODO: Remember that we are now out of sync and try again
            # later
            set_tag("error", True)
            log_kv(
                {"message": "Exception raised by federation request", "exception": e}
            )
            logger.exception("Failed to handle device list update for %s", user_id)
            return
        log_kv({"result": result})
        stream_id = result["stream_id"]
        devices = result["devices"]

        # If the remote server has more than ~1000 devices for this user
        # we assume that something is going horribly wrong (e.g. a bot
        # that logs in and creates a new device every time it tries to
        # send a message).  Maintaining lots of devices per user in the
        # cache can cause serious performance issues as if this request
        # takes more than 60s to complete, internal replication from the
        # inbound federation worker to the synapse master may time out
        # causing the inbound federation to fail and causing the remote
        # server to retry, causing a DoS.  So in this scenario we give
        # up on storing the total list of devices and only handle the
        # delta instead.
        if len(devices) > 1000:
            logger.warning(
                "Ignoring device list snapshot for %s as it has >1K devs (%d)",
                user_id,
                len(devices),
            )
            devices = []

        for device in devices:
            logger.debug(
                "Handling resync update %r/%r, ID: %r",
                user_id,
                device["device_id"],
                stream_id,
            )

        yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
        device_ids = [device["device_id"] for device in devices]
        yield self.device_handler.notify_device_update(user_id, device_ids)

        # We clobber the seen updates since we've re-synced from a given
        # point.
        self._seen_updates[user_id] = set([stream_id])

        defer.returnValue(result)
Пример #53
0
class FederationClient(FederationBase):
    def __init__(self, hs):
        super(FederationClient, self).__init__(hs)

    def start_get_pdu_cache(self):
        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )

        self._get_pdu_cache.start()

    @log_function
    def send_pdu(self, pdu, destinations):
        """Informs the replication layer about a new PDU generated within the
        home server that should be transmitted to others.

        TODO: Figure out when we should actually resolve the deferred.

        Args:
            pdu (Pdu): The new Pdu.

        Returns:
            Deferred: Completes when we have successfully processed the PDU
            and replicated it to any interested remote home servers.
        """
        order = self._order
        self._order += 1

        sent_pdus_destination_dist.inc_by(len(destinations))

        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)

        # TODO, add errback, etc.
        self._transaction_queue.enqueue_pdu(pdu, destinations, order)

        logger.debug("[%s] transaction_layer.enqueue_pdu... done",
                     pdu.event_id)

    @log_function
    def send_edu(self, destination, edu_type, content):
        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        sent_edus_counter.inc()

        # TODO, add errback, etc.
        self._transaction_queue.enqueue_edu(edu)
        return defer.succeed(None)

    @log_function
    def send_failure(self, failure, destination):
        self._transaction_queue.enqueue_failure(failure, destination)
        return defer.succeed(None)

    @log_function
    def make_query(self,
                   destination,
                   query_type,
                   args,
                   retry_on_dns_fail=False):
        """Sends a federation Query to a remote homeserver of the given type
        and arguments.

        Args:
            destination (str): Domain name of the remote homeserver
            query_type (str): Category of the query type; should match the
                handler name used in register_query_handler().
            args (dict): Mapping of strings to strings containing the details
                of the query request.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.inc(query_type)

        return self.transport_layer.make_query(
            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail)

    @log_function
    def query_client_keys(self, destination, content):
        """Query device keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.inc("client_device_keys")
        return self.transport_layer.query_client_keys(destination, content)

    @log_function
    def claim_client_keys(self, destination, content):
        """Claims one-time keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.inc("client_one_time_keys")
        return self.transport_layer.claim_client_keys(destination, content)

    @defer.inlineCallbacks
    @log_function
    def backfill(self, dest, context, limit, extremities):
        """Requests some more historic PDUs for the given context from the
        given destination server.

        Args:
            dest (str): The remote home server to ask.
            context (str): The context to backfill.
            limit (int): The maximum number of PDUs to return.
            extremities (list): List of PDU id and origins of the first pdus
                we have seen from the context

        Returns:
            Deferred: Results in the received PDUs.
        """
        logger.debug("backfill extrem=%s", extremities)

        # If there are no extremeties then we've (probably) reached the start.
        if not extremities:
            return

        transaction_data = yield self.transport_layer.backfill(
            dest, context, extremities, limit)

        logger.debug("backfill transaction_data=%s", repr(transaction_data))

        pdus = [
            self.event_from_pdu_json(p, outlier=False)
            for p in transaction_data["pdus"]
        ]

        # FIXME: We should handle signature failures more gracefully.
        pdus[:] = yield defer.gatherResults(
            self._check_sigs_and_hashes(pdus),
            consumeErrors=True,
        ).addErrback(unwrapFirstError)

        defer.returnValue(pdus)

    @defer.inlineCallbacks
    @log_function
    def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
        """Requests the PDU with given origin and ID from the remote home
        servers.

        Will attempt to get the PDU from each destination in the list until
        one succeeds.

        This will persist the PDU locally upon receipt.

        Args:
            destinations (list): Which home servers to query
            pdu_origin (str): The home server that originally sent the pdu.
            event_id (str)
            outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                it's from an arbitary point in the context as opposed to part
                of the current block of PDUs. Defaults to `False`
            timeout (int): How long to try (in ms) each destination for before
                moving to the next destination. None indicates no timeout.

        Returns:
            Deferred: Results in the requested PDU.
        """

        # TODO: Rate limit the number of times we try and get the same event.

        if self._get_pdu_cache:
            e = self._get_pdu_cache.get(event_id)
            if e:
                defer.returnValue(e)

        pdu = None
        for destination in destinations:
            try:
                limiter = yield get_retry_limiter(
                    destination,
                    self._clock,
                    self.store,
                )

                with limiter:
                    transaction_data = yield self.transport_layer.get_event(
                        destination,
                        event_id,
                        timeout=timeout,
                    )

                    logger.debug("transaction_data %r", transaction_data)

                    pdu_list = [
                        self.event_from_pdu_json(p, outlier=outlier)
                        for p in transaction_data["pdus"]
                    ]

                    if pdu_list and pdu_list[0]:
                        pdu = pdu_list[0]

                        # Check signatures are correct.
                        pdu = yield self._check_sigs_and_hashes([pdu])[0]

                        break

            except SynapseError:
                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id,
                    destination,
                    e,
                )
                continue
            except CodeMessageException as e:
                if 400 <= e.code < 500:
                    raise

                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id,
                    destination,
                    e,
                )
                continue
            except NotRetryingDestination as e:
                logger.info(e.message)
                continue
            except Exception as e:
                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id,
                    destination,
                    e,
                )
                continue

        if self._get_pdu_cache is not None and pdu:
            self._get_pdu_cache[event_id] = pdu

        defer.returnValue(pdu)

    @defer.inlineCallbacks
    @log_function
    def get_state_for_room(self, destination, room_id, event_id):
        """Requests all of the `current` state PDUs for a given room from
        a remote home server.

        Args:
            destination (str): The remote homeserver to query for the state.
            room_id (str): The id of the room we're interested in.
            event_id (str): The id of the event we want the state at.

        Returns:
            Deferred: Results in a list of PDUs.
        """

        result = yield self.transport_layer.get_room_state(
            destination,
            room_id,
            event_id=event_id,
        )

        pdus = [
            self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
        ]

        auth_chain = [
            self.event_from_pdu_json(p, outlier=True)
            for p in result.get("auth_chain", [])
        ]

        signed_pdus = yield self._check_sigs_and_hash_and_fetch(destination,
                                                                pdus,
                                                                outlier=True)

        signed_auth = yield self._check_sigs_and_hash_and_fetch(destination,
                                                                auth_chain,
                                                                outlier=True)

        signed_auth.sort(key=lambda e: e.depth)

        defer.returnValue((signed_pdus, signed_auth))

    @defer.inlineCallbacks
    @log_function
    def get_event_auth(self, destination, room_id, event_id):
        res = yield self.transport_layer.get_event_auth(
            destination,
            room_id,
            event_id,
        )

        auth_chain = [
            self.event_from_pdu_json(p, outlier=True)
            for p in res["auth_chain"]
        ]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(destination,
                                                                auth_chain,
                                                                outlier=True)

        signed_auth.sort(key=lambda e: e.depth)

        defer.returnValue(signed_auth)

    @defer.inlineCallbacks
    def make_membership_event(
        self,
        destinations,
        room_id,
        user_id,
        membership,
        content={},
    ):
        """
        Creates an m.room.member event, with context, without participating in the room.

        Does so by asking one of the already participating servers to create an
        event with proper context.

        Note that this does not append any events to any graphs.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            room_id (str): The room in which the event will happen.
            user_id (str): The user whose membership is being evented.
            membership (str): The "membership" property of the event. Must be
                one of "join" or "leave".
            content (object): Any additional data to put into the content field
                of the event.
        Return:
            A tuple of (origin (str), event (object)) where origin is the remote
            homeserver which generated the event.
        """
        valid_memberships = {Membership.JOIN, Membership.LEAVE}
        if membership not in valid_memberships:
            raise RuntimeError(
                "make_membership_event called with membership='%s', must be one of %s"
                % (membership, ",".join(valid_memberships)))
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                ret = yield self.transport_layer.make_membership_event(
                    destination, room_id, user_id, membership)

                pdu_dict = ret["event"]

                logger.debug("Got response to make_%s: %s", membership,
                             pdu_dict)

                pdu_dict["content"].update(content)

                # The protoevent received over the JSON wire may not have all
                # the required fields. Lets just gloss over that because
                # there's some we never care about
                if "prev_state" not in pdu_dict:
                    pdu_dict["prev_state"] = []

                defer.returnValue(
                    (destination, self.event_from_pdu_json(pdu_dict)))
                break
            except CodeMessageException:
                raise
            except Exception as e:
                logger.warn("Failed to make_%s via %s: %s", membership,
                            destination, e.message)
                raise

        raise RuntimeError("Failed to send to any server.")

    @defer.inlineCallbacks
    def send_join(self, destinations, pdu):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                time_now = self._clock.time_msec()
                _, content = yield self.transport_layer.send_join(
                    destination=destination,
                    room_id=pdu.room_id,
                    event_id=pdu.event_id,
                    content=pdu.get_pdu_json(time_now),
                )

                logger.debug("Got content: %s", content)

                state = [
                    self.event_from_pdu_json(p, outlier=True)
                    for p in content.get("state", [])
                ]

                auth_chain = [
                    self.event_from_pdu_json(p, outlier=True)
                    for p in content.get("auth_chain", [])
                ]

                pdus = {
                    p.event_id: p
                    for p in itertools.chain(state, auth_chain)
                }

                valid_pdus = yield self._check_sigs_and_hash_and_fetch(
                    destination,
                    pdus.values(),
                    outlier=True,
                )

                valid_pdus_map = {p.event_id: p for p in valid_pdus}

                # NB: We *need* to copy to ensure that we don't have multiple
                # references being passed on, as that causes... issues.
                signed_state = [
                    copy.copy(valid_pdus_map[p.event_id]) for p in state
                    if p.event_id in valid_pdus_map
                ]

                signed_auth = [
                    valid_pdus_map[p.event_id] for p in auth_chain
                    if p.event_id in valid_pdus_map
                ]

                # NB: We *need* to copy to ensure that we don't have multiple
                # references being passed on, as that causes... issues.
                for s in signed_state:
                    s.internal_metadata = copy.deepcopy(s.internal_metadata)

                auth_chain.sort(key=lambda e: e.depth)

                defer.returnValue({
                    "state": signed_state,
                    "auth_chain": signed_auth,
                    "origin": destination,
                })
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception("Failed to send_join via %s: %s", destination,
                                 e.message)

        raise RuntimeError("Failed to send to any server.")

    @defer.inlineCallbacks
    def send_invite(self, destination, room_id, event_id, pdu):
        time_now = self._clock.time_msec()
        code, content = yield self.transport_layer.send_invite(
            destination=destination,
            room_id=room_id,
            event_id=event_id,
            content=pdu.get_pdu_json(time_now),
        )

        pdu_dict = content["event"]

        logger.debug("Got response to send_invite: %s", pdu_dict)

        pdu = self.event_from_pdu_json(pdu_dict)

        # Check signatures are correct.
        pdu = yield self._check_sigs_and_hash(pdu)

        # FIXME: We should handle signature failures more gracefully.

        defer.returnValue(pdu)

    @defer.inlineCallbacks
    def send_leave(self, destinations, pdu):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                time_now = self._clock.time_msec()
                _, content = yield self.transport_layer.send_leave(
                    destination=destination,
                    room_id=pdu.room_id,
                    event_id=pdu.event_id,
                    content=pdu.get_pdu_json(time_now),
                )

                logger.debug("Got content: %s", content)
                defer.returnValue(None)
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception("Failed to send_leave via %s: %s",
                                 destination, e.message)

        raise RuntimeError("Failed to send to any server.")

    @defer.inlineCallbacks
    def get_public_rooms(self, destinations):
        results_by_server = {}

        @defer.inlineCallbacks
        def _get_result(s):
            if s == self.server_name:
                defer.returnValue()

            try:
                result = yield self.transport_layer.get_public_rooms(s)
                results_by_server[s] = result
            except:
                logger.exception("Error getting room list from server %r", s)

        yield concurrently_execute(_get_result, destinations, 3)

        defer.returnValue(results_by_server)

    @defer.inlineCallbacks
    def query_auth(self, destination, room_id, event_id, local_auth):
        """
        Params:
            destination (str)
            event_it (str)
            local_auth (list)
        """
        time_now = self._clock.time_msec()

        send_content = {
            "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
        }

        code, content = yield self.transport_layer.send_query_auth(
            destination=destination,
            room_id=room_id,
            event_id=event_id,
            content=send_content,
        )

        auth_chain = [
            self.event_from_pdu_json(e) for e in content["auth_chain"]
        ]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(destination,
                                                                auth_chain,
                                                                outlier=True)

        signed_auth.sort(key=lambda e: e.depth)

        ret = {
            "auth_chain": signed_auth,
            "rejects": content.get("rejects", []),
            "missing": content.get("missing", []),
        }

        defer.returnValue(ret)

    @defer.inlineCallbacks
    def get_missing_events(self, destination, room_id, earliest_events_ids,
                           latest_events, limit, min_depth):
        """Tries to fetch events we are missing. This is called when we receive
        an event without having received all of its ancestors.

        Args:
            destination (str)
            room_id (str)
            earliest_events_ids (list): List of event ids. Effectively the
                events we expected to receive, but haven't. `get_missing_events`
                should only return events that didn't happen before these.
            latest_events (list): List of events we have received that we don't
                have all previous events for.
            limit (int): Maximum number of events to return.
            min_depth (int): Minimum depth of events tor return.
        """
        try:
            content = yield self.transport_layer.get_missing_events(
                destination=destination,
                room_id=room_id,
                earliest_events=earliest_events_ids,
                latest_events=[e.event_id for e in latest_events],
                limit=limit,
                min_depth=min_depth,
            )

            events = [
                self.event_from_pdu_json(e) for e in content.get("events", [])
            ]

            signed_events = yield self._check_sigs_and_hash_and_fetch(
                destination, events, outlier=False)

            have_gotten_all_from_destination = True
        except HttpResponseException as e:
            if not e.code == 400:
                raise

            # We are probably hitting an old server that doesn't support
            # get_missing_events
            signed_events = []
            have_gotten_all_from_destination = False

        if len(signed_events) >= limit:
            defer.returnValue(signed_events)

        servers = yield self.store.get_joined_hosts_for_room(room_id)

        servers = set(servers)
        servers.discard(self.server_name)

        failed_to_fetch = set()

        while len(signed_events) < limit:
            # Are we missing any?

            seen_events = set(earliest_events_ids)
            seen_events.update(e.event_id for e in signed_events if e)

            missing_events = {}
            for e in itertools.chain(latest_events, signed_events):
                if e.depth > min_depth:
                    missing_events.update({
                        e_id: e.depth
                        for e_id, _ in e.prev_events if e_id not in seen_events
                        and e_id not in failed_to_fetch
                    })

            if not missing_events:
                break

            have_seen = yield self.store.have_events(missing_events)

            for k in have_seen:
                missing_events.pop(k, None)

            if not missing_events:
                break

            # Okay, we haven't gotten everything yet. Lets get them.
            ordered_missing = sorted(missing_events.items(),
                                     key=lambda x: x[0])

            if have_gotten_all_from_destination:
                servers.discard(destination)

            def random_server_list():
                srvs = list(servers)
                random.shuffle(srvs)
                return srvs

            deferreds = [
                self.get_pdu(
                    destinations=random_server_list(),
                    event_id=e_id,
                )
                for e_id, depth in ordered_missing[:limit - len(signed_events)]
            ]

            res = yield defer.DeferredList(deferreds, consumeErrors=True)
            for (result, val), (e_id, _) in zip(res, ordered_missing):
                if result and val:
                    signed_events.append(val)
                else:
                    failed_to_fetch.add(e_id)

        defer.returnValue(signed_events)

    def event_from_pdu_json(self, pdu_json, outlier=False):
        event = FrozenEvent(pdu_json)

        event.internal_metadata.outlier = outlier

        return event

    @defer.inlineCallbacks
    def forward_third_party_invite(self, destinations, room_id, event_dict):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                yield self.transport_layer.exchange_third_party_invite(
                    destination=destination,
                    room_id=room_id,
                    event_dict=event_dict,
                )
                defer.returnValue(None)
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception(
                    "Failed to send_third_party_invite via %s: %s",
                    destination, e.message)

        raise RuntimeError("Failed to send to any server.")
Пример #54
0
class StateHandler(object):
    """ Responsible for doing state conflict resolution.
    """

    def __init__(self, hs):
        self.clock = hs.get_clock()
        self.store = hs.get_datastore()
        self.hs = hs

        # dict of set of event_ids -> _StateCacheEntry.
        self._state_cache = None
        self.resolve_linearizer = Linearizer()

    def start_caching(self):
        logger.debug("start_caching")

        self._state_cache = ExpiringCache(
            cache_name="state_cache",
            clock=self.clock,
            max_len=SIZE_OF_CACHE,
            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
            reset_expiry_on_get=True,
        )

        self._state_cache.start()

    @defer.inlineCallbacks
    def get_current_state(self, room_id, event_type=None, state_key="",
                          latest_event_ids=None):
        """ Retrieves the current state for the room. This is done by
        calling `get_latest_events_in_room` to get the leading edges of the
        event graph and then resolving any of the state conflicts.

        This is equivalent to getting the state of an event that were to send
        next before receiving any new events.

        If `event_type` is specified, then the method returns only the one
        event (or None) with that `event_type` and `state_key`.

        Returns:
            map from (type, state_key) to event
        """
        if not latest_event_ids:
            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)

        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
        state = ret.state

        if event_type:
            event_id = state.get((event_type, state_key))
            event = None
            if event_id:
                event = yield self.store.get_event(event_id, allow_none=True)
            defer.returnValue(event)
            return

        state_map = yield self.store.get_events(state.values(), get_prev_content=False)
        state = {
            key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
        }

        defer.returnValue(state)

    @defer.inlineCallbacks
    def get_current_state_ids(self, room_id, event_type=None, state_key="",
                              latest_event_ids=None):
        if not latest_event_ids:
            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)

        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
        state = ret.state

        if event_type:
            defer.returnValue(state.get((event_type, state_key)))
            return

        defer.returnValue(state)

    @defer.inlineCallbacks
    def get_current_user_in_room(self, room_id, latest_event_ids=None):
        if not latest_event_ids:
            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
        joined_users = yield self.store.get_joined_users_from_state(
            room_id, entry.state_id, entry.state
        )
        defer.returnValue(joined_users)

    @defer.inlineCallbacks
    def compute_event_context(self, event, old_state=None):
        """ Fills out the context with the `current state` of the graph. The
        `current state` here is defined to be the state of the event graph
        just before the event - i.e. it never includes `event`

        If `event` has `auth_events` then this will also fill out the
        `auth_events` field on `context` from the `current_state`.

        Args:
            event (EventBase)
        Returns:
            an EventContext
        """
        context = EventContext()

        if event.internal_metadata.is_outlier():
            # If this is an outlier, then we know it shouldn't have any current
            # state. Certainly store.get_current_state won't return any, and
            # persisting the event won't store the state group.
            if old_state:
                context.prev_state_ids = {
                    (s.type, s.state_key): s.event_id for s in old_state
                }
                if event.is_state():
                    context.current_state_events = dict(context.prev_state_ids)
                    key = (event.type, event.state_key)
                    context.current_state_events[key] = event.event_id
                else:
                    context.current_state_events = context.prev_state_ids
            else:
                context.current_state_ids = {}
                context.prev_state_ids = {}
            context.prev_state_events = []
            context.state_group = self.store.get_next_state_group()
            defer.returnValue(context)

        if old_state:
            context.prev_state_ids = {
                (s.type, s.state_key): s.event_id for s in old_state
            }
            context.state_group = self.store.get_next_state_group()

            if event.is_state():
                key = (event.type, event.state_key)
                if key in context.prev_state_ids:
                    replaces = context.prev_state_ids[key]
                    if replaces != event.event_id:  # Paranoia check
                        event.unsigned["replaces_state"] = replaces
                context.current_state_ids = dict(context.prev_state_ids)
                context.current_state_ids[key] = event.event_id
            else:
                context.current_state_ids = context.prev_state_ids

            context.prev_state_events = []
            defer.returnValue(context)

        if event.is_state():
            entry = yield self.resolve_state_groups(
                event.room_id, [e for e, _ in event.prev_events],
                event_type=event.type,
                state_key=event.state_key,
            )
        else:
            entry = yield self.resolve_state_groups(
                event.room_id, [e for e, _ in event.prev_events],
            )

        curr_state = entry.state

        context.prev_state_ids = curr_state
        if event.is_state():
            context.state_group = self.store.get_next_state_group()

            key = (event.type, event.state_key)
            if key in context.prev_state_ids:
                replaces = context.prev_state_ids[key]
                event.unsigned["replaces_state"] = replaces

            context.current_state_ids = dict(context.prev_state_ids)
            context.current_state_ids[key] = event.event_id

            context.prev_group = entry.prev_group
            context.delta_ids = entry.delta_ids
            if context.delta_ids is not None:
                context.delta_ids = dict(context.delta_ids)
                context.delta_ids[key] = event.event_id
        else:
            if entry.state_group is None:
                entry.state_group = self.store.get_next_state_group()
                entry.state_id = entry.state_group

            context.state_group = entry.state_group
            context.current_state_ids = context.prev_state_ids
            context.prev_group = entry.prev_group
            context.delta_ids = entry.delta_ids

        context.prev_state_events = []
        defer.returnValue(context)

    @defer.inlineCallbacks
    @log_function
    def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
        """ Given a list of event_ids this method fetches the state at each
        event, resolves conflicts between them and returns them.

        Returns:
            a Deferred tuple of (`state_group`, `state`, `prev_state`).
            `state_group` is the name of a state group if one and only one is
            involved. `state` is a map from (type, state_key) to event, and
            `prev_state` is a list of event ids.
        """
        logger.debug("resolve_state_groups event_ids %s", event_ids)

        state_groups_ids = yield self.store.get_state_groups_ids(
            room_id, event_ids
        )

        logger.debug(
            "resolve_state_groups state_groups %s",
            state_groups_ids.keys()
        )

        group_names = frozenset(state_groups_ids.keys())
        if len(group_names) == 1:
            name, state_list = state_groups_ids.items().pop()

            defer.returnValue(_StateCacheEntry(
                state=state_list,
                state_group=name,
                prev_group=name,
                delta_ids={},
            ))

        with (yield self.resolve_linearizer.queue(group_names)):
            if self._state_cache is not None:
                cache = self._state_cache.get(group_names, None)
                if cache:
                    defer.returnValue(cache)

            logger.info(
                "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
            )

            state = {}
            for st in state_groups_ids.values():
                for key, e_id in st.items():
                    state.setdefault(key, set()).add(e_id)

            conflicted_state = {
                k: list(v)
                for k, v in state.items()
                if len(v) > 1
            }

            if conflicted_state:
                logger.info("Resolving conflicted state for %r", room_id)
                state_map = yield self.store.get_events(
                    [e_id for st in state_groups_ids.values() for e_id in st.values()],
                    get_prev_content=False
                )
                state_sets = [
                    [state_map[e_id] for key, e_id in st.items() if e_id in state_map]
                    for st in state_groups_ids.values()
                ]
                new_state, _ = self._resolve_events(
                    state_sets, event_type, state_key
                )
                new_state = {
                    key: e.event_id for key, e in new_state.items()
                }
            else:
                new_state = {
                    key: e_ids.pop() for key, e_ids in state.items()
                }

            state_group = None
            new_state_event_ids = frozenset(new_state.values())
            for sg, events in state_groups_ids.items():
                if new_state_event_ids == frozenset(e_id for e_id in events):
                    state_group = sg
                    break
            if state_group is None:
                # Worker instances don't have access to this method, but we want
                # to set the state_group on the main instance to increase cache
                # hits.
                if hasattr(self.store, "get_next_state_group"):
                    state_group = self.store.get_next_state_group()

            prev_group = None
            delta_ids = None
            for old_group, old_ids in state_groups_ids.items():
                if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
                    n_delta_ids = {
                        k: v
                        for k, v in new_state.items()
                        if old_ids.get(k) != v
                    }
                    if not delta_ids or len(n_delta_ids) < len(delta_ids):
                        prev_group = old_group
                        delta_ids = n_delta_ids

            cache = _StateCacheEntry(
                state=new_state,
                state_group=state_group,
                prev_group=prev_group,
                delta_ids=delta_ids,
            )

            if self._state_cache is not None:
                self._state_cache[group_names] = cache

            defer.returnValue(cache)

    def resolve_events(self, state_sets, event):
        logger.info(
            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
        )
        if event.is_state():
            return self._resolve_events(
                state_sets, event.type, event.state_key
            )
        else:
            return self._resolve_events(state_sets)

    def _resolve_events(self, state_sets, event_type=None, state_key=""):
        """
        Returns
            (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
            (new_state, prev_states). new_state is a map from (type, state_key)
            to event. prev_states is a list of event_ids.
        """
        with Measure(self.clock, "state._resolve_events"):
            state = {}
            for st in state_sets:
                for e in st:
                    state.setdefault(
                        (e.type, e.state_key),
                        {}
                    )[e.event_id] = e

            unconflicted_state = {
                k: v.values()[0] for k, v in state.items()
                if len(v.values()) == 1
            }

            conflicted_state = {
                k: v.values()
                for k, v in state.items()
                if len(v.values()) > 1
            }

            if event_type:
                prev_states_events = conflicted_state.get(
                    (event_type, state_key), []
                )
                prev_states = [s.event_id for s in prev_states_events]
            else:
                prev_states = []

            auth_events = {
                k: e for k, e in unconflicted_state.items()
                if k[0] in AuthEventTypes
            }

            try:
                resolved_state = self._resolve_state_events(
                    conflicted_state, auth_events
                )
            except:
                logger.exception("Failed to resolve state")
                raise

            new_state = unconflicted_state
            new_state.update(resolved_state)

        return new_state, prev_states

    @log_function
    def _resolve_state_events(self, conflicted_state, auth_events):
        """ This is where we actually decide which of the conflicted state to
        use.

        We resolve conflicts in the following order:
            1. power levels
            2. join rules
            3. memberships
            4. other events.
        """
        resolved_state = {}
        power_key = (EventTypes.PowerLevels, "")
        if power_key in conflicted_state:
            events = conflicted_state[power_key]
            logger.debug("Resolving conflicted power levels %r", events)
            resolved_state[power_key] = self._resolve_auth_events(
                events, auth_events)

        auth_events.update(resolved_state)

        for key, events in conflicted_state.items():
            if key[0] == EventTypes.JoinRules:
                logger.debug("Resolving conflicted join rules %r", events)
                resolved_state[key] = self._resolve_auth_events(
                    events,
                    auth_events
                )

        auth_events.update(resolved_state)

        for key, events in conflicted_state.items():
            if key[0] == EventTypes.Member:
                logger.debug("Resolving conflicted member lists %r", events)
                resolved_state[key] = self._resolve_auth_events(
                    events,
                    auth_events
                )

        auth_events.update(resolved_state)

        for key, events in conflicted_state.items():
            if key not in resolved_state:
                logger.debug("Resolving conflicted state %r:%r", key, events)
                resolved_state[key] = self._resolve_normal_events(
                    events, auth_events
                )

        return resolved_state

    def _resolve_auth_events(self, events, auth_events):
        reverse = [i for i in reversed(self._ordered_events(events))]

        auth_events = dict(auth_events)

        prev_event = reverse[0]
        for event in reverse[1:]:
            auth_events[(prev_event.type, prev_event.state_key)] = prev_event
            try:
                # FIXME: hs.get_auth() is bad style, but we need to do it to
                # get around circular deps.
                # The signatures have already been checked at this point
                self.hs.get_auth().check(event, auth_events, do_sig_check=False)
                prev_event = event
            except AuthError:
                return prev_event

        return event

    def _resolve_normal_events(self, events, auth_events):
        for event in self._ordered_events(events):
            try:
                # FIXME: hs.get_auth() is bad style, but we need to do it to
                # get around circular deps.
                # The signatures have already been checked at this point
                self.hs.get_auth().check(event, auth_events, do_sig_check=False)
                return event
            except AuthError:
                pass

        # Use the last event (the one with the least depth) if they all fail
        # the auth check.
        return event

    def _ordered_events(self, events):
        def key_func(e):
            return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()

        return sorted(events, key=key_func)
Пример #55
0
class DeviceListEduUpdater(object):
    "Handles incoming device list updates from federation and updates the DB"

    def __init__(self, hs, device_handler):
        self.store = hs.get_datastore()
        self.federation = hs.get_federation_client()
        self.clock = hs.get_clock()
        self.device_handler = device_handler

        self._remote_edu_linearizer = Linearizer(name="remote_device_list")

        # user_id -> list of updates waiting to be handled.
        self._pending_updates = {}

        # Recently seen stream ids. We don't bother keeping these in the DB,
        # but they're useful to have them about to reduce the number of spurious
        # resyncs.
        self._seen_updates = ExpiringCache(
            cache_name="device_update_edu",
            clock=self.clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
            iterable=True,
        )

    @defer.inlineCallbacks
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]   # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning(
                "Got device list update edu for %r/%r from %r",
                user_id, device_id, origin,
            )
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            logger.warning(
                "Got device list update edu for %r/%r, but don't share a room",
                user_id, device_id,
            )
            return

        logger.debug(
            "Received device list update for %r/%r", user_id, device_id,
        )

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)

    @measure_func("_incoming_device_list_update")
    @defer.inlineCallbacks
    def _handle_device_updates(self, user_id):
        "Actually handle pending updates."

        with (yield self._remote_edu_linearizer.queue(user_id)):
            pending_updates = self._pending_updates.pop(user_id, [])
            if not pending_updates:
                # This can happen since we batch updates
                return

            for device_id, stream_id, prev_ids, content in pending_updates:
                logger.debug(
                    "Handling update %r/%r, ID: %r, prev: %r ",
                    user_id, device_id, stream_id, prev_ids,
                )

            # Given a list of updates we check if we need to resync. This
            # happens if we've missed updates.
            resync = yield self._need_to_do_resync(user_id, pending_updates)

            logger.debug("Need to re-sync devices for %r? %r", user_id, resync)

            if resync:
                # Fetch all devices for the user.
                origin = get_domain_from_id(user_id)
                try:
                    result = yield self.federation.query_user_devices(origin, user_id)
                except (
                    NotRetryingDestination, RequestSendFailed, HttpResponseException,
                ):
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.warn(
                        "Failed to handle device list update for %s", user_id,
                    )
                    # We abort on exceptions rather than accepting the update
                    # as otherwise synapse will 'forget' that its device list
                    # is out of date. If we bail then we will retry the resync
                    # next time we get a device list update for this user_id.
                    # This makes it more likely that the device lists will
                    # eventually become consistent.
                    return
                except FederationDeniedError as e:
                    logger.info(e)
                    return
                except Exception:
                    # TODO: Remember that we are now out of sync and try again
                    # later
                    logger.exception(
                        "Failed to handle device list update for %s", user_id
                    )
                    return

                stream_id = result["stream_id"]
                devices = result["devices"]

                # If the remote server has more than ~1000 devices for this user
                # we assume that something is going horribly wrong (e.g. a bot
                # that logs in and creates a new device every time it tries to
                # send a message).  Maintaining lots of devices per user in the
                # cache can cause serious performance issues as if this request
                # takes more than 60s to complete, internal replication from the
                # inbound federation worker to the synapse master may time out
                # causing the inbound federation to fail and causing the remote
                # server to retry, causing a DoS.  So in this scenario we give
                # up on storing the total list of devices and only handle the
                # delta instead.
                if len(devices) > 1000:
                    logger.warn(
                        "Ignoring device list snapshot for %s as it has >1K devs (%d)",
                        user_id, len(devices)
                    )
                    devices = []

                for device in devices:
                    logger.debug(
                        "Handling resync update %r/%r, ID: %r",
                        user_id, device["device_id"], stream_id,
                    )

                yield self.store.update_remote_device_list_cache(
                    user_id, devices, stream_id,
                )
                device_ids = [device["device_id"] for device in devices]
                yield self.device_handler.notify_device_update(user_id, device_ids)

                # We clobber the seen updates since we've re-synced from a given
                # point.
                self._seen_updates[user_id] = set([stream_id])
            else:
                # Simply update the single device, since we know that is the only
                # change (because of the single prev_id matching the current cache)
                for device_id, stream_id, prev_ids, content in pending_updates:
                    yield self.store.update_remote_device_list_cache_entry(
                        user_id, device_id, content, stream_id,
                    )

                yield self.device_handler.notify_device_update(
                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                )

                self._seen_updates.setdefault(user_id, set()).update(
                    stream_id for _, stream_id, _, _ in pending_updates
                )

    @defer.inlineCallbacks
    def _need_to_do_resync(self, user_id, updates):
        """Given a list of updates for a user figure out if we need to do a full
        resync, or whether we have enough data that we can just apply the delta.
        """
        seen_updates = self._seen_updates.get(user_id, set())

        extremity = yield self.store.get_device_list_last_stream_id_for_remote(
            user_id
        )

        logger.debug(
            "Current extremity for %r: %r",
            user_id, extremity,
        )

        stream_id_in_updates = set()  # stream_ids in updates list
        for _, stream_id, prev_ids, _ in updates:
            if not prev_ids:
                # We always do a resync if there are no previous IDs
                defer.returnValue(True)

            for prev_id in prev_ids:
                if prev_id == extremity:
                    continue
                elif prev_id in seen_updates:
                    continue
                elif prev_id in stream_id_in_updates:
                    continue
                else:
                    defer.returnValue(True)

            stream_id_in_updates.add(stream_id)

        defer.returnValue(False)
Пример #56
0
class FederationClient(FederationBase):

    def start_get_pdu_cache(self):
        self._get_pdu_cache = ExpiringCache(
            cache_name="get_pdu_cache",
            clock=self._clock,
            max_len=1000,
            expiry_ms=120 * 1000,
            reset_expiry_on_get=False,
        )

        self._get_pdu_cache.start()

    @log_function
    def send_pdu(self, pdu, destinations):
        """Informs the replication layer about a new PDU generated within the
        home server that should be transmitted to others.

        TODO: Figure out when we should actually resolve the deferred.

        Args:
            pdu (Pdu): The new Pdu.

        Returns:
            Deferred: Completes when we have successfully processed the PDU
            and replicated it to any interested remote home servers.
        """
        order = self._order
        self._order += 1

        sent_pdus_destination_dist.inc_by(len(destinations))

        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)

        # TODO, add errback, etc.
        self._transaction_queue.enqueue_pdu(pdu, destinations, order)

        logger.debug(
            "[%s] transaction_layer.enqueue_pdu... done",
            pdu.event_id
        )

    @log_function
    def send_edu(self, destination, edu_type, content):
        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        sent_edus_counter.inc()

        # TODO, add errback, etc.
        self._transaction_queue.enqueue_edu(edu)
        return defer.succeed(None)

    @log_function
    def send_failure(self, failure, destination):
        self._transaction_queue.enqueue_failure(failure, destination)
        return defer.succeed(None)

    @log_function
    def make_query(self, destination, query_type, args,
                   retry_on_dns_fail=False):
        """Sends a federation Query to a remote homeserver of the given type
        and arguments.

        Args:
            destination (str): Domain name of the remote homeserver
            query_type (str): Category of the query type; should match the
                handler name used in register_query_handler().
            args (dict): Mapping of strings to strings containing the details
                of the query request.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.inc(query_type)

        return self.transport_layer.make_query(
            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
        )

    @log_function
    def query_client_keys(self, destination, content):
        """Query device keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.inc("client_device_keys")
        return self.transport_layer.query_client_keys(destination, content)

    @log_function
    def claim_client_keys(self, destination, content):
        """Claims one-time keys for a device hosted on a remote server.

        Args:
            destination (str): Domain name of the remote homeserver
            content (dict): The query content.

        Returns:
            a Deferred which will eventually yield a JSON object from the
            response
        """
        sent_queries_counter.inc("client_one_time_keys")
        return self.transport_layer.claim_client_keys(destination, content)

    @defer.inlineCallbacks
    @log_function
    def backfill(self, dest, context, limit, extremities):
        """Requests some more historic PDUs for the given context from the
        given destination server.

        Args:
            dest (str): The remote home server to ask.
            context (str): The context to backfill.
            limit (int): The maximum number of PDUs to return.
            extremities (list): List of PDU id and origins of the first pdus
                we have seen from the context

        Returns:
            Deferred: Results in the received PDUs.
        """
        logger.debug("backfill extrem=%s", extremities)

        # If there are no extremeties then we've (probably) reached the start.
        if not extremities:
            return

        transaction_data = yield self.transport_layer.backfill(
            dest, context, extremities, limit)

        logger.debug("backfill transaction_data=%s", repr(transaction_data))

        pdus = [
            self.event_from_pdu_json(p, outlier=False)
            for p in transaction_data["pdus"]
        ]

        # FIXME: We should handle signature failures more gracefully.
        pdus[:] = yield defer.gatherResults(
            self._check_sigs_and_hashes(pdus),
            consumeErrors=True,
        ).addErrback(unwrapFirstError)

        defer.returnValue(pdus)

    @defer.inlineCallbacks
    @log_function
    def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
        """Requests the PDU with given origin and ID from the remote home
        servers.

        Will attempt to get the PDU from each destination in the list until
        one succeeds.

        This will persist the PDU locally upon receipt.

        Args:
            destinations (list): Which home servers to query
            pdu_origin (str): The home server that originally sent the pdu.
            event_id (str)
            outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                it's from an arbitary point in the context as opposed to part
                of the current block of PDUs. Defaults to `False`
            timeout (int): How long to try (in ms) each destination for before
                moving to the next destination. None indicates no timeout.

        Returns:
            Deferred: Results in the requested PDU.
        """

        # TODO: Rate limit the number of times we try and get the same event.

        if self._get_pdu_cache:
            e = self._get_pdu_cache.get(event_id)
            if e:
                defer.returnValue(e)

        pdu = None
        for destination in destinations:
            try:
                limiter = yield get_retry_limiter(
                    destination,
                    self._clock,
                    self.store,
                )

                with limiter:
                    transaction_data = yield self.transport_layer.get_event(
                        destination, event_id, timeout=timeout,
                    )

                    logger.debug("transaction_data %r", transaction_data)

                    pdu_list = [
                        self.event_from_pdu_json(p, outlier=outlier)
                        for p in transaction_data["pdus"]
                    ]

                    if pdu_list and pdu_list[0]:
                        pdu = pdu_list[0]

                        # Check signatures are correct.
                        pdu = yield self._check_sigs_and_hashes([pdu])[0]

                        break

            except SynapseError:
                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id, destination, e,
                )
                continue
            except CodeMessageException as e:
                if 400 <= e.code < 500:
                    raise

                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id, destination, e,
                )
                continue
            except NotRetryingDestination as e:
                logger.info(e.message)
                continue
            except Exception as e:
                logger.info(
                    "Failed to get PDU %s from %s because %s",
                    event_id, destination, e,
                )
                continue

        if self._get_pdu_cache is not None and pdu:
            self._get_pdu_cache[event_id] = pdu

        defer.returnValue(pdu)

    @defer.inlineCallbacks
    @log_function
    def get_state_for_room(self, destination, room_id, event_id):
        """Requests all of the `current` state PDUs for a given room from
        a remote home server.

        Args:
            destination (str): The remote homeserver to query for the state.
            room_id (str): The id of the room we're interested in.
            event_id (str): The id of the event we want the state at.

        Returns:
            Deferred: Results in a list of PDUs.
        """

        result = yield self.transport_layer.get_room_state(
            destination, room_id, event_id=event_id,
        )

        pdus = [
            self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
        ]

        auth_chain = [
            self.event_from_pdu_json(p, outlier=True)
            for p in result.get("auth_chain", [])
        ]

        signed_pdus = yield self._check_sigs_and_hash_and_fetch(
            destination, pdus, outlier=True
        )

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination, auth_chain, outlier=True
        )

        signed_auth.sort(key=lambda e: e.depth)

        defer.returnValue((signed_pdus, signed_auth))

    @defer.inlineCallbacks
    @log_function
    def get_event_auth(self, destination, room_id, event_id):
        res = yield self.transport_layer.get_event_auth(
            destination, room_id, event_id,
        )

        auth_chain = [
            self.event_from_pdu_json(p, outlier=True)
            for p in res["auth_chain"]
        ]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination, auth_chain, outlier=True
        )

        signed_auth.sort(key=lambda e: e.depth)

        defer.returnValue(signed_auth)

    @defer.inlineCallbacks
    def make_membership_event(self, destinations, room_id, user_id, membership,
                              content={},):
        """
        Creates an m.room.member event, with context, without participating in the room.

        Does so by asking one of the already participating servers to create an
        event with proper context.

        Note that this does not append any events to any graphs.

        Args:
            destinations (str): Candidate homeservers which are probably
                participating in the room.
            room_id (str): The room in which the event will happen.
            user_id (str): The user whose membership is being evented.
            membership (str): The "membership" property of the event. Must be
                one of "join" or "leave".
            content (object): Any additional data to put into the content field
                of the event.
        Return:
            A tuple of (origin (str), event (object)) where origin is the remote
            homeserver which generated the event.
        """
        valid_memberships = {Membership.JOIN, Membership.LEAVE}
        if membership not in valid_memberships:
            raise RuntimeError(
                "make_membership_event called with membership='%s', must be one of %s" %
                (membership, ",".join(valid_memberships))
            )
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                ret = yield self.transport_layer.make_membership_event(
                    destination, room_id, user_id, membership
                )

                pdu_dict = ret["event"]

                logger.debug("Got response to make_%s: %s", membership, pdu_dict)

                pdu_dict["content"].update(content)

                # The protoevent received over the JSON wire may not have all
                # the required fields. Lets just gloss over that because
                # there's some we never care about
                if "prev_state" not in pdu_dict:
                    pdu_dict["prev_state"] = []

                defer.returnValue(
                    (destination, self.event_from_pdu_json(pdu_dict))
                )
                break
            except CodeMessageException:
                raise
            except Exception as e:
                logger.warn(
                    "Failed to make_%s via %s: %s",
                    membership, destination, e.message
                )
                raise

        raise RuntimeError("Failed to send to any server.")

    @defer.inlineCallbacks
    def send_join(self, destinations, pdu):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                time_now = self._clock.time_msec()
                _, content = yield self.transport_layer.send_join(
                    destination=destination,
                    room_id=pdu.room_id,
                    event_id=pdu.event_id,
                    content=pdu.get_pdu_json(time_now),
                )

                logger.debug("Got content: %s", content)

                state = [
                    self.event_from_pdu_json(p, outlier=True)
                    for p in content.get("state", [])
                ]

                auth_chain = [
                    self.event_from_pdu_json(p, outlier=True)
                    for p in content.get("auth_chain", [])
                ]

                pdus = {
                    p.event_id: p
                    for p in itertools.chain(state, auth_chain)
                }

                valid_pdus = yield self._check_sigs_and_hash_and_fetch(
                    destination, pdus.values(),
                    outlier=True,
                )

                valid_pdus_map = {
                    p.event_id: p
                    for p in valid_pdus
                }

                # NB: We *need* to copy to ensure that we don't have multiple
                # references being passed on, as that causes... issues.
                signed_state = [
                    copy.copy(valid_pdus_map[p.event_id])
                    for p in state
                    if p.event_id in valid_pdus_map
                ]

                signed_auth = [
                    valid_pdus_map[p.event_id]
                    for p in auth_chain
                    if p.event_id in valid_pdus_map
                ]

                # NB: We *need* to copy to ensure that we don't have multiple
                # references being passed on, as that causes... issues.
                for s in signed_state:
                    s.internal_metadata = copy.deepcopy(s.internal_metadata)

                auth_chain.sort(key=lambda e: e.depth)

                defer.returnValue({
                    "state": signed_state,
                    "auth_chain": signed_auth,
                    "origin": destination,
                })
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception(
                    "Failed to send_join via %s: %s",
                    destination, e.message
                )

        raise RuntimeError("Failed to send to any server.")

    @defer.inlineCallbacks
    def send_invite(self, destination, room_id, event_id, pdu):
        time_now = self._clock.time_msec()
        code, content = yield self.transport_layer.send_invite(
            destination=destination,
            room_id=room_id,
            event_id=event_id,
            content=pdu.get_pdu_json(time_now),
        )

        pdu_dict = content["event"]

        logger.debug("Got response to send_invite: %s", pdu_dict)

        pdu = self.event_from_pdu_json(pdu_dict)

        # Check signatures are correct.
        pdu = yield self._check_sigs_and_hash(pdu)

        # FIXME: We should handle signature failures more gracefully.

        defer.returnValue(pdu)

    @defer.inlineCallbacks
    def send_leave(self, destinations, pdu):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                time_now = self._clock.time_msec()
                _, content = yield self.transport_layer.send_leave(
                    destination=destination,
                    room_id=pdu.room_id,
                    event_id=pdu.event_id,
                    content=pdu.get_pdu_json(time_now),
                )

                logger.debug("Got content: %s", content)
                defer.returnValue(None)
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception(
                    "Failed to send_leave via %s: %s",
                    destination, e.message
                )

        raise RuntimeError("Failed to send to any server.")

    @defer.inlineCallbacks
    def query_auth(self, destination, room_id, event_id, local_auth):
        """
        Params:
            destination (str)
            event_it (str)
            local_auth (list)
        """
        time_now = self._clock.time_msec()

        send_content = {
            "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
        }

        code, content = yield self.transport_layer.send_query_auth(
            destination=destination,
            room_id=room_id,
            event_id=event_id,
            content=send_content,
        )

        auth_chain = [
            self.event_from_pdu_json(e)
            for e in content["auth_chain"]
        ]

        signed_auth = yield self._check_sigs_and_hash_and_fetch(
            destination, auth_chain, outlier=True
        )

        signed_auth.sort(key=lambda e: e.depth)

        ret = {
            "auth_chain": signed_auth,
            "rejects": content.get("rejects", []),
            "missing": content.get("missing", []),
        }

        defer.returnValue(ret)

    @defer.inlineCallbacks
    def get_missing_events(self, destination, room_id, earliest_events_ids,
                           latest_events, limit, min_depth):
        """Tries to fetch events we are missing. This is called when we receive
        an event without having received all of its ancestors.

        Args:
            destination (str)
            room_id (str)
            earliest_events_ids (list): List of event ids. Effectively the
                events we expected to receive, but haven't. `get_missing_events`
                should only return events that didn't happen before these.
            latest_events (list): List of events we have received that we don't
                have all previous events for.
            limit (int): Maximum number of events to return.
            min_depth (int): Minimum depth of events tor return.
        """
        try:
            content = yield self.transport_layer.get_missing_events(
                destination=destination,
                room_id=room_id,
                earliest_events=earliest_events_ids,
                latest_events=[e.event_id for e in latest_events],
                limit=limit,
                min_depth=min_depth,
            )

            events = [
                self.event_from_pdu_json(e)
                for e in content.get("events", [])
            ]

            signed_events = yield self._check_sigs_and_hash_and_fetch(
                destination, events, outlier=False
            )

            have_gotten_all_from_destination = True
        except HttpResponseException as e:
            if not e.code == 400:
                raise

            # We are probably hitting an old server that doesn't support
            # get_missing_events
            signed_events = []
            have_gotten_all_from_destination = False

        if len(signed_events) >= limit:
            defer.returnValue(signed_events)

        servers = yield self.store.get_joined_hosts_for_room(room_id)

        servers = set(servers)
        servers.discard(self.server_name)

        failed_to_fetch = set()

        while len(signed_events) < limit:
            # Are we missing any?

            seen_events = set(earliest_events_ids)
            seen_events.update(e.event_id for e in signed_events if e)

            missing_events = {}
            for e in itertools.chain(latest_events, signed_events):
                if e.depth > min_depth:
                    missing_events.update({
                        e_id: e.depth for e_id, _ in e.prev_events
                        if e_id not in seen_events
                        and e_id not in failed_to_fetch
                    })

            if not missing_events:
                break

            have_seen = yield self.store.have_events(missing_events)

            for k in have_seen:
                missing_events.pop(k, None)

            if not missing_events:
                break

            # Okay, we haven't gotten everything yet. Lets get them.
            ordered_missing = sorted(missing_events.items(), key=lambda x: x[0])

            if have_gotten_all_from_destination:
                servers.discard(destination)

            def random_server_list():
                srvs = list(servers)
                random.shuffle(srvs)
                return srvs

            deferreds = [
                self.get_pdu(
                    destinations=random_server_list(),
                    event_id=e_id,
                )
                for e_id, depth in ordered_missing[:limit - len(signed_events)]
            ]

            res = yield defer.DeferredList(deferreds, consumeErrors=True)
            for (result, val), (e_id, _) in zip(res, ordered_missing):
                if result and val:
                    signed_events.append(val)
                else:
                    failed_to_fetch.add(e_id)

        defer.returnValue(signed_events)

    def event_from_pdu_json(self, pdu_json, outlier=False):
        event = FrozenEvent(
            pdu_json
        )

        event.internal_metadata.outlier = outlier

        return event

    @defer.inlineCallbacks
    def forward_third_party_invite(self, destinations, room_id, event_dict):
        for destination in destinations:
            if destination == self.server_name:
                continue

            try:
                yield self.transport_layer.exchange_third_party_invite(
                    destination=destination,
                    room_id=room_id,
                    event_dict=event_dict,
                )
                defer.returnValue(None)
            except CodeMessageException:
                raise
            except Exception as e:
                logger.exception(
                    "Failed to send_third_party_invite via %s: %s",
                    destination, e.message
                )

        raise RuntimeError("Failed to send to any server.")