Ejemplo n.º 1
0
async def complete(request) -> web.Response:
    resp = {
        'result': {
            'status': 'finished',
            'completions': [],
        }
    }
    registry = request.app['registry']
    sess_id = request.match_info['sess_id']
    access_key = request['keypair']['access_key']
    try:
        with _timeout(2):
            params = await request.json(loads=json.loads)
        log.info('COMPLETE(u:{0}, k:{1})', access_key, sess_id)
    except (asyncio.TimeoutError, json.decoder.JSONDecodeError):
        log.warning('COMPLETE: invalid/missing parameters')
        raise InvalidAPIParameters
    try:
        code = params.get('code', '')
        opts = params.get('options', None) or {}
        await registry.increment_session_usage(sess_id, access_key)
        resp['result'] = await request.app['registry'].get_completions(
            sess_id, code, opts)
    except AssertionError:
        log.warning('COMPLETE: invalid/missing parameters')
        raise InvalidAPIParameters
    except BackendError:
        log.exception('COMPLETE: exception')
        raise
    return web.json_response(resp, status=200)
Ejemplo n.º 2
0
async def execute_snippet(request):
    resp = {}
    kern_id = request.match_info['kernel_id']
    try:
        with _timeout(2):
            params = await request.json()
        log.info(f"EXECUTE(u:{request['keypair']['access_key']}, k:{kern_id})")
    except (asyncio.TimeoutError, json.decoder.JSONDecodeError):
        log.warning('EXECUTE: invalid/missing parameters')
        raise InvalidAPIParameters
    try:
        await request.app['registry'].increment_kernel_usage(kern_id)
        api_version = request['api_version']
        if api_version == 1:
            mode = 'query'
            code = params['code']
            opts = {}
        elif api_version == 2:
            mode = params['mode']
            code = params.get('code', '')
            assert mode in ('query', 'batch')
            opts = params.get('options', None) or {}
        resp['result'] = await request.app['registry'].execute_snippet(
            kern_id, api_version, mode, code, opts)
    except Backend.AiError:
        log.exception('EXECUTE_SNIPPET: API Internal Error')
        raise
    return web.Response(status=200, content_type=_json_type,
                        text=json.dumps(resp))
Ejemplo n.º 3
0
async def list_files(request) -> web.Response:
    try:
        access_key = request['keypair']['access_key']
        sess_id = request.match_info['sess_id']
        with _timeout(2):
            params = await request.json(loads=json.loads)
        path = params.get('path', '.')
        log.info('LIST_FILES (u:{0}, token:{1})', access_key, sess_id)
    except (asyncio.TimeoutError, AssertionError,
            json.decoder.JSONDecodeError) as e:
        log.warning('LIST_FILES: invalid/missing parameters, {0!r}', e)
        raise InvalidAPIParameters(extra_msg=str(e.args[0]))
    resp = {}
    try:
        registry = request.app['registry']
        await registry.increment_session_usage(sess_id, access_key)
        result = await registry.list_files(sess_id, access_key, path)
        resp.update(result)
        log.debug('container file list for {0} retrieved', path)
    except BackendError:
        log.exception('LIST_FILES: exception')
        raise
    except Exception:
        request.app['error_monitor'].capture_exception()
        log.exception('LIST_FILES: unexpected error!')
        raise InternalServerError
    return web.json_response(resp, status=200)
Ejemplo n.º 4
0
async def RPCContext(addr, timeout=10):
    preserved_exceptions = (
        NotFoundError,
        ParametersError,
        asyncio.TimeoutError,
        asyncio.CancelledError,
        asyncio.InvalidStateError,
    )
    server = None
    try:
        server = await aiozmq.rpc.connect_rpc(
            connect=addr,
            error_table={
                'concurrent.futures._base.TimeoutError': asyncio.TimeoutError,
            })
        server.transport.setsockopt(zmq.LINGER, 50)
        with _timeout(timeout):
            yield server
    except:
        exc_type, exc, tb = sys.exc_info()
        if issubclass(exc_type, GenericError):
            e = AgentError(exc.args[0], exc.args[1])
            raise e.with_traceback(tb)
        elif issubclass(exc_type, preserved_exceptions):
            raise
        else:
            e = AgentError(exc_type, exc.args)
            raise e.with_traceback(tb)
    finally:
        if server:
            server.close()
Ejemplo n.º 5
0
async def download_files(request) -> web.Response:
    try:
        registry = request.app['registry']
        sess_id = request.match_info['sess_id']
        access_key = request['keypair']['access_key']
        with _timeout(2):
            params = await request.json(loads=json.loads)
        assert params.get('files'), 'no file(s) specified!'
        files = params.get('files')
        log.info('DOWNLOAD_FILE (u:{0}, token:{1})', access_key, sess_id)
    except (asyncio.TimeoutError, AssertionError,
            json.decoder.JSONDecodeError) as e:
        log.warning('DOWNLOAD_FILE: invalid/missing parameters, {0!r}', e)
        raise InvalidAPIParameters(extra_msg=str(e.args[0]))

    try:
        assert len(files) <= 5, 'Too many files'
        await registry.increment_session_usage(sess_id, access_key)
        # TODO: Read all download file contents. Need to fix by using chuncking, etc.
        results = await asyncio.gather(*map(
            functools.partial(registry.download_file, sess_id, access_key),
            files))
        log.debug('file(s) inside container retrieved')
    except BackendError:
        log.exception('DOWNLOAD_FILE: exception')
        raise
    except Exception:
        request.app['error_monitor'].capture_exception()
        log.exception('DOWNLOAD_FILE: unexpected error!')
        raise InternalServerError

    with aiohttp.MultipartWriter('mixed') as mpwriter:
        for tarbytes in results:
            mpwriter.append(tarbytes)
        return web.Response(body=mpwriter, status=200)
Ejemplo n.º 6
0
    async def asend(self, *, sess=None, timeout=10.0):
        '''
        Sends the request to the server.

        This method is a coroutine.
        '''
        assert self.method in self._allowed_methods
        if sess is None:
            sess = aiohttp.ClientSession()
        else:
            assert isinstance(sess, aiohttp.ClientSession)
        with sess:
            if self.content_type == 'multipart/form-data':
                with aiohttp.MultipartWriter('mixed') as mpwriter:
                    for file in self._content:
                        part = mpwriter.append(file.file)
                        part.set_content_disposition('attachment',
                                                     filename=file.filename)
                data = mpwriter
            else:
                data = self._content
            self._sign()
            reqfunc = getattr(sess, self.method.lower())
            try:
                with _timeout(timeout):
                    resp = await reqfunc(self.build_url(),
                                         data=data,
                                         headers=self.headers)
                    async with resp:
                        body = await resp.read()
                        return Response(resp.status, resp.reason, body,
                                        resp.content_type, len(body))
            except Exception as e:
                msg = 'Request to the API endpoint has failed.'
                raise BackendClientError(msg) from e
Ejemplo n.º 7
0
 async def __aenter__(self):
     self.server = await aiozmq.rpc.connect_rpc(
         connect=self.addr, error_table={
             'concurrent.futures._base.TimeoutError': asyncio.TimeoutError,
         })
     self.server.transport.setsockopt(zmq.LINGER, 50)
     self.call = self.server.call
     self.t = _timeout(self.timeout)
     self.t.__enter__()
     return self
Ejemplo n.º 8
0
async def get_watcher_status(request: web.Request,
                             params: Any) -> web.Response:
    log.info('GET_WATCHER_STATUS ()')
    watcher_info = await get_watcher_info(request, params['agent_id'])
    connector = aiohttp.TCPConnector()
    async with aiohttp.ClientSession(connector=connector) as sess:
        with _timeout(5.0):
            headers = {'X-BackendAI-Watcher-Token': watcher_info['token']}
            async with sess.get(watcher_info['addr'], headers=headers) as resp:
                if resp.status == 200:
                    data = await resp.json()
                    return web.json_response(data, status=resp.status)
                else:
                    data = await resp.text()
                    return web.Response(text=data, status=resp.status)
Ejemplo n.º 9
0
async def watcher_agent_restart(request: web.Request,
                                params: Any) -> web.Response:
    log.info('WATCHER_AGENT_RESTART ()')
    watcher_info = await get_watcher_info(request, params['agent_id'])
    connector = aiohttp.TCPConnector()
    async with aiohttp.ClientSession(connector=connector) as sess:
        with _timeout(20.0):
            watcher_url = watcher_info['addr'] / 'agent/restart'
            headers = {'X-BackendAI-Watcher-Token': watcher_info['token']}
            async with sess.post(watcher_url, headers=headers) as resp:
                if resp.status == 200:
                    data = await resp.json()
                    return web.json_response(data, status=resp.status)
                else:
                    data = await resp.text()
                    return web.Response(text=data, status=resp.status)
Ejemplo n.º 10
0
async def create(request):
    try:
        with _timeout(2):
            params = await request.json(loads=json.loads)
        log.info(f"GET_OR_CREATE (u:{request['keypair']['access_key']}, "
                 f"lang:{params['lang']}, token:{params['clientSessionToken']})")
        assert 8 <= len(params['clientSessionToken']) <= 80
    except (asyncio.TimeoutError, AssertionError,
            KeyError, json.decoder.JSONDecodeError) as e:
        log.warning(f'GET_OR_CREATE: invalid/missing parameters, {e!r}')
        raise InvalidAPIParameters
    resp = {}
    try:
        access_key = request['keypair']['access_key']
        concurrency_limit = request['keypair']['concurrency_limit']
        async with request.app['dbpool'].acquire() as conn, conn.begin():
            query = (sa.select([keypairs.c.concurrency_used], for_update=True)
                       .select_from(keypairs)
                       .where(keypairs.c.access_key == access_key))
            concurrency_used = await conn.scalar(query)
            log.debug(f'access_key: {access_key} '
                      f'({concurrency_used} / {concurrency_limit})')
            if concurrency_used >= concurrency_limit:
                raise QuotaExceeded
            if request['api_version'] == 1:
                limits = params.get('resourceLimits', None)
                mounts = None
            elif request['api_version'] in (2, 3):
                limits = params.get('limits', None)
                mounts = params.get('mounts', None)
            kernel, created = await request.app['registry'].get_or_create_kernel(
                params['clientSessionToken'],
                params['lang'], access_key,
                limits, mounts,
                conn=conn)
            resp['kernelId'] = str(kernel['sess_id'])
            if created:
                query = (sa.update(keypairs)
                           .values(concurrency_used=keypairs.c.concurrency_used + 1)
                           .where(keypairs.c.access_key == access_key))
                await conn.execute(query)
    except BackendError:
        log.exception('GET_OR_CREATE: exception')
        raise
    return web.json_response(resp, status=201, dumps=json.dumps)
Ejemplo n.º 11
0
async def watcher_agent_stop(request: web.Request,
                             params: Any) -> web.Response:
    access_key = request['keypair']['access_key']
    user_uuid = request['user']['uuid']
    log.info('WATCHER_AGENT_STOP (u:[{0}], ak:[{1}])', user_uuid, access_key)
    watcher_info = await get_watcher_info(request, params['agent_id'])
    connector = aiohttp.TCPConnector()
    async with aiohttp.ClientSession(connector=connector) as sess:
        with _timeout(20.0):
            watcher_url = watcher_info['addr'] / 'agent/stop'
            headers = {'X-BackendAI-Watcher-Token': watcher_info['token']}
            async with sess.post(watcher_url, headers=headers) as resp:
                if resp.status == 200:
                    data = await resp.json()
                    return web.json_response(data, status=resp.status)
                else:
                    data = await resp.text()
                    return web.Response(text=data, status=resp.status)
Ejemplo n.º 12
0
async def create(request):
    try:
        with _timeout(2):
            params = await request.json()
        log.info(f"GET_OR_CREATE (u:{request['keypair']['access_key']}, "
                 f"lang:{params['lang']}, token:{params['clientSessionToken']})")
        assert 8 <= len(params['clientSessionToken']) <= 80
    except (asyncio.TimeoutError, AssertionError,
            KeyError, json.decoder.JSONDecodeError) as e:
        log.warning(f'GET_OR_CREATE: invalid/missing parameters, {e!r}')
        raise InvalidAPIParameters
    resp = {}
    try:
        access_key = request['keypair']['access_key']
        concurrency_limit = request['keypair']['concurrency_limit']
        async with request.app.dbpool.acquire() as conn, conn.transaction():
            query = (sa.select([KeyPair.c.concurrency_used], for_update=True)
                       .select_from(KeyPair)
                       .where(KeyPair.c.access_key == access_key))
            concurrency_used = await conn.fetchval(query)
            log.debug(f'access_key: {access_key} ({concurrency_used} / {concurrency_limit})')
            if concurrency_used >= concurrency_limit:
                raise QuotaExceeded
            if request['api_version'] == 1:
                limits = params.get('resourceLimits', None)
                mounts = None
            elif request['api_version'] == 2:
                limits = params.get('limits', None)
                mounts = params.get('mounts', None)
            kernel, created = await request.app['registry'].get_or_create_kernel(
                params['clientSessionToken'],
                params['lang'], access_key,
                limits, mounts)
            resp['kernelId'] = kernel.id
            if created:
                query = (sa.update(KeyPair)
                           .values(concurrency_used=KeyPair.c.concurrency_used + 1)
                           .where(KeyPair.c.access_key == access_key))
                await conn.execute(query)
    except Backend.AiError:
        log.exception('GET_OR_CREATE: API Internal Error')
        raise
    return web.Response(status=201, content_type=_json_type,
                        text=json.dumps(resp))
Ejemplo n.º 13
0
async def RPCContext(addr, timeout=None):
    preserved_exceptions = (
        NotFoundError,
        ParametersError,
        asyncio.TimeoutError,
        asyncio.CancelledError,
        asyncio.InvalidStateError,
    )
    global agent_peers
    peer = agent_peers.get(addr, None)
    if peer is None:
        peer = await aiozmq.rpc.connect_rpc(
            connect=addr,
            error_table={
                'concurrent.futures._base.TimeoutError': asyncio.TimeoutError,
            })
        peer.transport.setsockopt(zmq.LINGER, 1000)
        agent_peers[addr] = peer
    try:
        with _timeout(timeout):
            yield peer
    except Exception:
        exc_type, exc, tb = sys.exc_info()
        if issubclass(exc_type, GenericError):
            e = AgentError(exc.args[0], exc.args[1])
            raise e.with_traceback(tb)
        elif issubclass(exc_type, TypeError):
            if exc.args[0] == "'NoneType' object is not iterable":
                log.warning('The agent has cancelled the operation '
                            'or the kernel has terminated too quickly.')
                # In this case, you may need to use "--debug-skip-container-deletion"
                # CLI option in the agent and check out the container logs via
                # "docker logs" command to see what actually happened.
            else:
                e = AgentError(exc_type, exc.args)
                raise e.with_traceback(tb)
        elif issubclass(exc_type, preserved_exceptions):
            raise
        else:
            e = AgentError(exc_type, exc.args)
            raise e.with_traceback(tb)
Ejemplo n.º 14
0
async def execute(request):
    resp = {}
    sess_id = request.match_info['sess_id']
    try:
        with _timeout(2):
            params = await request.json(loads=json.loads)
        log.info(f"EXECUTE(u:{request['keypair']['access_key']}, k:{sess_id})")
    except (asyncio.TimeoutError, json.decoder.JSONDecodeError):
        log.warning('EXECUTE: invalid/missing parameters')
        raise InvalidAPIParameters
    try:
        await request.app['registry'].increment_session_usage(sess_id)
        api_version = request['api_version']
        if api_version == 1:
            mode = 'query'
            code = params['code']
            run_id = params.get('runId', secrets.token_hex(8))
            opts = {}
        elif api_version in (2, 3):
            mode = params['mode']
            code = params.get('code', '')
            run_id = params.get('runId', secrets.token_hex(8))
            mode = params['mode']
            assert mode in ('query', 'batch', 'complete')
            opts = params.get('options', None) or {}
        if mode == 'complete':
            # For legacy
            resp['result'] = await request.app['registry'].get_completions(
                sess_id, 'query', code, opts)
        else:
            resp['result'] = await request.app['registry'].execute(
                sess_id, api_version, run_id, mode, code, opts)
    except AssertionError:
        log.warning('EXECUTE: invalid/missing parameters')
        raise InvalidAPIParameters
    except BackendError:
        log.exception('EXECUTE: exception')
        raise
    return web.json_response(resp, status=200, dumps=json.dumps)
Ejemplo n.º 15
0
    async def _query(self, path, method='GET', params=None, timeout=None,
                     data=None, headers=None, **kwargs):
        '''
        Get the response object by performing the HTTP request.
        The caller is responsible to finalize the response object.
        '''
        url = self._endpoint(path)
        try:
            with _timeout(timeout):
                response = await self.session.request(
                    method, url,
                    params=httpize(params), headers=headers,
                    data=data, **kwargs)
        except asyncio.TimeoutError:
            raise

        if (response.status // 100) in [4, 5]:
            what = await response.read()
            response.close()
            raise DockerError(response.status, json.loads(what.decode('utf8')))

        return response
Ejemplo n.º 16
0
async def complete(request):
    resp = {}
    sess_id = request.match_info['sess_id']
    try:
        with _timeout(2):
            params = await request.json(loads=json.loads)
        log.info(f"COMPLETE(u:{request['keypair']['access_key']}, k:{sess_id})")
    except (asyncio.TimeoutError, json.decoder.JSONDecodeError):
        log.warning('COMPLETE: invalid/missing parameters')
        raise InvalidAPIParameters
    try:
        code = params.get('code', '')
        opts = params.get('options', None) or {}
        await request.app['registry'].increment_session_usage(sess_id)
        resp['result'] = await request.app['registry'].get_completions(
            sess_id, 'query', code, opts)
    except AssertionError:
        log.warning('COMPLETE: invalid/missing parameters')
        raise InvalidAPIParameters
    except BackendError:
        log.exception('COMPLETE: exception')
        raise
    return web.json_response(resp, status=200, dumps=json.dumps)
Ejemplo n.º 17
0
 async def inner_timeout():
     async with _timeout(timeout):
         return await coroutine
     return None
Ejemplo n.º 18
0
async def create(request):
    try:
        with _timeout(2):
            params = await request.json(loads=json.loads)
        log.info(
            f"GET_OR_CREATE (u:{request['keypair']['access_key']}, "
            f"lang:{params['lang']}, token:{params['clientSessionToken']})")
        assert 8 <= len(params['clientSessionToken']) <= 64, \
               'clientSessionToken is too short or long (8 to 64 bytes required)!'
    except (asyncio.TimeoutError, AssertionError, KeyError,
            json.decoder.JSONDecodeError) as e:
        log.warning(f'GET_OR_CREATE: invalid/missing parameters, {e!r}')
        raise InvalidAPIParameters
    resp = {}
    try:
        access_key = request['keypair']['access_key']
        concurrency_limit = request['keypair']['concurrency_limit']
        async with request.app['dbpool'].acquire() as conn, conn.begin():
            query = (sa.select([keypairs.c.concurrency_used],
                               for_update=True).select_from(keypairs).where(
                                   keypairs.c.access_key == access_key))
            concurrency_used = await conn.scalar(query)
            log.debug(f'access_key: {access_key} '
                      f'({concurrency_used} / {concurrency_limit})')
            if concurrency_used >= concurrency_limit:
                raise QuotaExceeded
            creation_config = {
                'mounts': None,
                'environ': None,
                'clusterSize': None,
                'instanceMemory': None,
                'instanceCores': None,
                'instanceGPUs': None,
            }
            if request['api_version'] == 1:
                # custom resource limit unsupported
                pass
            elif request['api_version'] in (2, 3):
                creation_config.update(params.get('config', {}))
            # sanity check for vfolders
            if creation_config['mounts']:
                mount_details = []
                for mount in creation_config['mounts']:
                    query = (sa.select('*').select_from(
                        vfolders).where((vfolders.c.belongs_to == access_key)
                                        & (vfolders.c.name == mount)))
                    result = await conn.execute(query)
                    row = await result.first()
                    if row is None:
                        raise FolderNotFound
                    else:
                        mount_details.append([row.name, row.host, row.id.hex])
                creation_config['mounts'] = mount_details
            kernel, created = await request.app[
                'registry'].get_or_create_kernel(params['clientSessionToken'],
                                                 params['lang'],
                                                 access_key,
                                                 creation_config,
                                                 conn=conn)
            resp['kernelId'] = str(kernel['sess_id'])
            if created:
                query = (sa.update(keypairs).values(
                    concurrency_used=keypairs.c.concurrency_used +
                    1).where(keypairs.c.access_key == access_key))
                await conn.execute(query)
    except BackendError:
        log.exception('GET_OR_CREATE: exception')
        raise
    return web.json_response(resp, status=201, dumps=json.dumps)