async def complete(request) -> web.Response: resp = { 'result': { 'status': 'finished', 'completions': [], } } registry = request.app['registry'] sess_id = request.match_info['sess_id'] access_key = request['keypair']['access_key'] try: with _timeout(2): params = await request.json(loads=json.loads) log.info('COMPLETE(u:{0}, k:{1})', access_key, sess_id) except (asyncio.TimeoutError, json.decoder.JSONDecodeError): log.warning('COMPLETE: invalid/missing parameters') raise InvalidAPIParameters try: code = params.get('code', '') opts = params.get('options', None) or {} await registry.increment_session_usage(sess_id, access_key) resp['result'] = await request.app['registry'].get_completions( sess_id, code, opts) except AssertionError: log.warning('COMPLETE: invalid/missing parameters') raise InvalidAPIParameters except BackendError: log.exception('COMPLETE: exception') raise return web.json_response(resp, status=200)
async def execute_snippet(request): resp = {} kern_id = request.match_info['kernel_id'] try: with _timeout(2): params = await request.json() log.info(f"EXECUTE(u:{request['keypair']['access_key']}, k:{kern_id})") except (asyncio.TimeoutError, json.decoder.JSONDecodeError): log.warning('EXECUTE: invalid/missing parameters') raise InvalidAPIParameters try: await request.app['registry'].increment_kernel_usage(kern_id) api_version = request['api_version'] if api_version == 1: mode = 'query' code = params['code'] opts = {} elif api_version == 2: mode = params['mode'] code = params.get('code', '') assert mode in ('query', 'batch') opts = params.get('options', None) or {} resp['result'] = await request.app['registry'].execute_snippet( kern_id, api_version, mode, code, opts) except Backend.AiError: log.exception('EXECUTE_SNIPPET: API Internal Error') raise return web.Response(status=200, content_type=_json_type, text=json.dumps(resp))
async def list_files(request) -> web.Response: try: access_key = request['keypair']['access_key'] sess_id = request.match_info['sess_id'] with _timeout(2): params = await request.json(loads=json.loads) path = params.get('path', '.') log.info('LIST_FILES (u:{0}, token:{1})', access_key, sess_id) except (asyncio.TimeoutError, AssertionError, json.decoder.JSONDecodeError) as e: log.warning('LIST_FILES: invalid/missing parameters, {0!r}', e) raise InvalidAPIParameters(extra_msg=str(e.args[0])) resp = {} try: registry = request.app['registry'] await registry.increment_session_usage(sess_id, access_key) result = await registry.list_files(sess_id, access_key, path) resp.update(result) log.debug('container file list for {0} retrieved', path) except BackendError: log.exception('LIST_FILES: exception') raise except Exception: request.app['error_monitor'].capture_exception() log.exception('LIST_FILES: unexpected error!') raise InternalServerError return web.json_response(resp, status=200)
async def RPCContext(addr, timeout=10): preserved_exceptions = ( NotFoundError, ParametersError, asyncio.TimeoutError, asyncio.CancelledError, asyncio.InvalidStateError, ) server = None try: server = await aiozmq.rpc.connect_rpc( connect=addr, error_table={ 'concurrent.futures._base.TimeoutError': asyncio.TimeoutError, }) server.transport.setsockopt(zmq.LINGER, 50) with _timeout(timeout): yield server except: exc_type, exc, tb = sys.exc_info() if issubclass(exc_type, GenericError): e = AgentError(exc.args[0], exc.args[1]) raise e.with_traceback(tb) elif issubclass(exc_type, preserved_exceptions): raise else: e = AgentError(exc_type, exc.args) raise e.with_traceback(tb) finally: if server: server.close()
async def download_files(request) -> web.Response: try: registry = request.app['registry'] sess_id = request.match_info['sess_id'] access_key = request['keypair']['access_key'] with _timeout(2): params = await request.json(loads=json.loads) assert params.get('files'), 'no file(s) specified!' files = params.get('files') log.info('DOWNLOAD_FILE (u:{0}, token:{1})', access_key, sess_id) except (asyncio.TimeoutError, AssertionError, json.decoder.JSONDecodeError) as e: log.warning('DOWNLOAD_FILE: invalid/missing parameters, {0!r}', e) raise InvalidAPIParameters(extra_msg=str(e.args[0])) try: assert len(files) <= 5, 'Too many files' await registry.increment_session_usage(sess_id, access_key) # TODO: Read all download file contents. Need to fix by using chuncking, etc. results = await asyncio.gather(*map( functools.partial(registry.download_file, sess_id, access_key), files)) log.debug('file(s) inside container retrieved') except BackendError: log.exception('DOWNLOAD_FILE: exception') raise except Exception: request.app['error_monitor'].capture_exception() log.exception('DOWNLOAD_FILE: unexpected error!') raise InternalServerError with aiohttp.MultipartWriter('mixed') as mpwriter: for tarbytes in results: mpwriter.append(tarbytes) return web.Response(body=mpwriter, status=200)
async def asend(self, *, sess=None, timeout=10.0): ''' Sends the request to the server. This method is a coroutine. ''' assert self.method in self._allowed_methods if sess is None: sess = aiohttp.ClientSession() else: assert isinstance(sess, aiohttp.ClientSession) with sess: if self.content_type == 'multipart/form-data': with aiohttp.MultipartWriter('mixed') as mpwriter: for file in self._content: part = mpwriter.append(file.file) part.set_content_disposition('attachment', filename=file.filename) data = mpwriter else: data = self._content self._sign() reqfunc = getattr(sess, self.method.lower()) try: with _timeout(timeout): resp = await reqfunc(self.build_url(), data=data, headers=self.headers) async with resp: body = await resp.read() return Response(resp.status, resp.reason, body, resp.content_type, len(body)) except Exception as e: msg = 'Request to the API endpoint has failed.' raise BackendClientError(msg) from e
async def __aenter__(self): self.server = await aiozmq.rpc.connect_rpc( connect=self.addr, error_table={ 'concurrent.futures._base.TimeoutError': asyncio.TimeoutError, }) self.server.transport.setsockopt(zmq.LINGER, 50) self.call = self.server.call self.t = _timeout(self.timeout) self.t.__enter__() return self
async def get_watcher_status(request: web.Request, params: Any) -> web.Response: log.info('GET_WATCHER_STATUS ()') watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: with _timeout(5.0): headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} async with sess.get(watcher_info['addr'], headers=headers) as resp: if resp.status == 200: data = await resp.json() return web.json_response(data, status=resp.status) else: data = await resp.text() return web.Response(text=data, status=resp.status)
async def watcher_agent_restart(request: web.Request, params: Any) -> web.Response: log.info('WATCHER_AGENT_RESTART ()') watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: with _timeout(20.0): watcher_url = watcher_info['addr'] / 'agent/restart' headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} async with sess.post(watcher_url, headers=headers) as resp: if resp.status == 200: data = await resp.json() return web.json_response(data, status=resp.status) else: data = await resp.text() return web.Response(text=data, status=resp.status)
async def create(request): try: with _timeout(2): params = await request.json(loads=json.loads) log.info(f"GET_OR_CREATE (u:{request['keypair']['access_key']}, " f"lang:{params['lang']}, token:{params['clientSessionToken']})") assert 8 <= len(params['clientSessionToken']) <= 80 except (asyncio.TimeoutError, AssertionError, KeyError, json.decoder.JSONDecodeError) as e: log.warning(f'GET_OR_CREATE: invalid/missing parameters, {e!r}') raise InvalidAPIParameters resp = {} try: access_key = request['keypair']['access_key'] concurrency_limit = request['keypair']['concurrency_limit'] async with request.app['dbpool'].acquire() as conn, conn.begin(): query = (sa.select([keypairs.c.concurrency_used], for_update=True) .select_from(keypairs) .where(keypairs.c.access_key == access_key)) concurrency_used = await conn.scalar(query) log.debug(f'access_key: {access_key} ' f'({concurrency_used} / {concurrency_limit})') if concurrency_used >= concurrency_limit: raise QuotaExceeded if request['api_version'] == 1: limits = params.get('resourceLimits', None) mounts = None elif request['api_version'] in (2, 3): limits = params.get('limits', None) mounts = params.get('mounts', None) kernel, created = await request.app['registry'].get_or_create_kernel( params['clientSessionToken'], params['lang'], access_key, limits, mounts, conn=conn) resp['kernelId'] = str(kernel['sess_id']) if created: query = (sa.update(keypairs) .values(concurrency_used=keypairs.c.concurrency_used + 1) .where(keypairs.c.access_key == access_key)) await conn.execute(query) except BackendError: log.exception('GET_OR_CREATE: exception') raise return web.json_response(resp, status=201, dumps=json.dumps)
async def watcher_agent_stop(request: web.Request, params: Any) -> web.Response: access_key = request['keypair']['access_key'] user_uuid = request['user']['uuid'] log.info('WATCHER_AGENT_STOP (u:[{0}], ak:[{1}])', user_uuid, access_key) watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: with _timeout(20.0): watcher_url = watcher_info['addr'] / 'agent/stop' headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} async with sess.post(watcher_url, headers=headers) as resp: if resp.status == 200: data = await resp.json() return web.json_response(data, status=resp.status) else: data = await resp.text() return web.Response(text=data, status=resp.status)
async def create(request): try: with _timeout(2): params = await request.json() log.info(f"GET_OR_CREATE (u:{request['keypair']['access_key']}, " f"lang:{params['lang']}, token:{params['clientSessionToken']})") assert 8 <= len(params['clientSessionToken']) <= 80 except (asyncio.TimeoutError, AssertionError, KeyError, json.decoder.JSONDecodeError) as e: log.warning(f'GET_OR_CREATE: invalid/missing parameters, {e!r}') raise InvalidAPIParameters resp = {} try: access_key = request['keypair']['access_key'] concurrency_limit = request['keypair']['concurrency_limit'] async with request.app.dbpool.acquire() as conn, conn.transaction(): query = (sa.select([KeyPair.c.concurrency_used], for_update=True) .select_from(KeyPair) .where(KeyPair.c.access_key == access_key)) concurrency_used = await conn.fetchval(query) log.debug(f'access_key: {access_key} ({concurrency_used} / {concurrency_limit})') if concurrency_used >= concurrency_limit: raise QuotaExceeded if request['api_version'] == 1: limits = params.get('resourceLimits', None) mounts = None elif request['api_version'] == 2: limits = params.get('limits', None) mounts = params.get('mounts', None) kernel, created = await request.app['registry'].get_or_create_kernel( params['clientSessionToken'], params['lang'], access_key, limits, mounts) resp['kernelId'] = kernel.id if created: query = (sa.update(KeyPair) .values(concurrency_used=KeyPair.c.concurrency_used + 1) .where(KeyPair.c.access_key == access_key)) await conn.execute(query) except Backend.AiError: log.exception('GET_OR_CREATE: API Internal Error') raise return web.Response(status=201, content_type=_json_type, text=json.dumps(resp))
async def RPCContext(addr, timeout=None): preserved_exceptions = ( NotFoundError, ParametersError, asyncio.TimeoutError, asyncio.CancelledError, asyncio.InvalidStateError, ) global agent_peers peer = agent_peers.get(addr, None) if peer is None: peer = await aiozmq.rpc.connect_rpc( connect=addr, error_table={ 'concurrent.futures._base.TimeoutError': asyncio.TimeoutError, }) peer.transport.setsockopt(zmq.LINGER, 1000) agent_peers[addr] = peer try: with _timeout(timeout): yield peer except Exception: exc_type, exc, tb = sys.exc_info() if issubclass(exc_type, GenericError): e = AgentError(exc.args[0], exc.args[1]) raise e.with_traceback(tb) elif issubclass(exc_type, TypeError): if exc.args[0] == "'NoneType' object is not iterable": log.warning('The agent has cancelled the operation ' 'or the kernel has terminated too quickly.') # In this case, you may need to use "--debug-skip-container-deletion" # CLI option in the agent and check out the container logs via # "docker logs" command to see what actually happened. else: e = AgentError(exc_type, exc.args) raise e.with_traceback(tb) elif issubclass(exc_type, preserved_exceptions): raise else: e = AgentError(exc_type, exc.args) raise e.with_traceback(tb)
async def execute(request): resp = {} sess_id = request.match_info['sess_id'] try: with _timeout(2): params = await request.json(loads=json.loads) log.info(f"EXECUTE(u:{request['keypair']['access_key']}, k:{sess_id})") except (asyncio.TimeoutError, json.decoder.JSONDecodeError): log.warning('EXECUTE: invalid/missing parameters') raise InvalidAPIParameters try: await request.app['registry'].increment_session_usage(sess_id) api_version = request['api_version'] if api_version == 1: mode = 'query' code = params['code'] run_id = params.get('runId', secrets.token_hex(8)) opts = {} elif api_version in (2, 3): mode = params['mode'] code = params.get('code', '') run_id = params.get('runId', secrets.token_hex(8)) mode = params['mode'] assert mode in ('query', 'batch', 'complete') opts = params.get('options', None) or {} if mode == 'complete': # For legacy resp['result'] = await request.app['registry'].get_completions( sess_id, 'query', code, opts) else: resp['result'] = await request.app['registry'].execute( sess_id, api_version, run_id, mode, code, opts) except AssertionError: log.warning('EXECUTE: invalid/missing parameters') raise InvalidAPIParameters except BackendError: log.exception('EXECUTE: exception') raise return web.json_response(resp, status=200, dumps=json.dumps)
async def _query(self, path, method='GET', params=None, timeout=None, data=None, headers=None, **kwargs): ''' Get the response object by performing the HTTP request. The caller is responsible to finalize the response object. ''' url = self._endpoint(path) try: with _timeout(timeout): response = await self.session.request( method, url, params=httpize(params), headers=headers, data=data, **kwargs) except asyncio.TimeoutError: raise if (response.status // 100) in [4, 5]: what = await response.read() response.close() raise DockerError(response.status, json.loads(what.decode('utf8'))) return response
async def complete(request): resp = {} sess_id = request.match_info['sess_id'] try: with _timeout(2): params = await request.json(loads=json.loads) log.info(f"COMPLETE(u:{request['keypair']['access_key']}, k:{sess_id})") except (asyncio.TimeoutError, json.decoder.JSONDecodeError): log.warning('COMPLETE: invalid/missing parameters') raise InvalidAPIParameters try: code = params.get('code', '') opts = params.get('options', None) or {} await request.app['registry'].increment_session_usage(sess_id) resp['result'] = await request.app['registry'].get_completions( sess_id, 'query', code, opts) except AssertionError: log.warning('COMPLETE: invalid/missing parameters') raise InvalidAPIParameters except BackendError: log.exception('COMPLETE: exception') raise return web.json_response(resp, status=200, dumps=json.dumps)
async def inner_timeout(): async with _timeout(timeout): return await coroutine return None
async def create(request): try: with _timeout(2): params = await request.json(loads=json.loads) log.info( f"GET_OR_CREATE (u:{request['keypair']['access_key']}, " f"lang:{params['lang']}, token:{params['clientSessionToken']})") assert 8 <= len(params['clientSessionToken']) <= 64, \ 'clientSessionToken is too short or long (8 to 64 bytes required)!' except (asyncio.TimeoutError, AssertionError, KeyError, json.decoder.JSONDecodeError) as e: log.warning(f'GET_OR_CREATE: invalid/missing parameters, {e!r}') raise InvalidAPIParameters resp = {} try: access_key = request['keypair']['access_key'] concurrency_limit = request['keypair']['concurrency_limit'] async with request.app['dbpool'].acquire() as conn, conn.begin(): query = (sa.select([keypairs.c.concurrency_used], for_update=True).select_from(keypairs).where( keypairs.c.access_key == access_key)) concurrency_used = await conn.scalar(query) log.debug(f'access_key: {access_key} ' f'({concurrency_used} / {concurrency_limit})') if concurrency_used >= concurrency_limit: raise QuotaExceeded creation_config = { 'mounts': None, 'environ': None, 'clusterSize': None, 'instanceMemory': None, 'instanceCores': None, 'instanceGPUs': None, } if request['api_version'] == 1: # custom resource limit unsupported pass elif request['api_version'] in (2, 3): creation_config.update(params.get('config', {})) # sanity check for vfolders if creation_config['mounts']: mount_details = [] for mount in creation_config['mounts']: query = (sa.select('*').select_from( vfolders).where((vfolders.c.belongs_to == access_key) & (vfolders.c.name == mount))) result = await conn.execute(query) row = await result.first() if row is None: raise FolderNotFound else: mount_details.append([row.name, row.host, row.id.hex]) creation_config['mounts'] = mount_details kernel, created = await request.app[ 'registry'].get_or_create_kernel(params['clientSessionToken'], params['lang'], access_key, creation_config, conn=conn) resp['kernelId'] = str(kernel['sess_id']) if created: query = (sa.update(keypairs).values( concurrency_used=keypairs.c.concurrency_used + 1).where(keypairs.c.access_key == access_key)) await conn.execute(query) except BackendError: log.exception('GET_OR_CREATE: exception') raise return web.json_response(resp, status=201, dumps=json.dumps)