예제 #1
0
 async def setUp(self):
     self._tserver = ServerRunner.patch(Server)
     self.tserver = await self._tserver.__aenter__()
     self.cli = Server(port=0, insecure=True)
     await self.tserver.start(self.cli.run())
     # Set up client
     self._client = aiohttp.ClientSession()
     self.session = await self._client.__aenter__()
예제 #2
0
 async def setUp(self):
     self.exit_stack = AsyncExitStack()
     await self.exit_stack.__aenter__()
     self.tserver = await self.exit_stack.enter_async_context(
         ServerRunner.patch(Server))
     self.cli = Server(port=0, insecure=True)
     await self.tserver.start(self.cli.run())
     # Set up client
     self.session = await self.exit_stack.enter_async_context(
         aiohttp.ClientSession())
예제 #3
0
class TestRoutesRunning:
    async def setUp(self):
        self._tserver = ServerRunner.patch(Server)
        self.tserver = await self._tserver.__aenter__()
        self.cli = Server(port=0, insecure=True)
        await self.tserver.start(self.cli.run())
        # Set up client
        self._client = aiohttp.ClientSession()
        self.session = await self._client.__aenter__()

    async def tearDown(self):
        await self._client.__aexit__(None, None, None)
        await self._tserver.__aexit__(None, None, None)

    @property
    def url(self):
        return f"http://{self.cli.addr}:{self.cli.port}"

    @asynccontextmanager
    async def get(self, path):
        async with self.session.get(self.url + path) as r:
            if r.status != HTTPStatus.OK:
                raise ServerException((await r.json())["error"])
            yield r

    @asynccontextmanager
    async def post(self, path, *args, **kwargs):
        async with self.session.post(self.url + path, *args, **kwargs) as r:
            if r.status != HTTPStatus.OK:
                raise ServerException((await r.json())["error"])
            yield r
예제 #4
0
class TestRoutesRunning:
    async def setUp(self):
        self.exit_stack = AsyncExitStack()
        await self.exit_stack.__aenter__()
        self.tserver = await self.exit_stack.enter_async_context(
            ServerRunner.patch(Server)
        )
        self.cli = Server(port=0, insecure=True)
        await self.tserver.start(self.cli.run())
        # Set up client
        self.session = await self.exit_stack.enter_async_context(
            aiohttp.ClientSession()
        )

    async def tearDown(self):
        await self.exit_stack.__aexit__(None, None, None)

    @property
    def url(self):
        return f"http://{self.cli.addr}:{self.cli.port}"

    @asynccontextmanager
    async def get(self, path):
        async with self.session.get(self.url + path) as r:
            if r.status != HTTPStatus.OK:
                raise ServerException((await r.json())["error"])
            yield r

    @asynccontextmanager
    async def post(self, path, *args, **kwargs):
        async with self.session.post(self.url + path, *args, **kwargs) as r:
            if r.status != HTTPStatus.OK:
                raise ServerException((await r.json())["error"])
            yield r

    @asynccontextmanager
    async def _add_memory_source(self):
        self.features = Features(DefFeature("by_ten", int, 1))
        async with MemorySource(
            MemorySourceConfig(
                repos=[
                    Repo(str(i), data={"features": {"by_ten": i * 10}})
                    for i in range(0, self.num_repos)
                ]
            )
        ) as source:
            self.source = self.cli.app["sources"][self.slabel] = source
            async with source() as sctx:
                self.sctx = self.cli.app["source_contexts"][self.slabel] = sctx
                yield

    @asynccontextmanager
    async def _add_fake_model(self):
        async with FakeModel(BaseConfig()) as model:
            self.model = self.cli.app["models"][self.mlabel] = model
            async with model(self.features) as mctx:
                self.mctx = self.cli.app["model_contexts"][self.mlabel] = mctx
                yield
예제 #5
0
class TestRoutesRunning:
    async def setUp(self):
        self.exit_stack = contextlib.AsyncExitStack()
        await self.exit_stack.__aenter__()
        self.tserver = await self.exit_stack.enter_async_context(
            ServerRunner.patch(Server))
        self.cli = Server(port=0, insecure=True)
        await self.tserver.start(self.cli.run())
        # Set up client
        self.session = await self.exit_stack.enter_async_context(
            aiohttp.ClientSession())

    async def tearDown(self):
        await self.exit_stack.__aexit__(None, None, None)

    @property
    def url(self):
        return f"http://{self.cli.addr}:{self.cli.port}"

    def check_allow_caching(self, r):
        for header, should_be in DISALLOW_CACHING.items():
            if not header in r.headers:
                raise Exception(f"No cache header {header} not in {r.headers}")
            if r.headers[header] != should_be:
                raise Exception(
                    f"No cache header {header} should have been {should_be!r} but was {r.headers[header]!r}"
                )

    @contextlib.asynccontextmanager
    async def get(self, path):
        async with self.session.get(self.url + path) as r:
            self.check_allow_caching(r)
            if r.status != HTTPStatus.OK:
                raise ServerException((await r.json())["error"])
            yield r

    @contextlib.asynccontextmanager
    async def post(self, path, *args, **kwargs):
        async with self.session.post(self.url + path, *args, **kwargs) as r:
            self.check_allow_caching(r)
            if r.status != HTTPStatus.OK:
                raise ServerException((await r.json())["error"])
            yield r

    @contextlib.asynccontextmanager
    async def _add_memory_source(self):
        async with MemorySource(records=[
                Record(str(i), data={"features": {
                    "by_ten": i * 10
                }}) for i in range(0, self.num_records)
        ]) as source:
            self.source = self.cli.app["sources"][self.slabel] = source
            async with source() as sctx:
                self.sctx = self.cli.app["source_contexts"][self.slabel] = sctx
                yield

    @contextlib.asynccontextmanager
    async def _add_fake_model(self):
        with tempfile.TemporaryDirectory() as tempdir:
            async with FakeModel(
                    FakeModelConfig(
                        location=tempdir,
                        features=Features(Feature("by_ten")),
                        predict=Feature("by_ten"),
                    )) as model:
                self.model = self.cli.app["models"][self.mlabel] = model
                async with model() as mctx:
                    self.mctx = self.cli.app["model_contexts"][
                        self.mlabel] = mctx
                    yield

    @contextlib.asynccontextmanager
    async def _add_fake_scorer(self):
        async with FakeScorer(FakeScorerConfig()) as scorer:
            self.scorer = self.cli.app["scorers"][self.alabel] = scorer
            async with scorer() as actx:
                self.actx = self.cli.app["scorer_contexts"][self.alabel] = actx
                yield