def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() self.profile_handler = hs.get_profile_handler() self.event_builder_factory = hs.get_event_builder_factory() self.server_name = hs.hostname self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config self.http_client = hs.get_simple_http_client() # This is only used to get at ratelimit function, and maybe_kick_guest_users self.base_handler = BaseHandler(hs) self.pusher_pool = hs.get_pusherpool() # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") self.action_generator = hs.get_action_generator() self.spam_checker = hs.get_spam_checker() if self.config.block_events_without_consent_error is not None: self._consent_uri_builder = ConsentURIBuilder(self.config)
def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.backup_base_path = hs.config.backup_media_store_path self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.clock.looping_call(self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS)
def test_cancellation(self): linearizer = Linearizer() key = object() d1 = linearizer.queue(key) cm1 = yield d1 d2 = linearizer.queue(key) self.assertFalse(d2.called) d3 = linearizer.queue(key) self.assertFalse(d3.called) d2.cancel() with cm1: pass self.assertTrue(d2.called) try: yield d2 self.fail("Expected d2 to raise CancelledError") except CancelledError: pass with (yield d3): pass
def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() self.hs = hs # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None self.resolve_linearizer = Linearizer()
def __init__(self, hs): super(FederationServer, self).__init__(hs) self.auth = hs.get_auth() self._server_linearizer = Linearizer("fed_server") # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) self.member_linearizer = Linearizer() self.clock = hs.get_clock() self.distributor = hs.get_distributor() self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room")
def __init__(self, store, room_id): self.store = store self.room_id = room_id self.hosts_to_joined_users = {} self.state_group = object() self.linearizer = Linearizer("_JoinedHostsCache") self._len = 0
def __init__(self, hs): super(FederationServer, self).__init__(hs) self.auth = hs.get_auth() self._room_pdu_linearizer = Linearizer() self._server_linearizer = Linearizer() # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def __init__(self, hs, replication_client): self.store = hs.get_datastore() self.federation_sender = hs.get_federation_sender() self.replication_client = replication_client self.federation_position = self.store.federation_out_pos_startup self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") self._last_ack = self.federation_position self._room_serials = {} self._room_typing = {}
def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) self.profile_handler = hs.get_profile_handler() self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() self.distributor = hs.get_distributor() self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room")
def __init__(self, hs): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None self.macaroon_gen = hs.get_macaroon_generator() self._generate_user_id_linearizer = Linearizer( name="_generate_user_id_linearizer", )
def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.hs, self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, )
def test_linearizer(self): linearizer = Linearizer() key = object() d1 = linearizer.queue(key) cm1 = yield d1 d2 = linearizer.queue(key) self.assertFalse(d2.called) with cm1: self.assertFalse(d2.called) with (yield d2): pass
def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): """ Args: hs (HomeServer) room_id (str) rules_for_room_cache(Cache): The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics (CacheMetric) """ self.room_id = room_id self.is_mine_id = hs.is_mine_id self.store = hs.get_datastore() self.room_push_rule_cache_metrics = room_push_rule_cache_metrics self.linearizer = Linearizer(name="rules_for_room") self.member_map = {} # event_id -> (user_id, state) self.rules_by_user = {} # user_id -> rules # The last state group we updated the caches for. If the state_group of # a new event comes along, we know that we can just return the cached # result. # On invalidation of the rules themselves (if the user changes them), # we invalidate everything and set state_group to `object()` self.state_group = object() # A sequence number to keep track of when we're allowed to update the # cache. We bump the sequence number when we invalidate the cache. If # the sequence number changes while we're calculating stuff we should # not update the cache with it. self.sequence = 0 # A cache of user_ids that we *know* aren't interesting, e.g. user_ids # owned by AS's, or remote users, etc. (I.e. users we will never need to # calculate push for) # These never need to be invalidated as we will never set up push for # them. self.uninteresting_user_set = set() # We need to be clever on the invalidating caches callbacks, as # otherwise the invalidation callback holds a reference to the object, # potentially causing it to leak. # To get around this we pass a function that on invalidations looks ups # the RoomsForUser entry in the cache, rather than keeping a reference # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
class FederationSenderHandler(object): """Processes the replication stream and forwards the appropriate entries to the federation sender. """ def __init__(self, hs, replication_client): self.store = hs.get_datastore() self.federation_sender = hs.get_federation_sender() self.replication_client = replication_client self.federation_position = self.store.federation_out_pos_startup self._fed_position_linearizer = Linearizer( name="_fed_position_linearizer") self._last_ack = self.federation_position self._room_serials = {} self._room_typing = {} def on_start(self): # There may be some events that are persisted but haven't been sent, # so send them now. self.federation_sender.notify_new_events( self.store.get_room_max_stream_ordering()) def stream_positions(self): return {"federation": self.federation_position} def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": send_queue.process_rows_for_federation(self.federation_sender, rows) run_in_background(self.update_token, token) # We also need to poke the federation sender when new events happen elif stream_name == "events": self.federation_sender.notify_new_events(token) @defer.inlineCallbacks def update_token(self, token): try: self.federation_position = token # We linearize here to ensure we don't have races updating the token with (yield self._fed_position_linearizer.queue(None)): if self._last_ack < self.federation_position: yield self.store.update_federation_out_pos( "federation", self.federation_position) # We ACK this token over replication so that the master can drop # its in memory queues self.replication_client.send_federation_ack( self.federation_position) self._last_ack = self.federation_position except Exception: logger.exception("Error updating federation stream position")
def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config self.simple_http_client = hs.get_simple_http_client() self.federation_handler = hs.get_handlers().federation_handler self.directory_handler = hs.get_handlers().directory_handler self.registration_handler = hs.get_handlers().registration_handler self.profile_handler = hs.get_profile_handler() self.event_creation_hander = hs.get_event_creation_handler() self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker()
class FederationSenderHandler(object): """Processes the replication stream and forwards the appropriate entries to the federation sender. """ def __init__(self, hs, replication_client): self.store = hs.get_datastore() self.federation_sender = hs.get_federation_sender() self.replication_client = replication_client self.federation_position = self.store.federation_out_pos_startup self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") self._last_ack = self.federation_position self._room_serials = {} self._room_typing = {} def on_start(self): # There may be some events that are persisted but haven't been sent, # so send them now. self.federation_sender.notify_new_events( self.store.get_room_max_stream_ordering() ) def stream_positions(self): return {"federation": self.federation_position} def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": send_queue.process_rows_for_federation(self.federation_sender, rows) run_in_background(self.update_token, token) # We also need to poke the federation sender when new events happen elif stream_name == "events": self.federation_sender.notify_new_events(token) @defer.inlineCallbacks def update_token(self, token): try: self.federation_position = token # We linearize here to ensure we don't have races updating the token with (yield self._fed_position_linearizer.queue(None)): if self._last_ack < self.federation_position: yield self.store.update_federation_out_pos( "federation", self.federation_position ) # We ACK this token over replication so that the master can drop # its in memory queues self.replication_client.send_federation_ack(self.federation_position) self._last_ack = self.federation_position except Exception: logger.exception("Error updating federation stream position")
def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): """ super(RegistrationHandler, self).__init__(hs) self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None self.macaroon_gen = hs.get_macaroon_generator() self._generate_user_id_linearizer = Linearizer( name="_generate_user_id_linearizer", ) self._server_notices_mxid = hs.config.server_notices_mxid
def __init__(self, hs, device_handler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. self._pending_updates = {} # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious # resyncs. self._seen_updates = ExpiringCache( cache_name="device_update_edu", clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, )
def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): """ self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config self.simple_http_client = hs.get_simple_http_client() self.federation_handler = hs.get_handlers().federation_handler self.directory_handler = hs.get_handlers().directory_handler self.registration_handler = hs.get_handlers().registration_handler self.profile_handler = hs.get_profile_handler() self.event_creation_hander = hs.get_event_creation_handler() self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() self._server_notices_mxid = self.config.server_notices_mxid
def test_multiple_entries(self): limiter = Linearizer(max_count=3) key = object() d1 = limiter.queue(key) cm1 = yield d1 d2 = limiter.queue(key) cm2 = yield d2 d3 = limiter.queue(key) cm3 = yield d3 d4 = limiter.queue(key) self.assertFalse(d4.called) d5 = limiter.queue(key) self.assertFalse(d5.called) with cm1: self.assertFalse(d4.called) self.assertFalse(d5.called) cm4 = yield d4 self.assertFalse(d5.called) with cm3: self.assertFalse(d5.called) cm5 = yield d5 with cm2: pass with cm4: pass with cm5: pass d6 = limiter.queue(key) with (yield d6): pass
def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, )
def __init__(self, hs): super(DeviceHandler, self).__init__(hs) self.hs = hs self.state = hs.get_state_handler() self.federation_sender = hs.get_federation_sender() self.federation = hs.get_replication_layer() self._remote_edue_linearizer = Linearizer(name="remote_device_list") self.federation.register_edu_handler( "m.device_list_update", self._incoming_device_list_update, ) self.federation.register_query_handler( "user_devices", self.on_federation_query_user_devices, ) hs.get_distributor().observe("user_left_room", self.user_left_room)
def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = MediaFilePaths(hs.config.media_store_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer() self.recently_accessed_remotes = set() self.clock.looping_call( self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS )
class ReadMarkerHandler(BaseHandler): def __init__(self, hs): super(ReadMarkerHandler, self).__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() @defer.inlineCallbacks def received_client_read_marker(self, room_id, user_id, event_id): """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. This uses a notifier to indicate that account data should be sent down /sync if the read marker has changed. """ with (yield self.read_marker_linearizer.queue((room_id, user_id))): account_data = yield self.store.get_account_data_for_room(user_id, room_id) existing_read_marker = account_data.get("m.fully_read", None) should_update = True if existing_read_marker: # Only update if the new marker is ahead in the stream should_update = yield self.store.is_event_after( event_id, existing_read_marker['event_id'] ) if should_update: content = { "event_id": event_id } max_id = yield self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
def test_lots_of_queued_things(self): # we have one slow thing, and lots of fast things queued up behind it. # it should *not* explode the stack. linearizer = Linearizer() @defer.inlineCallbacks def func(i, sleep=False): with logcontext.LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): self.assertEqual( logcontext.LoggingContext.current_context(), lc) if sleep: yield async .sleep(0) self.assertEqual(logcontext.LoggingContext.current_context(), lc) func(0, sleep=True) for i in range(1, 100): func(i) return func(1000)
class ReadMarkerHandler(BaseHandler): def __init__(self, hs): super(ReadMarkerHandler, self).__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier() @defer.inlineCallbacks def received_client_read_marker(self, room_id, user_id, event_id): """Updates the read marker for a given user in a given room if the event ID given is ahead in the stream relative to the current read marker. This uses a notifier to indicate that account data should be sent down /sync if the read marker has changed. """ with (yield self.read_marker_linearizer.queue((room_id, user_id))): existing_read_marker = yield self.store.get_account_data_for_room_and_type( user_id, room_id, "m.fully_read", ) should_update = True if existing_read_marker: # Only update if the new marker is ahead in the stream should_update = yield self.store.is_event_after( event_id, existing_read_marker['event_id'] ) if should_update: content = { "event_id": event_id } max_id = yield self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
class DeviceListEduUpdater(object): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs, device_handler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. self._pending_updates = {} # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious # resyncs. self._seen_updates = ExpiringCache( cache_name="device_update_edu", clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, ) @defer.inlineCallbacks def incoming_device_list_update(self, origin, edu_content): """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ user_id = edu_content.pop("user_id") device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? logger.warning("Got device list update edu for %r from %r", user_id, origin) return room_ids = yield self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. return self._pending_updates.setdefault(user_id, []).append( (device_id, stream_id, prev_ids, edu_content) ) yield self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") @defer.inlineCallbacks def _handle_device_updates(self, user_id): "Actually handle pending updates." with (yield self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates return # Given a list of updates we check if we need to resync. This # happens if we've missed updates. resync = yield self._need_to_do_resync(user_id, pending_updates) if resync: # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: result = yield self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: # TODO: Remember that we are now out of sync and try again # later logger.warn( "Failed to handle device list update for %s," " we're not retrying the remote", user_id, ) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync # next time we get a device list update for this user_id. # This makes it more likely that the device lists will # eventually become consistent. return except FederationDeniedError as e: logger.info(e) return except Exception: # TODO: Remember that we are now out of sync and try again # later logger.exception( "Failed to handle device list update for %s", user_id ) return stream_id = result["stream_id"] devices = result["devices"] yield self.store.update_remote_device_list_cache( user_id, devices, stream_id, ) device_ids = [device["device_id"] for device in devices] yield self.device_handler.notify_device_update(user_id, device_ids) else: # Simply update the single device, since we know that is the only # change (becuase of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: yield self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id, ) yield self.device_handler.notify_device_update( user_id, [device_id for device_id, _, _, _ in pending_updates] ) self._seen_updates.setdefault(user_id, set()).update( stream_id for _, stream_id, _, _ in pending_updates ) @defer.inlineCallbacks def _need_to_do_resync(self, user_id, updates): """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ seen_updates = self._seen_updates.get(user_id, set()) extremity = yield self.store.get_device_list_last_stream_id_for_remote( user_id ) stream_id_in_updates = set() # stream_ids in updates list for _, stream_id, prev_ids, _ in updates: if not prev_ids: # We always do a resync if there are no previous IDs defer.returnValue(True) for prev_id in prev_ids: if prev_id == extremity: continue elif prev_id in seen_updates: continue elif prev_id in stream_id_in_updates: continue else: defer.returnValue(True) stream_id_in_updates.add(stream_id) defer.returnValue(False)
class RoomMemberHandler(BaseHandler): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) self.member_linearizer = Linearizer() self.clock = hs.get_clock() self.distributor = hs.get_distributor() self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, prev_event_ids, txn_id=None, ratelimit=True, ): msg_handler = self.hs.get_handlers().message_handler content = {"membership": membership} if requester.is_guest: content["kind"] = "guest" event, context = yield msg_handler.create_event( { "type": EventTypes.Member, "content": content, "room_id": room_id, "sender": requester.user.to_string(), "state_key": target.to_string(), # For backwards compatibility: "membership": membership, }, token_id=requester.access_token_id, txn_id=txn_id, prev_event_ids=prev_event_ids, ) yield msg_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit, ) prev_member_event = context.current_state.get( (EventTypes.Member, target.to_string()), None) if event.membership == Membership.JOIN: if not prev_member_event or prev_member_event.membership != Membership.JOIN: # Only fire user_joined_room if the user has acutally joined the # room. Don't bother if the user is just changing their profile # info. yield user_joined_room(self.distributor, target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event and prev_member_event.membership == Membership.JOIN: user_left_room(self.distributor, target, room_id) @defer.inlineCallbacks def remote_join(self, remote_room_hosts, room_id, user, content): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") # We don't do an auth check if we are doing an invite # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. yield self.hs.get_handlers().federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content, ) yield user_joined_room(self.distributor, user, room_id) def reject_remote_invite(self, user_id, room_id, remote_room_hosts): return self.hs.get_handlers( ).federation_handler.do_remotely_reject_invite(remote_room_hosts, room_id, user_id) @defer.inlineCallbacks def update_membership( self, requester, target, room_id, action, txn_id=None, remote_room_hosts=None, third_party_signed=None, ratelimit=True, ): key = ( target, room_id, ) with (yield self.member_linearizer.queue(key)): result = yield self._update_membership( requester, target, room_id, action, txn_id=txn_id, remote_room_hosts=remote_room_hosts, third_party_signed=third_party_signed, ratelimit=ratelimit, ) defer.returnValue(result) @defer.inlineCallbacks def _update_membership( self, requester, target, room_id, action, txn_id=None, remote_room_hosts=None, third_party_signed=None, ratelimit=True, ): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" if third_party_signed is not None: replication = self.hs.get_replication_layer() yield replication.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, third_party_signed, ) if not remote_room_hosts: remote_room_hosts = [] latest_event_ids = yield self.store.get_latest_event_ids_in_room( room_id) current_state = yield self.state_handler.get_current_state( room_id, latest_event_ids=latest_event_ids, ) old_state = current_state.get((EventTypes.Member, target.to_string())) old_membership = old_state.content.get( "membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( 403, "Cannot unban user who was not banned (membership=%s)" % old_membership, errcode=Codes.BAD_STATE) if old_membership == "ban" and action != "unban": raise SynapseError(403, "Cannot %s user who was banned" % (action, ), errcode=Codes.BAD_STATE) is_host_in_room = self.is_host_in_room(current_state) if effective_membership_state == Membership.JOIN: if requester.is_guest and not self._can_guest_join(current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: inviter = yield self.get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) content = {"membership": Membership.JOIN} profile = self.hs.get_handlers().profile_handler content["displayname"] = yield profile.get_displayname(target) content["avatar_url"] = yield profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" ret = yield self.remote_join(remote_room_hosts, room_id, target, content) defer.returnValue(ret) elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited inviter = yield self.get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") if self.hs.is_mine(inviter): # the inviter was on our server, but has now left. Carry on # with the normal rejection codepath. # # This is a bit of a hack, because the room might still be # active on other servers. pass else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] try: ret = yield self.reject_remote_invite( target.to_string(), room_id, remote_room_hosts) defer.returnValue(ret) except SynapseError as e: logger.warn("Failed to reject invite: %s", e) yield self.store.locally_reject_invite( target.to_string(), room_id) defer.returnValue({}) yield self._local_membership_update( requester=requester, target=target, room_id=room_id, membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, prev_event_ids=latest_event_ids, ) @defer.inlineCallbacks def send_membership_event( self, requester, event, context, remote_room_hosts=None, ratelimit=True, ): """ Change the membership status of a user in a room. Args: requester (Requester): The local user who requested the membership event. If None, certain checks, like whether this homeserver can act as the sender, will be skipped. event (SynapseEvent): The membership event. context: The context of the event. is_guest (bool): Whether the sender is a guest. room_hosts ([str]): Homeservers which are likely to already be in the room, and could be danced with in order to join this homeserver for the first time. ratelimit (bool): Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. """ remote_room_hosts = remote_room_hosts or [] target_user = UserID.from_string(event.state_key) room_id = event.room_id if requester is not None: sender = UserID.from_string(event.sender) assert sender == requester.user, ( "Sender (%s) must be same as requester (%s)" % (sender, requester.user)) assert self.hs.is_mine( sender), "Sender must be our own: %s" % (sender, ) else: requester = Requester(target_user, None, False) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) if prev_event is not None: return if event.membership == Membership.JOIN: if requester.is_guest and not self._can_guest_join( context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") yield message_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit, ) prev_member_event = context.current_state.get( (EventTypes.Member, target_user.to_string()), None) if event.membership == Membership.JOIN: if not prev_member_event or prev_member_event.membership != Membership.JOIN: # Only fire user_joined_room if the user has acutally joined the # room. Don't bother if the user is just changing their profile # info. yield user_joined_room(self.distributor, target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event and prev_member_event.membership == Membership.JOIN: user_left_room(self.distributor, target_user, room_id) def _can_guest_join(self, current_state): """ Returns whether a guest can join a room based on its current state. """ guest_access = current_state.get((EventTypes.GuestAccess, ""), None) return (guest_access and guest_access.content and "guest_access" in guest_access.content and guest_access.content["guest_access"] == "can_join") @defer.inlineCallbacks def lookup_room_alias(self, room_alias): """ Get the room ID associated with a room alias. Args: room_alias (RoomAlias): The alias to look up. Returns: A tuple of: The room ID as a RoomID object. Hosts likely to be participating in the room ([str]). Raises: SynapseError if room alias could not be found. """ directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) if not mapping: raise SynapseError(404, "No such room alias") room_id = mapping["room_id"] servers = mapping["servers"] defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks def get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( user_id=user_id, room_id=room_id, ) if invite: defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def do_3pid_invite(self, room_id, inviter, medium, address, id_server, requester, txn_id): invitee = yield self._lookup_3pid(id_server, medium, address) if invitee: yield self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id, ) else: yield self._make_and_store_3pid_invite(requester, id_server, medium, address, room_id, inviter, txn_id=txn_id) @defer.inlineCallbacks def _lookup_3pid(self, id_server, medium, address): """Looks up a 3pid in the passed identity server. Args: id_server (str): The server name (including port, if required) of the identity server to use. medium (str): The type of the third party identifier (e.g. "email"). address (str): The third party identifier (e.g. "*****@*****.**"). Returns: str: the matrix ID of the 3pid, or None if it is not recognized. """ try: data = yield self.hs.get_simple_http_client().get_json( "%s%s/_matrix/identity/api/v1/lookup" % ( id_server_scheme, id_server, ), { "medium": medium, "address": address, }) if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") self.verify_any_signature(data, id_server) defer.returnValue(data["mxid"]) except IOError as e: logger.warn("Error from identity server lookup: %s" % (e, )) defer.returnValue(None) @defer.inlineCallbacks def verify_any_signature(self, data, server_hostname): if server_hostname not in data["signatures"]: raise AuthError( 401, "No signature from server %s" % (server_hostname, )) for key_name, signature in data["signatures"][server_hostname].items(): key_data = yield self.hs.get_simple_http_client().get_json( "%s%s/_matrix/identity/api/v1/pubkey/%s" % ( id_server_scheme, server_hostname, key_name, ), ) if "public_key" not in key_data: raise AuthError( 401, "No public key named %s from %s" % ( key_name, server_hostname, )) verify_signed_json( data, server_hostname, decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))) return @defer.inlineCallbacks def _make_and_store_3pid_invite(self, requester, id_server, medium, address, room_id, user, txn_id): room_state = yield self.hs.get_state_handler().get_current_state( room_id) inviter_display_name = "" inviter_avatar_url = "" member_event = room_state.get((EventTypes.Member, user.to_string())) if member_event: inviter_display_name = member_event.content.get("displayname", "") inviter_avatar_url = member_event.content.get("avatar_url", "") canonical_room_alias = "" canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event: canonical_room_alias = canonical_alias_event.content.get( "alias", "") room_name = "" room_name_event = room_state.get((EventTypes.Name, "")) if room_name_event: room_name = room_name_event.content.get("name", "") room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: room_join_rules = join_rules_event.content.get("join_rule", "") room_avatar_url = "" room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( id_server=id_server, medium=medium, address=address, room_id=room_id, inviter_user_id=user.to_string(), room_alias=canonical_room_alias, room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url)) msg_handler = self.hs.get_handlers().message_handler yield msg_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, "content": { "display_name": display_name, "public_keys": public_keys, # For backwards compatibility: "key_validity_url": fallback_public_key["key_validity_url"], "public_key": fallback_public_key["public_key"], }, "room_id": room_id, "sender": user.to_string(), "state_key": token, }, txn_id=txn_id, ) @defer.inlineCallbacks def _ask_id_server_for_third_party_invite(self, id_server, medium, address, room_id, inviter_user_id, room_alias, room_avatar_url, room_join_rules, room_name, inviter_display_name, inviter_avatar_url): """ Asks an identity server for a third party invite. Args: id_server (str): hostname + optional port for the identity server. medium (str): The literal string "email". address (str): The third party address being invited. room_id (str): The ID of the room to which the user is invited. inviter_user_id (str): The user ID of the inviter. room_alias (str): An alias for the room, for cosmetic notifications. room_avatar_url (str): The URL of the room's avatar, for cosmetic notifications. room_join_rules (str): The join rules of the email (e.g. "public"). room_name (str): The m.room.name of the room. inviter_display_name (str): The current display name of the inviter. inviter_avatar_url (str): The URL of the inviter's avatar. Returns: A deferred tuple containing: token (str): The token which must be signed to prove authenticity. public_keys ([{"public_key": str, "key_validity_url": str}]): public_key is a base64-encoded ed25519 public key. fallback_public_key: One element from public_keys. display_name (str): A user-friendly name to represent the invited user. """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, id_server, ) invite_config = { "medium": medium, "address": address, "room_id": room_id, "room_alias": room_alias, "room_avatar_url": room_avatar_url, "room_join_rules": room_join_rules, "room_name": room_name, "sender": inviter_user_id, "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } if self.hs.config.invite_3pid_guest: registration_handler = self.hs.get_handlers().registration_handler guest_access_token = yield registration_handler.guest_access_token_for( medium=medium, address=address, inviter_user_id=inviter_user_id, ) guest_user_info = yield self.hs.get_auth( ).get_user_by_access_token(guest_access_token) invite_config.update({ "guest_access_token": guest_access_token, "guest_user_id": guest_user_info["user"].to_string(), }) data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( is_url, invite_config) # TODO: Check for success token = data["token"] public_keys = data.get("public_keys", []) if "public_key" in data: fallback_public_key = { "public_key": data["public_key"], "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( id_server_scheme, id_server, ), } else: fallback_public_key = public_keys[0] if not public_keys: public_keys.append(fallback_public_key) display_name = data["display_name"] defer.returnValue( (token, public_keys, fallback_public_key, display_name)) @defer.inlineCallbacks def forget(self, user, room_id): user_id = user.to_string() member = yield self.state_handler.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id) membership = member.membership if member else None if membership is not None and membership != Membership.LEAVE: raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) if membership: yield self.store.forget(user_id, room_id)
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, ) @defer.inlineCallbacks def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() yield self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec() ) def mark_recently_accessed(self, server_name, media_id): """Mark the given media as recently accessed. Args: server_name (str|None): Origin server of media, or None if local media_id (str): The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) else: self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): """Store uploaded content for a local user and return the mxc URL Args: media_type(str): The content type of the file upload_name(str): The name of the file content: A file like object that is the content to store content_length(int): The length of the content auth_user(str): The user_id of the uploader Returns: Deferred[str]: The mxc url of the stored content """ media_id = random_string(24) file_info = FileInfo( server_name=None, file_id=media_id, ) fname = yield self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) yield self._generate_thumbnails( None, media_id, media_id, media_type, ) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: request(twisted.web.http.Request) media_id (str): The media ID of the content. (This is the same as the file_id for local content.) name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ media_info = yield self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return self.mark_recently_accessed(None, media_id) media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] file_info = FileInfo( None, media_id, url_cache=url_cache, ) responder = yield self.media_storage.fetch_media(file_info) yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) @defer.inlineCallbacks def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: request(twisted.web.http.Request) server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) self.mark_recently_accessed(server_name, media_id) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # We deliberately stream the file outside the lock if responder: media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) else: respond_404(request) @defer.inlineCallbacks def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[dict]: The media_info of the file """ if ( self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # Ensure we actually use the responder so that it releases resources if responder: with responder: pass defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[(Responder, media_info)] """ media_info = yield self.store.get_cached_remote_media( server_name, media_id ) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. if media_info: file_id = media_info["filesystem_id"] else: file_id = random_string(24) file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() responder = yield self.media_storage.fetch_media(file_info) if responder: defer.returnValue((responder, media_info)) # Failed to find the file anywhere, lets download it. media_info = yield self._download_remote_file( server_name, media_id, file_id ) responder = yield self.media_storage.fetch_media(file_info) defer.returnValue((responder, media_info)) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: server_name (str): Originating server media_id (str): The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. file_id (str): Local file ID Returns: Deferred[MediaInfo] """ file_info = FileInfo( server_name=server_name, file_id=file_id, ) with self.media_storage.store_into_file(file_info) as (f, fname, finish): request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. "allow_remote": "false", } ) except twisted.internet.error.DNSLookupError as e: logger.warn("HTTP error fetching remote media %s/%s: %r", server_name, media_id, e) raise NotFoundError() except HttpResponseException as e: logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: raise SynapseError.from_http_response_exception(e) raise SynapseError(502, "Failed to fetch remote media") except SynapseError: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") yield finish() media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None logger.info("Stored remote media in file %r", fname) yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_thumbnails( server_name, media_id, file_id, media_type, ) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: t_byte_source = None return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type, url_cache): input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( None, media_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type )) if t_byte_source: try: file_info = FileInfo( server_name=None, file_id=media_id, url_cache=url_cache, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( server_name, file_id, url_cache=False, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type )) if t_byte_source: try: file_info = FileInfo( server_name=server_name, file_id=media_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(output_path) @defer.inlineCallbacks def _generate_thumbnails(self, server_name, media_id, file_id, media_type, url_cache=False): """Generate and store thumbnails for an image. Args: server_name (str|None): The server name if remote media, else None if local media_id (str): The media ID of the content. (This is the same as the file_id for local content) file_id (str): Local file ID media_type (str): The content type of the file url_cache (bool): If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: Deferred[dict]: Dict with "width" and "height" keys of original image """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( server_name, file_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) elif r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): # Generate the thumbnail if t_method == "crop": t_byte_source = yield make_deferred_yieldable(threads.deferToThread( thumbnailer.crop, t_width, t_height, t_type, )) elif t_method == "scale": t_byte_source = yield make_deferred_yieldable(threads.deferToThread( thumbnailer.scale, t_width, t_height, t_type, )) else: logger.error("Unrecognized method: %r", t_method) continue if not t_byte_source: continue try: file_info = FileInfo( server_name=server_name, file_id=file_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, url_cache=url_cache, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() t_len = os.path.getsize(output_path) # Write to database if server_name: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) else: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) # TODO: Should we delete from the backup store with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id ) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
class _JoinedHostsCache(object): """Cache for joined hosts in a room that is optimised to handle updates via state deltas. """ def __init__(self, store, room_id): self.store = store self.room_id = room_id self.hosts_to_joined_users = {} self.state_group = object() self.linearizer = Linearizer("_JoinedHostsCache") self._len = 0 @defer.inlineCallbacks def get_destinations(self, state_entry): """Get set of destinations for a state entry Args: state_entry(synapse.state._StateCacheEntry) """ if state_entry.state_group == self.state_group: defer.returnValue(frozenset(self.hosts_to_joined_users)) with (yield self.linearizer.queue(())): if state_entry.state_group == self.state_group: pass elif state_entry.prev_group == self.state_group: for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): if typ != EventTypes.Member: continue host = intern_string(get_domain_from_id(state_key)) user_id = state_key known_joins = self.hosts_to_joined_users.setdefault(host, set()) event = yield self.store.get_event(event_id) if event.membership == Membership.JOIN: known_joins.add(user_id) else: known_joins.discard(user_id) if not known_joins: self.hosts_to_joined_users.pop(host, None) else: joined_users = yield self.store.get_joined_users_from_state( self.room_id, state_entry, ) self.hosts_to_joined_users = {} for user_id in joined_users: host = intern_string(get_domain_from_id(user_id)) self.hosts_to_joined_users.setdefault(host, set()).add(user_id) if state_entry.state_group: self.state_group = state_entry.state_group else: self.state_group = object() self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) defer.returnValue(frozenset(self.hosts_to_joined_users)) def __len__(self): return self._len
def __init__(self, hs): super(FederationServer, self).__init__(hs) self._room_pdu_linearizer = Linearizer()
class RoomMemberHandler(object): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. __metaclass__ = abc.ABCMeta def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): """ self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config self.simple_http_client = hs.get_simple_http_client() self.federation_handler = hs.get_handlers().federation_handler self.directory_handler = hs.get_handlers().directory_handler self.registration_handler = hs.get_handlers().registration_handler self.profile_handler = hs.get_profile_handler() self.event_creation_hander = hs.get_event_creation_handler() self.member_linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() self._server_notices_mxid = self.config.server_notices_mxid @abc.abstractmethod def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Try and join a room that this server is not in Args: requester (Requester) remote_room_hosts (list[str]): List of servers that can be used to join via. room_id (str): Room that we are trying to join user (UserID): User who is trying to join content (dict): A dict that should be used as the content of the join event. Returns: Deferred """ raise NotImplementedError() @abc.abstractmethod def _remote_reject_invite(self, remote_room_hosts, room_id, target): """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. Args: requester (Requester) remote_room_hosts (list[str]): List of servers to use to try and reject invite room_id (str) target (UserID): The user rejecting the invite Returns: Deferred[dict]: A dictionary to be returned to the client, may include event_id etc, or nothing if we locally rejected """ raise NotImplementedError() @abc.abstractmethod def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): """Get a guest access token for a 3PID, creating a guest account if one doesn't already exist. Args: requester (Requester) medium (str) address (str) inviter_user_id (str): The user ID who is trying to invite the 3PID Returns: Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the 3PID guest account. """ raise NotImplementedError() @abc.abstractmethod def _user_joined_room(self, target, room_id): """Notifies distributor on master process that the user has joined the room. Args: target (UserID) room_id (str) Returns: Deferred|None """ raise NotImplementedError() @abc.abstractmethod def _user_left_room(self, target, room_id): """Notifies distributor on master process that the user has left the room. Args: target (UserID) room_id (str) Returns: Deferred|None """ raise NotImplementedError() @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, prev_events_and_hashes, txn_id=None, ratelimit=True, content=None, ): if content is None: content = {} content["membership"] = membership if requester.is_guest: content["kind"] = "guest" event, context = yield self.event_creation_hander.create_event( requester, { "type": EventTypes.Member, "content": content, "room_id": room_id, "sender": requester.user.to_string(), "state_key": target.to_string(), # For backwards compatibility: "membership": membership, }, token_id=requester.access_token_id, txn_id=txn_id, prev_events_and_hashes=prev_events_and_hashes, ) # Check if this event matches the previous membership event for the user. duplicate = yield self.event_creation_hander.deduplicate_state_event( event, context, ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) yield self.event_creation_hander.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit, ) prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has acutally joined the # room. Don't bother if the user is just changing their profile # info. newly_joined = True if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: yield self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: yield self._user_left_room(target, room_id) defer.returnValue(event) @defer.inlineCallbacks def update_membership( self, requester, target, room_id, action, txn_id=None, remote_room_hosts=None, third_party_signed=None, ratelimit=True, content=None, ): key = (room_id,) with (yield self.member_linearizer.queue(key)): result = yield self._update_membership( requester, target, room_id, action, txn_id=txn_id, remote_room_hosts=remote_room_hosts, third_party_signed=third_party_signed, ratelimit=ratelimit, content=content, ) defer.returnValue(result) @defer.inlineCallbacks def _update_membership( self, requester, target, room_id, action, txn_id=None, remote_room_hosts=None, third_party_signed=None, ratelimit=True, content=None, ): content_specified = bool(content) if content is None: content = {} else: # We do a copy here as we potentially change some keys # later on. content = dict(content) effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: yield self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, third_party_signed, ) if not remote_room_hosts: remote_room_hosts = [] if effective_membership_state not in ("leave", "ban",): is_blocked = yield self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: raise SynapseError( http_client.FORBIDDEN, "Cannot invite this user", ) block_invite = False if (self._server_notices_mxid is not None and requester.user.to_string() == self._server_notices_mxid): # allow the server notices mxid to send invites is_requester_admin = True else: is_requester_admin = yield self.auth.is_server_admin( requester.user, ) if not is_requester_admin: if self.config.block_non_admin_invites: logger.info( "Blocking invite: user is not admin and non-admin " "invites disabled" ) block_invite = True if not self.spam_checker.user_may_invite( requester.user.to_string(), target.to_string(), room_id, ): logger.info("Blocking invite due to spam checker") block_invite = True if block_invite: raise SynapseError( 403, "Invites have been disabled on this server", ) prev_events_and_hashes = yield self.store.get_prev_events_for_room( room_id, ) latest_event_ids = ( event_id for (event_id, _, _) in prev_events_and_hashes ) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) if old_state_id: old_state = yield self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( 403, "Cannot unban user who was not banned" " (membership=%s)" % old_membership, errcode=Codes.BAD_STATE ) if old_membership == "ban" and action != "unban": raise SynapseError( 403, "Cannot %s user who was banned" % (action,), errcode=Codes.BAD_STATE ) if old_state: same_content = content == old_state.content same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: defer.returnValue(old_state) # we don't allow people to reject invites to the server notice # room, but they can leave it once they are joined. if ( old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): is_blocked = yield self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( http_client.FORBIDDEN, "You cannot reject this invite", errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM, ) is_host_in_room = yield self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(current_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: inviter = yield self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) content["membership"] = Membership.JOIN profile = self.profile_handler if not content_specified: content["displayname"] = yield profile.get_displayname(target) content["avatar_url"] = yield profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" ret = yield self._remote_join( requester, remote_room_hosts, room_id, target, content ) defer.returnValue(ret) elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited inviter = yield self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") if self.hs.is_mine(inviter): # the inviter was on our server, but has now left. Carry on # with the normal rejection codepath. # # This is a bit of a hack, because the room might still be # active on other servers. pass else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( requester, remote_room_hosts, room_id, target, ) defer.returnValue(res) res = yield self._local_membership_update( requester=requester, target=target, room_id=room_id, membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, prev_events_and_hashes=prev_events_and_hashes, content=content, ) defer.returnValue(res) @defer.inlineCallbacks def send_membership_event( self, requester, event, context, remote_room_hosts=None, ratelimit=True, ): """ Change the membership status of a user in a room. Args: requester (Requester): The local user who requested the membership event. If None, certain checks, like whether this homeserver can act as the sender, will be skipped. event (SynapseEvent): The membership event. context: The context of the event. is_guest (bool): Whether the sender is a guest. room_hosts ([str]): Homeservers which are likely to already be in the room, and could be danced with in order to join this homeserver for the first time. ratelimit (bool): Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. """ remote_room_hosts = remote_room_hosts or [] target_user = UserID.from_string(event.state_key) room_id = event.room_id if requester is not None: sender = UserID.from_string(event.sender) assert sender == requester.user, ( "Sender (%s) must be same as requester (%s)" % (sender, requester.user) ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = synapse.types.create_requester(target_user) prev_event = yield self.event_creation_hander.deduplicate_state_event( event, context, ) if prev_event is not None: return if event.membership == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(context.prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if event.membership not in (Membership.LEAVE, Membership.BAN): is_blocked = yield self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") yield self.event_creation_hander.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit, ) prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has acutally joined the # room. Don't bother if the user is just changing their profile # info. newly_joined = True if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: yield self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = yield self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: yield self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): """ Returns whether a guest can join a room based on its current state. """ guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) if not guest_access_id: defer.returnValue(False) guest_access = yield self.store.get_event(guest_access_id) defer.returnValue( guest_access and guest_access.content and "guest_access" in guest_access.content and guest_access.content["guest_access"] == "can_join" ) @defer.inlineCallbacks def lookup_room_alias(self, room_alias): """ Get the room ID associated with a room alias. Args: room_alias (RoomAlias): The alias to look up. Returns: A tuple of: The room ID as a RoomID object. Hosts likely to be participating in the room ([str]). Raises: SynapseError if room alias could not be found. """ directory_handler = self.directory_handler mapping = yield directory_handler.get_association(room_alias) if not mapping: raise SynapseError(404, "No such room alias") room_id = mapping["room_id"] servers = mapping["servers"] defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( user_id=user_id, room_id=room_id, ) if invite: defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def do_3pid_invite( self, room_id, inviter, medium, address, id_server, requester, txn_id ): if self.config.block_non_admin_invites: is_requester_admin = yield self.auth.is_server_admin( requester.user, ) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN, ) invitee = yield self._lookup_3pid( id_server, medium, address ) if invitee: yield self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id, ) else: yield self._make_and_store_3pid_invite( requester, id_server, medium, address, room_id, inviter, txn_id=txn_id ) @defer.inlineCallbacks def _lookup_3pid(self, id_server, medium, address): """Looks up a 3pid in the passed identity server. Args: id_server (str): The server name (including port, if required) of the identity server to use. medium (str): The type of the third party identifier (e.g. "email"). address (str): The third party identifier (e.g. "*****@*****.**"). Returns: str: the matrix ID of the 3pid, or None if it is not recognized. """ try: data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), { "medium": medium, "address": address, } ) if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") yield self._verify_any_signature(data, id_server) defer.returnValue(data["mxid"]) except IOError as e: logger.warn("Error from identity server lookup: %s" % (e,)) defer.returnValue(None) @defer.inlineCallbacks def _verify_any_signature(self, data, server_hostname): if server_hostname not in data["signatures"]: raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): key_data = yield self.simple_http_client.get_json( "%s%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_scheme, server_hostname, key_name,), ) if "public_key" not in key_data: raise AuthError(401, "No public key named %s from %s" % (key_name, server_hostname,)) verify_signed_json( data, server_hostname, decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) ) return @defer.inlineCallbacks def _make_and_store_3pid_invite( self, requester, id_server, medium, address, room_id, user, txn_id ): room_state = yield self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" member_event = room_state.get((EventTypes.Member, user.to_string())) if member_event: inviter_display_name = member_event.content.get("displayname", "") inviter_avatar_url = member_event.content.get("avatar_url", "") canonical_room_alias = "" canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event: canonical_room_alias = canonical_alias_event.content.get("alias", "") room_name = "" room_name_event = room_state.get((EventTypes.Name, "")) if room_name_event: room_name = room_name_event.content.get("name", "") room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: room_join_rules = join_rules_event.content.get("join_rule", "") room_avatar_url = "" room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( requester=requester, id_server=id_server, medium=medium, address=address, room_id=room_id, inviter_user_id=user.to_string(), room_alias=canonical_room_alias, room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url ) ) yield self.event_creation_hander.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, "content": { "display_name": display_name, "public_keys": public_keys, # For backwards compatibility: "key_validity_url": fallback_public_key["key_validity_url"], "public_key": fallback_public_key["public_key"], }, "room_id": room_id, "sender": user.to_string(), "state_key": token, }, txn_id=txn_id, ) @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( self, requester, id_server, medium, address, room_id, inviter_user_id, room_alias, room_avatar_url, room_join_rules, room_name, inviter_display_name, inviter_avatar_url ): """ Asks an identity server for a third party invite. Args: requester (Requester) id_server (str): hostname + optional port for the identity server. medium (str): The literal string "email". address (str): The third party address being invited. room_id (str): The ID of the room to which the user is invited. inviter_user_id (str): The user ID of the inviter. room_alias (str): An alias for the room, for cosmetic notifications. room_avatar_url (str): The URL of the room's avatar, for cosmetic notifications. room_join_rules (str): The join rules of the email (e.g. "public"). room_name (str): The m.room.name of the room. inviter_display_name (str): The current display name of the inviter. inviter_avatar_url (str): The URL of the inviter's avatar. Returns: A deferred tuple containing: token (str): The token which must be signed to prove authenticity. public_keys ([{"public_key": str, "key_validity_url": str}]): public_key is a base64-encoded ed25519 public key. fallback_public_key: One element from public_keys. display_name (str): A user-friendly name to represent the invited user. """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, id_server, ) invite_config = { "medium": medium, "address": address, "room_id": room_id, "room_alias": room_alias, "room_avatar_url": room_avatar_url, "room_join_rules": room_join_rules, "room_name": room_name, "sender": inviter_user_id, "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } if self.config.invite_3pid_guest: guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest( requester=requester, medium=medium, address=address, inviter_user_id=inviter_user_id, ) invite_config.update({ "guest_access_token": guest_access_token, "guest_user_id": guest_user_id, }) data = yield self.simple_http_client.post_urlencoded_get_json( is_url, invite_config ) # TODO: Check for success token = data["token"] public_keys = data.get("public_keys", []) if "public_key" in data: fallback_public_key = { "public_key": data["public_key"], "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( id_server_scheme, id_server, ), } else: fallback_public_key = public_keys[0] if not public_keys: public_keys.append(fallback_public_key) display_name = data["display_name"] defer.returnValue((token, public_keys, fallback_public_key, display_name)) @defer.inlineCallbacks def _is_host_in_room(self, current_state_ids): # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) if len(current_state_ids) == 1 and create_event_id: defer.returnValue(self.hs.is_mine_id(create_event_id)) for etype, state_key in current_state_ids: if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): continue event_id = current_state_ids[(etype, state_key)] event = yield self.store.get_event(event_id, allow_none=True) if not event: continue if event.membership == Membership.JOIN: defer.returnValue(True) defer.returnValue(False) @defer.inlineCallbacks def _is_server_notice_room(self, room_id): if self._server_notices_mxid is None: defer.returnValue(False) user_ids = yield self.store.get_users_in_room(room_id) defer.returnValue(self._server_notices_mxid in user_ids)
class StateResolutionHandler(object): """Responsible for doing state conflict resolution. Note that the storage layer depends on this handler, so all functions must be storage-independent. """ def __init__(self, hs): self.clock = hs.get_clock() # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock") def start_caching(self): logger.debug("start_caching") self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, ) self._state_cache.start() @defer.inlineCallbacks @log_function def resolve_state_groups( self, room_id, state_groups_ids, event_map, state_map_factory, ): """Resolves conflicts between a set of state groups Always generates a new state group (unless we hit the cache), so should not be called for a single state group Args: room_id (str): room we are resolving for (used for logging) state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group (where 'state' is a map from state key to event id) event_map(dict[str,FrozenEvent]|None): a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing events will be requested via state_map_factory. If None, all events will be fetched via state_map_factory. Returns: Deferred[_StateCacheEntry]: resolved state """ logger.debug( "resolve_state_groups state_groups %s", state_groups_ids.keys() ) group_names = frozenset(state_groups_ids.keys()) with (yield self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: defer.returnValue(cache) logger.info( "Resolving state for %s with %d groups", room_id, len(state_groups_ids) ) # build a map from state key to the event_ids which set that state. # dict[(str, str), set[str]) state = {} for st in state_groups_ids.itervalues(): for key, e_id in st.iteritems(): state.setdefault(key, set()).add(e_id) # build a map from state key to the event_ids which set that state, # including only those where there are state keys in conflict. conflicted_state = { k: list(v) for k, v in state.iteritems() if len(v) > 1 } if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_factory( state_groups_ids.values(), event_map=event_map, state_map_factory=state_map_factory, ) else: new_state = { key: e_ids.pop() for key, e_ids in state.iteritems() } with Measure(self.clock, "state.create_group_ids"): # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id # which will be used as a cache key for future resolutions, but # not get persisted. state_group = None new_state_event_ids = frozenset(new_state.itervalues()) for sg, events in state_groups_ids.iteritems(): if new_state_event_ids == frozenset(e_id for e_id in events): state_group = sg break # TODO: We want to create a state group for this set of events, to # increase cache hits, but we need to make sure that it doesn't # end up as a prev_group without being added to the database prev_group = None delta_ids = None for old_group, old_ids in state_groups_ids.iteritems(): if not set(new_state) - set(old_ids): n_delta_ids = { k: v for k, v in new_state.iteritems() if old_ids.get(k) != v } if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids cache = _StateCacheEntry( state=new_state, state_group=state_group, prev_group=prev_group, delta_ids=delta_ids, ) if self._state_cache is not None: self._state_cache[group_names] = cache defer.returnValue(cache)
class FederationServer(FederationBase): def __init__(self, hs): super(FederationServer, self).__init__(hs) self._room_pdu_linearizer = Linearizer() def set_handler(self, handler): """Sets the handler that the replication layer will use to communicate receipt of new PDUs from other home servers. The required methods are documented on :py:class:`.ReplicationHandler`. """ self.handler = handler def register_edu_handler(self, edu_type, handler): if edu_type in self.edu_handlers: raise KeyError("Already have an EDU handler for %s" % (edu_type,)) self.edu_handlers[edu_type] = handler def register_query_handler(self, query_type, handler): """Sets the handler callable that will be used to handle an incoming federation Query of the given type. Args: query_type (str): Category name of the query, which should match the string used by make_query. handler (callable): Invoked to handle incoming queries of this type handler is invoked as: result = handler(args) where 'args' is a dict mapping strings to strings of the query arguments. It should return a Deferred that will eventually yield an object to encode as JSON. """ if query_type in self.query_handlers: raise KeyError( "Already have a Query handler for %s" % (query_type,) ) self.query_handlers[query_type] = handler @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, versions, limit): pdus = yield self.handler.on_backfill_request( origin, room_id, versions, limit ) defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) @defer.inlineCallbacks @log_function def on_incoming_transaction(self, transaction_data): transaction = Transaction(**transaction_data) received_pdus_counter.inc_by(len(transaction.pdus)) for p in transaction.pdus: if "unsigned" in p: unsigned = p["unsigned"] if "age" in unsigned: p["age"] = unsigned["age"] if "age" in p: p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) del p["age"] pdu_list = [ self.event_from_pdu_json(p) for p in transaction.pdus ] logger.debug("[%s] Got transaction", transaction.transaction_id) response = yield self.transaction_actions.have_responded(transaction) if response: logger.debug( "[%s] We've already responed to this request", transaction.transaction_id ) defer.returnValue(response) return logger.debug("[%s] Transaction is new", transaction.transaction_id) results = [] for pdu in pdu_list: try: yield self._handle_new_pdu(transaction.origin, pdu) results.append({}) except FederationError as e: self.send_failure(e, transaction.origin) results.append({"error": str(e)}) except Exception as e: results.append({"error": str(e)}) logger.exception("Failed to handle PDU") if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): yield self.received_edu( transaction.origin, edu.edu_type, edu.content ) for failure in getattr(transaction, "pdu_failures", []): logger.info("Got failure %r", failure) logger.debug("Returning: %s", str(results)) response = { "pdus": dict(zip( (p.event_id for p in pdu_list), results )), } yield self.transaction_actions.set_response( transaction, 200, response ) defer.returnValue((200, response)) @defer.inlineCallbacks def received_edu(self, origin, edu_type, content): received_edus_counter.inc() if edu_type in self.edu_handlers: try: yield self.edu_handlers[edu_type](origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception as e: logger.exception("Failed to handle edu %r", edu_type, e) else: logger.warn("Received EDU of type %s with no handler", edu_type) @defer.inlineCallbacks @log_function def on_context_state_request(self, origin, room_id, event_id): if event_id: pdus = yield self.handler.get_state_for_pdu( origin, room_id, event_id, ) auth_chain = yield self.store.get_auth_chain( [pdu.event_id for pdu in pdus] ) for event in auth_chain: # We sign these again because there was a bug where we # incorrectly signed things the first time round if self.hs.is_mine_id(event.event_id): event.signatures.update( compute_event_signature( event, self.hs.hostname, self.hs.config.signing_key[0] ) ) else: raise NotImplementedError("Specify an event") defer.returnValue((200, { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], })) @defer.inlineCallbacks @log_function def on_pdu_request(self, origin, event_id): pdu = yield self._get_persisted_pdu(origin, event_id) if pdu: defer.returnValue( (200, self._transaction_from_pdus([pdu]).get_dict()) ) else: defer.returnValue((404, "")) @defer.inlineCallbacks @log_function def on_pull_request(self, origin, versions): raise NotImplementedError("Pull transactions not implemented") @defer.inlineCallbacks def on_query_request(self, query_type, args): received_queries_counter.inc(query_type) if query_type in self.query_handlers: response = yield self.query_handlers[query_type](args) defer.returnValue((200, response)) else: defer.returnValue( (404, "No handler for Query type '%s'" % (query_type,)) ) @defer.inlineCallbacks def on_make_join_request(self, room_id, user_id): pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @defer.inlineCallbacks def on_invite_request(self, origin, content): pdu = self.event_from_pdu_json(content) ret_pdu = yield self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) @defer.inlineCallbacks def on_send_join_request(self, origin, content): logger.debug("on_send_join_request: content: %s", content) pdu = self.event_from_pdu_json(content) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() defer.returnValue((200, { "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], "auth_chain": [ p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] ], })) @defer.inlineCallbacks def on_make_leave_request(self, room_id, user_id): pdu = yield self.handler.on_make_leave_request(room_id, user_id) time_now = self._clock.time_msec() defer.returnValue({"event": pdu.get_pdu_json(time_now)}) @defer.inlineCallbacks def on_send_leave_request(self, origin, content): logger.debug("on_send_leave_request: content: %s", content) pdu = self.event_from_pdu_json(content) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) yield self.handler.on_send_leave_request(origin, pdu) defer.returnValue((200, {})) @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) defer.returnValue((200, { "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], })) @defer.inlineCallbacks def on_query_auth_request(self, origin, content, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. missing (list): A list of event_ids indicating what the other side (`origin`) think we're missing. rejects (dict): A mapping from event_id to a 2-tuple of reason string and a proof (or None) of why the event was rejected. The keys of this dict give the list of events the `origin` has rejected. Args: origin (str) content (dict) event_id (str) Returns: Deferred: Results in `dict` with the same format as `content` """ auth_chain = [ self.event_from_pdu_json(e) for e in content["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True ) ret = yield self.handler.on_query_auth( origin, event_id, signed_auth, content.get("rejects", []), content.get("missing", []), ) time_now = self._clock.time_msec() send_content = { "auth_chain": [ e.get_pdu_json(time_now) for e in ret["auth_chain"] ], "rejects": ret.get("rejects", []), "missing": ret.get("missing", []), } defer.returnValue( (200, send_content) ) @defer.inlineCallbacks @log_function def on_query_client_keys(self, origin, content): query = [] for user_id, device_ids in content.get("device_keys", {}).items(): if not device_ids: query.append((user_id, None)) else: for device_id in device_ids: query.append((user_id, device_id)) results = yield self.store.get_e2e_device_keys(query) json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) defer.returnValue({"device_keys": json_result}) @defer.inlineCallbacks @log_function def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) results = yield self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks @log_function def on_get_missing_events(self, origin, room_id, earliest_events, latest_events, limit, min_depth): logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d, min_depth: %d", earliest_events, latest_events, limit, min_depth ) missing_events = yield self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit, min_depth ) if len(missing_events) < 5: logger.info("Returning %d events: %r", len(missing_events), missing_events) else: logger.info("Returning %d events", len(missing_events)) time_now = self._clock.time_msec() defer.returnValue({ "events": [ev.get_pdu_json(time_now) for ev in missing_events], }) @log_function def on_openid_userinfo(self, token): ts_now_ms = self._clock.time_msec() return self.store.get_user_id_for_open_id_token(token, ts_now_ms) @log_function def _get_persisted_pdu(self, origin, event_id, do_auth=True): """ Get a PDU from the database with given origin and id. Returns: Deferred: Results in a `Pdu`. """ return self.handler.get_persisted_pdu( origin, event_id, do_auth=do_auth ) def _transaction_from_pdus(self, pdu_list): """Returns a new Transaction containing the given PDUs suitable for transmission. """ time_now = self._clock.time_msec() pdus = [p.get_pdu_json(time_now) for p in pdu_list] return Transaction( origin=self.server_name, pdus=pdus, origin_server_ts=int(time_now), destination=None, ) @defer.inlineCallbacks @log_function def _handle_new_pdu(self, origin, pdu, get_missing=True): # We reprocess pdus when we have seen them only as outliers existing = yield self._get_persisted_pdu( origin, pdu.event_id, do_auth=False ) # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. already_seen = ( existing and ( not existing.internal_metadata.is_outlier() or pdu.internal_metadata.is_outlier() ) ) if already_seen: logger.debug("Already seen pdu %s", pdu.event_id) return # Check signature. try: pdu = yield self._check_sigs_and_hash(pdu) except SynapseError as e: raise FederationError( "ERROR", e.code, e.msg, affected=pdu.event_id, ) state = None auth_chain = [] have_seen = yield self.store.have_events( [ev for ev, _ in pdu.prev_events] ) fetch_state = False # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. min_depth = yield self.handler.get_min_depth_for_context( pdu.room_id ) logger.debug( "_handle_new_pdu min_depth for %s: %d", pdu.room_id, min_depth ) prevs = {e_id for e_id, _ in pdu.prev_events} seen = set(have_seen.keys()) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this # message, to work around the fact that some events will # reference really really old events we really don't want to # send to the clients. pdu.internal_metadata.outlier = True elif min_depth and pdu.depth > min_depth: if get_missing and prevs - seen: # If we're missing stuff, ensure we only fetch stuff one # at a time. with (yield self._room_pdu_linearizer.queue(pdu.room_id)): # We recalculate seen, since it may have changed. have_seen = yield self.store.have_events(prevs) seen = set(have_seen.keys()) if prevs - seen: latest = yield self.store.get_latest_event_ids_in_room( pdu.room_id ) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us latest = set(latest) latest |= seen logger.info( "Missing %d events for room %r: %r...", len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] ) missing_events = yield self.get_missing_events( origin, pdu.room_id, earliest_events_ids=list(latest), latest_events=[pdu], limit=10, min_depth=min_depth, ) # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) for e in missing_events: yield self._handle_new_pdu( origin, e, get_missing=False ) have_seen = yield self.store.have_events( [ev for ev, _ in pdu.prev_events] ) prevs = {e_id for e_id, _ in pdu.prev_events} seen = set(have_seen.keys()) if prevs - seen: logger.info( "Still missing %d events for room %r: %r...", len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] ) fetch_state = True if fetch_state: # We need to get the state at this event, since we haven't # processed all the prev events. logger.debug( "_handle_new_pdu getting state for %s", pdu.room_id ) try: state, auth_chain = yield self.get_state_for_room( origin, pdu.room_id, pdu.event_id, ) except: logger.warn("Failed to get state for event: %s", pdu.event_id) yield self.handler.on_receive_pdu( origin, pdu, state=state, auth_chain=auth_chain, ) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name def event_from_pdu_json(self, pdu_json, outlier=False): event = FrozenEvent( pdu_json ) event.internal_metadata.outlier = outlier return event @defer.inlineCallbacks def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed, ): ret = yield self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed, ) defer.returnValue(ret) @defer.inlineCallbacks def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): ret = yield self.handler.on_exchange_third_party_invite_request( origin, room_id, event_dict ) defer.returnValue(ret)
class DeviceListEduUpdater(object): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs, device_handler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. self._pending_updates = {} # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious # resyncs. self._seen_updates = ExpiringCache( cache_name="device_update_edu", clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, iterable=True, ) @defer.inlineCallbacks def incoming_device_list_update(self, origin, edu_content): """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ user_id = edu_content.pop("user_id") device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? logger.warning("Got device list update edu for %r from %r", user_id, origin) return room_ids = yield self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. return self._pending_updates.setdefault(user_id, []).append( (device_id, stream_id, prev_ids, edu_content)) yield self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") @defer.inlineCallbacks def _handle_device_updates(self, user_id): "Actually handle pending updates." with (yield self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates return # Given a list of updates we check if we need to resync. This # happens if we've missed updates. resync = yield self._need_to_do_resync(user_id, pending_updates) if resync: # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: result = yield self.federation.query_user_devices( origin, user_id) except NotRetryingDestination: # TODO: Remember that we are now out of sync and try again # later logger.warn( "Failed to handle device list update for %s," " we're not retrying the remote", user_id, ) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync # next time we get a device list update for this user_id. # This makes it more likely that the device lists will # eventually become consistent. return except FederationDeniedError as e: logger.info(e) return except Exception: # TODO: Remember that we are now out of sync and try again # later logger.exception( "Failed to handle device list update for %s", user_id) return stream_id = result["stream_id"] devices = result["devices"] yield self.store.update_remote_device_list_cache( user_id, devices, stream_id, ) device_ids = [device["device_id"] for device in devices] yield self.device_handler.notify_device_update( user_id, device_ids) else: # Simply update the single device, since we know that is the only # change (becuase of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: yield self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id, ) yield self.device_handler.notify_device_update( user_id, [device_id for device_id, _, _, _ in pending_updates]) self._seen_updates.setdefault(user_id, set()).update( stream_id for _, stream_id, _, _ in pending_updates) @defer.inlineCallbacks def _need_to_do_resync(self, user_id, updates): """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ seen_updates = self._seen_updates.get(user_id, set()) extremity = yield self.store.get_device_list_last_stream_id_for_remote( user_id) stream_id_in_updates = set() # stream_ids in updates list for _, stream_id, prev_ids, _ in updates: if not prev_ids: # We always do a resync if there are no previous IDs defer.returnValue(True) for prev_id in prev_ids: if prev_id == extremity: continue elif prev_id in seen_updates: continue elif prev_id in stream_id_in_updates: continue else: defer.returnValue(True) stream_id_in_updates.add(stream_id) defer.returnValue(False)
def __init__(self, hs): super(ReadMarkerHandler, self).__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() self.read_marker_linearizer = Linearizer(name="read_marker") self.notifier = hs.get_notifier()
def __init__(self, hs): self.clock = hs.get_clock() # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock")
class RoomMemberHandler(BaseHandler): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level # API that takes ID strings and returns pagination chunks. These concerns # ought to be separated out a lot better. def __init__(self, hs): super(RoomMemberHandler, self).__init__(hs) self.member_linearizer = Linearizer() self.clock = hs.get_clock() self.distributor = hs.get_distributor() self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, prev_event_ids, txn_id=None, ratelimit=True, ): msg_handler = self.hs.get_handlers().message_handler content = {"membership": membership} if requester.is_guest: content["kind"] = "guest" event, context = yield msg_handler.create_event( { "type": EventTypes.Member, "content": content, "room_id": room_id, "sender": requester.user.to_string(), "state_key": target.to_string(), # For backwards compatibility: "membership": membership, }, token_id=requester.access_token_id, txn_id=txn_id, prev_event_ids=prev_event_ids, ) yield msg_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit, ) prev_member_event = context.current_state.get( (EventTypes.Member, target.to_string()), None ) if event.membership == Membership.JOIN: if not prev_member_event or prev_member_event.membership != Membership.JOIN: # Only fire user_joined_room if the user has acutally joined the # room. Don't bother if the user is just changing their profile # info. yield user_joined_room(self.distributor, target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event and prev_member_event.membership == Membership.JOIN: user_left_room(self.distributor, target, room_id) @defer.inlineCallbacks def remote_join(self, remote_room_hosts, room_id, user, content): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") # We don't do an auth check if we are doing an invite # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. yield self.hs.get_handlers().federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content, ) yield user_joined_room(self.distributor, user, room_id) def reject_remote_invite(self, user_id, room_id, remote_room_hosts): return self.hs.get_handlers().federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id ) @defer.inlineCallbacks def update_membership( self, requester, target, room_id, action, txn_id=None, remote_room_hosts=None, third_party_signed=None, ratelimit=True, ): key = (target, room_id,) with (yield self.member_linearizer.queue(key)): result = yield self._update_membership( requester, target, room_id, action, txn_id=txn_id, remote_room_hosts=remote_room_hosts, third_party_signed=third_party_signed, ratelimit=ratelimit, ) defer.returnValue(result) @defer.inlineCallbacks def _update_membership( self, requester, target, room_id, action, txn_id=None, remote_room_hosts=None, third_party_signed=None, ratelimit=True, ): effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" if third_party_signed is not None: replication = self.hs.get_replication_layer() yield replication.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, third_party_signed, ) if not remote_room_hosts: remote_room_hosts = [] latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) current_state = yield self.state_handler.get_current_state( room_id, latest_event_ids=latest_event_ids, ) old_state = current_state.get((EventTypes.Member, target.to_string())) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( 403, "Cannot unban user who was not banned (membership=%s)" % old_membership, errcode=Codes.BAD_STATE ) if old_membership == "ban" and action != "unban": raise SynapseError( 403, "Cannot %s user who was banned" % (action,), errcode=Codes.BAD_STATE ) is_host_in_room = self.is_host_in_room(current_state) if effective_membership_state == Membership.JOIN: if requester.is_guest and not self._can_guest_join(current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: inviter = yield self.get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) content = {"membership": Membership.JOIN} profile = self.hs.get_handlers().profile_handler content["displayname"] = yield profile.get_displayname(target) content["avatar_url"] = yield profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" ret = yield self.remote_join( remote_room_hosts, room_id, target, content ) defer.returnValue(ret) elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited inviter = yield self.get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") if self.hs.is_mine(inviter): # the inviter was on our server, but has now left. Carry on # with the normal rejection codepath. # # This is a bit of a hack, because the room might still be # active on other servers. pass else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] try: ret = yield self.reject_remote_invite( target.to_string(), room_id, remote_room_hosts ) defer.returnValue(ret) except SynapseError as e: logger.warn("Failed to reject invite: %s", e) yield self.store.locally_reject_invite( target.to_string(), room_id ) defer.returnValue({}) yield self._local_membership_update( requester=requester, target=target, room_id=room_id, membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, prev_event_ids=latest_event_ids, ) @defer.inlineCallbacks def send_membership_event( self, requester, event, context, remote_room_hosts=None, ratelimit=True, ): """ Change the membership status of a user in a room. Args: requester (Requester): The local user who requested the membership event. If None, certain checks, like whether this homeserver can act as the sender, will be skipped. event (SynapseEvent): The membership event. context: The context of the event. is_guest (bool): Whether the sender is a guest. room_hosts ([str]): Homeservers which are likely to already be in the room, and could be danced with in order to join this homeserver for the first time. ratelimit (bool): Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. """ remote_room_hosts = remote_room_hosts or [] target_user = UserID.from_string(event.state_key) room_id = event.room_id if requester is not None: sender = UserID.from_string(event.sender) assert sender == requester.user, ( "Sender (%s) must be same as requester (%s)" % (sender, requester.user) ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = Requester(target_user, None, False) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) if prev_event is not None: return if event.membership == Membership.JOIN: if requester.is_guest and not self._can_guest_join(context.current_state): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") yield message_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit, ) prev_member_event = context.current_state.get( (EventTypes.Member, target_user.to_string()), None ) if event.membership == Membership.JOIN: if not prev_member_event or prev_member_event.membership != Membership.JOIN: # Only fire user_joined_room if the user has acutally joined the # room. Don't bother if the user is just changing their profile # info. yield user_joined_room(self.distributor, target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event and prev_member_event.membership == Membership.JOIN: user_left_room(self.distributor, target_user, room_id) def _can_guest_join(self, current_state): """ Returns whether a guest can join a room based on its current state. """ guest_access = current_state.get((EventTypes.GuestAccess, ""), None) return ( guest_access and guest_access.content and "guest_access" in guest_access.content and guest_access.content["guest_access"] == "can_join" ) @defer.inlineCallbacks def lookup_room_alias(self, room_alias): """ Get the room ID associated with a room alias. Args: room_alias (RoomAlias): The alias to look up. Returns: A tuple of: The room ID as a RoomID object. Hosts likely to be participating in the room ([str]). Raises: SynapseError if room alias could not be found. """ directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) if not mapping: raise SynapseError(404, "No such room alias") room_id = mapping["room_id"] servers = mapping["servers"] defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks def get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( user_id=user_id, room_id=room_id, ) if invite: defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def do_3pid_invite( self, room_id, inviter, medium, address, id_server, requester, txn_id ): invitee = yield self._lookup_3pid( id_server, medium, address ) if invitee: yield self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id, ) else: yield self._make_and_store_3pid_invite( requester, id_server, medium, address, room_id, inviter, txn_id=txn_id ) @defer.inlineCallbacks def _lookup_3pid(self, id_server, medium, address): """Looks up a 3pid in the passed identity server. Args: id_server (str): The server name (including port, if required) of the identity server to use. medium (str): The type of the third party identifier (e.g. "email"). address (str): The third party identifier (e.g. "*****@*****.**"). Returns: str: the matrix ID of the 3pid, or None if it is not recognized. """ try: data = yield self.hs.get_simple_http_client().get_json( "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), { "medium": medium, "address": address, } ) if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") self.verify_any_signature(data, id_server) defer.returnValue(data["mxid"]) except IOError as e: logger.warn("Error from identity server lookup: %s" % (e,)) defer.returnValue(None) @defer.inlineCallbacks def verify_any_signature(self, data, server_hostname): if server_hostname not in data["signatures"]: raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): key_data = yield self.hs.get_simple_http_client().get_json( "%s%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_scheme, server_hostname, key_name,), ) if "public_key" not in key_data: raise AuthError(401, "No public key named %s from %s" % (key_name, server_hostname,)) verify_signed_json( data, server_hostname, decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) ) return @defer.inlineCallbacks def _make_and_store_3pid_invite( self, requester, id_server, medium, address, room_id, user, txn_id ): room_state = yield self.hs.get_state_handler().get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" member_event = room_state.get((EventTypes.Member, user.to_string())) if member_event: inviter_display_name = member_event.content.get("displayname", "") inviter_avatar_url = member_event.content.get("avatar_url", "") canonical_room_alias = "" canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event: canonical_room_alias = canonical_alias_event.content.get("alias", "") room_name = "" room_name_event = room_state.get((EventTypes.Name, "")) if room_name_event: room_name = room_name_event.content.get("name", "") room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: room_join_rules = join_rules_event.content.get("join_rule", "") room_avatar_url = "" room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") token, public_keys, fallback_public_key, display_name = ( yield self._ask_id_server_for_third_party_invite( id_server=id_server, medium=medium, address=address, room_id=room_id, inviter_user_id=user.to_string(), room_alias=canonical_room_alias, room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url ) ) msg_handler = self.hs.get_handlers().message_handler yield msg_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, "content": { "display_name": display_name, "public_keys": public_keys, # For backwards compatibility: "key_validity_url": fallback_public_key["key_validity_url"], "public_key": fallback_public_key["public_key"], }, "room_id": room_id, "sender": user.to_string(), "state_key": token, }, txn_id=txn_id, ) @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( self, id_server, medium, address, room_id, inviter_user_id, room_alias, room_avatar_url, room_join_rules, room_name, inviter_display_name, inviter_avatar_url ): """ Asks an identity server for a third party invite. Args: id_server (str): hostname + optional port for the identity server. medium (str): The literal string "email". address (str): The third party address being invited. room_id (str): The ID of the room to which the user is invited. inviter_user_id (str): The user ID of the inviter. room_alias (str): An alias for the room, for cosmetic notifications. room_avatar_url (str): The URL of the room's avatar, for cosmetic notifications. room_join_rules (str): The join rules of the email (e.g. "public"). room_name (str): The m.room.name of the room. inviter_display_name (str): The current display name of the inviter. inviter_avatar_url (str): The URL of the inviter's avatar. Returns: A deferred tuple containing: token (str): The token which must be signed to prove authenticity. public_keys ([{"public_key": str, "key_validity_url": str}]): public_key is a base64-encoded ed25519 public key. fallback_public_key: One element from public_keys. display_name (str): A user-friendly name to represent the invited user. """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, id_server, ) invite_config = { "medium": medium, "address": address, "room_id": room_id, "room_alias": room_alias, "room_avatar_url": room_avatar_url, "room_join_rules": room_join_rules, "room_name": room_name, "sender": inviter_user_id, "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } if self.hs.config.invite_3pid_guest: registration_handler = self.hs.get_handlers().registration_handler guest_access_token = yield registration_handler.guest_access_token_for( medium=medium, address=address, inviter_user_id=inviter_user_id, ) guest_user_info = yield self.hs.get_auth().get_user_by_access_token( guest_access_token ) invite_config.update({ "guest_access_token": guest_access_token, "guest_user_id": guest_user_info["user"].to_string(), }) data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( is_url, invite_config ) # TODO: Check for success token = data["token"] public_keys = data.get("public_keys", []) if "public_key" in data: fallback_public_key = { "public_key": data["public_key"], "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( id_server_scheme, id_server, ), } else: fallback_public_key = public_keys[0] if not public_keys: public_keys.append(fallback_public_key) display_name = data["display_name"] defer.returnValue((token, public_keys, fallback_public_key, display_name)) @defer.inlineCallbacks def forget(self, user, room_id): user_id = user.to_string() member = yield self.state_handler.get_current_state( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership is not None and membership != Membership.LEAVE: raise SynapseError(400, "User %s in room %s" % ( user_id, room_id )) if membership: yield self.store.forget(user_id, room_id)
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, ) @defer.inlineCallbacks def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() yield self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec()) def mark_recently_accessed(self, server_name, media_id): """Mark the given media as recently accessed. Args: server_name (str|None): Origin server of media, or None if local media_id (str): The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) else: self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): """Store uploaded content for a local user and return the mxc URL Args: media_type(str): The content type of the file upload_name(str): The name of the file content: A file like object that is the content to store content_length(int): The length of the content auth_user(str): The user_id of the uploader Returns: Deferred[str]: The mxc url of the stored content """ media_id = random_string(24) file_info = FileInfo( server_name=None, file_id=media_id, ) fname = yield self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) yield self._generate_thumbnails( None, media_id, media_id, media_type, ) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: request(twisted.web.http.Request) media_id (str): The media ID of the content. (This is the same as the file_id for local content.) name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ media_info = yield self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return self.mark_recently_accessed(None, media_id) media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] file_info = FileInfo( None, media_id, url_cache=url_cache, ) responder = yield self.media_storage.fetch_media(file_info) yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) @defer.inlineCallbacks def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: request(twisted.web.http.Request) server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ if (self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist): raise FederationDeniedError(server_name) self.mark_recently_accessed(server_name, media_id) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # We deliberately stream the file outside the lock if responder: media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) else: respond_404(request) @defer.inlineCallbacks def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[dict]: The media_info of the file """ if (self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist): raise FederationDeniedError(server_name) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # Ensure we actually use the responder so that it releases resources if responder: with responder: pass defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[(Responder, media_info)] """ media_info = yield self.store.get_cached_remote_media( server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. if media_info: file_id = media_info["filesystem_id"] else: file_id = random_string(24) file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() responder = yield self.media_storage.fetch_media(file_info) if responder: defer.returnValue((responder, media_info)) # Failed to find the file anywhere, lets download it. media_info = yield self._download_remote_file(server_name, media_id, file_id) responder = yield self.media_storage.fetch_media(file_info) defer.returnValue((responder, media_info)) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: server_name (str): Originating server media_id (str): The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. file_id (str): Local file ID Returns: Deferred[MediaInfo] """ file_info = FileInfo( server_name=server_name, file_id=file_id, ) with self.media_storage.store_into_file(file_info) as (f, fname, finish): request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. "allow_remote": "false", }) except twisted.internet.error.DNSLookupError as e: logger.warn("HTTP error fetching remote media %s/%s: %r", server_name, media_id, e) raise NotFoundError() except HttpResponseException as e: logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: raise SynapseError.from_http_response_exception(e) raise SynapseError(502, "Failed to fetch remote media") except SynapseError: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") yield finish() media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0], ) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None logger.info("Stored remote media in file %r", fname) yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_thumbnails( server_name, media_id, file_id, media_type, ) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: t_byte_source = None return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type, url_cache): input_path = yield self.media_storage.ensure_media_is_in_local_cache( FileInfo( None, media_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: file_info = FileInfo( server_name=None, file_id=media_id, url_cache=url_cache, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = yield self.media_storage.ensure_media_is_in_local_cache( FileInfo( server_name, file_id, url_cache=False, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: file_info = FileInfo( server_name=server_name, file_id=media_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def _generate_thumbnails(self, server_name, media_id, file_id, media_type, url_cache=False): """Generate and store thumbnails for an image. Args: server_name (str|None): The server name if remote media, else None if local media_id (str): The media ID of the content. (This is the same as the file_id for local content) file_id (str): Local file ID media_type (str): The content type of the file url_cache (bool): If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: Deferred[dict]: Dict with "width" and "height" keys of original image """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = yield self.media_storage.ensure_media_is_in_local_cache( FileInfo( server_name, file_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) elif r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): # Generate the thumbnail if t_method == "crop": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.crop, t_width, t_height, t_type, )) elif t_method == "scale": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.scale, t_width, t_height, t_type, )) else: logger.error("Unrecognized method: %r", t_method) continue if not t_byte_source: continue try: file_info = FileInfo( server_name=server_name, file_id=file_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, url_cache=url_cache, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() t_len = os.path.getsize(output_path) # Write to database if server_name: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) else: yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) # TODO: Should we delete from the backup store with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath( origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = MediaFilePaths(hs.config.media_store_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer() self.recently_accessed_remotes = set() self.clock.looping_call( self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS ) @defer.inlineCallbacks def _update_recently_accessed_remotes(self): media = self.recently_accessed_remotes self.recently_accessed_remotes = set() yield self.store.update_cached_last_access_time( media, self.clock.time_msec() ) @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) if not os.path.exists(dirname): os.makedirs(dirname) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): media_id = random_string(24) fname = self.filepaths.local_media_filepath(media_id) self._makedirs(fname) # This shouldn't block for very long because the content will have # already been uploaded at this point. with open(fname, "wb") as f: f.write(content) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) media_info = { "media_type": media_type, "media_length": content_length, } yield self._generate_local_thumbnails(media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_remote_media(self, server_name, media_id): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): media_info = yield self._get_remote_media_impl(server_name, media_id) defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): media_info = yield self.store.get_cached_remote_media( server_name, media_id ) if not media_info: media_info = yield self._download_remote_file( server_name, media_id ) else: self.recently_accessed_remotes.add((server_name, media_id)) yield self.store.update_cached_last_access_time( [(server_name, media_id)], self.clock.time_msec() ) defer.returnValue(media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) fname = self.filepaths.remote_media_filepath( server_name, file_id ) self._makedirs(fname) try: with open(fname, "wb") as f: request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, ) except Exception as e: logger.warn("Failed to fetch remoted media %r", e) raise SynapseError(502, "Failed to fetch remoted media") media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) except: os.remove(fname) raise media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_remote_thumbnails( server_name, media_id, media_info ) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, input_path, t_path, t_width, t_height, t_method, t_type): thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return if t_method == "crop": t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) elif t_method == "scale": t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) else: t_len = None return t_len @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def _generate_local_thumbnails(self, media_id, media_info): media_type = media_info["media_type"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return local_thumbnails = [] def generate_thumbnails(): scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for l in local_thumbnails: yield self.store.store_local_thumbnail(*l) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def _generate_remote_thumbnails(self, server_name, media_id, media_info): media_type = media_info["media_type"] file_id = media_info["filesystem_id"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return remote_thumbnails = [] input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height def generate_thumbnails(): if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for r in remote_thumbnails: yield self.store.store_remote_media_thumbnail(*r) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id ) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
class RulesForRoom(object): """Caches push rules for users in a room. This efficiently handles users joining/leaving the room by not invalidating the entire cache for the room. """ def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): """ Args: hs (HomeServer) room_id (str) rules_for_room_cache(Cache): The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics (CacheMetric) """ self.room_id = room_id self.is_mine_id = hs.is_mine_id self.store = hs.get_datastore() self.room_push_rule_cache_metrics = room_push_rule_cache_metrics self.linearizer = Linearizer(name="rules_for_room") self.member_map = {} # event_id -> (user_id, state) self.rules_by_user = {} # user_id -> rules # The last state group we updated the caches for. If the state_group of # a new event comes along, we know that we can just return the cached # result. # On invalidation of the rules themselves (if the user changes them), # we invalidate everything and set state_group to `object()` self.state_group = object() # A sequence number to keep track of when we're allowed to update the # cache. We bump the sequence number when we invalidate the cache. If # the sequence number changes while we're calculating stuff we should # not update the cache with it. self.sequence = 0 # A cache of user_ids that we *know* aren't interesting, e.g. user_ids # owned by AS's, or remote users, etc. (I.e. users we will never need to # calculate push for) # These never need to be invalidated as we will never set up push for # them. self.uninteresting_user_set = set() # We need to be clever on the invalidating caches callbacks, as # otherwise the invalidation callback holds a reference to the object, # potentially causing it to leak. # To get around this we pass a function that on invalidations looks ups # the RoomsForUser entry in the cache, rather than keeping a reference # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) @defer.inlineCallbacks def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ state_group = context.state_group if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() defer.returnValue(self.rules_by_user) with (yield self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() defer.returnValue(self.rules_by_user) self.room_push_rule_cache_metrics.inc_misses() ret_rules_by_user = {} missing_member_event_ids = {} if state_group and self.state_group == context.prev_group: # If we have a simple delta then we can reuse most of the previous # results. ret_rules_by_user = self.rules_by_user current_state_ids = context.delta_ids push_rules_delta_state_cache_metric.inc_hits() else: current_state_ids = context.current_state_ids push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc_by(len(current_state_ids)) logger.debug( "Looking for member changes in %r %r", state_group, current_state_ids ) # Loop through to see which member events we've seen and have rules # for and which we need to fetch for key in current_state_ids: typ, user_id = key if typ != EventTypes.Member: continue if user_id in self.uninteresting_user_set: continue if not self.is_mine_id(user_id): self.uninteresting_user_set.add(user_id) continue if self.store.get_if_app_services_interested_in_user(user_id): self.uninteresting_user_set.add(user_id) continue event_id = current_state_ids[key] res = self.member_map.get(event_id, None) if res: user_id, state = res if state == Membership.JOIN: rules = self.rules_by_user.get(user_id, None) if rules: ret_rules_by_user[user_id] = rules continue # If a user has left a room we remove their push rule. If they # joined then we readd it later in _update_rules_with_member_event_ids ret_rules_by_user.pop(user_id, None) missing_member_event_ids[user_id] = event_id if missing_member_event_ids: # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) yield self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) else: # The push rules didn't change but lets update the cache anyway self.update_cache( self.sequence, members={}, # There were no membership changes rules_by_user=ret_rules_by_user, state_group=state_group ) if logger.isEnabledFor(logging.DEBUG): logger.debug( "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys(), ) defer.returnValue(ret_rules_by_user) @defer.inlineCallbacks def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, state_group, event): """Update the partially filled rules_by_user dict by fetching rules for any newly joined users in the `member_event_ids` list. Args: ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets updated with any new rules. member_event_ids (list): List of event ids for membership events that have happened since the last time we filled rules_by_user state_group: The state group we are currently computing push rules for. Used when updating the cache. """ sequence = self.sequence rows = yield self.store._simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids.values(), retcols=('user_id', 'membership', 'event_id'), keyvalues={}, batch_size=500, desc="_get_rules_for_member_event_ids", ) members = { row["event_id"]: (row["user_id"], row["membership"]) for row in rows } # If the event is a join event then it will be in current state evnts # map but not in the DB, so we have to explicitly insert it. if event.type == EventTypes.Member: for event_id in member_event_ids.itervalues(): if event_id == event.event_id: members[event_id] = (event.state_key, event.membership) if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) interested_in_user_ids = set( user_id for user_id, membership in members.itervalues() if membership == Membership.JOIN ) logger.debug("Joined: %r", interested_in_user_ids) if_users_with_pushers = yield self.store.get_if_users_have_pushers( interested_in_user_ids, on_invalidate=self.invalidate_all_cb, ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher ) logger.debug("With pushers: %r", user_ids) users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( self.room_id, on_invalidate=self.invalidate_all_cb, ) logger.debug("With receipts: %r", users_with_receipts) # any users with pushers must be ours: they have pushers for uid in users_with_receipts: if uid in interested_in_user_ids: user_ids.add(uid) rules_by_user = yield self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb, ) ret_rules_by_user.update( item for item in rules_by_user.iteritems() if item[0] is not None ) self.update_cache(sequence, members, ret_rules_by_user, state_group) def invalidate_all(self): # Note: Don't hand this function directly to an invalidation callback # as it keeps a reference to self and will stop this instance from being # GC'd if it gets dropped from the rules_to_user cache. Instead use # `self.invalidate_all_cb` logger.debug("Invalidating RulesForRoom for %r", self.room_id) self.sequence += 1 self.state_group = object() self.member_map = {} self.rules_by_user = {} push_rules_invalidation_counter.inc() def update_cache(self, sequence, members, rules_by_user, state_group): if sequence == self.sequence: self.member_map.update(members) self.rules_by_user = rules_by_user self.state_group = state_group
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.backup_base_path = hs.config.backup_media_store_path self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.clock.looping_call(self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS) @defer.inlineCallbacks def _update_recently_accessed_remotes(self): media = self.recently_accessed_remotes self.recently_accessed_remotes = set() yield self.store.update_cached_last_access_time( media, self.clock.time_msec()) @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) if not os.path.exists(dirname): os.makedirs(dirname) @staticmethod def _write_file_synchronously(source, fname): """Write `source` to the path `fname` synchronously. Should be called from a thread. Args: source: A file like object to be written fname (str): Path to write to """ MediaRepository._makedirs(fname) source.seek(0) # Ensure we read from the start of the file with open(fname, "wb") as f: shutil.copyfileobj(source, f) @defer.inlineCallbacks def write_to_file_and_backup(self, source, path): """Write `source` to the on disk media store, and also the backup store if configured. Args: source: A file like object that should be written path (str): Relative path to write file to Returns: Deferred[str]: the file path written to in the primary media store """ fname = os.path.join(self.primary_base_path, path) # Write to the main repository yield make_deferred_yieldable( threads.deferToThread( self._write_file_synchronously, source, fname, )) # Write to backup repository yield self.copy_to_backup(path) defer.returnValue(fname) @defer.inlineCallbacks def copy_to_backup(self, path): """Copy a file from the primary to backup media store, if configured. Args: path(str): Relative path to write file to """ if self.backup_base_path: primary_fname = os.path.join(self.primary_base_path, path) backup_fname = os.path.join(self.backup_base_path, path) # We can either wait for successful writing to the backup repository # or write in the background and immediately return if self.synchronous_backup_media_store: yield make_deferred_yieldable( threads.deferToThread( shutil.copyfile, primary_fname, backup_fname, )) else: preserve_fn(threads.deferToThread)( shutil.copyfile, primary_fname, backup_fname, ) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): """Store uploaded content for a local user and return the mxc URL Args: media_type(str): The content type of the file upload_name(str): The name of the file content: A file like object that is the content to store content_length(int): The length of the content auth_user(str): The user_id of the uploader Returns: Deferred[str]: The mxc url of the stored content """ media_id = random_string(24) fname = yield self.write_to_file_and_backup( content, self.filepaths.local_media_filepath_rel(media_id)) logger.info("Stored local media in file %r", fname) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) media_info = { "media_type": media_type, "media_length": content_length, } yield self._generate_thumbnails(None, media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_remote_media(self, server_name, media_id): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): media_info = yield self._get_remote_media_impl( server_name, media_id) defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): media_info = yield self.store.get_cached_remote_media( server_name, media_id) if not media_info: media_info = yield self._download_remote_file( server_name, media_id) elif media_info["quarantined_by"]: raise NotFoundError() else: self.recently_accessed_remotes.add((server_name, media_id)) yield self.store.update_cached_last_access_time( [(server_name, media_id)], self.clock.time_msec()) defer.returnValue(media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) fpath = self.filepaths.remote_media_filepath_rel(server_name, file_id) fname = os.path.join(self.primary_base_path, fpath) self._makedirs(fname) try: with open(fname, "wb") as f: request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. "allow_remote": "false", }) except twisted.internet.error.DNSLookupError as e: logger.warn("HTTP error fetching remote media %s/%s: %r", server_name, media_id, e) raise NotFoundError() except HttpResponseException as e: logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: raise SynapseError.from_http_response_exception(e) raise SynapseError(502, "Failed to fetch remote media") except SynapseError: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") yield self.copy_to_backup(fpath) media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0], ) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None logger.info("Stored remote media in file %r", fname) yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) except Exception: os.remove(fname) raise media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_thumbnails(server_name, media_id, media_info) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: t_byte_source = None return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: output_path = yield self.write_to_file_and_backup( t_byte_source, self.filepaths.local_media_thumbnail_rel( media_id, t_width, t_height, t_type, t_method)) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: output_path = yield self.write_to_file_and_backup( t_byte_source, self.filepaths.remote_media_thumbnail_rel( server_name, file_id, t_width, t_height, t_type, t_method)) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): """Generate and store thumbnails for an image. Args: server_name(str|None): The server name if remote media, else None if local media_id(str) media_info(dict) url_cache(bool): If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: Deferred[dict]: Dict with "width" and "height" keys of original image """ media_type = media_info["media_type"] file_id = media_info.get("filesystem_id") requirements = self._get_thumbnail_requirements(media_type) if not requirements: return if server_name: input_path = self.filepaths.remote_media_filepath( server_name, file_id) elif url_cache: input_path = self.filepaths.url_cache_filepath(media_id) else: input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) elif r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): # Work out the correct file name for thumbnail if server_name: file_path = self.filepaths.remote_media_thumbnail_rel( server_name, file_id, t_width, t_height, t_type, t_method) elif url_cache: file_path = self.filepaths.url_cache_thumbnail_rel( media_id, t_width, t_height, t_type, t_method) else: file_path = self.filepaths.local_media_thumbnail_rel( media_id, t_width, t_height, t_type, t_method) # Generate the thumbnail if t_method == "crop": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.crop, t_width, t_height, t_type, )) elif t_method == "scale": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.scale, t_width, t_height, t_type, )) else: logger.error("Unrecognized method: %r", t_method) continue if not t_byte_source: continue try: # Write to disk output_path = yield self.write_to_file_and_backup( t_byte_source, file_path, ) finally: t_byte_source.close() t_len = os.path.getsize(output_path) # Write to database if server_name: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) else: yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) # TODO: Should we delete from the backup store with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath( origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
class RegistrationHandler(BaseHandler): def __init__(self, hs): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None self.macaroon_gen = hs.get_macaroon_generator() self._generate_user_id_linearizer = Linearizer( name="_generate_user_id_linearizer", ) @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", Codes.INVALID_USERNAME ) if not localpart: raise SynapseError( 400, "User ID cannot be empty", Codes.INVALID_USERNAME ) if localpart[0] == '_': raise SynapseError( 400, "User ID may not begin with _", Codes.INVALID_USERNAME ) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() if assigned_user_id: if user_id == assigned_user_id: return else: raise SynapseError( 400, "A different user ID has already been registered for this session", ) self.check_user_id_not_appservice_exclusive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE, ) user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( 403, "Cannot register taken user ID without valid guest " "credentials for that user.", errcode=Codes.FORBIDDEN, ) @defer.inlineCallbacks def register( self, localpart=None, password=None, generate_token=True, guest_access_token=None, make_guest=False, admin=False, ): """Registers a new client on the server. Args: localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). generate_token (bool): Whether a new access token should be generated. Having this be True should be considered deprecated, since it offers no means of associating a device_id with the access_token. Instead you should call auth_handler.issue_access_token after registration. Returns: A tuple of (user_id, access_token). Raises: RegistrationError if there was a problem registering. """ yield run_on_reactor() password_hash = None if password: password_hash = yield self.auth_handler().hash(password) if localpart: yield self.check_username(localpart, guest_access_token=guest_access_token) was_guest = guest_access_token is not None if not was_guest: try: int(localpart) raise RegistrationError( 400, "Numeric user IDs are reserved for guest users." ) except ValueError: pass user = UserID(localpart, self.hs.hostname) user_id = user.to_string() token = None if generate_token: token = self.macaroon_gen.generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, create_profile_with_localpart=( # If the user was a guest then they already have a profile None if was_guest else user.localpart ), admin=admin, ) if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(localpart) yield self.user_directory_handler.handle_local_profile_change( user_id, profile ) else: # autogen a sequential user ID attempts = 0 token = None user = None while not user: localpart = yield self._generate_user_id(attempts > 0) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) if generate_token: token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, token=token, password_hash=password_hash, make_guest=make_guest, create_profile_with_localpart=user.localpart, ) except SynapseError: # if user id is taken, just generate another user = None user_id = None token = None attempts += 1 # auto-join the user to any rooms we're supposed to dump them into fake_requester = create_requester(user_id) for r in self.hs.config.auto_join_rooms: try: yield self._join_user_to_room(fake_requester, r) except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding # rather than there being consistent matrix-wide ones, so we don't. defer.returnValue((user_id, token)) @defer.inlineCallbacks def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) if not service: raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): raise SynapseError( 400, "Invalid user localpart for this application service.", errcode=Codes.EXCLUSIVE ) service_id = service.id if service.is_exclusive_user(user_id) else None yield self.check_user_id_not_appservice_exclusive( user_id, allowed_appservice=service ) yield self.store.register( user_id=user_id, password_hash="", appservice_id=service_id, create_profile_with_localpart=user.localpart, ) defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): """ Checks a recaptcha is correct. Used only by c/s api v1 """ captcha_response = yield self._validate_captcha( ip, private_key, challenge, response ) if not captcha_response["valid"]: logger.info("Invalid captcha entered from %s. Error: %s", ip, captcha_response["error_url"]) raise InvalidCaptchaError( error_url=captcha_response["error_url"] ) else: logger.info("Valid captcha entered from %s", ip) @defer.inlineCallbacks def register_saml2(self, localpart): """ Registers email_id as SAML2 Based Auth. """ if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", ) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() yield self.check_user_id_not_appservice_exclusive(user_id) token = self.macaroon_gen.generate_access_token(user_id) try: yield self.store.register( user_id=user_id, token=token, password_hash=None, create_profile_with_localpart=user.localpart, ) except Exception as e: yield self.store.add_access_token_to_user(user_id, token) # Ignore Registration errors logger.exception(e) defer.returnValue((user_id, token)) @defer.inlineCallbacks def register_email(self, threepidCreds): """ Registers emails with an identity server. Used only by c/s api v1 """ for c in threepidCreds: logger.info("validating threepidcred sid %s on id server %s", c['sid'], c['idServer']) try: identity_handler = self.hs.get_handlers().identity_handler threepid = yield identity_handler.threepid_from_creds(c) except Exception: logger.exception("Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid") if not threepid: raise RegistrationError(400, "Couldn't validate 3pid") logger.info("got threepid with medium '%s' and address '%s'", threepid['medium'], threepid['address']) if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): raise RegistrationError( 403, "Third party identifier is not allowed" ) @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): """Links emails with a user ID and informs an identity server. Used only by c/s api v1 """ # Now we have a matrix ID, bind it to the threepids we were given for c in threepidCreds: identity_handler = self.hs.get_handlers().identity_handler # XXX: This should be a deferred list, shouldn't it? yield identity_handler.bind_threepid(c, user_id) def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_user(user_id) and s != allowed_appservice ] for service in interested_services: if service.is_exclusive_user(user_id): raise SynapseError( 400, "This user ID is reserved by an application service.", errcode=Codes.EXCLUSIVE ) @defer.inlineCallbacks def _generate_user_id(self, reseed=False): if reseed or self._next_generated_user_id is None: with (yield self._generate_user_id_linearizer.queue(())): if reseed or self._next_generated_user_id is None: self._next_generated_user_id = ( yield self.store.find_next_generated_user_id_localpart() ) id = self._next_generated_user_id self._next_generated_user_id += 1 defer.returnValue(str(id)) @defer.inlineCallbacks def _validate_captcha(self, ip_addr, private_key, challenge, response): """Validates the captcha provided. Used only by c/s api v1 Returns: dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. """ response = yield self._submit_captcha(ip_addr, private_key, challenge, response) # parse Google's response. Lovely format.. lines = response.split('\n') json = { "valid": lines[0] == 'true', "error_url": "http://www.google.com/recaptcha/api/challenge?" + "error=%s" % lines[1] } defer.returnValue(json) @defer.inlineCallbacks def _submit_captcha(self, ip_addr, private_key, challenge, response): """ Used only by c/s api v1 """ data = yield self.captcha_client.post_urlencoded_get_raw( "http://www.google.com:80/recaptcha/api/verify", args={ 'privatekey': private_key, 'remoteip': ip_addr, 'challenge': challenge, 'response': response } ) defer.returnValue(data) @defer.inlineCallbacks def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. Args: localpart : The local part of the user ID to register. If None, one will be randomly generated. Returns: A tuple of (user_id, access_token). Raises: RegistrationError if there was a problem registering. """ yield run_on_reactor() if localpart is None: raise SynapseError(400, "Request must include user id") need_register = True try: yield self.check_username(localpart) except SynapseError as e: if e.errcode == Codes.USER_IN_USE: need_register = False else: raise user = UserID(localpart, self.hs.hostname) user_id = user.to_string() token = self.macaroon_gen.generate_access_token(user_id) if need_register: yield self.store.register( user_id=user_id, token=token, password_hash=password_hash, create_profile_with_localpart=user.localpart, ) else: yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) yield self.profile_handler.set_displayname( user, requester, displayname, by_admin=True, ) defer.returnValue((user_id, token)) def auth_handler(self): return self.hs.get_auth_handler() @defer.inlineCallbacks def get_or_register_3pid_guest(self, medium, address, inviter_user_id): """Get a guest access token for a 3PID, creating a guest account if one doesn't already exist. Args: medium (str) address (str) inviter_user_id (str): The user ID who is trying to invite the 3PID Returns: Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the 3PID guest account. """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: user_info = yield self.auth.get_user_by_access_token( access_token ) defer.returnValue((user_info["user"].to_string(), access_token)) user_id, access_token = yield self.register( generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id ) defer.returnValue((user_id, access_token)) @defer.inlineCallbacks def _join_user_to_room(self, requester, room_identifier): room_id = None room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = ( yield room_member_handler.lookup_room_alias(room_alias) ) room_id = room_id.to_string() else: raise SynapseError(400, "%s was not legal room ID or room alias" % ( room_identifier, )) yield room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, remote_room_hosts=remote_room_hosts, action="join", )
class StateHandler(object): """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() self.hs = hs # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None self.resolve_linearizer = Linearizer() def start_caching(self): logger.debug("start_caching") self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, max_len=SIZE_OF_CACHE, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, reset_expiry_on_get=True, ) self._state_cache.start() @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key="", latest_event_ids=None): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. This is equivalent to getting the state of an event that were to send next before receiving any new events. If `event_type` is specified, then the method returns only the one event (or None) with that `event_type` and `state_key`. Returns: map from (type, state_key) to event """ if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state if event_type: event_id = state.get((event_type, state_key)) event = None if event_id: event = yield self.store.get_event(event_id, allow_none=True) defer.returnValue(event) return state_map = yield self.store.get_events(state.values(), get_prev_content=False) state = { key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } defer.returnValue(state) @defer.inlineCallbacks def get_current_state_ids(self, room_id, event_type=None, state_key="", latest_event_ids=None): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) ret = yield self.resolve_state_groups(room_id, latest_event_ids) state = ret.state if event_type: defer.returnValue(state.get((event_type, state_key))) return defer.returnValue(state) @defer.inlineCallbacks def get_current_user_in_room(self, room_id, latest_event_ids=None): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( room_id, entry.state_id, entry.state ) defer.returnValue(joined_users) @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """ Fills out the context with the `current state` of the graph. The `current state` here is defined to be the state of the event graph just before the event - i.e. it never includes `event` If `event` has `auth_events` then this will also fill out the `auth_events` field on `context` from the `current_state`. Args: event (EventBase) Returns: an EventContext """ context = EventContext() if event.internal_metadata.is_outlier(): # If this is an outlier, then we know it shouldn't have any current # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } if event.is_state(): context.current_state_events = dict(context.prev_state_ids) key = (event.type, event.state_key) context.current_state_events[key] = event.event_id else: context.current_state_events = context.prev_state_ids else: context.current_state_ids = {} context.prev_state_ids = {} context.prev_state_events = [] context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: context.prev_state_ids = { (s.type, s.state_key): s.event_id for s in old_state } context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) if key in context.prev_state_ids: replaces = context.prev_state_ids[key] if replaces != event.event_id: # Paranoia check event.unsigned["replaces_state"] = replaces context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id else: context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) if event.is_state(): entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], ) curr_state = entry.state context.prev_state_ids = curr_state if event.is_state(): context.state_group = self.store.get_next_state_group() key = (event.type, event.state_key) if key in context.prev_state_ids: replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id context.prev_group = entry.prev_group context.delta_ids = entry.delta_ids if context.delta_ids is not None: context.delta_ids = dict(context.delta_ids) context.delta_ids[key] = event.event_id else: if entry.state_group is None: entry.state_group = self.store.get_next_state_group() entry.state_id = entry.state_group context.state_group = entry.state_group context.current_state_ids = context.prev_state_ids context.prev_group = entry.prev_group context.delta_ids = entry.delta_ids context.prev_state_events = [] defer.returnValue(context) @defer.inlineCallbacks @log_function def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Returns: a Deferred tuple of (`state_group`, `state`, `prev_state`). `state_group` is the name of a state group if one and only one is involved. `state` is a map from (type, state_key) to event, and `prev_state` is a list of event ids. """ logger.debug("resolve_state_groups event_ids %s", event_ids) state_groups_ids = yield self.store.get_state_groups_ids( room_id, event_ids ) logger.debug( "resolve_state_groups state_groups %s", state_groups_ids.keys() ) group_names = frozenset(state_groups_ids.keys()) if len(group_names) == 1: name, state_list = state_groups_ids.items().pop() defer.returnValue(_StateCacheEntry( state=state_list, state_group=name, prev_group=name, delta_ids={}, )) with (yield self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: defer.returnValue(cache) logger.info( "Resolving state for %s with %d groups", room_id, len(state_groups_ids) ) state = {} for st in state_groups_ids.values(): for key, e_id in st.items(): state.setdefault(key, set()).add(e_id) conflicted_state = { k: list(v) for k, v in state.items() if len(v) > 1 } if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) state_map = yield self.store.get_events( [e_id for st in state_groups_ids.values() for e_id in st.values()], get_prev_content=False ) state_sets = [ [state_map[e_id] for key, e_id in st.items() if e_id in state_map] for st in state_groups_ids.values() ] new_state, _ = self._resolve_events( state_sets, event_type, state_key ) new_state = { key: e.event_id for key, e in new_state.items() } else: new_state = { key: e_ids.pop() for key, e_ids in state.items() } state_group = None new_state_event_ids = frozenset(new_state.values()) for sg, events in state_groups_ids.items(): if new_state_event_ids == frozenset(e_id for e_id in events): state_group = sg break if state_group is None: # Worker instances don't have access to this method, but we want # to set the state_group on the main instance to increase cache # hits. if hasattr(self.store, "get_next_state_group"): state_group = self.store.get_next_state_group() prev_group = None delta_ids = None for old_group, old_ids in state_groups_ids.items(): if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): n_delta_ids = { k: v for k, v in new_state.items() if old_ids.get(k) != v } if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids cache = _StateCacheEntry( state=new_state, state_group=state_group, prev_group=prev_group, delta_ids=delta_ids, ) if self._state_cache is not None: self._state_cache[group_names] = cache defer.returnValue(cache) def resolve_events(self, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) if event.is_state(): return self._resolve_events( state_sets, event.type, event.state_key ) else: return self._resolve_events(state_sets) def _resolve_events(self, state_sets, event_type=None, state_key=""): """ Returns (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple (new_state, prev_states). new_state is a map from (type, state_key) to event. prev_states is a list of event_ids. """ with Measure(self.clock, "state._resolve_events"): state = {} for st in state_sets: for e in st: state.setdefault( (e.type, e.state_key), {} )[e.event_id] = e unconflicted_state = { k: v.values()[0] for k, v in state.items() if len(v.values()) == 1 } conflicted_state = { k: v.values() for k, v in state.items() if len(v.values()) > 1 } if event_type: prev_states_events = conflicted_state.get( (event_type, state_key), [] ) prev_states = [s.event_id for s in prev_states_events] else: prev_states = [] auth_events = { k: e for k, e in unconflicted_state.items() if k[0] in AuthEventTypes } try: resolved_state = self._resolve_state_events( conflicted_state, auth_events ) except: logger.exception("Failed to resolve state") raise new_state = unconflicted_state new_state.update(resolved_state) return new_state, prev_states @log_function def _resolve_state_events(self, conflicted_state, auth_events): """ This is where we actually decide which of the conflicted state to use. We resolve conflicts in the following order: 1. power levels 2. join rules 3. memberships 4. other events. """ resolved_state = {} power_key = (EventTypes.PowerLevels, "") if power_key in conflicted_state: events = conflicted_state[power_key] logger.debug("Resolving conflicted power levels %r", events) resolved_state[power_key] = self._resolve_auth_events( events, auth_events) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) resolved_state[key] = self._resolve_auth_events( events, auth_events ) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) resolved_state[key] = self._resolve_auth_events( events, auth_events ) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) resolved_state[key] = self._resolve_normal_events( events, auth_events ) return resolved_state def _resolve_auth_events(self, events, auth_events): reverse = [i for i in reversed(self._ordered_events(events))] auth_events = dict(auth_events) prev_event = reverse[0] for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. # The signatures have already been checked at this point self.hs.get_auth().check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: return prev_event return event def _resolve_normal_events(self, events, auth_events): for event in self._ordered_events(events): try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. # The signatures have already been checked at this point self.hs.get_auth().check(event, auth_events, do_sig_check=False) return event except AuthError: pass # Use the last event (the one with the least depth) if they all fail # the auth check. return event def _ordered_events(self, events): def key_func(e): return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() return sorted(events, key=key_func)