コード例 #1
0
async def wait_for_url(client: aiohttp.ClientSession, url: str, method: str):
    while True:
        try:
            async with client.options(url) as response:
                if method in response.headers['Allow']:
                    return
        except aiohttp.ClientConnectorError:
            pass
        await asyncio.sleep(0.1)
コード例 #2
0
ファイル: server.py プロジェクト: jh3mail/susnote
class Client:
    _client = None

    def __init__(self, loop, url=None):
        self._client = ClientSession(loop=loop)
        self._url = url

    @property
    def cli(self):
        return self._client

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        pass

    def handler_url(self, url):
        if url.startswith("http"):
            return url
        if self._url:
            return "{}{}".format(self._url, url)
        return url

    def request(self, method, url, *args, **kwargs):
        return self._client.request(method, self.handler_url(url), *args,
                                    **kwargs)

    def get(self, url, allow_redirects=True, **kwargs):
        return self._client.get(self.handler_url(url),
                                allow_redirects=True,
                                **kwargs)

    def post(self, url, data=None, **kwargs):
        return self._client.post(self.handler_url(url), data=data, **kwargs)

    def put(self, url, data=None, **kwargs):
        return self._client.put(self.handler_url(url), data=data, **kwargs)

    def delete(self, url, **kwargs):
        return self._client.delete(self.handler_url(url), **kwargs)

    def head(self, url, allow_redirects=False, **kwargs):
        return self._client.head(self.handler_url(url),
                                 allow_redirects=allow_redirects,
                                 **kwargs)

    def options(self, url, allow_redirects=True, **kwargs):
        return self._client.options(self.handler_url(url),
                                    allow_redirects=allow_redirects,
                                    **kwargs)

    def close(self):
        self._client.close()
コード例 #3
0
class TestAdminServer(AsyncTestCase):
    async def setUp(self):
        self.message_results = []
        self.webhook_results = []
        self.port = 0

        self.connector = TCPConnector(limit=16, limit_per_host=4)
        session_args = {
            "cookie_jar": DummyCookieJar(),
            "connector": self.connector
        }
        self.client_session = ClientSession(cookie_jar=DummyCookieJar(),
                                            connector=self.connector)

    async def tearDown(self):
        if self.client_session:
            await self.client_session.close()
            self.client_session = None

    async def test_debug_middleware(self):
        with async_mock.patch.object(test_module, "LOGGER",
                                     async_mock.MagicMock()) as mock_logger:
            mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True)
            mock_logger.debug = async_mock.MagicMock()

            request = async_mock.MagicMock(
                method="GET",
                path_qs="/hello/world?a=1&b=2",
                match_info={"match": "info"},
                text=async_mock.CoroutineMock(return_value="abc123"),
            )
            handler = async_mock.CoroutineMock()

            await test_module.debug_middleware(request, handler)
            mock_logger.isEnabledFor.assert_called_once()
            assert mock_logger.debug.call_count == 3

    async def test_ready_middleware(self):
        with async_mock.patch.object(test_module, "LOGGER",
                                     async_mock.MagicMock()) as mock_logger:
            mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True)
            mock_logger.debug = async_mock.MagicMock()
            mock_logger.info = async_mock.MagicMock()
            mock_logger.error = async_mock.MagicMock()

            request = async_mock.MagicMock(
                rel_url="/", app=async_mock.MagicMock(_state={"ready": False}))
            handler = async_mock.CoroutineMock(return_value="OK")
            with self.assertRaises(test_module.web.HTTPServiceUnavailable):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            assert await test_module.ready_middleware(request, handler) == "OK"

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=test_module.LedgerConfigError("Bad config"))
            with self.assertRaises(test_module.LedgerConfigError):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=test_module.web.HTTPFound(location="/api/doc"))
            with self.assertRaises(test_module.web.HTTPFound):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=test_module.asyncio.CancelledError("Cancelled"))
            with self.assertRaises(test_module.asyncio.CancelledError):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=KeyError("No such thing"))
            with self.assertRaises(KeyError):
                await test_module.ready_middleware(request, handler)

    def get_admin_server(self,
                         settings: dict = None,
                         context: InjectionContext = None) -> AdminServer:
        if not context:
            context = InjectionContext()
        if settings:
            context.update_settings(settings)

        # middleware is task queue xor collector: cover both over test suite
        task_queue = (settings or {}).pop("task_queue", None)

        plugin_registry = async_mock.MagicMock(test_module.PluginRegistry,
                                               autospec=True)
        plugin_registry.post_process_routes = async_mock.MagicMock()
        context.injector.bind_instance(test_module.PluginRegistry,
                                       plugin_registry)

        collector = Collector()
        context.injector.bind_instance(test_module.Collector, collector)

        profile = InMemoryProfile.test_profile()

        self.port = unused_port()
        return AdminServer(
            "0.0.0.0",
            self.port,
            context,
            profile,
            self.outbound_message_router,
            self.webhook_router,
            conductor_stop=async_mock.CoroutineMock(),
            task_queue=TaskQueue(max_active=4) if task_queue else None,
            conductor_stats=(None if task_queue else async_mock.CoroutineMock(
                return_value={"a": 1})),
        )

    async def outbound_message_router(self, *args):
        self.message_results.append(args)

    def webhook_router(self, *args):
        self.webhook_results.append(args)

    async def test_start_stop(self):
        with self.assertRaises(AssertionError):
            await self.get_admin_server().start()

        settings = {"admin.admin_insecure_mode": False}
        with self.assertRaises(AssertionError):
            await self.get_admin_server(settings).start()

        settings = {
            "admin.admin_insecure_mode": True,
            "admin.admin_api_key": "test-api-key",
        }
        with self.assertRaises(AssertionError):
            await self.get_admin_server(settings).start()

        settings = {
            "admin.admin_insecure_mode": False,
            "admin.admin_client_max_request_size": 4,
            "admin.admin_api_key": "test-api-key",
        }
        server = self.get_admin_server(settings)
        await server.start()
        assert server.app._client_max_size == 4 * 1024 * 1024
        with async_mock.patch.object(server, "websocket_queues",
                                     async_mock.MagicMock()) as mock_wsq:
            mock_wsq.values = async_mock.MagicMock(return_value=[
                async_mock.MagicMock(stop=async_mock.MagicMock())
            ])
            await server.stop()

        with async_mock.patch.object(web.TCPSite, "start",
                                     async_mock.CoroutineMock()) as mock_start:
            mock_start.side_effect = OSError("Failure to launch")
            with self.assertRaises(AdminSetupError):
                await self.get_admin_server(settings).start()

    async def test_import_routes(self):
        # this test just imports all default admin routes
        # for routes with associated tests, this shouldn't make a difference in coverage
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        context.injector.bind_instance(GoalCodeRegistry, GoalCodeRegistry())
        await DefaultContextBuilder().load_plugins(context)
        server = self.get_admin_server({"admin.admin_insecure_mode": True},
                                       context)
        app = await server.make_application()

    async def test_import_routes_multitenant_middleware(self):
        # imports all default admin routes
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        context.injector.bind_instance(GoalCodeRegistry, GoalCodeRegistry())
        profile = InMemoryProfile.test_profile()
        context.injector.bind_instance(
            test_module.BaseMultitenantManager,
            test_module.BaseMultitenantManager(profile),
        )
        await DefaultContextBuilder().load_plugins(context)
        server = self.get_admin_server(
            {
                "admin.admin_insecure_mode": False,
                "admin.admin_api_key": "test-api-key",
            },
            context,
        )

        # cover multitenancy start code
        app = await server.make_application()
        app["swagger_dict"] = {}
        await server.on_startup(app)

        # multitenant authz
        [mt_authz_middle] = [
            m for m in app.middlewares
            if ".check_multitenant_authorization" in str(m)
        ]

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Bearer ..."},
            path="/multitenancy/etc",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with self.assertRaises(test_module.web.HTTPUnauthorized):
            await mt_authz_middle(mock_request, None)

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with self.assertRaises(test_module.web.HTTPUnauthorized):
            await mt_authz_middle(mock_request, None)

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Bearer ..."},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        mock_handler = async_mock.CoroutineMock()
        await mt_authz_middle(mock_request, mock_handler)
        assert mock_handler.called_once_with(mock_request)

        # multitenant setup context exception paths
        [setup_ctx_middle
         ] = [m for m in app.middlewares if ".setup_context" in str(m)]

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Non-bearer ..."},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with self.assertRaises(test_module.web.HTTPUnauthorized):
            await setup_ctx_middle(mock_request, None)

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Bearer ..."},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with async_mock.patch.object(
                server.multitenant_manager,
                "get_profile_for_token",
                async_mock.CoroutineMock(),
        ) as mock_get_profile:
            mock_get_profile.side_effect = [
                test_module.MultitenantManagerError("corrupt token"),
                test_module.StorageNotFoundError("out of memory"),
            ]
            for i in range(2):
                with self.assertRaises(test_module.web.HTTPUnauthorized):
                    await setup_ctx_middle(mock_request, None)

    async def test_register_external_plugin_x(self):
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        with self.assertRaises(ValueError):
            builder = DefaultContextBuilder(
                settings={"external_plugins": "aries_cloudagent.nosuchmodule"})
            await builder.load_plugins(context)

    async def test_visit_insecure_mode(self):
        settings = {"admin.admin_insecure_mode": True, "task_queue": True}
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.post(
                f"http://127.0.0.1:{self.port}/status/reset",
                headers={}) as response:
            assert response.status == 200

        async with self.client_session.ws_connect(
                f"http://127.0.0.1:{self.port}/ws") as ws:
            result = await ws.receive_json()
            assert result["topic"] == "settings"

        for path in (
                "",
                "plugins",
                "status",
                "status/live",
                "status/ready",
                "shutdown",  # mock conductor has magic-mock stop()
        ):
            async with self.client_session.get(
                    f"http://127.0.0.1:{self.port}/{path}",
                    headers={}) as response:
                assert response.status == 200

        await server.stop()

    async def test_visit_secure_mode(self):
        settings = {
            "admin.admin_insecure_mode": False,
            "admin.admin_api_key": "test-api-key",
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={"x-api-key": "wrong-key"}) as response:
            assert response.status == 401

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={"x-api-key": "test-api-key"},
        ) as response:
            assert response.status == 200

        # Make sure that OPTIONS requests used by browsers for CORS
        # are allowed without a x-api-key even when x-api-key security is enabled
        async with self.client_session.options(
                f"http://127.0.0.1:{self.port}/status",
                headers={
                    "Access-Control-Request-Headers": "x-api-key",
                    "Access-Control-Request-Method": "GET",
                    "Connection": "keep-alive",
                    "Host": f"http://127.0.0.1:{self.port}/status",
                    "Origin": "http://localhost:3000",
                    "Referer": "http://localhost:3000/",
                    "Sec-Fetch-Dest": "empty",
                    "Sec-Fetch-Mode": "cors",
                    "Sec-Fetch-Site": "same-site",
                },
        ) as response:
            assert response.status == 200
            assert response.headers[
                "Access-Control-Allow-Credentials"] == "true"
            assert response.headers[
                "Access-Control-Allow-Headers"] == "X-API-KEY"
            assert response.headers["Access-Control-Allow-Methods"] == "GET"
            assert (response.headers["Access-Control-Allow-Origin"] ==
                    "http://localhost:3000")

        async with self.client_session.ws_connect(
                f"http://127.0.0.1:{self.port}/ws",
                headers={"x-api-key": "test-api-key"}) as ws:
            result = await ws.receive_json()
            assert result["topic"] == "settings"

        await server.stop()

    async def test_query_config(self):
        settings = {
            "admin.admin_insecure_mode":
            False,
            "admin.admin_api_key":
            "test-api-key",
            "admin.webhook_urls":
            ["localhost:8123/abc#secret", "localhost:8123/def"],
            "multitenant.jwt_secret":
            "abc123",
            "wallet.key":
            "abc123",
            "wallet.rekey":
            "def456",
            "wallet.seed":
            "00000000000000000000000000000000",
            "wallet.storage.creds":
            "secret",
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/config",
                headers={"x-api-key": "test-api-key"},
        ) as response:
            config = json.loads(await response.text())["config"]
            assert "admin.admin_insecure_mode" in config
            assert all(k not in config for k in [
                "admin.admin_api_key",
                "multitenant.jwt_secret",
                "wallet.key",
                "wallet.rekey",
                "wallet.seed",
                "wallet.storage_creds",
            ])
            assert config["admin.webhook_urls"] == [
                "localhost:8123/abc",
                "localhost:8123/def",
            ]

    async def test_visit_shutting_down(self):
        settings = {
            "admin.admin_insecure_mode": True,
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/shutdown",
                headers={}) as response:
            assert response.status == 200

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={}) as response:
            assert response.status == 503

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/live",
                headers={}) as response:
            assert response.status == 200
        await server.stop()

    async def test_server_health_state(self):
        settings = {
            "admin.admin_insecure_mode": True,
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/live",
                headers={}) as response:
            assert response.status == 200
            response_json = await response.json()
            assert response_json["alive"]

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/ready",
                headers={}) as response:
            assert response.status == 200
            response_json = await response.json()
            assert response_json["ready"]

        server.notify_fatal_error()
        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/live",
                headers={}) as response:
            assert response.status == 503

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/ready",
                headers={}) as response:
            assert response.status == 503
        await server.stop()
コード例 #4
0
class AioHttpClient(HttpClient):
    def __init__(self,
                 *,
                 connector=None,
                 loop=None,
                 cookies=None,
                 headers=None,
                 skip_auto_headers=None,
                 auth=None,
                 json_serialize=json.dumps,
                 request_class=ClientRequest,
                 response_class=ClientResponse,
                 ws_response_class=ClientWebSocketResponse,
                 version=http.HttpVersion11,
                 cookie_jar=None,
                 connector_owner=True,
                 raise_for_status=False,
                 read_timeout=sentinel,
                 conn_timeout=None,
                 auto_decompress=True,
                 trust_env=False,
                 **kwargs):
        """
        The class packaging a class ClientSession to perform HTTP request and manager that these HTTP connection.

        For details of the params: http://aiohttp.readthedocs.io/en/stable/client_advanced.html#client-session
        """
        super(AioHttpClient, self).__init__(**kwargs)
        self.client = ClientSession(connector=connector,
                                    loop=loop,
                                    cookies=cookies,
                                    headers=headers,
                                    skip_auto_headers=skip_auto_headers,
                                    auth=auth,
                                    json_serialize=json_serialize,
                                    request_class=request_class,
                                    response_class=response_class,
                                    ws_response_class=ws_response_class,
                                    version=version,
                                    cookie_jar=cookie_jar,
                                    connector_owner=connector_owner,
                                    raise_for_status=raise_for_status,
                                    read_timeout=read_timeout,
                                    conn_timeout=conn_timeout,
                                    auto_decompress=auto_decompress,
                                    trust_env=trust_env)

    def request(self, method, url, *args, **kwargs):
        return self.client.request(method=method, url=url, **kwargs)

    def get(self, url, *args, **kwargs):
        return self.client.get(url=url, **kwargs)

    def post(self, url, *args, data=None, **kwargs):
        return self.client.post(url=url, data=data, **kwargs)

    def put(self, url, *args, data=None, **kwargs):
        return self.client.put(url=url, data=data, **kwargs)

    def delete(self, url, *args, **kwargs):
        return self.client.delete(url=url, **kwargs)

    def options(self, url, *args, **kwargs):
        return self.client.options(url=url, **kwargs)

    def head(self, url, *args, **kwargs):
        return self.client.head(url=url, **kwargs)

    def patch(self, url, *args, data=None, **kwargs):
        return self.client.patch(url=url, data=data, **kwargs)

    async def close(self):
        await self.client.close()

    async def get_response(self, response):
        text = await response.text()
        return Response(url=response.url,
                        status=response.status,
                        charset=response.charset,
                        content_type=response.content_type,
                        content_length=response.content_length,
                        reason=response.reason,
                        headers=response.headers,
                        text=text,
                        selector=etree.HTML(text))

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()