# The hook handlers should accept a tuple of the user email, # the new user's UUID, and a dict with initial user's preferences. initial_user_prefs = { 'lang': request.headers.get('Accept-Language', 'en-us').split(',')[0].lower(), } await request.app['hook_plugin_ctx'].notify( 'POST_SIGNUP', (params['email'], user.uuid, initial_user_prefs)) return web.json_response(resp_data, status=201) @atomic @auth_required @check_api_params( t.Dict({ tx.AliasedKey(['email', 'username']): t.String, t.Key('password'): t.String, })) async def signout(request: web.Request, params: Any) -> web.Response: domain_name = request['user']['domain_name'] log.info('AUTH.SIGNOUT(d:{}, email:{})', domain_name, params['email']) dbpool = request.app['dbpool'] if request['user']['email'] != params['email']: raise GenericForbidden('Not the account owner') result = await check_credential(dbpool, domain_name, params['email'], params['password']) if result is None: raise GenericBadRequest('Invalid email and/or password') async with dbpool.acquire() as conn, conn.begin(): # Inactivate the user. query = (users.update().values(is_active=False).where(
from .utils import check_api_params from ..manager.models.base import DataLoaderManager from ..manager.models.gql import Mutations, Queries log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.admin')) @atomic @auth_required @check_api_params( t.Dict({ t.Key('query'): t.String, t.Key('variables', default=None): t.Null | t.Mapping(t.String, t.Any), tx.AliasedKey(['operation_name', 'operationName'], default=None): t.Null | t.String, })) async def handle_gql(request: web.Request, params: Any) -> web.Response: executor = request.app['admin.gql_executor'] schema = request.app['admin.gql_schema'] manager_status = await request.app['config_server'].get_manager_status() known_slot_types = await request.app['config_server'].get_resource_slots() context = { 'config': request.app['config'], 'config_server': request.app['config_server'], 'etcd': request.app['config_server'].etcd, 'user': request['user'], 'access_key': request['keypair']['access_key'], 'dbpool': request.app['dbpool'], 'redis_stat': request.app['redis_stat'],
msg['agent_id'], msg['args']) while True: try: await redis.execute_with_retries(lambda: _subscribe_impl()) except asyncio.CancelledError: break except Exception: log.exception('EventDispatcher.subscribe(): unexpected-error') @server_status_required(READ_ALLOWED) @auth_required @check_api_params( t.Dict({ tx.AliasedKey(['name', 'sessionName'], default='*') >> 'session_name': t.String, t.Key('ownerAccessKey', default=None) >> 'owner_access_key': t.Null | t.String, t.Key('sessionId', default=None) >> 'session_id': t.Null | tx.UUID, # NOTE: if set, sessionId overrides sessionName and ownerAccessKey parameters. tx.AliasedKey(['group', 'groupName'], default='*') >> 'group_name': t.String, t.Key('scope', default='*'): t.Enum('*', 'session', 'kernel'), })) @adefer async def push_session_events( defer, request: web.Request,
if not ws.closed: await ws.send_json({ 'status': 'server-restarting', 'msg': 'The API server is going to restart for maintenance. ' 'Please connect again with the same run ID.', }) raise finally: return ws @server_status_required(READ_ALLOWED) @auth_required @check_api_params( t.Dict({ tx.AliasedKey(['app', 'service']): t.String, # The port argument is only required to use secondary ports # when the target app listens multiple TCP ports. # Otherwise it should be omitted or set to the same value of # the actual port number used by the app. tx.AliasedKey(['port'], default=None): t.Null | t.Int[1024:65535], tx.AliasedKey(['envs'], default=None): t.Null | t.String, # stringified JSON # e.g., '{"PASSWORD": "******"}' tx.AliasedKey(['arguments'], default=None): t.Null | t.String, # stringified JSON # e.g., '{"-P": "12345"}' # The value can be one of: # None, str, List[str] })) @adefer async def stream_proxy(defer, request: web.Request, params: Mapping[str, Any]) -> web.StreamResponse: registry: AgentRegistry = request.app['registry']
from ai.backend.common.logging import BraceStyleAdapter from .auth import auth_required from .manager import (READ_ALLOWED, server_status_required) from .utils import check_api_params from ..manager.models import ( query_allowed_sgroups, ) log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.scaling_group')) @auth_required @server_status_required(READ_ALLOWED) @check_api_params( t.Dict({ tx.AliasedKey(['group', 'group_id', 'group_name']): tx.UUID | t.String, }), ) async def list_available_sgroups(request: web.Request, params: Any) -> web.Response: dbpool = request.app['dbpool'] access_key = request['keypair']['access_key'] domain_name = request['user']['domain_name'] group_id_or_name = params['group'] log.info('SGROUPS.LIST(u:{}, g:{}, d:{})', access_key, group_id_or_name, domain_name) async with dbpool.acquire() as conn: sgroups = await query_allowed_sgroups(conn, domain_name, group_id_or_name, access_key) return web.json_response( {
groups, association_groups_users as agus, query_group_dotfiles, query_group_domain, verify_dotfile_name, MAXIMUM_DOTFILE_SIZE, ) log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.dotfile')) @server_status_required(READ_ALLOWED) @admin_required @check_api_params( t.Dict({ tx.AliasedKey(['group', 'groupId', 'group_id']): tx.UUID | t.String, t.Key('domain', default=None): t.String | t.Null, t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), t.Key('path'): t.String, t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), })) async def create(request: web.Request, params: Any) -> web.Response: log.info('CREATE DOTFILE (group: {0})', params['group']) dbpool = request.app['dbpool'] group_id_or_name = params['group'] async with dbpool.acquire() as conn, conn.begin(): if isinstance(group_id_or_name, str): if params['domain'] is None: raise InvalidAPIParameters('Missing parameter \'domain\'')
watcher_port = await request.app['registry'].config_server.get( f'nodes/agents/{agent_id}/watcher_port') if watcher_port is None: watcher_port = 6009 # TODO: watcher scheme is assumed to be http addr = yarl.URL(f'http://{agent_ip}:{watcher_port}') return { 'addr': addr, 'token': token, } @server_status_required(READ_ALLOWED) @superadmin_required @check_api_params(t.Dict({ tx.AliasedKey(['agent_id', 'agent']): t.String, })) async def get_watcher_status(request: web.Request, params: Any) -> web.Response: log.info('GET_WATCHER_STATUS ()') watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: with _timeout(5.0): headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} async with sess.get(watcher_info['addr'], headers=headers) as resp: if resp.status == 200: data = await resp.json() return web.json_response(data, status=resp.status) else: data = await resp.text()
'status': 'server-restarting', 'msg': 'The API server is going to restart for maintenance. ' 'Please connect again with the same run ID.', }) raise finally: return ws @server_status_required(READ_ALLOWED) @auth_required @check_api_params( t.Dict({ tx.AliasedKey(['app', 'service']): t.String, # The port argument is only required to use secondary ports # when the target app listens multiple TCP ports. # Otherwise it should be omitted or set to the same value of # the actual port number used by the app. tx.AliasedKey(['port'], default=None): t.Null | t.Int[1024:65535], tx.AliasedKey(['envs'], default=None): t.Null | t.String, # stringified JSON # e.g., '{"PASSWORD": "******"}' tx.AliasedKey(['arguments'], default=None): t.Null | t.String, # stringified JSON # e.g., '{"-P": "12345"}' # The value can be one of: # None, str, List[str]
sa.Column('group_id', GUID, sa.ForeignKey('groups.id'), nullable=True), sa.Column('user_uuid', GUID, sa.ForeignKey('users.uuid'), index=True, nullable=False), sa.Column('type', EnumType(TemplateType), nullable=False, server_default='TASK', index=True), sa.Column('name', sa.String(length=128), nullable=True), sa.Column('template', pgsql.JSONB(), nullable=False)) task_template_v1 = t.Dict({ tx.AliasedKey(['api_version', 'apiVersion']): t.String, t.Key('kind'): t.Enum('taskTemplate', 'task_template'), t.Key('metadata'): t.Dict({ t.Key('name'): t.String, t.Key('tag', default=None): t.Null | t.String, }), t.Key('spec'): t.Dict({ tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type': tx.Enum(SessionTypes), t.Key('kernel'): t.Dict({ t.Key('image'):
from ai.backend.common import validators as tx from .stats import StatModes from .types import AgentBackend coredump_defaults = { 'enabled': False, 'path': './coredumps', 'backup-count': 10, 'size-limit': '64M', } agent_local_config_iv = t.Dict({ t.Key('agent'): t.Dict({ tx.AliasedKey(['backend', 'mode']): tx.Enum(AgentBackend), t.Key('rpc-listen-addr', default=('', 6001)): tx.HostPortPair(allow_blank_host=True), t.Key('agent-sock-port', default=6007): t.Int[1024:65535], t.Key('id', default=None): t.Null | t.String, t.Key('region', default=None): t.Null | t.String, t.Key('instance-type', default=None): t.Null | t.String, t.Key('scaling-group', default='default'): t.String, t.Key('pid-file', default=os.devnull): tx.Path(type='file', allow_nonexisting=True, allow_devnull=True),
from .auth import auth_required from .exceptions import InvalidAPIParameters, TaskTemplateNotFound from .manager import READ_ALLOWED, server_status_required from .types import CORSOptions, Iterable, WebMiddleware from .utils import check_api_params, get_access_key_scopes from ..manager.models import (association_groups_users as agus, domains, groups, session_templates, keypairs, users, UserRole, query_accessible_session_templates, TemplateType, verify_vfolder_name) log = BraceStyleAdapter( logging.getLogger('ai.backend.gateway.session_template')) task_template_v1 = t.Dict({ tx.AliasedKey(['api_version', 'apiVersion']): t.String, t.Key('kind'): t.Enum('taskTemplate', 'task_template'), t.Key('metadata'): t.Dict({ t.Key('name'): t.String, t.Key('tag', default=None): t.Null | t.String, }), t.Key('spec'): t.Dict({ tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type': tx.Enum(SessionTypes), t.Key('kernel'): t.Dict({ t.Key('image'):
t.Null | t.Int[1:], t.Key('instanceMemory', default=None): t.Null | tx.BinarySize, t.Key('instanceCores', default=None): t.Null | t.Int, t.Key('instanceGPUs', default=None): t.Null | t.Float, t.Key('instanceTPUs', default=None): t.Null | t.Int, }) creation_config_v3 = t.Dict({ t.Key('mounts', default=None): t.Null | t.List(t.String), t.Key('environ', default=None): t.Null | t.Mapping(t.String, t.String), tx.AliasedKey(['clusterSize', 'cluster_size'], default=None): t.Null | t.Int[1:], tx.AliasedKey(['scalingGroup', 'scaling_group'], default=None): t.Null | t.String, t.Key('resources', default=None): t.Null | t.Mapping(t.String, t.Any), t.Key('resource_opts', default=None): t.Null | t.Mapping(t.String, t.Any), }) @server_status_required(ALL_ALLOWED) @auth_required @check_api_params(t.Dict({ t.Key('clientSessionToken') >> 'sess_id': t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII),
from ..manager.models import ( association_groups_users as agus, domains, groups, session_templates, keypairs, users, UserRole, query_accessible_session_templates, TemplateType ) from ..manager.models.session_template import check_cluster_template log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.cluster_template')) @server_status_required(READ_ALLOWED) @auth_required @check_api_params(t.Dict( { tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String, tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String, t.Key('owner_access_key', default=None): t.Null | t.String, t.Key('payload'): t.String } )) async def create(request: web.Request, params: Any) -> web.Response: if params['domain'] is None: params['domain'] = request['user']['domain_name'] requester_access_key, owner_access_key = await get_access_key_scopes(request, params) requester_uuid = request['user']['uuid'] log.info('CREATE (ak:{0}/{1})', requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*') user_uuid = request['user']['uuid'] dbpool = request.app['dbpool']