def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.clock = Clock(self.reactor) self.hs = setup_test_homeserver( self.addCleanup, "red", http_client=None, clock=self.clock, reactor=self.reactor, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.store = self.hs.get_datastore() self.hs.config.registrations_require_3pid = [] self.hs.config.enable_registration_captcha = False self.hs.config.recaptcha_public_key = [] self.hs.config.limit_usage_by_mau = True self.hs.config.hs_disabled = False self.hs.config.max_mau_value = 2 self.hs.config.mau_trial_days = 0 self.hs.config.server_notices_mxid = "@server:red" self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_room_name = "Test Server Notice Room" self.resource = JsonResource(self.hs) register.register_servlets(self.hs, self.resource) sync.register_servlets(self.hs, self.resource)
def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") self.tls_factory = ClientTLSOptionsFactory(config) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=self.tls_factory, _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, )
def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver(self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor)
def setUp(self): self.clock = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.clock) self.url = "/_matrix/client/r0/admin/register" self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() self.device_handler = Mock() self.device_handler.check_device_registered = Mock(return_value="FAKE") self.datastore = Mock(return_value=Mock()) self.datastore.get_current_state_deltas = Mock(return_value=[]) self.secrets = Mock() self.hs = setup_test_homeserver(self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock) self.hs.config.registration_shared_secret = u"shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() self.resource = JsonResource(self.hs) register_servlets(self.hs, self.resource)
def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") self.tls_factory = FederationPolicyForHTTPS(config) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=self.tls_factory, user_agent= "test-agent", # Note that this is unused since _well_known_resolver is provided. ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, )
def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory(None), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, )
def setUp(self): self.clock = MemoryReactorClock() self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver(self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock) self.auth = self.hs.get_auth() def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.USER_ID), "token_id": 1, "is_guest": False, } def get_user_by_req(request, allow_guest=False, rights="access"): return synapse.types.create_requester( UserID.from_string(self.USER_ID), 1, False, None) self.auth.get_user_by_access_token = get_user_by_access_token self.auth.get_user_by_req = get_user_by_req self.store = self.hs.get_datastore() self.filtering = self.hs.get_filtering() self.resource = JsonResource(self.hs) for r in self.TO_REGISTER: r.register_servlets(self.hs, self.resource)
def setUp(self): self.reactor = ThreadedMemoryReactorClock() class DummyResource(Resource): isLeaf = True def render(self, request): return request.path # Setup a resource with some children. self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource())
def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] # config_dict["trusted_key_servers"] = [] self._config = config = HomeServerConfig() config.parse_config_dict(config_dict) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory(config), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, )
def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() class DummyResource(Resource): isLeaf = True def render(self, request: SynapseRequest) -> bytes: # Type-ignore: mypy thinks request.path is Optional[Any], not bytes. return request.path # type: ignore[return-value] # Setup a resource with some children. self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource())
def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory( default_config("test", parse=True) ), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, )
def setUp(self): self.clock = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver( self.addCleanup, "red", http_client=None, clock=self.hs_clock, reactor=self.clock, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.ratelimiter = self.hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.helper.auth_user_id), "token_id": 1, "is_guest": False, } def get_user_by_req(request, allow_guest=False, rights="access"): return synapse.types.create_requester( UserID.from_string(self.helper.auth_user_id), 1, False, None) self.hs.get_auth().get_user_by_req = get_user_by_req self.hs.get_auth().get_user_by_access_token = get_user_by_access_token self.hs.get_auth().get_access_token_from_request = Mock( return_value=b"1234") def _insert_client_ip(*args, **kwargs): return defer.succeed(None) self.hs.get_datastore().insert_client_ip = _insert_client_ip self.resource = JsonResource(self.hs) synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) synapse.rest.client.v1.room.register_deprecated_servlets( self.hs, self.resource) self.helper = RestHelper(self.hs, self.resource, self.user_id)
class MatrixFederationAgentTests(TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory( default_config("test", parse=True) ), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, ) def _make_connection(self, client_factory, expected_sni): """Builds a test server, and completes the outgoing client connection Returns: HTTPChannel: the test server """ # build the test server server_tls_protocol = _build_test_server() # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an # HTTP11ClientProtocol) and wire the output of said protocol up to the server via # a FakeTransport. # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( FakeTransport(server_tls_protocol, self.reactor, client_protocol) ) # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) # check the SNI server_name = server_tls_protocol._tlsConnection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) # fish the test server back out of the server-side TLS protocol. return server_tls_protocol.wrappedProtocol @defer.inlineCallbacks def _make_get_request(self, uri): """ Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: fetch_d = self.agent.request(b'GET', uri) # Nothing happened yet self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel _check_logcontext(LoggingContext.sentinel) try: fetch_res = yield fetch_d defer.returnValue(fetch_res) except Exception as e: logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) raise finally: _check_logcontext(context) def _handle_well_known_connection( self, client_factory, expected_sni, content, response_headers={} ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. Args: client_factory (IProtocolFactory): outgoing connection expected_sni (bytes): SNI that we expect the outgoing connection to send content (bytes): content to send back as the .well-known Returns: HTTPChannel: server impl """ # make the connection for .well-known well_known_server = self._make_connection( client_factory, expected_sni=expected_sni ) # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) request = well_known_server.requests[0] self._send_well_known_response(request, content, headers=response_headers) return well_known_server def _send_well_known_response(self, request, content, headers={}): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/.well-known/matrix/server') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # send back a response for k, v in headers.items(): request.setHeader(k, v) request.write(content) request.finish() self.reactor.pump((0.1,)) def test_get(self): """ happy-path test of a GET request with an explicit port """ self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] ) content = request.content.read() self.assertEqual(content, b'') # Deferred is still without a result self.assertNoResult(test_d) # send the headers request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json']) request.write('') self.reactor.pump((0.1,)) response = self.successResultOf(test_d) # that should give us a Response object self.assertEqual(response.code, 200) # Send the body request.write('{ "a": 1 }'.encode('ascii')) request.finish() self.reactor.pump((0.1,)) # check it can be read json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) def test_get_ip_address(self): """ Test the behaviour when the server name contains an explicit IP (with no port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_get_ipv6_address(self): """ Test the behaviour when the server name contains an explicit IPv6 address (with no port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" test_d = self._make_get_request(b"matrix://[::1]/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '::1') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_get_ipv6_address_with_port(self): """ Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" test_d = self._make_get_request(b"matrix://[::1]:80/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '::1') self.assertEqual(port, 80) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_get_no_srv_no_well_known(self): """ Test the behaviour when the server name has no port, no SRV, and no well-known """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4,)) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv" ) # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_get_well_known(self): """Test the behaviour when the .well-known delegates elsewhere """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server" ) # now we should get a connection to the target server self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '1::f') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires self.reactor.pump((25 * 3600,)) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) def test_get_well_known_redirect(self): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) redirect_server = self._make_connection( client_factory, expected_sni=b"testserv" ) # send a 302 redirect self.assertEqual(len(redirect_server.requests), 1) request = redirect_server.requests[0] request.redirect(b'https://testserv/even_better_known') request.finish() self.reactor.pump((0.1,)) # now there should be another connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) well_known_server = self._make_connection( client_factory, expected_sni=b"testserv" ) self.assertEqual(len(well_known_server.requests), 1, "No request after 302") request = well_known_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/even_better_known') request.write(b'{ "m.server": "target-server" }') request.finish() self.reactor.pump((0.1,)) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server" ) # now we should get a connection to the target server self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1::f') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires self.reactor.pump((25 * 3600,)) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) def test_get_invalid_well_known(self): """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'NOT JSON' ) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv" ) # we should fall back to a direct connection self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_get_hostname_srv(self): """ Test the behaviour when there is a single SRV record """ self.mock_resolver.resolve_service.side_effect = lambda _: [ Server(host=b"srvtarget", port=8443) ] self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv" ) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_get_well_known_srv(self): """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ Server(host=b"srvtarget", port=8443) ] self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server" ) # now we should get a connection to the target of the SRV record self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '5.6.7.8') self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" self.mock_resolver.resolve_service.side_effect = lambda _: [] # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" # this is idna for bücher.com test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4,)) # now there should have been a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.xn--bcher-kva.com" ) # We should fall back to port 8448 clients = self.reactor.tcpClients self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" self.mock_resolver.resolve_service.side_effect = lambda _: [ Server(host=b"xn--trget-3qa.com", port=8443) # târget.com ] self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") # Nothing happened yet self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.xn--bcher-kva.com" ) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request request.finish() self.reactor.pump((0.1,)) self.successResultOf(test_d) @defer.inlineCallbacks def do_get_well_known(self, serv): try: result = yield self.agent._get_well_known(serv) logger.info("Result from well-known fetch: %s", result) except Exception as e: logger.warning("Error fetching well-known: %s", e) raise defer.returnValue(result) def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = self.do_get_well_known(b'testserv') # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", response_headers={b'Cache-Control': b'max-age=10'}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) self.assertEqual(r, b'target-server') # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache fetch_d = self.do_get_well_known(b'testserv') r = self.successResultOf(fetch_d) self.assertEqual(r, b'target-server') # expire the cache self.reactor.pump((10.0,)) # now it should connect again fetch_d = self.do_get_well_known(b'testserv') self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "other-server" }', ) r = self.successResultOf(fetch_d) self.assertEqual(r, b'other-server')
class MessageAcceptTests(unittest.TestCase): def setUp(self): self.http_client = Mock() self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room = room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False ) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) )[0] join_event = FrozenEvent( { "room_id": self.room_id, "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", "event_id": "$join:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.member", "origin": "test.servx", "content": {"membership": "join"}, "auth_events": [], "prev_state": [(most_recent, {})], "prev_events": [(most_recent, {})], } ) self.handler = self.homeserver.get_handlers().federation_handler self.handler.do_auth = lambda *a, **b: succeed(True) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus ) # Send the join, it should return None (which is not an error) d = self.handler.on_receive_pdu( "test.serv", join_event, sent_to_us_directly=True ) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) )[0], "$join:test.serv", ) def test_cant_hide_direct_ancestors(self): """ If you send a message, you must be able to provide the direct prev_events that said event references. """ def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} self.http_client.post_json = post_json # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) )[0] # Now lie about an event lying_event = FrozenEvent( { "room_id": self.room_id, "sender": "@baduser:test.serv", "event_id": "one:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.message", "origin": "test.serv", "content": {"body": "hewwo?"}, "auth_events": [], "prev_events": [("two:test.serv", {}), (most_recent, {})], } ) with LoggingContext(request="lying_event"): d = self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True ) # Step the reactor, so the database fetches come back self.reactor.advance(1) # on_receive_pdu should throw an error failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ( "ERROR 403: Your server isn't divulging details about prev_events " "referenced in this event." ), ) # Make sure the invalid event isn't there extrem = maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
class MatrixFederationAgentTests(TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] # config_dict["trusted_key_servers"] = [] self._config = config = HomeServerConfig() config.parse_config_dict(config_dict) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory(config), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, ) def _make_connection(self, client_factory, expected_sni): """Builds a test server, and completes the outgoing client connection Returns: HTTPChannel: the test server """ # build the test server server_tls_protocol = _build_test_server(get_connection_factory()) # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an # HTTP11ClientProtocol) and wire the output of said protocol up to the server via # a FakeTransport. # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( FakeTransport(server_tls_protocol, self.reactor, client_protocol)) # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, server_tls_protocol)) # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1, )) # check the SNI server_name = server_tls_protocol._tlsConnection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) # fish the test server back out of the server-side TLS protocol. return server_tls_protocol.wrappedProtocol @defer.inlineCallbacks def _make_get_request(self, uri): """ Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: fetch_d = self.agent.request(b'GET', uri) # Nothing happened yet self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel _check_logcontext(LoggingContext.sentinel) try: fetch_res = yield fetch_d defer.returnValue(fetch_res) except Exception as e: logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) raise finally: _check_logcontext(context) def _handle_well_known_connection(self, client_factory, expected_sni, content, response_headers={}): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. Args: client_factory (IProtocolFactory): outgoing connection expected_sni (bytes): SNI that we expect the outgoing connection to send content (bytes): content to send back as the .well-known Returns: HTTPChannel: server impl """ # make the connection for .well-known well_known_server = self._make_connection(client_factory, expected_sni=expected_sni) # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) request = well_known_server.requests[0] self._send_well_known_response(request, content, headers=response_headers) return well_known_server def _send_well_known_response(self, request, content, headers={}): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/.well-known/matrix/server') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # send back a response for k, v in headers.items(): request.setHeader(k, v) request.write(content) request.finish() self.reactor.pump((0.1, )) def test_get(self): """ happy-path test of a GET request with an explicit port """ self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448']) content = request.content.read() self.assertEqual(content, b'') # Deferred is still without a result self.assertNoResult(test_d) # send the headers request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json']) request.write('') self.reactor.pump((0.1, )) response = self.successResultOf(test_d) # that should give us a Response object self.assertEqual(response.code, 200) # Send the body request.write('{ "a": 1 }'.encode('ascii')) request.finish() self.reactor.pump((0.1, )) # check it can be read json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) def test_get_ip_address(self): """ Test the behaviour when the server name contains an explicit IP (with no port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_ipv6_address(self): """ Test the behaviour when the server name contains an explicit IPv6 address (with no port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" test_d = self._make_get_request(b"matrix://[::1]/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '::1') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_ipv6_address_with_port(self): """ Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" test_d = self._make_get_request(b"matrix://[::1]:80/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '::1') self.assertEqual(port, 80) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_hostname_bad_cert(self): """ Test the behaviour when the certificate on the server doesn't match the hostname """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv1/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4, )) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv1") # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'testserv1') # there should be no requests self.assertEqual(len(http_server.requests), 0) # ... and the request should have failed e = self.failureResultOf(test_d, ResponseNeverReceived) failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) def test_get_ip_address_bad_cert(self): """ Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it """ # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.5"] = "1.2.3.5" test_d = self._make_get_request(b"matrix://1.2.3.5/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.5') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) # there should be no requests self.assertEqual(len(http_server.requests), 0) # ... and the request should have failed e = self.failureResultOf(test_d, ResponseNeverReceived) failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) def test_get_no_srv_no_well_known(self): """ Test the behaviour when the server name has no port, no SRV, and no well-known """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4, )) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_well_known(self): """Test the behaviour when the .well-known delegates elsewhere """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server") # now we should get a connection to the target server self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '1::f') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'target-server') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'target-server']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires self.reactor.pump((25 * 3600, )) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) def test_get_well_known_redirect(self): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) redirect_server = self._make_connection(client_factory, expected_sni=b"testserv") # send a 302 redirect self.assertEqual(len(redirect_server.requests), 1) request = redirect_server.requests[0] request.redirect(b'https://testserv/even_better_known') request.finish() self.reactor.pump((0.1, )) # now there should be another connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) well_known_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(well_known_server.requests), 1, "No request after 302") request = well_known_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/even_better_known') request.write(b'{ "m.server": "target-server" }') request.finish() self.reactor.pump((0.1, )) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server") # now we should get a connection to the target server self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1::f') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'target-server') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'target-server']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires self.reactor.pump((25 * 3600, )) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) def test_get_invalid_well_known(self): """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self._handle_well_known_connection(client_factory, expected_sni=b"testserv", content=b'NOT JSON') # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") # we should fall back to a direct connection self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_well_known_unsigned_cert(self): """Test the behaviour when the .well-known server presents a cert not signed by a CA """ # we use the same test server as the other tests, but use an agent # with _well_known_tls_policy left to the default, which will not # trust it (since the presented cert is signed by a test CA) self.mock_resolver.resolve_service.side_effect = lambda _: [] self.reactor.lookups["testserv"] = "1.2.3.4" agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory(self._config), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, ) test_d = agent.request(b"GET", b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) http_proto = self._make_connection( client_factory, expected_sni=b"testserv", ) # there should be no requests self.assertEqual(len(http_proto.requests), 0) # and there should be a SRV lookup instead self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") def test_get_hostname_srv(self): """ Test the behaviour when there is a single SRV record """ self.mock_resolver.resolve_service.side_effect = lambda _: [ Server(host=b"srvtarget", port=8443) ] self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_well_known_srv(self): """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ Server(host=b"srvtarget", port=8443) ] self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server") # now we should get a connection to the target of the SRV record self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '5.6.7.8') self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'target-server') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'target-server']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" self.mock_resolver.resolve_service.side_effect = lambda _: [] # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" # this is idna for bücher.com test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4, )) # now there should have been a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.xn--bcher-kva.com") # We should fall back to port 8448 clients = self.reactor.tcpClients self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'xn--bcher-kva.com') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" self.mock_resolver.resolve_service.side_effect = lambda _: [ Server(host=b"xn--trget-3qa.com", port=8443) # târget.com ] self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") # Nothing happened yet self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.xn--bcher-kva.com") # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b'xn--bcher-kva.com') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) @defer.inlineCallbacks def do_get_well_known(self, serv): try: result = yield self.agent._get_well_known(serv) logger.info("Result from well-known fetch: %s", result) except Exception as e: logger.warning("Error fetching well-known: %s", e) raise defer.returnValue(result) def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = self.do_get_well_known(b'testserv') # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", response_headers={b'Cache-Control': b'max-age=10'}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) self.assertEqual(r, b'target-server') # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache fetch_d = self.do_get_well_known(b'testserv') r = self.successResultOf(fetch_d) self.assertEqual(r, b'target-server') # expire the cache self.reactor.pump((10.0, )) # now it should connect again fetch_d = self.do_get_well_known(b'testserv') self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, '1.2.3.4') self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "other-server" }', ) r = self.successResultOf(fetch_d) self.assertEqual(r, b'other-server')
class MatrixFederationAgentTests(TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() def _make_connection(self, client_factory, server_factory, ssl=False, expected_sni=None): """Builds a test server, and completes the outgoing client connection Args: client_factory (interfaces.IProtocolFactory): the the factory that the application is trying to use to make the outbound connection. We will invoke it to build the client Protocol server_factory (interfaces.IProtocolFactory): a factory to build the server-side protocol ssl (bool): If true, we will expect an ssl connection and wrap server_factory with a TLSMemoryBIOFactory expected_sni (bytes|None): the expected SNI value Returns: IProtocol: the server Protocol returned by server_factory """ if ssl: server_factory = _wrap_server_factory_for_tls(server_factory) server_protocol = server_factory.buildProtocol(None) # now, tell the client protocol factory to build the client protocol, # and wire the output of said protocol up to the server via # a FakeTransport. # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol)) # tell the server protocol to send its stuff back to the client, too server_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, server_protocol)) if ssl: http_protocol = server_protocol.wrappedProtocol tls_connection = server_protocol._tlsConnection else: http_protocol = server_protocol tls_connection = None # give the reactor a pump to get the TLS juices flowing (if needed) self.reactor.advance(0) if expected_sni is not None: server_name = tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) return http_protocol def _test_request_direct_connection(self, agent, scheme, hostname, path): """Runs a test case for a direct connection not going through a proxy. Args: agent (ProxyAgent): the proxy agent being tested scheme (bytes): expected to be either "http" or "https" hostname (bytes): the hostname to connect to in the test path (bytes): the path to connect to in the test """ is_https = scheme == b"https" self.reactor.lookups[hostname.decode()] = "1.2.3.4" d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path) # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443 if is_https else 80) # make a test server, and wire up the client http_server = self._make_connection( client_factory, _get_test_protocol_factory(), ssl=is_https, expected_sni=hostname if is_https else None, ) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/" + path) self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") def test_http_request(self): agent = ProxyAgent(self.reactor) self._test_request_direct_connection(agent, b"http", b"test.com", b"") def test_https_request(self): agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") def test_http_request_use_proxy_empty_environment(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, { "http_proxy": "proxy.com:8888", "NO_PROXY": "test.com" }) def test_http_request_via_uppercase_no_proxy(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, { "http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com" }) def test_http_request_via_no_proxy(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, { "https_proxy": "proxy.com", "no_proxy": "test.com,unused.com" }) def test_https_request_via_no_proxy(self): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"}) def test_http_request_via_no_proxy_star(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"}) def test_https_request_via_no_proxy_star(self): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, { "http_proxy": "proxy.com:8888", "no_proxy": "unused.com" }) def test_http_request_via_proxy(self): agent = ProxyAgent(self.reactor, use_proxy=True) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"http://test.com") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8888) # make a test server, and wire up the client http_server = self._make_connection(client_factory, _get_test_protocol_factory()) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") @patch.dict(os.environ, { "https_proxy": "proxy.com", "no_proxy": "unused.com" }) def test_https_request_via_proxy(self): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"https://test.com/abc") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 1080) # make a test HTTP server, and wire up the client proxy_server = self._make_connection(client_factory, _get_test_protocol_factory()) # fish the transports back out so that we can do the old switcheroo s2c_transport = proxy_server.transport client_protocol = s2c_transport.other c2s_transport = client_protocol.transport # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending CONNECT request self.assertEqual(len(proxy_server.requests), 1) request = proxy_server.requests[0] self.assertEqual(request.method, b"CONNECT") self.assertEqual(request.path, b"test.com:443") # tell the proxy server not to close the connection proxy_server.persistent = True # this just stops the http Request trying to do a chunked response # request.setHeader(b"Content-Length", b"0") request.finish() # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel ssl_factory = _wrap_server_factory_for_tls( _get_test_protocol_factory()) ssl_protocol = ssl_factory.buildProtocol(None) http_server = ssl_protocol.wrappedProtocol ssl_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, ssl_protocol)) c2s_transport.other = ssl_protocol self.reactor.advance(0) server_name = ssl_protocol._tlsConnection.get_servername() expected_sni = b"test.com" self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) def test_http_request_via_proxy_with_blacklist(self): # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper(self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])), self.reactor, use_proxy=True, ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"http://test.com") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8888) # make a test server, and wire up the client http_server = self._make_connection(client_factory, _get_test_protocol_factory()) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"}) def test_https_request_via_uppercase_proxy_with_blacklist(self): # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper(self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])), self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"https://test.com/abc") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 1080) # make a test HTTP server, and wire up the client proxy_server = self._make_connection(client_factory, _get_test_protocol_factory()) # fish the transports back out so that we can do the old switcheroo s2c_transport = proxy_server.transport client_protocol = s2c_transport.other c2s_transport = client_protocol.transport # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending CONNECT request self.assertEqual(len(proxy_server.requests), 1) request = proxy_server.requests[0] self.assertEqual(request.method, b"CONNECT") self.assertEqual(request.path, b"test.com:443") # tell the proxy server not to close the connection proxy_server.persistent = True # this just stops the http Request trying to do a chunked response # request.setHeader(b"Content-Length", b"0") request.finish() # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel ssl_factory = _wrap_server_factory_for_tls( _get_test_protocol_factory()) ssl_protocol = ssl_factory.buildProtocol(None) http_server = ssl_protocol.wrappedProtocol ssl_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, ssl_protocol)) c2s_transport.other = ssl_protocol self.reactor.advance(0) server_name = ssl_protocol._tlsConnection.get_servername() expected_sni = b"test.com" self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result")
class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.mock_resolver = Mock() config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] self._config = config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") self.tls_factory = FederationPolicyForHTTPS(config) self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ) self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=self.tls_factory, user_agent= "test-agent", # Note that this is unused since _well_known_resolver is provided. ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) def _make_connection(self, client_factory, expected_sni): """Builds a test server, and completes the outgoing client connection Returns: HTTPChannel: the test server """ # build the test server server_tls_protocol = _build_test_server(get_connection_factory()) # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an # HTTP11ClientProtocol) and wire the output of said protocol up to the server via # a FakeTransport. # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( FakeTransport(server_tls_protocol, self.reactor, client_protocol)) # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, server_tls_protocol)) # grab a hold of the TLS connection, in case it gets torn down server_tls_connection = server_tls_protocol._tlsConnection # fish the test server back out of the server-side TLS protocol. http_protocol = server_tls_protocol.wrappedProtocol # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1, )) # check the SNI server_name = server_tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) return http_protocol @defer.inlineCallbacks def _make_get_request(self, uri): """ Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: fetch_d = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel _check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d return fetch_res except Exception as e: logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e) raise finally: _check_logcontext(context) def _handle_well_known_connection(self, client_factory, expected_sni, content, response_headers={}): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. Args: client_factory (IProtocolFactory): outgoing connection expected_sni (bytes): SNI that we expect the outgoing connection to send content (bytes): content to send back as the .well-known Returns: HTTPChannel: server impl """ # make the connection for .well-known well_known_server = self._make_connection(client_factory, expected_sni=expected_sni) # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) request = well_known_server.requests[0] self.assertEqual(request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]) self._send_well_known_response(request, content, headers=response_headers) return well_known_server def _send_well_known_response(self, request, content, headers={}): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/.well-known/matrix/server") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response for k, v in headers.items(): request.setHeader(k, v) request.write(content) request.finish() self.reactor.pump((0.1, )) def test_get(self): """ happy-path test of a GET request with an explicit port """ self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]) self.assertEqual(request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]) content = request.content.read() self.assertEqual(content, b"") # Deferred is still without a result self.assertNoResult(test_d) # send the headers request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) request.write("") self.reactor.pump((0.1, )) response = self.successResultOf(test_d) # that should give us a Response object self.assertEqual(response.code, 200) # Send the body request.write('{ "a": 1 }'.encode("ascii")) request.finish() self.reactor.pump((0.1, )) # check it can be read json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) def test_get_ip_address(self): """ Test the behaviour when the server name contains an explicit IP (with no port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"1.2.3.4"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_ipv6_address(self): """ Test the behaviour when the server name contains an explicit IPv6 address (with no port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" test_d = self._make_get_request(b"matrix://[::1]/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "::1") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_ipv6_address_with_port(self): """ Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) """ # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" test_d = self._make_get_request(b"matrix://[::1]:80/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "::1") self.assertEqual(port, 80) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]:80"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_hostname_bad_cert(self): """ Test the behaviour when the certificate on the server doesn't match the hostname """ self.mock_resolver.resolve_service.side_effect = generate_resolve_service( []) self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv1/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4, )) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv1") # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv1") # there should be no requests self.assertEqual(len(http_server.requests), 0) # ... and the request should have failed e = self.failureResultOf(test_d, ResponseNeverReceived) failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) def test_get_ip_address_bad_cert(self): """ Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it """ # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.5"] = "1.2.3.5" test_d = self._make_get_request(b"matrix://1.2.3.5/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=None) # there should be no requests self.assertEqual(len(http_server.requests), 0) # ... and the request should have failed e = self.failureResultOf(test_d, ResponseNeverReceived) failure_reason = e.value.reasons[0] self.assertIsInstance(failure_reason.value, VerificationError) def test_get_no_srv_no_well_known(self): """ Test the behaviour when the server name has no port, no SRV, and no well-known """ self.mock_resolver.resolve_service.side_effect = generate_resolve_service( []) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4, )) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_well_known(self): """Test the behaviour when the .well-known delegates elsewhere """ self.mock_resolver.resolve_service.side_effect = generate_resolve_service( []) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server") # now we should get a connection to the target server self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"target-server") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires self.reactor.pump((48 * 3600, )) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) def test_get_well_known_redirect(self): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ self.mock_resolver.resolve_service.side_effect = generate_resolve_service( []) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) redirect_server = self._make_connection(client_factory, expected_sni=b"testserv") # send a 302 redirect self.assertEqual(len(redirect_server.requests), 1) request = redirect_server.requests[0] request.redirect(b"https://testserv/even_better_known") request.finish() self.reactor.pump((0.1, )) # now there should be another connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(well_known_server.requests), 1, "No request after 302") request = well_known_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/even_better_known") request.write(b'{ "m.server": "target-server" }') request.finish() self.reactor.pump((0.1, )) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server") # now we should get a connection to the target server self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"target-server") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") # check the cache expires self.reactor.pump((48 * 3600, )) self.well_known_cache.expire() self.assertNotIn(b"testserv", self.well_known_cache) def test_get_invalid_well_known(self): """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ self.mock_resolver.resolve_service.side_effect = generate_resolve_service( []) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection(client_factory, expected_sni=b"testserv", content=b"NOT JSON") # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") # we should fall back to a direct connection self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_well_known_unsigned_cert(self): """Test the behaviour when the .well-known server presents a cert not signed by a CA """ # we use the same test server as the other tests, but use an agent with # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) self.mock_resolver.resolve_service.side_effect = generate_resolve_service( []) self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) # Build a new agent and WellKnownResolver with a different tls factory tls_factory = FederationPolicyForHTTPS(config) agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent= b"test-agent", # This is unused since _well_known_resolver is passed below. ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=tls_factory), b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ), ) test_d = agent.request(b"GET", b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) http_proto = self._make_connection(client_factory, expected_sni=b"testserv") # there should be no requests self.assertEqual(len(http_proto.requests), 0) # and there should be a SRV lookup instead self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") def test_get_hostname_srv(self): """ Test the behaviour when there is a single SRV record """ self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"srvtarget", port=8443)]) self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_get_well_known_srv(self): """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"srvtarget", port=8443)]) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.target-server") # now we should get a connection to the target of the SRV record self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, "5.6.7.8") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"target-server") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"target-server"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" self.mock_resolver.resolve_service.side_effect = generate_resolve_service( []) # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" # this is idna for bücher.com test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # No SRV record lookup yet self.mock_resolver.resolve_service.assert_not_called() # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # attemptdelay on the hostnameendpoint is 0.3, so takes that long before the # .well-known request fails. self.reactor.pump((0.4, )) # now there should have been a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.xn--bcher-kva.com") # We should fall back to port 8448 clients = self.reactor.tcpClients self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"xn--bcher-kva.com") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com ) self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") # Nothing happened yet self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.xn--bcher-kva.com") # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"xn--bcher-kva.com") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d) def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv")) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", response_headers={b"Cache-Control": b"max-age=1000"}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv")) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") # expire the cache self.reactor.pump((1000.0, )) # now it should connect again fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv")) self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", content=b'{ "m.server": "other-server" }', ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"other-server") def test_well_known_cache_with_temp_failure(self): """Test that we refetch well-known before the cache expires, and that it ignores transient errors. """ self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv")) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", response_headers={b"Cache-Control": b"max-age=1000"}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") # close the tcp connection well_known_server.loseConnection() # Get close to the cache expiry, this will cause the resolver to do # another lookup. self.reactor.pump((900.0, )) fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv")) # The resolver may retry a few times, so fonx all requests that come along attempts = 0 while self.reactor.tcpClients: clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) attempts += 1 # fonx the connection attempt, this will be treated as a temporary # failure. client_factory.clientConnectionFailed(None, Exception("nope")) # There's a few sleeps involved, so we have to pump the reactor a # bit. self.reactor.pump((1.0, 1.0)) # We expect to see more than one attempt as there was previously a valid # well known. self.assertGreater(attempts, 1) # Resolver should return cached value, despite the lookup failing. r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") # Expire both caches and repeat the request self.reactor.pump((10000.0, )) # Repated the request, this time it should fail if the lookup fails. fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv")) clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) client_factory.clientConnectionFailed(None, Exception("nope")) self.reactor.pump((0.4, )) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) def test_well_known_too_large(self): """A well-known query that returns a result which is too large should be rejected.""" self.reactor.lookups["testserv"] = "1.2.3.4" fetch_d = defer.ensureDeferred( self.well_known_resolver.get_well_known(b"testserv")) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( client_factory, expected_sni=b"testserv", response_headers={b"Cache-Control": b"max-age=1000"}, content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }', ) # The result is sucessful, but disabled delegation. r = self.successResultOf(fetch_d) self.assertIsNone(r.delegated_server) def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [ Server(host=b"target.com", port=8443), Server(host=b"target.com", port=8444), ]) self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( b"_matrix._tcp.testserv") # We should see an attempt to connect to the first server clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # Fonx the connection client_factory.clientConnectionFailed(None, Exception("nope")) # There's a 300ms delay in HostnameEndpoint self.reactor.pump((0.4, )) # Hasn't failed yet self.assertNoResult(test_d) # We shouldnow see an attempt to connect to the second server clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8444) # make a test server, and wire up the client http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/foo/bar") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() self.reactor.pump((0.1, )) self.successResultOf(test_d)
def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.clock = Clock(self.reactor) self.resource = CancellableDirectServeHtmlResource(self.clock) self.site = FakeSite(self.resource, self.reactor)
class FilterTestCase(unittest.TestCase): USER_ID = "@apple:test" EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' TO_REGISTER = [filter] def setUp(self): self.clock = MemoryReactorClock() self.hs_clock = Clock(self.clock) self.hs = setup_test_homeserver(self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock) self.auth = self.hs.get_auth() def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.USER_ID), "token_id": 1, "is_guest": False, } def get_user_by_req(request, allow_guest=False, rights="access"): return synapse.types.create_requester( UserID.from_string(self.USER_ID), 1, False, None) self.auth.get_user_by_access_token = get_user_by_access_token self.auth.get_user_by_req = get_user_by_req self.store = self.hs.get_datastore() self.filtering = self.hs.get_filtering() self.resource = JsonResource(self.hs) for r in self.TO_REGISTER: r.register_servlets(self.hs, self.resource) def test_add_filter(self): request, channel = make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) self.clock.advance(0) self.assertEquals(filter.result, self.EXAMPLE_FILTER) def test_add_filter_for_other_user(self): request, channel = make_request( "POST", "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), self.EXAMPLE_FILTER_JSON, ) render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"403") self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self): _is_mine = self.hs.is_mine self.hs.is_mine = lambda target_user: False request, channel = make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON, ) render(request, self.resource, self.clock) self.hs.is_mine = _is_mine self.assertEqual(channel.result["code"], b"403") self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): filter_id = self.filtering.add_user_filter( user_localpart="apple", user_filter=self.EXAMPLE_FILTER) self.clock.advance(1) filter_id = filter_id.result request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)) render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"200") self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self): request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)) render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"400") self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode # in errors.py def test_get_filter_invalid_id(self): request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)) render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"400") # No ID also returns an invalid_id error def test_get_filter_no_id(self): request, channel = make_request( "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)) render(request, self.resource, self.clock) self.assertEqual(channel.result["code"], b"400")
class TestMauLimit(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.clock = Clock(self.reactor) self.hs = setup_test_homeserver( self.addCleanup, "red", http_client=None, clock=self.clock, reactor=self.reactor, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.store = self.hs.get_datastore() self.hs.config.registrations_require_3pid = [] self.hs.config.enable_registration_captcha = False self.hs.config.recaptcha_public_key = [] self.hs.config.limit_usage_by_mau = True self.hs.config.hs_disabled = False self.hs.config.max_mau_value = 2 self.hs.config.mau_trial_days = 0 self.hs.config.server_notices_mxid = "@server:red" self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_room_name = "Test Server Notice Room" self.resource = JsonResource(self.hs) register.register_servlets(self.hs, self.resource) sync.register_servlets(self.hs, self.resource) def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") self.do_sync_for_user(token1) token2 = self.create_user("kermit2") self.do_sync_for_user(token2) # We've created and activated two users, we shouldn't be able to # register new users with self.assertRaises(SynapseError) as cm: self.create_user("kermit3") e = cm.exception self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_allowed_after_a_month_mau(self): # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") self.do_sync_for_user(token1) token2 = self.create_user("kermit2") self.do_sync_for_user(token2) # Advance time by 31 days self.reactor.advance(31 * 24 * 60 * 60) self.store.reap_monthly_active_users() self.reactor.advance(0) # We should be able to register more users token3 = self.create_user("kermit3") self.do_sync_for_user(token3) def test_trial_delay(self): self.hs.config.mau_trial_days = 1 # We should be able to register more than the limit initially token1 = self.create_user("kermit1") self.do_sync_for_user(token1) token2 = self.create_user("kermit2") self.do_sync_for_user(token2) token3 = self.create_user("kermit3") self.do_sync_for_user(token3) # Advance time by 2 days self.reactor.advance(2 * 24 * 60 * 60) # Two users should be able to sync self.do_sync_for_user(token1) self.do_sync_for_user(token2) # But the third should fail with self.assertRaises(SynapseError) as cm: self.do_sync_for_user(token3) e = cm.exception self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) # And new registrations are now denied too with self.assertRaises(SynapseError) as cm: self.create_user("kermit4") e = cm.exception self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_trial_users_cant_come_back(self): self.hs.config.mau_trial_days = 1 # We should be able to register more than the limit initially token1 = self.create_user("kermit1") self.do_sync_for_user(token1) token2 = self.create_user("kermit2") self.do_sync_for_user(token2) token3 = self.create_user("kermit3") self.do_sync_for_user(token3) # Advance time by 2 days self.reactor.advance(2 * 24 * 60 * 60) # Two users should be able to sync self.do_sync_for_user(token1) self.do_sync_for_user(token2) # Advance by 2 months so everyone falls out of MAU self.reactor.advance(60 * 24 * 60 * 60) self.store.reap_monthly_active_users() self.reactor.advance(0) # We can create as many new users as we want token4 = self.create_user("kermit4") self.do_sync_for_user(token4) token5 = self.create_user("kermit5") self.do_sync_for_user(token5) token6 = self.create_user("kermit6") self.do_sync_for_user(token6) # users 2 and 3 can come back to bring us back up to MAU limit self.do_sync_for_user(token2) self.do_sync_for_user(token3) # New trial users can still sync self.do_sync_for_user(token4) self.do_sync_for_user(token5) self.do_sync_for_user(token6) # But old user cant with self.assertRaises(SynapseError) as cm: self.do_sync_for_user(token1) e = cm.exception self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def create_user(self, localpart): request_data = json.dumps({ "username": localpart, "password": "******", "auth": { "type": LoginType.DUMMY }, }) request, channel = make_request(b"POST", b"/register", request_data) render(request, self.resource, self.reactor) if channel.result["code"] != b"200": raise HttpResponseException( int(channel.result["code"]), channel.result["reason"], channel.result["body"], ).to_synapse_error() access_token = channel.json_body["access_token"] return access_token def do_sync_for_user(self, token): request, channel = make_request(b"GET", b"/sync", access_token=token) render(request, self.resource, self.reactor) if channel.result["code"] != b"200": raise HttpResponseException( int(channel.result["code"]), channel.result["reason"], channel.result["body"], ).to_synapse_error()
class MessageAcceptTests(unittest.HomeserverTestCase): def setUp(self): self.http_client = Mock() self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) user_id = UserID("us", "test") our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room(our_user, room_creator._presets_dict["public_chat"], ratelimit=False))[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is most_recent = self.get_success( self.homeserver.get_datastore().get_latest_event_ids_in_room( self.room_id))[0] join_event = make_event_from_dict({ "room_id": self.room_id, "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", "event_id": "$join:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.member", "origin": "test.servx", "content": { "membership": "join" }, "auth_events": [], "prev_state": [(most_recent, {})], "prev_events": [(most_recent, {})], }) self.handler = self.homeserver.get_handlers().federation_handler self.handler.do_auth = lambda origin, event, context, auth_events: succeed( context) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus) # Send the join, it should return None (which is not an error) self.assertEqual( self.get_success( self.handler.on_receive_pdu("test.serv", join_event, sent_to_us_directly=True)), None, ) # Make sure we actually joined the room self.assertEqual( self.get_success( self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", ) def test_cant_hide_direct_ancestors(self): """ If you send a message, you must be able to provide the direct prev_events that said event references. """ async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} self.http_client.post_json = post_json # Figure out what the most recent event is most_recent = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id))[0] # Now lie about an event lying_event = make_event_from_dict({ "room_id": self.room_id, "sender": "@baduser:test.serv", "event_id": "one:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.message", "origin": "test.serv", "content": { "body": "hewwo?" }, "auth_events": [], "prev_events": [("two:test.serv", {}), (most_recent, {})], }) with LoggingContext(request="lying_event"): failure = self.get_failure( self.handler.on_receive_pdu("test.serv", lying_event, sent_to_us_directly=True), FederationError, ) # on_receive_pdu should throw an error self.assertEqual( failure.value.args[0], ("ERROR 403: Your server isn't divulging details about prev_events " "referenced in this event."), ) # Make sure the invalid event isn't there extrem = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id)) self.assertEqual(extrem[0], "$join:test.serv") def test_retry_device_list_resync(self): """Tests that device lists are marked as stale if they couldn't be synced, and that stale device lists are retried periodically. """ remote_user_id = "@john:test_remote" remote_origin = "test_remote" # Track the number of attempts to resync the user's device list. self.resync_attempts = 0 # When this function is called, increment the number of resync attempts (only if # we're querying devices for the right user ID), then raise a # NotRetryingDestination error to fail the resync gracefully. def query_user_devices(destination, user_id): if user_id == remote_user_id: self.resync_attempts += 1 raise NotRetryingDestination(0, 0, destination) # Register the mock on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( side_effect=query_user_devices) # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() store.get_rooms_for_user = Mock( return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. device_list_updater = self.homeserver.get_device_handler( ).device_list_updater self.get_success( device_list_updater.incoming_device_list_update( origin=remote_origin, edu_content={ "deleted": False, "device_display_name": "Mobile", "device_id": "QBUAZIFURK", "prev_id": [5], "stream_id": 6, "user_id": remote_user_id, }, )) # Check that there was one resync attempt. self.assertEqual(self.resync_attempts, 1) # Check that the resync attempt failed and caused the user's device list to be # marked as stale. need_resync = self.get_success( store.get_user_ids_requiring_device_list_resync()) self.assertIn(remote_user_id, need_resync) # Check that waiting for 30 seconds caused Synapse to retry resyncing the device # list. self.reactor.advance(30) self.assertEqual(self.resync_attempts, 2) def test_cross_signing_keys_retry(self): """Tests that resyncing a device list correctly processes cross-signing keys from the remote server. """ remote_user_id = "@john:test_remote" remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( return_value=succeed({ "user_id": remote_user_id, "stream_id": 1, "devices": [], "master_key": { "user_id": remote_user_id, "usage": ["master"], "keys": { "ed25519:" + remote_master_key: remote_master_key }, }, "self_signing_key": { "user_id": remote_user_id, "usage": ["self_signing"], "keys": { "ed25519:" + remote_self_signing_key: remote_self_signing_key }, }, })) # Resync the device list. device_handler = self.homeserver.get_device_handler() self.get_success( device_handler.device_list_updater.user_device_resync( remote_user_id), ) # Retrieve the cross-signing keys for this user. keys = self.get_success( self.store.get_e2e_cross_signing_keys_bulk( user_ids=[remote_user_id]), ) self.assertTrue(remote_user_id in keys) # Check that the master key is the one returned by the mock. master_key = keys[remote_user_id]["master"] self.assertEqual(len(master_key["keys"]), 1) self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) self.assertTrue(remote_master_key in master_key["keys"].values()) # Check that the self-signing key is the one returned by the mock. self_signing_key = keys[remote_user_id]["self_signing"] self.assertEqual(len(self_signing_key["keys"]), 1) self.assertTrue( "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), ) self.assertTrue( remote_self_signing_key in self_signing_key["keys"].values())
class MatrixFederationAgentTests(TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() def _make_connection(self, client_factory, server_factory, ssl=False, expected_sni=None): """Builds a test server, and completes the outgoing client connection Args: client_factory (interfaces.IProtocolFactory): the the factory that the application is trying to use to make the outbound connection. We will invoke it to build the client Protocol server_factory (interfaces.IProtocolFactory): a factory to build the server-side protocol ssl (bool): If true, we will expect an ssl connection and wrap server_factory with a TLSMemoryBIOFactory expected_sni (bytes|None): the expected SNI value Returns: IProtocol: the server Protocol returned by server_factory """ if ssl: server_factory = _wrap_server_factory_for_tls(server_factory) server_protocol = server_factory.buildProtocol(None) # now, tell the client protocol factory to build the client protocol, # and wire the output of said protocol up to the server via # a FakeTransport. # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol)) # tell the server protocol to send its stuff back to the client, too server_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, server_protocol)) if ssl: http_protocol = server_protocol.wrappedProtocol tls_connection = server_protocol._tlsConnection else: http_protocol = server_protocol tls_connection = None # give the reactor a pump to get the TLS juices flowing (if needed) self.reactor.advance(0) if expected_sni is not None: server_name = tls_connection.get_servername() self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) return http_protocol def test_http_request(self): agent = ProxyAgent(self.reactor) self.reactor.lookups["test.com"] = "1.2.3.4" d = agent.request(b"GET", b"http://test.com") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 80) # make a test server, and wire up the client http_server = self._make_connection(client_factory, _get_test_protocol_factory()) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") def test_https_request(self): agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) self.reactor.lookups["test.com"] = "1.2.3.4" d = agent.request(b"GET", b"https://test.com/abc") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # make a test server, and wire up the client http_server = self._make_connection( client_factory, _get_test_protocol_factory(), ssl=True, expected_sni=b"test.com", ) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") def test_http_request_via_proxy(self): agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888") self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"http://test.com") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8888) # make a test server, and wire up the client http_server = self._make_connection(client_factory, _get_test_protocol_factory()) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") def test_https_request_via_proxy(self): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), https_proxy=b"proxy.com", ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"https://test.com/abc") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 1080) # make a test HTTP server, and wire up the client proxy_server = self._make_connection(client_factory, _get_test_protocol_factory()) # fish the transports back out so that we can do the old switcheroo s2c_transport = proxy_server.transport client_protocol = s2c_transport.other c2s_transport = client_protocol.transport # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending CONNECT request self.assertEqual(len(proxy_server.requests), 1) request = proxy_server.requests[0] self.assertEqual(request.method, b"CONNECT") self.assertEqual(request.path, b"test.com:443") # tell the proxy server not to close the connection proxy_server.persistent = True # this just stops the http Request trying to do a chunked response # request.setHeader(b"Content-Length", b"0") request.finish() # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel ssl_factory = _wrap_server_factory_for_tls( _get_test_protocol_factory()) ssl_protocol = ssl_factory.buildProtocol(None) http_server = ssl_protocol.wrappedProtocol ssl_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, ssl_protocol)) c2s_transport.other = ssl_protocol self.reactor.advance(0) server_name = ssl_protocol._tlsConnection.get_servername() expected_sni = b"test.com" self.assertEqual( server_name, expected_sni, "Expected SNI %s but got %s" % (expected_sni, server_name), ) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result")
class UserRegisterTestCase(unittest.TestCase): def setUp(self): self.clock = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.clock) self.url = "/_matrix/client/r0/admin/register" self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() self.device_handler = Mock() self.device_handler.check_device_registered = Mock(return_value="FAKE") self.datastore = Mock(return_value=Mock()) self.datastore.get_current_state_deltas = Mock(return_value=[]) self.secrets = Mock() self.hs = setup_test_homeserver(self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock) self.hs.config.registration_shared_secret = u"shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() self.resource = JsonResource(self.hs) register_servlets(self.hs, self.resource) def test_disabled(self): """ If there is no shared secret, registration through this method will be prevented. """ self.hs.config.registration_shared_secret = None request, channel = make_request("POST", self.url, b'{}') render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Shared secret registration is not enabled', channel.json_body["error"]) def test_get_nonce(self): """ Calling GET on the endpoint will return a randomised nonce, using the homeserver's secrets provider. """ secrets = Mock() secrets.token_hex = Mock(return_value="abcd") self.hs.get_secrets = Mock(return_value=secrets) request, channel = make_request("GET", self.url) render(request, self.resource, self.clock) self.assertEqual(channel.json_body, {"nonce": "abcd"}) def test_expired_nonce(self): """ Calling GET on the endpoint will return a randomised nonce, which will only last for SALT_TIMEOUT (60s). """ request, channel = make_request("GET", self.url) render(request, self.resource, self.clock) nonce = channel.json_body["nonce"] # 59 seconds self.clock.advance(59) body = json.dumps({"nonce": nonce}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('username must be specified', channel.json_body["error"]) # 61 seconds self.clock.advance(2) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('unrecognised nonce', channel.json_body["error"]) def test_register_incorrect_nonce(self): """ Only the provided nonce can be used, as it's checked in the MAC. """ request, channel = make_request("GET", self.url) render(request, self.resource, self.clock) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps({ "nonce": nonce, "username": "******", "password": "******", "admin": True, "mac": want_mac, }) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): """ When the correct nonce is provided, and the right key is provided, the user is registered. """ request, channel = make_request("GET", self.url) render(request, self.resource, self.clock) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps({ "nonce": nonce, "username": "******", "password": "******", "admin": True, "mac": want_mac, }) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): """ A valid unrecognised nonce. """ request, channel = make_request("GET", self.url) render(request, self.resource, self.clock) nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps({ "nonce": nonce, "username": "******", "password": "******", "admin": True, "mac": want_mac, }) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('unrecognised nonce', channel.json_body["error"]) def test_missing_parts(self): """ Synapse will complain if you don't give nonce, username, password, and mac. Admin is optional. Additional checks are done for length and type. """ def nonce(): request, channel = make_request("GET", self.url) render(request, self.resource, self.clock) return channel.json_body["nonce"] # # Nonce check # # Must be present body = json.dumps({}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('nonce must be specified', channel.json_body["error"]) # # Username checks # # Must be present body = json.dumps({"nonce": nonce()}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('username must be specified', channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "******" * 1000}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid username', channel.json_body["error"]) # # Username checks # # Must be present body = json.dumps({"nonce": nonce(), "username": "******"}) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('password must be specified', channel.json_body["error"]) # Must be a string body = json.dumps({ "nonce": nonce(), "username": "******", "password": 1234 }) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) # Must not have null bytes body = json.dumps({ "nonce": nonce(), "username": "******", "password": u"abcd\u0000" }) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"]) # Super long body = json.dumps({ "nonce": nonce(), "username": "******", "password": "******" * 1000 }) request, channel = make_request("POST", self.url, body.encode('utf8')) render(request, self.resource, self.clock) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid password', channel.json_body["error"])
def setUp(self): self.http_client = Mock() self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, federation_http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) user_id = UserID("us", "test") our_user = create_requester(user_id) room_creator = self.homeserver.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room(our_user, room_creator._presets_dict["public_chat"], ratelimit=False))[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is most_recent = self.get_success( self.homeserver.get_datastore().get_latest_event_ids_in_room( self.room_id))[0] join_event = make_event_from_dict({ "room_id": self.room_id, "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", "event_id": "$join:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.member", "origin": "test.servx", "content": { "membership": "join" }, "auth_events": [], "prev_state": [(most_recent, {})], "prev_events": [(most_recent, {})], }) self.handler = self.homeserver.get_federation_handler() self.handler.do_auth = lambda origin, event, context, auth_events: succeed( context) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus) # Send the join, it should return None (which is not an error) self.assertEqual( self.get_success( self.handler.on_receive_pdu("test.serv", join_event, sent_to_us_directly=True)), None, ) # Make sure we actually joined the room self.assertEqual( self.get_success( self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", )
def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor )
class JsonResourceTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor ) def test_handler_for_request(self): """ JsonResource.handler_for_request gives correctly decoded URL args to the callback, while Twisted will give the raw bytes of URL query arguments. """ got_kwargs = {} def _callback(request, **kwargs): got_kwargs.update(kwargs) return (200, kwargs) res = JsonResource(self.homeserver) res.register_paths( "GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback ) request, channel = make_request( self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" ) render(request, res, self.reactor) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) def test_callback_direct_exception(self): """ If the web callback raises an uncaught exception, it will be translated into a 500. """ def _callback(request, **kwargs): raise Exception("boo") res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'500') def test_callback_indirect_exception(self): """ If the web callback raises an uncaught exception in a Deferred, it will be translated into a 500. """ def _throw(*args): raise Exception("boo") def _callback(request, **kwargs): d = Deferred() d.addCallback(_throw) self.reactor.callLater(1, d.callback, True) return make_deferred_yieldable(d) res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'500') def test_callback_synapseerror(self): """ If the web callback raises a SynapseError, it returns the appropriate status code and message set in it. """ def _callback(request, **kwargs): raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'403') self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_no_handler(self): """ If there is no handler to process the request, Synapse will return 400. """ def _callback(request, **kwargs): """ Not ever actually called! """ self.fail("shouldn't ever get here") res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
class JsonResourceTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver(self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor) def test_handler_for_request(self): """ JsonResource.handler_for_request gives correctly decoded URL args to the callback, while Twisted will give the raw bytes of URL query arguments. """ got_kwargs = {} def _callback(request, **kwargs): got_kwargs.update(kwargs) return (200, kwargs) res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") render(request, res, self.reactor) self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) def test_callback_direct_exception(self): """ If the web callback raises an uncaught exception, it will be translated into a 500. """ def _callback(request, **kwargs): raise Exception("boo") res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"500") def test_callback_indirect_exception(self): """ If the web callback raises an uncaught exception in a Deferred, it will be translated into a 500. """ def _throw(*args): raise Exception("boo") def _callback(request, **kwargs): d = Deferred() d.addCallback(_throw) self.reactor.callLater(1, d.callback, True) return make_deferred_yieldable(d) res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"500") def test_callback_synapseerror(self): """ If the web callback raises a SynapseError, it returns the appropriate status code and message set in it. """ def _callback(request, **kwargs): raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_no_handler(self): """ If there is no handler to process the request, Synapse will return 400. """ def _callback(request, **kwargs): """ Not ever actually called! """ self.fail("shouldn't ever get here") res = JsonResource(self.homeserver) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") render(request, res, self.reactor) self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
def setUp(self): self.http_client = Mock() self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room = room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False ) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) )[0] join_event = FrozenEvent( { "room_id": self.room_id, "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", "event_id": "$join:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.member", "origin": "test.servx", "content": {"membership": "join"}, "auth_events": [], "prev_state": [(most_recent, {})], "prev_events": [(most_recent, {})], } ) self.handler = self.homeserver.get_handlers().federation_handler self.handler.do_auth = lambda *a, **b: succeed(True) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus ) # Send the join, it should return None (which is not an error) d = self.handler.on_receive_pdu( "test.serv", join_event, sent_to_us_directly=True ) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) )[0], "$join:test.serv", )
def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock()
class MessageAcceptTests(unittest.TestCase): def setUp(self): self.http_client = Mock() self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room = room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id))[0] join_event = FrozenEvent({ "room_id": self.room_id, "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", "event_id": "$join:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.member", "origin": "test.servx", "content": { "membership": "join" }, "auth_events": [], "prev_state": [(most_recent, {})], "prev_events": [(most_recent, {})], }) self.handler = self.homeserver.get_handlers().federation_handler self.handler.do_auth = lambda *a, **b: succeed(True) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus) # Send the join, it should return None (which is not an error) d = self.handler.on_receive_pdu("test.serv", join_event, sent_to_us_directly=True) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id))[0], "$join:test.serv", ) def test_cant_hide_direct_ancestors(self): """ If you send a message, you must be able to provide the direct prev_events that said event references. """ def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} self.http_client.post_json = post_json # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id))[0] # Now lie about an event lying_event = FrozenEvent({ "room_id": self.room_id, "sender": "@baduser:test.serv", "event_id": "one:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.message", "origin": "test.serv", "content": "hewwo?", "auth_events": [], "prev_events": [("two:test.serv", {}), (most_recent, {})], }) d = self.handler.on_receive_pdu("test.serv", lying_event, sent_to_us_directly=True) # Step the reactor, so the database fetches come back self.reactor.advance(1) # on_receive_pdu should throw an error failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ("ERROR 403: Your server isn't divulging details about prev_events " "referenced in this event."), ) # Make sure the invalid event isn't there extrem = maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") def test_cant_hide_past_history(self): """ If you send a message, you must be able to provide the direct prev_events that said event references. """ def post_json(destination, path, data, headers=None, timeout=0): if path.startswith("/_matrix/federation/v1/get_missing_events/"): return { "events": [{ "room_id": self.room_id, "sender": "@baduser:test.serv", "event_id": "three:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.message", "origin": "test.serv", "content": "hewwo?", "auth_events": [], "prev_events": [("four:test.serv", {})], }] } self.http_client.post_json = post_json def get_json(destination, path, args, headers=None): if path.startswith("/_matrix/federation/v1/state_ids/"): d = self.successResultOf( self.homeserver.datastore.get_state_ids_for_event( "one:test.serv")) return succeed({ "pdu_ids": [ y for x, y in d.items() if x == ("m.room.member", "@us:test") ], "auth_chain_ids": list(d.values()), }) self.http_client.get_json = get_json # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id))[0] # Make a good event good_event = FrozenEvent({ "room_id": self.room_id, "sender": "@baduser:test.serv", "event_id": "one:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.message", "origin": "test.serv", "content": "hewwo?", "auth_events": [], "prev_events": [(most_recent, {})], }) d = self.handler.on_receive_pdu("test.serv", good_event, sent_to_us_directly=True) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) bad_event = FrozenEvent({ "room_id": self.room_id, "sender": "@baduser:test.serv", "event_id": "two:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.message", "origin": "test.serv", "content": "hewwo?", "auth_events": [], "prev_events": [("one:test.serv", {}), ("three:test.serv", {})], }) d = self.handler.on_receive_pdu("test.serv", bad_event, sent_to_us_directly=True) self.reactor.advance(1) extrem = maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") state = self.homeserver.get_state_handler().get_current_state_ids( self.room_id) self.reactor.advance(1) self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys())
class JsonResourceTests(unittest.TestCase): def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, federation_http_client=None, clock=self.hs_clock, reactor=self.reactor, ) def test_handler_for_request(self) -> None: """ JsonResource.handler_for_request gives correctly decoded URL args to the callback, while Twisted will give the raw bytes of URL query arguments. """ got_kwargs = {} def _callback( request: SynapseRequest, **kwargs: object ) -> Tuple[int, Dict[str, object]]: got_kwargs.update(kwargs) return 200, kwargs res = JsonResource(self.homeserver) res.register_paths( "GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback, "test_servlet", ) make_request( self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83", ) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) def test_callback_direct_exception(self) -> None: """ If the web callback raises an uncaught exception, it will be translated into a 500. """ def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn: raise Exception("boo") res = JsonResource(self.homeserver) res.register_paths( "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) channel = make_request( self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" ) self.assertEqual(channel.result["code"], b"500") def test_callback_indirect_exception(self) -> None: """ If the web callback raises an uncaught exception in a Deferred, it will be translated into a 500. """ def _throw(*args: object) -> NoReturn: raise Exception("boo") def _callback(request: SynapseRequest, **kwargs: object) -> "Deferred[None]": d: "Deferred[None]" = Deferred() d.addCallback(_throw) self.reactor.callLater(0.5, d.callback, True) return make_deferred_yieldable(d) res = JsonResource(self.homeserver) res.register_paths( "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) channel = make_request( self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" ) self.assertEqual(channel.result["code"], b"500") def test_callback_synapseerror(self) -> None: """ If the web callback raises a SynapseError, it returns the appropriate status code and message set in it. """ def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn: raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) res = JsonResource(self.homeserver) res.register_paths( "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) channel = make_request( self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" ) self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_no_handler(self) -> None: """ If there is no handler to process the request, Synapse will return 400. """ def _callback(request: SynapseRequest, **kwargs: object) -> None: """ Not ever actually called! """ self.fail("shouldn't ever get here") res = JsonResource(self.homeserver) res.register_paths( "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) channel = make_request( self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" ) self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") def test_head_request(self) -> None: """ JsonResource.handler_for_request gives correctly decoded URL args to the callback, while Twisted will give the raw bytes of URL query arguments. """ def _callback( request: SynapseRequest, **kwargs: object ) -> Tuple[int, Dict[str, object]]: return 200, {"result": True} res = JsonResource(self.homeserver) res.register_paths( "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet", ) # The path was registered as GET, but this is a HEAD request. channel = make_request( self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo" ) self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result)
def setUp(self): self.http_client = Mock() self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor, ) user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room = room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id))[0] join_event = FrozenEvent({ "room_id": self.room_id, "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", "event_id": "$join:test.serv", "depth": 1000, "origin_server_ts": 1, "type": "m.room.member", "origin": "test.servx", "content": { "membership": "join" }, "auth_events": [], "prev_state": [(most_recent, {})], "prev_events": [(most_recent, {})], }) self.handler = self.homeserver.get_handlers().federation_handler self.handler.do_auth = lambda *a, **b: succeed(True) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( pdus) # Send the join, it should return None (which is not an error) d = self.handler.on_receive_pdu("test.serv", join_event, sent_to_us_directly=True) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( self.successResultOf( maybeDeferred( self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id))[0], "$join:test.serv", )
def _test_disconnect( self, reactor: ThreadedMemoryReactorClock, channel: FakeChannel, expect_cancellation: bool, expected_body: Union[bytes, JsonDict], expected_code: Optional[int] = None, ) -> None: """Disconnects an in-flight request and checks the response. Args: reactor: The twisted reactor running the request handler. channel: The `FakeChannel` for the request. expect_cancellation: `True` if request processing is expected to be cancelled, `False` if the request should run to completion. expected_body: The expected response for the request. expected_code: The expected status code for the request. Defaults to `200` or `499` depending on `expect_cancellation`. """ # Determine the expected status code. if expected_code is None: if expect_cancellation: expected_code = HTTP_STATUS_REQUEST_CANCELLED else: expected_code = HTTPStatus.OK request = channel.request self.assertFalse( channel.is_finished(), "Request finished before we could disconnect - " "was `await_result=False` passed to `make_request`?", ) # We're about to disconnect the request. This also disconnects the channel, so # we have to rely on mocks to extract the response. respond_method: Callable[..., Any] if isinstance(expected_body, bytes): respond_method = respond_with_html_bytes else: respond_method = respond_with_json with mock.patch(f"synapse.http.server.{respond_method.__name__}", wraps=respond_method) as respond_mock: # Disconnect the request. request.connectionLost(reason=ConnectionDone()) if expect_cancellation: # An immediate cancellation is expected. respond_mock.assert_called_once() args, _kwargs = respond_mock.call_args code, body = args[1], args[2] self.assertEqual(code, expected_code) self.assertEqual(request.code, expected_code) self.assertEqual(body, expected_body) else: respond_mock.assert_not_called() # The handler is expected to run to completion. reactor.pump([1.0]) respond_mock.assert_called_once() args, _kwargs = respond_mock.call_args code, body = args[1], args[2] self.assertEqual(code, expected_code) self.assertEqual(request.code, expected_code) self.assertEqual(body, expected_body)
class MatrixFederationAgentTests(TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() def _make_connection( self, client_factory: IProtocolFactory, server_factory: IProtocolFactory, ssl: bool = False, expected_sni: Optional[bytes] = None, tls_sanlist: Optional[Iterable[bytes]] = None, ) -> IProtocol: """Builds a test server, and completes the outgoing client connection Args: client_factory: the the factory that the application is trying to use to make the outbound connection. We will invoke it to build the client Protocol server_factory: a factory to build the server-side protocol ssl: If true, we will expect an ssl connection and wrap server_factory with a TLSMemoryBIOFactory expected_sni: the expected SNI value tls_sanlist: list of SAN entries for the TLS cert presented by the server. Defaults to [b'DNS:test.com'] Returns: the server Protocol returned by server_factory """ if ssl: server_factory = _wrap_server_factory_for_tls( server_factory, tls_sanlist) server_protocol = server_factory.buildProtocol(None) # now, tell the client protocol factory to build the client protocol, # and wire the output of said protocol up to the server via # a FakeTransport. # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol)) # tell the server protocol to send its stuff back to the client, too server_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, server_protocol)) if ssl: http_protocol = server_protocol.wrappedProtocol tls_connection = server_protocol._tlsConnection else: http_protocol = server_protocol tls_connection = None # give the reactor a pump to get the TLS juices flowing (if needed) self.reactor.advance(0) if expected_sni is not None: server_name = tls_connection.get_servername() self.assertEqual( server_name, expected_sni, f"Expected SNI {expected_sni!s} but got {server_name!s}", ) return http_protocol def _test_request_direct_connection( self, agent: ProxyAgent, scheme: bytes, hostname: bytes, path: bytes, ): """Runs a test case for a direct connection not going through a proxy. Args: agent: the proxy agent being tested scheme: expected to be either "http" or "https" hostname: the hostname to connect to in the test path: the path to connect to in the test """ is_https = scheme == b"https" self.reactor.lookups[hostname.decode()] = "1.2.3.4" d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path) # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443 if is_https else 80) # make a test server, and wire up the client http_server = self._make_connection( client_factory, _get_test_protocol_factory(), ssl=is_https, expected_sni=hostname if is_https else None, ) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/" + path) self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") def test_http_request(self): agent = ProxyAgent(self.reactor) self._test_request_direct_connection(agent, b"http", b"test.com", b"") def test_https_request(self): agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") def test_http_request_use_proxy_empty_environment(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, { "http_proxy": "proxy.com:8888", "NO_PROXY": "test.com" }) def test_http_request_via_uppercase_no_proxy(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, { "http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com" }) def test_http_request_via_no_proxy(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, { "https_proxy": "proxy.com", "no_proxy": "test.com,unused.com" }) def test_https_request_via_no_proxy(self): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"}) def test_http_request_via_no_proxy_star(self): agent = ProxyAgent(self.reactor, use_proxy=True) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"}) def test_https_request_via_no_proxy_star(self): agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, { "http_proxy": "proxy.com:8888", "no_proxy": "unused.com" }) def test_http_request_via_proxy(self): """ Tests that requests can be made through a proxy. """ self._do_http_request_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) @patch.dict( os.environ, { "http_proxy": "bob:[email protected]:8888", "no_proxy": "unused.com" }, ) def test_http_request_via_proxy_with_auth(self): """ Tests that authenticated requests can be made through a proxy. """ self._do_http_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies") @patch.dict(os.environ, { "http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com" }) def test_http_request_via_https_proxy(self): self._do_http_request_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) @patch.dict( os.environ, { "http_proxy": "https://*****:*****@proxy.com:8888", "no_proxy": "unused.com", }, ) def test_http_request_via_https_proxy_with_auth(self): self._do_http_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies") @patch.dict(os.environ, { "https_proxy": "proxy.com", "no_proxy": "unused.com" }) def test_https_request_via_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) @patch.dict( os.environ, { "https_proxy": "bob:[email protected]", "no_proxy": "unused.com" }, ) def test_https_request_via_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies") @patch.dict(os.environ, { "https_proxy": "https://proxy.com", "no_proxy": "unused.com" }) def test_https_request_via_https_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) @patch.dict( os.environ, { "https_proxy": "https://*****:*****@proxy.com", "no_proxy": "unused.com" }, ) def test_https_request_via_https_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies") def _do_http_request_via_proxy( self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, ): """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: expect_proxy_ssl: True if we expect the request to connect via https to proxy expected_auth_credentials: credentials to authenticate at proxy """ if expect_proxy_ssl: agent = ProxyAgent(self.reactor, use_proxy=True, contextFactory=get_test_https_policy()) else: agent = ProxyAgent(self.reactor, use_proxy=True) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"http://test.com") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8888) # make a test server, and wire up the client http_server = self._make_connection( client_factory, _get_test_protocol_factory(), ssl=expect_proxy_ssl, tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] # Check whether auth credentials have been supplied to the proxy proxy_auth_header_values = request.requestHeaders.getRawHeaders( b"Proxy-Authorization") if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value self.assertIn(expected_header_value, proxy_auth_header_values) else: # Check that the Proxy-Authorization header has not been supplied to the proxy self.assertIsNone(proxy_auth_header_values) self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") def _do_https_request_via_proxy( self, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, ): """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: expect_proxy_ssl: True if we expect the request to connect via https to proxy expected_auth_credentials: credentials to authenticate at proxy """ agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"https://test.com/abc") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 1080) # make a test server to act as the proxy, and wire up the client proxy_server = self._make_connection( client_factory, _get_test_protocol_factory(), ssl=expect_proxy_ssl, tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) assert isinstance(proxy_server, HTTPChannel) # now there should be a pending CONNECT request self.assertEqual(len(proxy_server.requests), 1) request = proxy_server.requests[0] self.assertEqual(request.method, b"CONNECT") self.assertEqual(request.path, b"test.com:443") # Check whether auth credentials have been supplied to the proxy proxy_auth_header_values = request.requestHeaders.getRawHeaders( b"Proxy-Authorization") if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value self.assertIn(expected_header_value, proxy_auth_header_values) else: # Check that the Proxy-Authorization header has not been supplied to the proxy self.assertIsNone(proxy_auth_header_values) # tell the proxy server not to close the connection proxy_server.persistent = True request.finish() # now we make another test server to act as the upstream HTTP server. server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory()).buildProtocol(None) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport server_ssl_protocol.makeConnection(proxy_server_transport) # ... and replace the protocol on the proxy's transport with the # TLSMemoryBIOProtocol for the test server, so that incoming traffic # to the proxy gets sent over to the HTTP(s) server. # # This needs a bit of gut-wrenching, which is different depending on whether # the proxy is using TLS or not. # # (an alternative, possibly more elegant, approach would be to use a custom # Protocol to implement the proxy, which starts out by forwarding to an # HTTPChannel (to implement the CONNECT command) and can then be switched # into a mode where it forwards its traffic to another Protocol.) if expect_proxy_ssl: assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) proxy_server_transport.wrappedProtocol = server_ssl_protocol else: assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other c2s_transport = client_protocol.transport c2s_transport.other = server_ssl_protocol self.reactor.advance(0) server_name = server_ssl_protocol._tlsConnection.get_servername() expected_sni = b"test.com" self.assertEqual( server_name, expected_sni, f"Expected SNI {expected_sni!s} but got {server_name!s}", ) # now there should be a pending request http_server = server_ssl_protocol.wrappedProtocol self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) # Check that the destination server DID NOT receive proxy credentials proxy_auth_header_values = request.requestHeaders.getRawHeaders( b"Proxy-Authorization") self.assertIsNone(proxy_auth_header_values) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) def test_http_request_via_proxy_with_blacklist(self): # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper(self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])), self.reactor, use_proxy=True, ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"http://test.com") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8888) # make a test server, and wire up the client http_server = self._make_connection(client_factory, _get_test_protocol_factory()) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"}) def test_https_request_via_uppercase_proxy_with_blacklist(self): # The blacklist includes the configured proxy IP. agent = ProxyAgent( BlacklistingReactorWrapper(self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])), self.reactor, contextFactory=get_test_https_policy(), use_proxy=True, ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"https://test.com/abc") # there should be a pending TCP connection clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 1080) # make a test HTTP server, and wire up the client proxy_server = self._make_connection(client_factory, _get_test_protocol_factory()) # fish the transports back out so that we can do the old switcheroo s2c_transport = proxy_server.transport client_protocol = s2c_transport.other c2s_transport = client_protocol.transport # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) # now there should be a pending CONNECT request self.assertEqual(len(proxy_server.requests), 1) request = proxy_server.requests[0] self.assertEqual(request.method, b"CONNECT") self.assertEqual(request.path, b"test.com:443") # tell the proxy server not to close the connection proxy_server.persistent = True # this just stops the http Request trying to do a chunked response # request.setHeader(b"Content-Length", b"0") request.finish() # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel ssl_factory = _wrap_server_factory_for_tls( _get_test_protocol_factory()) ssl_protocol = ssl_factory.buildProtocol(None) http_server = ssl_protocol.wrappedProtocol ssl_protocol.makeConnection( FakeTransport(client_protocol, self.reactor, ssl_protocol)) c2s_transport.other = ssl_protocol self.reactor.advance(0) server_name = ssl_protocol._tlsConnection.get_servername() expected_sni = b"test.com" self.assertEqual( server_name, expected_sni, f"Expected SNI {expected_sni!s} but got {server_name!s}", ) # now there should be a pending request self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"/abc") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) request.write(b"result") request.finish() self.reactor.advance(0) resp = self.successResultOf(d) body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) def test_proxy_with_no_scheme(self): http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) def test_proxy_with_unsupported_scheme(self): with self.assertRaises(ValueError): ProxyAgent(self.reactor, use_proxy=True) @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) def test_proxy_with_http_scheme(self): http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) def test_proxy_with_https_scheme(self): https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) self.assertEqual( https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com") self.assertEqual( https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888)