async def test_handshake_expired_organization(backend, server_factory, expiredorg, alice, type): if type == "invited": ch = InvitedClientHandshake( organization_id=expiredorg.organization_id, invitation_type=InvitationType.USER, token=uuid4(), ) else: # authenticated ch = AuthenticatedClientHandshake( organization_id=expiredorg.organization_id, device_id=alice.device_id, user_signkey=alice.signing_key, root_verify_key=expiredorg.root_verify_key, ) with backend.event_bus.listen() as spy: async with server_factory(backend.handle_client) as server: stream = server.connection_factory() transport = await Transport.init_for_client( stream, server.addr.hostname) challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) await transport.send(answer_req) result_req = await transport.recv() with pytest.raises(HandshakeOrganizationExpired): ch.process_result_req(result_req) await spy.wait_with_timeout(BackendEvent.ORGANIZATION_EXPIRED)
def test_process_challenge_req_good_api_version( alice, monkeypatch, client_version, backend_version, valid ): # Cast parameters client_version = ApiVersion(*client_version) backend_version = ApiVersion(*backend_version) ch = AuthenticatedClientHandshake( alice.organization_id, alice.device_id, alice.signing_key, alice.root_verify_key ) req = { "handshake": "challenge", "challenge": b"1234567890", "supported_api_versions": [backend_version], } monkeypatch.setattr(ch, "SUPPORTED_API_VERSIONS", [client_version]) if not valid: # Invalid versioning with pytest.raises(HandshakeAPIVersionError) as context: ch.process_challenge_req(packb(req)) assert context.value.client_versions == [client_version] assert context.value.backend_versions == [backend_version] else: # Valid versioning ch.process_challenge_req(packb(req)) assert ch.challenge_data["supported_api_versions"] == [backend_version] assert ch.backend_api_version == backend_version assert ch.client_api_version == client_version
async def test_handshake_unknown_organization(backend, server_factory, organization_factory, alice, type): bad_org = organization_factory() if type == "invited": ch = InvitedClientHandshake( organization_id=bad_org.organization_id, invitation_type=InvitationType.USER, token=uuid4(), ) else: # authenticated ch = AuthenticatedClientHandshake( organization_id=bad_org.organization_id, device_id=alice.device_id, user_signkey=alice.signing_key, root_verify_key=bad_org.root_verify_key, ) async with server_factory(backend.handle_client) as server: stream = server.connection_factory() transport = await Transport.init_for_client(stream, server.addr.hostname) challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) await transport.send(answer_req) result_req = await transport.recv() with pytest.raises(HandshakeBadIdentity): ch.process_result_req(result_req)
def test_good_authenticated_handshake(alice): sh = ServerHandshake() ch = AuthenticatedClientHandshake( alice.organization_id, alice.device_id, alice.signing_key, alice.root_verify_key ) assert sh.state == "stalled" challenge_req = sh.build_challenge_req() assert sh.state == "challenge" answer_req = ch.process_challenge_req(challenge_req) sh.process_answer_req(answer_req) assert sh.state == "answer" assert sh.answer_type == HandshakeType.AUTHENTICATED assert sh.answer_data == { "answer": ANY, "client_api_version": API_V2_VERSION, "organization_id": alice.organization_id, "device_id": alice.device_id, "rvk": alice.root_verify_key, } result_req = sh.build_result_req(alice.verify_key) assert sh.state == "result" ch.process_result_req(result_req) assert sh.client_api_version == API_V2_VERSION
async def _backend_sock_factory(backend, auth_as, freeze_on_transport_error=True): async with server_factory(backend.handle_client) as server: stream = server.connection_factory() transport = await Transport.init_for_client( stream, server.addr.hostname) if freeze_on_transport_error: transport = FreezeTestOnTransportError(transport) if auth_as: # Handshake if isinstance(auth_as, OrganizationID): ch = AnonymousClientHandshake(auth_as) elif auth_as == "anonymous": # TODO: for legacy test, refactorise this ? ch = AnonymousClientHandshake(coolorg.organization_id) elif auth_as == "administration": ch = AdministrationClientHandshake( backend.config.administration_token) else: ch = AuthenticatedClientHandshake( auth_as.organization_id, auth_as.device_id, auth_as.signing_key, auth_as.root_verify_key, ) challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) await transport.send(answer_req) result_req = await transport.recv() ch.process_result_req(result_req) yield transport
async def connect_as_authenticated( addr: BackendOrganizationAddr, device_id: DeviceID, signing_key: SigningKey, keepalive: Optional[int] = None, ): handshake = AuthenticatedClientHandshake( organization_id=addr.organization_id, device_id=device_id, user_signkey=signing_key, root_verify_key=addr.root_verify_key, ) return await _connect(addr.hostname, addr.port, addr.use_ssl, keepalive, handshake)
async def _backend_sock_factory(backend, auth_as: LocalDevice, freeze_on_transport_error=True): async with backend_raw_transport_factory( backend, freeze_on_transport_error=freeze_on_transport_error ) as transport: # Handshake ch = AuthenticatedClientHandshake( auth_as.organization_id, auth_as.device_id, auth_as.signing_key, auth_as.root_verify_key, ) challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) await transport.send(answer_req) result_req = await transport.recv() ch.process_result_req(result_req) yield transport
async def test_authenticated_handshake_bad_rvk(backend, server_factory, alice, otherorg): ch = AuthenticatedClientHandshake( organization_id=alice.organization_id, device_id=alice.device_id, user_signkey=alice.signing_key, root_verify_key=otherorg.root_verify_key, ) async with server_factory(backend.handle_client) as server: stream = await server.connection_factory() transport = await Transport.init_for_client(stream, "127.0.0.1") challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) await transport.send(answer_req) result_req = await transport.recv() with pytest.raises(HandshakeRVKMismatch): ch.process_result_req(result_req)
async def test_authenticated_handshake_unknown_device(backend, server_factory, mallory): ch = AuthenticatedClientHandshake( organization_id=mallory.organization_id, device_id=mallory.device_id, user_signkey=mallory.signing_key, root_verify_key=mallory.root_verify_key, ) async with server_factory(backend.handle_client) as server: stream = await server.connection_factory() transport = await Transport.init_for_client(stream, "127.0.0.1") challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) await transport.send(answer_req) result_req = await transport.recv() with pytest.raises(HandshakeBadIdentity): ch.process_result_req(result_req)
def test_process_challenge_req_good_multiple_api_version( alice, monkeypatch, client_versions, backend_versions, expected_client_version, expected_backend_version, ): # Cast parameters client_versions = [ApiVersion(*args) for args in client_versions] backend_versions = [ApiVersion(*args) for args in backend_versions] if expected_client_version: expected_client_version = ApiVersion(*expected_client_version) if expected_backend_version: expected_backend_version = ApiVersion(*expected_backend_version) ch = AuthenticatedClientHandshake( alice.organization_id, alice.device_id, alice.signing_key, alice.root_verify_key ) req = { "handshake": "challenge", "challenge": b"1234567890", "supported_api_versions": list(backend_versions), "backend_timestamp": pendulum.now(), "ballpark_client_early_offset": BALLPARK_CLIENT_EARLY_OFFSET, "ballpark_client_late_offset": BALLPARK_CLIENT_LATE_OFFSET, } monkeypatch.setattr(ch, "SUPPORTED_API_VERSIONS", client_versions) if expected_client_version is None: # Invalid versioning with pytest.raises(HandshakeAPIVersionError) as context: ch.process_challenge_req(packb(req)) assert context.value.client_versions == client_versions assert context.value.backend_versions == backend_versions else: # Valid versioning ch.process_challenge_req(packb(req)) assert ch.challenge_data["supported_api_versions"] == list(backend_versions) assert ch.backend_api_version == expected_backend_version assert ch.client_api_version == expected_client_version
async def test_authenticated_handshake_good(backend, server_factory, alice): ch = AuthenticatedClientHandshake( organization_id=alice.organization_id, device_id=alice.device_id, user_signkey=alice.signing_key, root_verify_key=alice.root_verify_key, ) async with server_factory(backend.handle_client) as server: stream = await server.connection_factory() transport = await Transport.init_for_client(stream, "127.0.0.1") challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) await transport.send(answer_req) result_req = await transport.recv() ch.process_result_req(result_req) assert ch.client_api_version == API_VERSION assert ch.backend_api_version == API_VERSION
def test_process_challenge_req_bad_format(alice, req): ch = AuthenticatedClientHandshake( alice.organization_id, alice.device_id, alice.signing_key, alice.root_verify_key ) with pytest.raises(InvalidMessageError): ch.process_challenge_req(packb(req))
async def test_handle_client_coroutine_destroyed_on_client_left( backend, alice, close_on, clean_close, recwarn): # For this test we want to use a real TCP socket (instead of relying on # the `tcp_stream_spy` mock fixture) test the backend on outcome = None outcome_available = trio.Event() async def _handle_client_with_captured_outcome(stream): nonlocal outcome try: ret = await backend.handle_client(stream) except BaseException as exc: outcome = ("exception", exc) outcome_available.set() raise else: outcome = ("return", ret) outcome_available.set() return ret async with trio.open_nursery() as nursery: try: # Start server listeners = await nursery.start( trio.serve_tcp, _handle_client_with_captured_outcome, 0) # Client connect to the server client_stream = await open_stream_to_socket_listener(listeners[0]) async def _do_close_client(): if clean_close: await client_stream.aclose() else: # Reset the tcp socket instead of regular clean close # See https://stackoverflow.com/a/54065411 l_onoff = 1 l_linger = 0 client_stream.setsockopt( socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", l_onoff, l_linger)) client_stream.socket.close() with trio.fail_after(1): await outcome_available.wait() if close_on == "tcp_ready": await _do_close_client() else: if close_on == "before_http_request": # Send the beginning of an http request await client_stream.send_all(b"GET / HTTP/1.1\r\n") await _do_close_client() elif close_on in ("after_http_request"): # Send an entire http request await client_stream.send_all(b"GET / HTTP/1.0\r\n\r\n") # Peer will realize connection is closed after having sent # the answer for the previous request await _do_close_client() else: # First request doing websocket negotiation hostname = f"127.0.0.1:{listeners[0].socket.getsockname()}" transport = await Transport.init_for_client( client_stream, hostname) if close_on == "websocket_ready": await _do_close_client() else: # Client do the handshake ch = AuthenticatedClientHandshake( alice.organization_id, alice.device_id, alice.signing_key, alice.root_verify_key, ) challenge_req = await transport.recv() answer_req = ch.process_challenge_req(challenge_req) if close_on == "handshake_started": await _do_close_client() else: await transport.send(answer_req) result_req = await transport.recv() ch.process_result_req(result_req) assert close_on == "handshake_done" # Sanity check await _do_close_client() # Outcome should aways be the same assert outcome == ("return", None) finally: nursery.cancel_scope.cancel()
async def connect( addr: Union[BackendAddr, BackendOrganizationBootstrapAddr, BackendOrganizationAddr], device_id: Optional[DeviceID] = None, signing_key: Optional[SigningKey] = None, administration_token: Optional[str] = None, keepalive: Optional[int] = None, ) -> Transport: """ Raises: BackendConnectionError """ if administration_token: if not isinstance(addr, BackendAddr): raise BackendConnectionError(f"Invalid url format `{addr}`") handshake = AdministrationClientHandshake(administration_token) elif not device_id: if isinstance(addr, BackendOrganizationBootstrapAddr): handshake = AnonymousClientHandshake(addr.organization_id) elif isinstance(addr, BackendOrganizationAddr): handshake = AnonymousClientHandshake(addr.organization_id, addr.root_verify_key) else: raise BackendConnectionError( f"Invalid url format `{addr}` " "(should be an organization url or organization bootstrap url)" ) else: if not isinstance(addr, BackendOrganizationAddr): raise BackendConnectionError( f"Invalid url format `{addr}` (should be an organization url)") if not signing_key: raise BackendConnectionError( f"Missing signing_key to connect as `{device_id}`") handshake = AuthenticatedClientHandshake(addr.organization_id, device_id, signing_key, addr.root_verify_key) try: stream = await trio.open_tcp_stream(addr.hostname, addr.port) except OSError as exc: logger.debug("Impossible to connect to backend", reason=exc) raise BackendNotAvailable(exc) from exc if addr.use_ssl: stream = _upgrade_stream_to_ssl(stream, addr.hostname) try: transport = await Transport.init_for_client(stream, host=addr.hostname) transport.handshake = handshake transport.keepalive = keepalive except TransportError as exc: logger.debug("Connection lost during transport creation", reason=exc) raise BackendNotAvailable(exc) from exc try: await _do_handshake(transport, handshake) except Exception as exc: transport.logger.debug("Connection lost during handshake", reason=exc) await transport.aclose() raise return transport