def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.backup_base_path = hs.config.backup_media_store_path self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.clock.looping_call(self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS)
def __init__(self, hs, filepaths): Resource.__init__(self) self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = filepaths self.version_string = hs.version_string self.downloads = {}
def __init__(self, hs, filepaths): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = filepaths self.downloads = {} self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements
def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.hs, self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, )
def get_federation_http_client(self) -> MatrixFederationHttpClient: """ An HTTP client for federation. """ tls_client_options_factory = context_factory.FederationPolicyForHTTPS( self.config) return MatrixFederationHttpClient(self, tls_client_options_factory)
def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.hs, self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, )
def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = MediaFilePaths(hs.config.media_store_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer() self.recently_accessed_remotes = set() self.clock.looping_call( self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS )
class FederationClientTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(reactor=reactor, clock=clock) hs.tls_client_options_factory = None return hs def prepare(self, reactor, clock, homeserver): self.cl = MatrixFederationHttpClient(self.hs) self.reactor.lookups["testserv"] = "1.2.3.4" def test_dns_error(self): """ If the DNS raising returns an error, it will bubble up. """ d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) self.pump() f = self.failureResultOf(d) self.assertIsInstance(f.value, DNSLookupError) def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertFalse(d.called) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) # Deferred is still without a result self.assertFalse(d.called) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, (ConnectingCancelledError, TimeoutError)) def test_client_connect_no_response(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertFalse(d.called) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) conn = Mock() client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred is still without a result self.assertFalse(d.called) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, ResponseNeverReceived) def test_client_gets_headers(self): """ Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar", ) d = self.cl._send_request(request, timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertFalse(d.called) # Send it the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r.code, 200) def test_client_headers_no_body(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertFalse(d.called) # Send it the HTTP response client.dataReceived( (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" b"Server: Fake\r\n\r\n") ) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, TimeoutError) def test_client_sends_body(self): self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} ) self.pump() clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) client = clients[0][2].buildProtocol(None) server = HTTPChannel() client.makeConnection(FakeTransport(server, self.reactor)) server.makeConnection(FakeTransport(client, self.reactor)) self.pump(0.1) self.assertEqual(len(server.requests), 1) request = server.requests[0] content = request.content.read() self.assertEqual(content, b'{"a":"b"}')
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = MediaFilePaths(hs.config.media_store_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer() self.recently_accessed_remotes = set() self.clock.looping_call( self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS ) @defer.inlineCallbacks def _update_recently_accessed_remotes(self): media = self.recently_accessed_remotes self.recently_accessed_remotes = set() yield self.store.update_cached_last_access_time( media, self.clock.time_msec() ) @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) if not os.path.exists(dirname): os.makedirs(dirname) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): media_id = random_string(24) fname = self.filepaths.local_media_filepath(media_id) self._makedirs(fname) # This shouldn't block for very long because the content will have # already been uploaded at this point. with open(fname, "wb") as f: f.write(content) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) media_info = { "media_type": media_type, "media_length": content_length, } yield self._generate_local_thumbnails(media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_remote_media(self, server_name, media_id): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): media_info = yield self._get_remote_media_impl(server_name, media_id) defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): media_info = yield self.store.get_cached_remote_media( server_name, media_id ) if not media_info: media_info = yield self._download_remote_file( server_name, media_id ) else: self.recently_accessed_remotes.add((server_name, media_id)) yield self.store.update_cached_last_access_time( [(server_name, media_id)], self.clock.time_msec() ) defer.returnValue(media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) fname = self.filepaths.remote_media_filepath( server_name, file_id ) self._makedirs(fname) try: with open(fname, "wb") as f: request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, ) except Exception as e: logger.warn("Failed to fetch remoted media %r", e) raise SynapseError(502, "Failed to fetch remoted media") media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) except: os.remove(fname) raise media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_remote_thumbnails( server_name, media_id, media_info ) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, input_path, t_path, t_width, t_height, t_method, t_type): thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return if t_method == "crop": t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) elif t_method == "scale": t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) else: t_len = None return t_len @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def _generate_local_thumbnails(self, media_id, media_info): media_type = media_info["media_type"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return local_thumbnails = [] def generate_thumbnails(): scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for l in local_thumbnails: yield self.store.store_local_thumbnail(*l) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def _generate_remote_thumbnails(self, server_name, media_id, media_info): media_type = media_info["media_type"] file_id = media_info["filesystem_id"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return remote_thumbnails = [] input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height def generate_thumbnails(): if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for r in remote_thumbnails: yield self.store.store_remote_media_thumbnail(*r) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id ) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, ) @defer.inlineCallbacks def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() yield self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec()) def mark_recently_accessed(self, server_name, media_id): """Mark the given media as recently accessed. Args: server_name (str|None): Origin server of media, or None if local media_id (str): The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) else: self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): """Store uploaded content for a local user and return the mxc URL Args: media_type(str): The content type of the file upload_name(str): The name of the file content: A file like object that is the content to store content_length(int): The length of the content auth_user(str): The user_id of the uploader Returns: Deferred[str]: The mxc url of the stored content """ media_id = random_string(24) file_info = FileInfo( server_name=None, file_id=media_id, ) fname = yield self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) yield self._generate_thumbnails( None, media_id, media_id, media_type, ) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: request(twisted.web.http.Request) media_id (str): The media ID of the content. (This is the same as the file_id for local content.) name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ media_info = yield self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return self.mark_recently_accessed(None, media_id) media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] file_info = FileInfo( None, media_id, url_cache=url_cache, ) responder = yield self.media_storage.fetch_media(file_info) yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) @defer.inlineCallbacks def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: request(twisted.web.http.Request) server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ if (self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist): raise FederationDeniedError(server_name) self.mark_recently_accessed(server_name, media_id) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # We deliberately stream the file outside the lock if responder: media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) else: respond_404(request) @defer.inlineCallbacks def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[dict]: The media_info of the file """ if (self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist): raise FederationDeniedError(server_name) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # Ensure we actually use the responder so that it releases resources if responder: with responder: pass defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[(Responder, media_info)] """ media_info = yield self.store.get_cached_remote_media( server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. if media_info: file_id = media_info["filesystem_id"] else: file_id = random_string(24) file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() responder = yield self.media_storage.fetch_media(file_info) if responder: defer.returnValue((responder, media_info)) # Failed to find the file anywhere, lets download it. media_info = yield self._download_remote_file(server_name, media_id, file_id) responder = yield self.media_storage.fetch_media(file_info) defer.returnValue((responder, media_info)) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: server_name (str): Originating server media_id (str): The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. file_id (str): Local file ID Returns: Deferred[MediaInfo] """ file_info = FileInfo( server_name=server_name, file_id=file_id, ) with self.media_storage.store_into_file(file_info) as (f, fname, finish): request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. "allow_remote": "false", }) except twisted.internet.error.DNSLookupError as e: logger.warn("HTTP error fetching remote media %s/%s: %r", server_name, media_id, e) raise NotFoundError() except HttpResponseException as e: logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: raise SynapseError.from_http_response_exception(e) raise SynapseError(502, "Failed to fetch remote media") except SynapseError: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") yield finish() media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0], ) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None logger.info("Stored remote media in file %r", fname) yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_thumbnails( server_name, media_id, file_id, media_type, ) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: t_byte_source = None return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type, url_cache): input_path = yield self.media_storage.ensure_media_is_in_local_cache( FileInfo( None, media_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: file_info = FileInfo( server_name=None, file_id=media_id, url_cache=url_cache, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = yield self.media_storage.ensure_media_is_in_local_cache( FileInfo( server_name, file_id, url_cache=False, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: file_info = FileInfo( server_name=server_name, file_id=media_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def _generate_thumbnails(self, server_name, media_id, file_id, media_type, url_cache=False): """Generate and store thumbnails for an image. Args: server_name (str|None): The server name if remote media, else None if local media_id (str): The media ID of the content. (This is the same as the file_id for local content) file_id (str): Local file ID media_type (str): The content type of the file url_cache (bool): If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: Deferred[dict]: Dict with "width" and "height" keys of original image """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = yield self.media_storage.ensure_media_is_in_local_cache( FileInfo( server_name, file_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) elif r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): # Generate the thumbnail if t_method == "crop": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.crop, t_width, t_height, t_type, )) elif t_method == "scale": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.scale, t_width, t_height, t_type, )) else: logger.error("Unrecognized method: %r", t_method) continue if not t_byte_source: continue try: file_info = FileInfo( server_name=server_name, file_id=file_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, url_cache=url_cache, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() t_len = os.path.getsize(output_path) # Write to database if server_name: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) else: yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) # TODO: Should we delete from the backup store with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath( origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
def build_http_client(self): return MatrixFederationHttpClient(self)
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.backup_base_path = hs.config.backup_media_store_path self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.clock.looping_call(self._update_recently_accessed_remotes, UPDATE_RECENTLY_ACCESSED_REMOTES_TS) @defer.inlineCallbacks def _update_recently_accessed_remotes(self): media = self.recently_accessed_remotes self.recently_accessed_remotes = set() yield self.store.update_cached_last_access_time( media, self.clock.time_msec()) @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) if not os.path.exists(dirname): os.makedirs(dirname) @staticmethod def _write_file_synchronously(source, fname): """Write `source` to the path `fname` synchronously. Should be called from a thread. Args: source: A file like object to be written fname (str): Path to write to """ MediaRepository._makedirs(fname) source.seek(0) # Ensure we read from the start of the file with open(fname, "wb") as f: shutil.copyfileobj(source, f) @defer.inlineCallbacks def write_to_file_and_backup(self, source, path): """Write `source` to the on disk media store, and also the backup store if configured. Args: source: A file like object that should be written path (str): Relative path to write file to Returns: Deferred[str]: the file path written to in the primary media store """ fname = os.path.join(self.primary_base_path, path) # Write to the main repository yield make_deferred_yieldable( threads.deferToThread( self._write_file_synchronously, source, fname, )) # Write to backup repository yield self.copy_to_backup(path) defer.returnValue(fname) @defer.inlineCallbacks def copy_to_backup(self, path): """Copy a file from the primary to backup media store, if configured. Args: path(str): Relative path to write file to """ if self.backup_base_path: primary_fname = os.path.join(self.primary_base_path, path) backup_fname = os.path.join(self.backup_base_path, path) # We can either wait for successful writing to the backup repository # or write in the background and immediately return if self.synchronous_backup_media_store: yield make_deferred_yieldable( threads.deferToThread( shutil.copyfile, primary_fname, backup_fname, )) else: preserve_fn(threads.deferToThread)( shutil.copyfile, primary_fname, backup_fname, ) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): """Store uploaded content for a local user and return the mxc URL Args: media_type(str): The content type of the file upload_name(str): The name of the file content: A file like object that is the content to store content_length(int): The length of the content auth_user(str): The user_id of the uploader Returns: Deferred[str]: The mxc url of the stored content """ media_id = random_string(24) fname = yield self.write_to_file_and_backup( content, self.filepaths.local_media_filepath_rel(media_id)) logger.info("Stored local media in file %r", fname) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) media_info = { "media_type": media_type, "media_length": content_length, } yield self._generate_thumbnails(None, media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_remote_media(self, server_name, media_id): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): media_info = yield self._get_remote_media_impl( server_name, media_id) defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): media_info = yield self.store.get_cached_remote_media( server_name, media_id) if not media_info: media_info = yield self._download_remote_file( server_name, media_id) elif media_info["quarantined_by"]: raise NotFoundError() else: self.recently_accessed_remotes.add((server_name, media_id)) yield self.store.update_cached_last_access_time( [(server_name, media_id)], self.clock.time_msec()) defer.returnValue(media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) fpath = self.filepaths.remote_media_filepath_rel(server_name, file_id) fname = os.path.join(self.primary_base_path, fpath) self._makedirs(fname) try: with open(fname, "wb") as f: request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. "allow_remote": "false", }) except twisted.internet.error.DNSLookupError as e: logger.warn("HTTP error fetching remote media %s/%s: %r", server_name, media_id, e) raise NotFoundError() except HttpResponseException as e: logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: raise SynapseError.from_http_response_exception(e) raise SynapseError(502, "Failed to fetch remote media") except SynapseError: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") yield self.copy_to_backup(fpath) media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0], ) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None logger.info("Stored remote media in file %r", fname) yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) except Exception: os.remove(fname) raise media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_thumbnails(server_name, media_id, media_info) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: t_byte_source = None return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: output_path = yield self.write_to_file_and_backup( t_byte_source, self.filepaths.local_media_thumbnail_rel( media_id, t_width, t_height, t_type, t_method)) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable( threads.deferToThread(self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type)) if t_byte_source: try: output_path = yield self.write_to_file_and_backup( t_byte_source, self.filepaths.remote_media_thumbnail_rel( server_name, file_id, t_width, t_height, t_type, t_method)) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(output_path) @defer.inlineCallbacks def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): """Generate and store thumbnails for an image. Args: server_name(str|None): The server name if remote media, else None if local media_id(str) media_info(dict) url_cache(bool): If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: Deferred[dict]: Dict with "width" and "height" keys of original image """ media_type = media_info["media_type"] file_id = media_info.get("filesystem_id") requirements = self._get_thumbnail_requirements(media_type) if not requirements: return if server_name: input_path = self.filepaths.remote_media_filepath( server_name, file_id) elif url_cache: input_path = self.filepaths.url_cache_filepath(media_id) else: input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) elif r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): # Work out the correct file name for thumbnail if server_name: file_path = self.filepaths.remote_media_thumbnail_rel( server_name, file_id, t_width, t_height, t_type, t_method) elif url_cache: file_path = self.filepaths.url_cache_thumbnail_rel( media_id, t_width, t_height, t_type, t_method) else: file_path = self.filepaths.local_media_thumbnail_rel( media_id, t_width, t_height, t_type, t_method) # Generate the thumbnail if t_method == "crop": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.crop, t_width, t_height, t_type, )) elif t_method == "scale": t_byte_source = yield make_deferred_yieldable( threads.deferToThread( thumbnailer.scale, t_width, t_height, t_type, )) else: logger.error("Unrecognized method: %r", t_method) continue if not t_byte_source: continue try: # Write to disk output_path = yield self.write_to_file_and_backup( t_byte_source, file_path, ) finally: t_byte_source.close() t_len = os.path.getsize(output_path) # Write to database if server_name: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) else: yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) # TODO: Should we delete from the backup store with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath( origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
class MediaRepository(object): def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.primary_base_path = hs.config.media_store_path self.filepaths = MediaFilePaths(self.primary_base_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") self.recently_accessed_remotes = set() self.recently_accessed_locals = set() self.federation_domain_whitelist = hs.config.federation_domain_whitelist # List of StorageProviders where we should search for media and # potentially upload to. storage_providers = [] for clz, provider_config, wrapper_config in hs.config.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, store_local=wrapper_config.store_local, store_remote=wrapper_config.store_remote, store_synchronous=wrapper_config.store_synchronous, ) storage_providers.append(provider) self.media_storage = MediaStorage( self.primary_base_path, self.filepaths, storage_providers, ) self.clock.looping_call( self._update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS, ) @defer.inlineCallbacks def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() yield self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec() ) def mark_recently_accessed(self, server_name, media_id): """Mark the given media as recently accessed. Args: server_name (str|None): Origin server of media, or None if local media_id (str): The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) else: self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): """Store uploaded content for a local user and return the mxc URL Args: media_type(str): The content type of the file upload_name(str): The name of the file content: A file like object that is the content to store content_length(int): The length of the content auth_user(str): The user_id of the uploader Returns: Deferred[str]: The mxc url of the stored content """ media_id = random_string(24) file_info = FileInfo( server_name=None, file_id=media_id, ) fname = yield self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) yield self._generate_thumbnails( None, media_id, media_id, media_type, ) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @defer.inlineCallbacks def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: request(twisted.web.http.Request) media_id (str): The media ID of the content. (This is the same as the file_id for local content.) name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ media_info = yield self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return self.mark_recently_accessed(None, media_id) media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] file_info = FileInfo( None, media_id, url_cache=url_cache, ) responder = yield self.media_storage.fetch_media(file_info) yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) @defer.inlineCallbacks def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: request(twisted.web.http.Request) server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). name (str|None): Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: Deferred: Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) self.mark_recently_accessed(server_name, media_id) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # We deliberately stream the file outside the lock if responder: media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] yield respond_with_responder( request, responder, media_type, media_length, upload_name, ) else: respond_404(request) @defer.inlineCallbacks def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[dict]: The media_info of the file """ if ( self.federation_domain_whitelist is not None and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( server_name, media_id, ) # Ensure we actually use the responder so that it releases resources if responder: with responder: pass defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. Args: server_name (str): Remote server_name where the media originated. media_id (str): The media ID of the content (as defined by the remote server). Returns: Deferred[(Responder, media_info)] """ media_info = yield self.store.get_cached_remote_media( server_name, media_id ) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. if media_info: file_id = media_info["filesystem_id"] else: file_id = random_string(24) file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() responder = yield self.media_storage.fetch_media(file_info) if responder: defer.returnValue((responder, media_info)) # Failed to find the file anywhere, lets download it. media_info = yield self._download_remote_file( server_name, media_id, file_id ) responder = yield self.media_storage.fetch_media(file_info) defer.returnValue((responder, media_info)) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: server_name (str): Originating server media_id (str): The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. file_id (str): Local file ID Returns: Deferred[MediaInfo] """ file_info = FileInfo( server_name=server_name, file_id=file_id, ) with self.media_storage.store_into_file(file_info) as (f, fname, finish): request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) try: length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. "allow_remote": "false", } ) except twisted.internet.error.DNSLookupError as e: logger.warn("HTTP error fetching remote media %s/%s: %r", server_name, media_id, e) raise NotFoundError() except HttpResponseException as e: logger.warn("HTTP error fetching remote media %s/%s: %s", server_name, media_id, e.response) if e.code == twisted.web.http.NOT_FOUND: raise SynapseError.from_http_response_exception(e) raise SynapseError(502, "Failed to fetch remote media") except SynapseError: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception("Failed to fetch remote media %s/%s", server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") yield finish() media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None logger.info("Stored remote media in file %r", fname) yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_thumbnails( server_name, media_id, file_id, media_type, ) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: t_byte_source = None return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type, url_cache): input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( None, media_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type )) if t_byte_source: try: file_info = FileInfo( server_name=None, file_id=media_id, url_cache=url_cache, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( server_name, file_id, url_cache=False, )) thumbnailer = Thumbnailer(input_path) t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, thumbnailer, t_width, t_height, t_method, t_type )) if t_byte_source: try: file_info = FileInfo( server_name=server_name, file_id=media_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() logger.info("Stored thumbnail in file %r", output_path) t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(output_path) @defer.inlineCallbacks def _generate_thumbnails(self, server_name, media_id, file_id, media_type, url_cache=False): """Generate and store thumbnails for an image. Args: server_name (str|None): The server name if remote media, else None if local media_id (str): The media ID of the content. (This is the same as the file_id for local content) file_id (str): Local file ID media_type (str): The content type of the file url_cache (bool): If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: Deferred[dict]: Dict with "width" and "height" keys of original image """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( server_name, file_id, url_cache=url_cache, )) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} for r_width, r_height, r_method, r_type in requirements: if r_method == "crop": thumbnails.setdefault((r_width, r_height, r_type), r_method) elif r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): # Generate the thumbnail if t_method == "crop": t_byte_source = yield make_deferred_yieldable(threads.deferToThread( thumbnailer.crop, t_width, t_height, t_type, )) elif t_method == "scale": t_byte_source = yield make_deferred_yieldable(threads.deferToThread( thumbnailer.scale, t_width, t_height, t_type, )) else: logger.error("Unrecognized method: %r", t_method) continue if not t_byte_source: continue try: file_info = FileInfo( server_name=server_name, file_id=file_id, thumbnail=True, thumbnail_width=t_width, thumbnail_height=t_height, thumbnail_method=t_method, thumbnail_type=t_type, url_cache=url_cache, ) output_path = yield self.media_storage.store_file( t_byte_source, file_info, ) finally: t_byte_source.close() t_len = os.path.getsize(output_path) # Write to database if server_name: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) else: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): old_media = yield self.store.get_remote_media_before(before_ts) deleted = 0 for media in old_media: origin = media["media_origin"] media_id = media["media_id"] file_id = media["filesystem_id"] key = (origin, media_id) logger.info("Deleting: %r", key) # TODO: Should we delete from the backup store with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) except OSError as e: logger.warn("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: continue thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( origin, file_id ) shutil.rmtree(thumbnail_dir, ignore_errors=True) yield self.store.delete_remote_media(origin, media_id) deleted += 1 defer.returnValue({"deleted": deleted})
class FederationClientTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(reactor=reactor, clock=clock) hs.tls_client_options_factory = None return hs def prepare(self, reactor, clock, homeserver): self.cl = MatrixFederationHttpClient(self.hs) self.reactor.lookups["testserv"] = "1.2.3.4" def test_dns_error(self): """ If the DNS raising returns an error, it will bubble up. """ d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) self.pump() f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertFalse(d.called) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) # Deferred is still without a result self.assertFalse(d.called) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance( f.value.inner_exception, (ConnectingCancelledError, TimeoutError), ) def test_client_connect_no_response(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertFalse(d.called) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) conn = Mock() client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred is still without a result self.assertFalse(d.called) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) def test_client_gets_headers(self): """ Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar", ) d = self.cl._send_request(request, timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertFalse(d.called) # Send it the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r.code, 200) def test_client_headers_no_body(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertFalse(d.called) # Send it the HTTP response client.dataReceived( (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" b"Server: Fake\r\n\r\n")) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, TimeoutError) def test_client_sends_body(self): self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) self.pump() clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) client = clients[0][2].buildProtocol(None) server = HTTPChannel() client.makeConnection(FakeTransport(server, self.reactor)) server.makeConnection(FakeTransport(client, self.reactor)) self.pump(0.1) self.assertEqual(len(server.requests), 1) request = server.requests[0] content = request.content.read() self.assertEqual(content, b'{"a":"b"}')
class BaseMediaResource(Resource): isLeaf = True def __init__(self, hs, filepaths): Resource.__init__(self) self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = filepaths self.version_string = hs.version_string self.downloads = {} self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements def _respond_404(self, request): respond_with_json( request, 404, cs_error( "Not found %r" % (request.postpath,), code=Codes.NOT_FOUND, ), send_cors=True ) @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) if not os.path.exists(dirname): os.makedirs(dirname) def _get_remote_media(self, server_name, media_id): key = (server_name, media_id) download = self.downloads.get(key) if download is None: download = self._get_remote_media_impl(server_name, media_id) download = ObservableDeferred( download, consumeErrors=True ) self.downloads[key] = download @download.addBoth def callback(media_info): del self.downloads[key] return media_info return download.observe() @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): media_info = yield self.store.get_cached_remote_media( server_name, media_id ) if not media_info: media_info = yield self._download_remote_file( server_name, media_id ) defer.returnValue(media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) fname = self.filepaths.remote_media_filepath( server_name, file_id ) self._makedirs(fname) try: with open(fname, "wb") as f: request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, ) media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) except: os.remove(fname) raise media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_remote_thumbnails( server_name, media_id, media_info ) defer.returnValue(media_info) @defer.inlineCallbacks def _respond_with_file(self, request, media_type, file_path, file_size=None, upload_name=None): logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): request.setHeader(b"Content-Type", media_type.encode("UTF-8")) if upload_name: if is_ascii(upload_name): request.setHeader( b"Content-Disposition", b"inline; filename=%s" % ( urllib.quote(upload_name.encode("utf-8")), ), ) else: request.setHeader( b"Content-Disposition", b"inline; filename*=utf-8''%s" % ( urllib.quote(upload_name.encode("utf-8")), ), ) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to # recommend caching as it's sensitive or private - or at least # select private. don't bother setting Expires as all our # clients are smart enough to be happy with Cache-Control request.setHeader( b"Cache-Control", b"public,max-age=86400,s-maxage=86400" ) if file_size is None: stat = os.stat(file_path) file_size = stat.st_size request.setHeader( b"Content-Length", b"%d" % (file_size,) ) with open(file_path, "rb") as f: yield FileSender().beginFileTransfer(f, request) finish_request(request) else: self._respond_404(request) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, input_path, t_path, t_width, t_height, t_method, t_type): thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return if t_method == "crop": t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) elif t_method == "scale": t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) else: t_len = None return t_len @defer.inlineCallbacks def _generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def _generate_local_thumbnails(self, media_id, media_info): media_type = media_info["media_type"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return local_thumbnails = [] def generate_thumbnails(): scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for l in local_thumbnails: yield self.store.store_local_thumbnail(*l) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def _generate_remote_thumbnails(self, server_name, media_id, media_info): media_type = media_info["media_type"] file_id = media_info["filesystem_id"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return remote_thumbnails = [] input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height def generate_thumbnails(): if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for r in remote_thumbnails: yield self.store.store_remote_media_thumbnail(*r) defer.returnValue({ "width": m_width, "height": m_height, })
class FederationClientTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(reactor=reactor, clock=clock) return hs def prepare(self, reactor, clock, homeserver): self.cl = MatrixFederationHttpClient(self.hs, None) self.reactor.lookups["testserv"] = "1.2.3.4" def test_client_get(self): """ happy-path test of a GET request """ @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: fetch_d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar")) # Nothing happened yet self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d return fetch_res finally: check_logcontext(context) test_d = do_request() self.pump() # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) # complete the connection and wire it up to a fake transport protocol = factory.buildProtocol(None) transport = StringTransport() protocol.makeConnection(transport) # that should have made it send the request to the transport self.assertRegex(transport.value(), b"^GET /foo/bar") self.assertRegex(transport.value(), b"Host: testserv:8008") # Deferred is still without a result self.assertNoResult(test_d) # Send it the HTTP response res_json = '{ "a": 1 }'.encode("ascii") protocol.dataReceived(b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" b"Content-Type: application/json\r\n" b"Content-Length: %i\r\n" b"\r\n" b"%s" % (len(res_json), res_json)) self.pump() res = self.successResultOf(test_d) # check the response is as expected self.assertEqual(res, {"a": 1}) def test_dns_error(self): """ If the DNS lookup returns an error, it will bubble up. """ d = defer.ensureDeferred( self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)) self.pump() f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)) self.pump() # Nothing happened yet self.assertNoResult(d) clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) e = Exception("go away") factory.clientConnectionFailed(None, e) self.pump(0.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIs(f.value.inner_exception, e) def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)) self.pump() # Nothing happened yet self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) # Deferred is still without a result self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, (ConnectingCancelledError, TimeoutError)) def test_client_connect_no_response(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)) self.pump() # Nothing happened yet self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) conn = Mock() client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred is still without a result self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) def test_client_ip_range_blacklist(self): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist self.hs.config.federation_ip_range_blacklist = IPSet( ["127.0.0.0/8", "fe80::/64"]) self.reactor.lookups["internal"] = "127.0.0.1" self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" self.reactor.lookups["fine"] = "10.20.30.40" cl = MatrixFederationHttpClient(self.hs, None) # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request d = defer.ensureDeferred( cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request d = defer.ensureDeferred( cl.post_json("internalv6:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) # Check that it was due to a blacklisted DNS lookup f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request d = defer.ensureDeferred( cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was able to resolve the address clients = self.reactor.tcpClients self.assertNotEqual(len(clients), 0) # Connection will still fail as this IP address does not resolve to anything f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError) def test_client_gets_headers(self): """ Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest(method="GET", destination="testserv:8008", path="foo/bar") d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertNoResult(d) # Send it the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r.code, 200) @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"]) def test_timeout_reading_body(self, method_name: str): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a RequestSendFailed with can_retry. """ method = getattr(self.cl, method_name) d = defer.ensureDeferred( method("testserv:8008", "foo/bar", timeout=10000)) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertNoResult(d) # Send it the HTTP response client.dataReceived( (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" b"Server: Fake\r\n\r\n")) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertTrue(f.value.can_retry) self.assertIsInstance(f.value.inner_exception, defer.TimeoutError) def test_client_requires_trailing_slashes(self): """ If a connection is made to a client but the client rejects it due to requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)) # Send the request self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Clear the original request data before sending a response conn.clear() # Send the HTTP response client.dataReceived( b"HTTP/1.1 400 Bad Request\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 59\r\n" b"\r\n" b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}') # We should get another request with a trailing slash self.assertRegex(conn.value(), b"^GET /foo/bar/") # Send a happy response this time client.dataReceived(b"HTTP/1.1 200 OK\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b"{}") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r, {}) def test_client_does_not_retry_on_400_plus(self): """ Another test for trailing slashes but now test that we don't retry on trailing slashes on a non-400/M_UNRECOGNIZED response. See test_client_requires_trailing_slashes() for context. """ d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)) # Send the request self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Clear the original request data before sending a response conn.clear() # Send the HTTP response client.dataReceived(b"HTTP/1.1 404 Not Found\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b"{}") # We should not get another request self.assertEqual(conn.value(), b"") # We should get a 404 failure response self.failureResultOf(d) def test_client_sends_body(self): defer.ensureDeferred( self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})) self.pump() clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) client = clients[0][2].buildProtocol(None) server = HTTPChannel() client.makeConnection(FakeTransport(server, self.reactor)) server.makeConnection(FakeTransport(client, self.reactor)) self.pump(0.1) self.assertEqual(len(server.requests), 1) request = server.requests[0] content = request.content.read() self.assertEqual(content, b'{"a":"b"}') def test_closes_connection(self): """Check that the client closes unused HTTP connections""" d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Send the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b"{}") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r, {}) self.assertFalse(conn.disconnecting) # wait for a while self.reactor.advance(120) self.assertTrue(conn.disconnecting) @parameterized.expand([(b"", ), (b"foo", ), (b'{"a": Infinity}', )]) def test_json_error(self, return_value): """ Test what happens if invalid JSON is returned from the remote endpoint. """ test_d = defer.ensureDeferred( self.cl.get_json("testserv:8008", "foo/bar")) self.pump() # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) # complete the connection and wire it up to a fake transport protocol = factory.buildProtocol(None) transport = StringTransport() protocol.makeConnection(transport) # that should have made it send the request to the transport self.assertRegex(transport.value(), b"^GET /foo/bar") self.assertRegex(transport.value(), b"Host: testserv:8008") # Deferred is still without a result self.assertNoResult(test_d) # Send it the HTTP response protocol.dataReceived(b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" b"Content-Type: application/json\r\n" b"Content-Length: %i\r\n" b"\r\n" b"%s" % (len(return_value), return_value)) self.pump() f = self.failureResultOf(test_d) self.assertIsInstance(f.value, ValueError)
class MediaRepository(object): def __init__(self, hs, filepaths): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = filepaths self.downloads = {} self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) if not os.path.exists(dirname): os.makedirs(dirname) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): media_id = random_string(24) fname = self.filepaths.local_media_filepath(media_id) self._makedirs(fname) # This shouldn't block for very long because the content will have # already been uploaded at this point. with open(fname, "wb") as f: f.write(content) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) media_info = { "media_type": media_type, "media_length": content_length, } yield self._generate_local_thumbnails(media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) def get_remote_media(self, server_name, media_id): key = (server_name, media_id) download = self.downloads.get(key) if download is None: download = self._get_remote_media_impl(server_name, media_id) download = ObservableDeferred( download, consumeErrors=True ) self.downloads[key] = download @download.addBoth def callback(media_info): del self.downloads[key] return media_info return download.observe() @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): media_info = yield self.store.get_cached_remote_media( server_name, media_id ) if not media_info: media_info = yield self._download_remote_file( server_name, media_id ) defer.returnValue(media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) fname = self.filepaths.remote_media_filepath( server_name, file_id ) self._makedirs(fname) try: with open(fname, "wb") as f: request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, ) media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0],) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) except: os.remove(fname) raise media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_remote_thumbnails( server_name, media_id, media_info ) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, input_path, t_path, t_width, t_height, t_method, t_type): thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return if t_method == "crop": t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) elif t_method == "scale": t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) else: t_len = None return t_len @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = yield preserve_context_over_fn( threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type ) if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) defer.returnValue(t_path) @defer.inlineCallbacks def _generate_local_thumbnails(self, media_id, media_info): media_type = media_info["media_type"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return local_thumbnails = [] def generate_thumbnails(): scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) local_thumbnails.append(( media_id, t_width, t_height, t_type, t_method, t_len )) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for l in local_thumbnails: yield self.store.store_local_thumbnail(*l) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def _generate_remote_thumbnails(self, server_name, media_id, media_info): media_type = media_info["media_type"] file_id = media_info["filesystem_id"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return remote_thumbnails = [] input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height def generate_thumbnails(): if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels ) return scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method ) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for r in remote_thumbnails: yield self.store.store_remote_media_thumbnail(*r) defer.returnValue({ "width": m_width, "height": m_height, })
def build_http_client(self): tls_client_options_factory = context_factory.ClientTLSOptionsFactory( self.config) return MatrixFederationHttpClient(self, tls_client_options_factory)
def test_client_ip_range_blacklist(self): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist self.hs.config.federation_ip_range_blacklist = IPSet([ "127.0.0.0/8", "fe80::/64", ]) self.reactor.lookups["internal"] = "127.0.0.1" self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" self.reactor.lookups["fine"] = "10.20.30.40" cl = MatrixFederationHttpClient(self.hs, None) # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request d = cl.get_json("internal:8008", "foo/bar", timeout=10000) # Nothing happened yet self.assertNoResult(d) self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) # Check that it was due to a blacklisted DNS lookup f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request d = cl.post_json("fine:8008", "foo/bar", timeout=10000) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was able to resolve the address clients = self.reactor.tcpClients self.assertNotEqual(len(clients), 0) # Connection will still fail as this IP address does not resolve to anything f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
def build_http_client(self): tls_client_options_factory = context_factory.FederationPolicyForHTTPS( self.config) return MatrixFederationHttpClient(self, tls_client_options_factory)
def prepare(self, reactor, clock, homeserver): self.cl = MatrixFederationHttpClient(self.hs, None) self.reactor.lookups["testserv"] = "1.2.3.4"
class FederationClientTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(reactor=reactor, clock=clock) return hs def prepare(self, reactor, clock, homeserver): self.cl = MatrixFederationHttpClient(self.hs, None) self.reactor.lookups["testserv"] = "1.2.3.4" def test_client_get(self): """ happy-path test of a GET request """ @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: fetch_d = self.cl.get_json("testserv:8008", "foo/bar") # Nothing happened yet self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel check_logcontext(LoggingContext.sentinel) try: fetch_res = yield fetch_d defer.returnValue(fetch_res) finally: check_logcontext(context) test_d = do_request() self.pump() # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8008) # complete the connection and wire it up to a fake transport protocol = factory.buildProtocol(None) transport = StringTransport() protocol.makeConnection(transport) # that should have made it send the request to the transport self.assertRegex(transport.value(), b"^GET /foo/bar") self.assertRegex(transport.value(), b"Host: testserv:8008") # Deferred is still without a result self.assertNoResult(test_d) # Send it the HTTP response res_json = '{ "a": 1 }'.encode('ascii') protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" b"Content-Type: application/json\r\n" b"Content-Length: %i\r\n" b"\r\n" b"%s" % (len(res_json), res_json) ) self.pump() res = self.successResultOf(test_d) # check the response is as expected self.assertEqual(res, {"a": 1}) def test_dns_error(self): """ If the DNS lookup returns an error, it will bubble up. """ d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) self.pump() f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertNoResult(d) clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8008) e = Exception("go away") factory.clientConnectionFailed(None, e) self.pump(0.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIs(f.value.inner_exception, e) def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) # Deferred is still without a result self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance( f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) def test_client_connect_no_response(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) conn = Mock() client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred is still without a result self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) def test_client_ip_range_blacklist(self): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist self.hs.config.federation_ip_range_blacklist = IPSet([ "127.0.0.0/8", "fe80::/64", ]) self.reactor.lookups["internal"] = "127.0.0.1" self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" self.reactor.lookups["fine"] = "10.20.30.40" cl = MatrixFederationHttpClient(self.hs, None) # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request d = cl.get_json("internal:8008", "foo/bar", timeout=10000) # Nothing happened yet self.assertNoResult(d) self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) # Check that it was due to a blacklisted DNS lookup f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request d = cl.post_json("fine:8008", "foo/bar", timeout=10000) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was able to resolve the address clients = self.reactor.tcpClients self.assertNotEqual(len(clients), 0) # Connection will still fail as this IP address does not resolve to anything f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError) def test_client_gets_headers(self): """ Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) d = self.cl._send_request(request, timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertNoResult(d) # Send it the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r.code, 200) def test_client_headers_no_body(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertNoResult(d) # Send it the HTTP response client.dataReceived( ( b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" b"Server: Fake\r\n\r\n" ) ) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, TimeoutError) def test_client_requires_trailing_slashes(self): """ If a connection is made to a client but the client rejects it due to requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Clear the original request data before sending a response conn.clear() # Send the HTTP response client.dataReceived( b"HTTP/1.1 400 Bad Request\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 59\r\n" b"\r\n" b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' ) # We should get another request with a trailing slash self.assertRegex(conn.value(), b"^GET /foo/bar/") # Send a happy response this time client.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b'{}' ) # We should get a successful response r = self.successResultOf(d) self.assertEqual(r, {}) def test_client_does_not_retry_on_400_plus(self): """ Another test for trailing slashes but now test that we don't retry on trailing slashes on a non-400/M_UNRECOGNIZED response. See test_client_requires_trailing_slashes() for context. """ d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Clear the original request data before sending a response conn.clear() # Send the HTTP response client.dataReceived( b"HTTP/1.1 404 Not Found\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b"{}" ) # We should not get another request self.assertEqual(conn.value(), b"") # We should get a 404 failure response self.failureResultOf(d) def test_client_sends_body(self): self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) self.pump() clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) client = clients[0][2].buildProtocol(None) server = HTTPChannel() client.makeConnection(FakeTransport(server, self.reactor)) server.makeConnection(FakeTransport(client, self.reactor)) self.pump(0.1) self.assertEqual(len(server.requests), 1) request = server.requests[0] content = request.content.read() self.assertEqual(content, b'{"a":"b"}') def test_closes_connection(self): """Check that the client closes unused HTTP connections""" d = self.cl.get_json("testserv:8008", "foo/bar") self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Send the HTTP response client.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b"{}" ) # We should get a successful response r = self.successResultOf(d) self.assertEqual(r, {}) self.assertFalse(conn.disconnecting) # wait for a while self.pump(120) self.assertTrue(conn.disconnecting)
def test_client_ip_range_blacklist(self): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist self.hs.config.federation_ip_range_blacklist = IPSet( ["127.0.0.0/8", "fe80::/64"]) self.reactor.lookups["internal"] = "127.0.0.1" self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" self.reactor.lookups["fine"] = "10.20.30.40" cl = MatrixFederationHttpClient(self.hs, None) # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request d = defer.ensureDeferred( cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request d = defer.ensureDeferred( cl.post_json("internalv6:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was unable to resolve the address clients = self.reactor.tcpClients self.assertEqual(len(clients), 0) # Check that it was due to a blacklisted DNS lookup f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request d = defer.ensureDeferred( cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) # Move the reactor forwards self.pump(1) # Check that it was able to resolve the address clients = self.reactor.tcpClients self.assertNotEqual(len(clients), 0) # Connection will still fail as this IP address does not resolve to anything f = self.failureResultOf(d, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
class FederationClientTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(reactor=reactor, clock=clock) return hs def prepare(self, reactor, clock, homeserver): self.cl = MatrixFederationHttpClient(self.hs, None) self.reactor.lookups["testserv"] = "1.2.3.4" def test_client_get(self): """ happy-path test of a GET request """ @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: fetch_d = self.cl.get_json("testserv:8008", "foo/bar") # Nothing happened yet self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel check_logcontext(LoggingContext.sentinel) try: fetch_res = yield fetch_d defer.returnValue(fetch_res) finally: check_logcontext(context) test_d = do_request() self.pump() # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8008) # complete the connection and wire it up to a fake transport protocol = factory.buildProtocol(None) transport = StringTransport() protocol.makeConnection(transport) # that should have made it send the request to the transport self.assertRegex(transport.value(), b"^GET /foo/bar") self.assertRegex(transport.value(), b"Host: testserv:8008") # Deferred is still without a result self.assertNoResult(test_d) # Send it the HTTP response res_json = '{ "a": 1 }'.encode('ascii') protocol.dataReceived(b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" b"Content-Type: application/json\r\n" b"Content-Length: %i\r\n" b"\r\n" b"%s" % (len(res_json), res_json)) self.pump() res = self.successResultOf(test_d) # check the response is as expected self.assertEqual(res, {"a": 1}) def test_dns_error(self): """ If the DNS lookup returns an error, it will bubble up. """ d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) self.pump() f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertNoResult(d) clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8008) e = Exception("go away") factory.clientConnectionFailed(None, e) self.pump(0.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIs(f.value.inner_exception, e) def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) # Deferred is still without a result self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, (ConnectingCancelledError, TimeoutError)) def test_client_connect_no_response(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) self.pump() # Nothing happened yet self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) self.assertEqual(clients[0][0], '1.2.3.4') self.assertEqual(clients[0][1], 8008) conn = Mock() client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred is still without a result self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived) def test_client_gets_headers(self): """ Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest(method="GET", destination="testserv:8008", path="foo/bar") d = self.cl._send_request(request, timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertNoResult(d) # Send it the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r.code, 200) def test_client_headers_no_body(self): """ If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) self.pump() conn = Mock() clients = self.reactor.tcpClients client = clients[0][2].buildProtocol(None) client.makeConnection(conn) # Deferred does not have a result self.assertNoResult(d) # Send it the HTTP response client.dataReceived( (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" b"Server: Fake\r\n\r\n")) # Push by enough to time it out self.reactor.advance(10.5) f = self.failureResultOf(d) self.assertIsInstance(f.value, TimeoutError) def test_client_requires_trailing_slashes(self): """ If a connection is made to a client but the client rejects it due to requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Clear the original request data before sending a response conn.clear() # Send the HTTP response client.dataReceived( b"HTTP/1.1 400 Bad Request\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 59\r\n" b"\r\n" b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}') # We should get another request with a trailing slash self.assertRegex(conn.value(), b"^GET /foo/bar/") # Send a happy response this time client.dataReceived(b"HTTP/1.1 200 OK\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b'{}') # We should get a successful response r = self.successResultOf(d) self.assertEqual(r, {}) def test_client_does_not_retry_on_400_plus(self): """ Another test for trailing slashes but now test that we don't retry on trailing slashes on a non-400/M_UNRECOGNIZED response. See test_client_requires_trailing_slashes() for context. """ d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Clear the original request data before sending a response conn.clear() # Send the HTTP response client.dataReceived(b"HTTP/1.1 404 Not Found\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b"{}") # We should not get another request self.assertEqual(conn.value(), b"") # We should get a 404 failure response self.failureResultOf(d) def test_client_sends_body(self): self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) self.pump() clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) client = clients[0][2].buildProtocol(None) server = HTTPChannel() client.makeConnection(FakeTransport(server, self.reactor)) server.makeConnection(FakeTransport(client, self.reactor)) self.pump(0.1) self.assertEqual(len(server.requests), 1) request = server.requests[0] content = request.content.read() self.assertEqual(content, b'{"a":"b"}') def test_closes_connection(self): """Check that the client closes unused HTTP connections""" d = self.cl.get_json("testserv:8008", "foo/bar") self.pump() # there should have been a call to connectTCP clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (_host, _port, factory, _timeout, _bindAddress) = clients[0] # complete the connection and wire it up to a fake transport client = factory.buildProtocol(None) conn = StringTransport() client.makeConnection(conn) # that should have made it send the request to the connection self.assertRegex(conn.value(), b"^GET /foo/bar") # Send the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" b"{}") # We should get a successful response r = self.successResultOf(d) self.assertEqual(r, {}) self.assertFalse(conn.disconnecting) # wait for a while self.pump(120) self.assertTrue(conn.disconnecting)
class MediaRepository(object): def __init__(self, hs, filepaths): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels self.filepaths = filepaths self.downloads = {} self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) if not os.path.exists(dirname): os.makedirs(dirname) @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): media_id = random_string(24) fname = self.filepaths.local_media_filepath(media_id) self._makedirs(fname) # This shouldn't block for very long because the content will have # already been uploaded at this point. with open(fname, "wb") as f: f.write(content) yield self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=content_length, user_id=auth_user, ) media_info = { "media_type": media_type, "media_length": content_length, } yield self._generate_local_thumbnails(media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) def get_remote_media(self, server_name, media_id): key = (server_name, media_id) download = self.downloads.get(key) if download is None: download = self._get_remote_media_impl(server_name, media_id) download = ObservableDeferred(download, consumeErrors=True) self.downloads[key] = download @download.addBoth def callback(media_info): del self.downloads[key] return media_info return download.observe() @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): media_info = yield self.store.get_cached_remote_media( server_name, media_id) if not media_info: media_info = yield self._download_remote_file( server_name, media_id) defer.returnValue(media_info) @defer.inlineCallbacks def _download_remote_file(self, server_name, media_id): file_id = random_string(24) fname = self.filepaths.remote_media_filepath(server_name, file_id) self._makedirs(fname) try: with open(fname, "wb") as f: request_path = "/".join(( "/_matrix/media/v1/download", server_name, media_id, )) length, headers = yield self.client.get_file( server_name, request_path, output_stream=f, max_size=self.max_upload_size, ) media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() content_disposition = headers.get("Content-Disposition", None) if content_disposition: _, params = cgi.parse_header(content_disposition[0], ) upload_name = None # First check if there is a valid UTF-8 filename upload_name_utf8 = params.get("filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith("utf-8''"): upload_name = upload_name_utf8[7:] # If there isn't check for an ascii name. if not upload_name: upload_name_ascii = params.get("filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): upload_name = upload_name_ascii if upload_name: upload_name = urlparse.unquote(upload_name) try: upload_name = upload_name.decode("utf-8") except UnicodeDecodeError: upload_name = None else: upload_name = None yield self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), upload_name=upload_name, media_length=length, filesystem_id=file_id, ) except: os.remove(fname) raise media_info = { "media_type": media_type, "media_length": length, "upload_name": upload_name, "created_ts": time_now_ms, "filesystem_id": file_id, } yield self._generate_remote_thumbnails(server_name, media_id, media_info) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) def _generate_thumbnail(self, input_path, t_path, t_width, t_height, t_method, t_type): thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return if t_method == "crop": t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) elif t_method == "scale": t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) else: t_len = None return t_len @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) t_path = self.filepaths.local_media_thumbnail(media_id, t_width, t_height, t_type, t_method) self._makedirs(t_path) t_len = yield preserve_context_over_fn(threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type) if t_len: yield self.store.store_local_thumbnail(media_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(t_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) t_path = self.filepaths.remote_media_thumbnail(server_name, file_id, t_width, t_height, t_type, t_method) self._makedirs(t_path) t_len = yield preserve_context_over_fn(threads.deferToThread, self._generate_thumbnail, input_path, t_path, t_width, t_height, t_method, t_type) if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len) defer.returnValue(t_path) @defer.inlineCallbacks def _generate_local_thumbnails(self, media_id, media_info): media_type = media_info["media_type"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return input_path = self.filepaths.local_media_filepath(media_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return local_thumbnails = [] def generate_thumbnails(): scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) local_thumbnails.append( (media_id, t_width, t_height, t_type, t_method, t_len)) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.local_media_thumbnail( media_id, t_width, t_height, t_type, t_method) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) local_thumbnails.append( (media_id, t_width, t_height, t_type, t_method, t_len)) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for l in local_thumbnails: yield self.store.store_local_thumbnail(*l) defer.returnValue({ "width": m_width, "height": m_height, }) @defer.inlineCallbacks def _generate_remote_thumbnails(self, server_name, media_id, media_info): media_type = media_info["media_type"] file_id = media_info["filesystem_id"] requirements = self._get_thumbnail_requirements(media_type) if not requirements: return remote_thumbnails = [] input_path = self.filepaths.remote_media_filepath(server_name, file_id) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height def generate_thumbnails(): if m_width * m_height >= self.max_image_pixels: logger.info("Image too large to thumbnail %r x %r > %r", m_width, m_height, self.max_image_pixels) return scales = set() crops = set() for r_width, r_height, r_method, r_type in requirements: if r_method == "scale": t_width, t_height = thumbnailer.aspect(r_width, r_height) scales.add(( min(m_width, t_width), min(m_height, t_height), r_type, )) elif r_method == "crop": crops.add((r_width, r_height, r_type)) for t_width, t_height, t_type in scales: t_method = "scale" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method) self._makedirs(t_path) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) for t_width, t_height, t_type in crops: if (t_width, t_height, t_type) in scales: # If the aspect ratio of the cropped thumbnail matches a purely # scaled one then there is no point in calculating a separate # thumbnail. continue t_method = "crop" t_path = self.filepaths.remote_media_thumbnail( server_name, file_id, t_width, t_height, t_type, t_method) self._makedirs(t_path) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) remote_thumbnails.append([ server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ]) yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) for r in remote_thumbnails: yield self.store.store_remote_media_thumbnail(*r) defer.returnValue({ "width": m_width, "height": m_height, })