Esempio n. 1
0
    # The hook handlers should accept a tuple of the user email,
    # the new user's UUID, and a dict with initial user's preferences.
    initial_user_prefs = {
        'lang':
        request.headers.get('Accept-Language', 'en-us').split(',')[0].lower(),
    }
    await request.app['hook_plugin_ctx'].notify(
        'POST_SIGNUP', (params['email'], user.uuid, initial_user_prefs))
    return web.json_response(resp_data, status=201)


@atomic
@auth_required
@check_api_params(
    t.Dict({
        tx.AliasedKey(['email', 'username']): t.String,
        t.Key('password'): t.String,
    }))
async def signout(request: web.Request, params: Any) -> web.Response:
    domain_name = request['user']['domain_name']
    log.info('AUTH.SIGNOUT(d:{}, email:{})', domain_name, params['email'])
    dbpool = request.app['dbpool']
    if request['user']['email'] != params['email']:
        raise GenericForbidden('Not the account owner')
    result = await check_credential(dbpool, domain_name, params['email'],
                                    params['password'])
    if result is None:
        raise GenericBadRequest('Invalid email and/or password')
    async with dbpool.acquire() as conn, conn.begin():
        # Inactivate the user.
        query = (users.update().values(is_active=False).where(
Esempio n. 2
0
from .utils import check_api_params
from ..manager.models.base import DataLoaderManager
from ..manager.models.gql import Mutations, Queries

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.admin'))


@atomic
@auth_required
@check_api_params(
    t.Dict({
        t.Key('query'):
        t.String,
        t.Key('variables', default=None):
        t.Null | t.Mapping(t.String, t.Any),
        tx.AliasedKey(['operation_name', 'operationName'], default=None):
        t.Null | t.String,
    }))
async def handle_gql(request: web.Request, params: Any) -> web.Response:
    executor = request.app['admin.gql_executor']
    schema = request.app['admin.gql_schema']
    manager_status = await request.app['config_server'].get_manager_status()
    known_slot_types = await request.app['config_server'].get_resource_slots()
    context = {
        'config': request.app['config'],
        'config_server': request.app['config_server'],
        'etcd': request.app['config_server'].etcd,
        'user': request['user'],
        'access_key': request['keypair']['access_key'],
        'dbpool': request.app['dbpool'],
        'redis_stat': request.app['redis_stat'],
Esempio n. 3
0
                                                msg['agent_id'], msg['args'])

        while True:
            try:
                await redis.execute_with_retries(lambda: _subscribe_impl())
            except asyncio.CancelledError:
                break
            except Exception:
                log.exception('EventDispatcher.subscribe(): unexpected-error')


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
    t.Dict({
        tx.AliasedKey(['name', 'sessionName'], default='*') >> 'session_name':
        t.String,
        t.Key('ownerAccessKey', default=None) >> 'owner_access_key':
        t.Null | t.String,
        t.Key('sessionId', default=None) >> 'session_id':
        t.Null | tx.UUID,
        # NOTE: if set, sessionId overrides sessionName and ownerAccessKey parameters.
        tx.AliasedKey(['group', 'groupName'], default='*') >> 'group_name':
        t.String,
        t.Key('scope', default='*'):
        t.Enum('*', 'session', 'kernel'),
    }))
@adefer
async def push_session_events(
    defer,
    request: web.Request,
Esempio n. 4
0
        if not ws.closed:
            await ws.send_json({
                'status': 'server-restarting',
                'msg': 'The API server is going to restart for maintenance. '
                       'Please connect again with the same run ID.',
            })
        raise
    finally:
        return ws


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
    t.Dict({
        tx.AliasedKey(['app', 'service']): t.String,
        # The port argument is only required to use secondary ports
        # when the target app listens multiple TCP ports.
        # Otherwise it should be omitted or set to the same value of
        # the actual port number used by the app.
        tx.AliasedKey(['port'], default=None): t.Null | t.Int[1024:65535],
        tx.AliasedKey(['envs'], default=None): t.Null | t.String,  # stringified JSON
                                                                   # e.g., '{"PASSWORD": "******"}'
        tx.AliasedKey(['arguments'], default=None): t.Null | t.String,  # stringified JSON
                                                                        # e.g., '{"-P": "12345"}'
                                                                        # The value can be one of:
                                                                        # None, str, List[str]
    }))
@adefer
async def stream_proxy(defer, request: web.Request, params: Mapping[str, Any]) -> web.StreamResponse:
    registry: AgentRegistry = request.app['registry']
from ai.backend.common.logging import BraceStyleAdapter

from .auth import auth_required
from .manager import (READ_ALLOWED, server_status_required)
from .utils import check_api_params
from ..manager.models import (
    query_allowed_sgroups, )

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.scaling_group'))


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
    t.Dict({
        tx.AliasedKey(['group', 'group_id', 'group_name']):
        tx.UUID | t.String,
    }), )
async def list_available_sgroups(request: web.Request,
                                 params: Any) -> web.Response:
    dbpool = request.app['dbpool']
    access_key = request['keypair']['access_key']
    domain_name = request['user']['domain_name']
    group_id_or_name = params['group']
    log.info('SGROUPS.LIST(u:{}, g:{}, d:{})', access_key, group_id_or_name,
             domain_name)
    async with dbpool.acquire() as conn:
        sgroups = await query_allowed_sgroups(conn, domain_name,
                                              group_id_or_name, access_key)
        return web.json_response(
            {
Esempio n. 6
0
    groups,
    association_groups_users as agus,
    query_group_dotfiles,
    query_group_domain,
    verify_dotfile_name,
    MAXIMUM_DOTFILE_SIZE,
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.dotfile'))


@server_status_required(READ_ALLOWED)
@admin_required
@check_api_params(
    t.Dict({
        tx.AliasedKey(['group', 'groupId', 'group_id']): tx.UUID | t.String,
        t.Key('domain', default=None): t.String | t.Null,
        t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE),
        t.Key('path'): t.String,
        t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII),
    }))
async def create(request: web.Request, params: Any) -> web.Response:
    log.info('CREATE DOTFILE (group: {0})', params['group'])

    dbpool = request.app['dbpool']
    group_id_or_name = params['group']
    async with dbpool.acquire() as conn, conn.begin():
        if isinstance(group_id_or_name, str):
            if params['domain'] is None:
                raise InvalidAPIParameters('Missing parameter \'domain\'')
Esempio n. 7
0
    watcher_port = await request.app['registry'].config_server.get(
        f'nodes/agents/{agent_id}/watcher_port')
    if watcher_port is None:
        watcher_port = 6009
    # TODO: watcher scheme is assumed to be http
    addr = yarl.URL(f'http://{agent_ip}:{watcher_port}')
    return {
        'addr': addr,
        'token': token,
    }


@server_status_required(READ_ALLOWED)
@superadmin_required
@check_api_params(t.Dict({
    tx.AliasedKey(['agent_id', 'agent']): t.String,
}))
async def get_watcher_status(request: web.Request,
                             params: Any) -> web.Response:
    log.info('GET_WATCHER_STATUS ()')
    watcher_info = await get_watcher_info(request, params['agent_id'])
    connector = aiohttp.TCPConnector()
    async with aiohttp.ClientSession(connector=connector) as sess:
        with _timeout(5.0):
            headers = {'X-BackendAI-Watcher-Token': watcher_info['token']}
            async with sess.get(watcher_info['addr'], headers=headers) as resp:
                if resp.status == 200:
                    data = await resp.json()
                    return web.json_response(data, status=resp.status)
                else:
                    data = await resp.text()
Esempio n. 8
0
                'status':
                'server-restarting',
                'msg':
                'The API server is going to restart for maintenance. '
                'Please connect again with the same run ID.',
            })
        raise
    finally:
        return ws


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
    t.Dict({
        tx.AliasedKey(['app', 'service']):
        t.String,
        # The port argument is only required to use secondary ports
        # when the target app listens multiple TCP ports.
        # Otherwise it should be omitted or set to the same value of
        # the actual port number used by the app.
        tx.AliasedKey(['port'], default=None):
        t.Null | t.Int[1024:65535],
        tx.AliasedKey(['envs'], default=None):
        t.Null | t.String,  # stringified JSON
        # e.g., '{"PASSWORD": "******"}'
        tx.AliasedKey(['arguments'], default=None):
        t.Null | t.String,  # stringified JSON
        # e.g., '{"-P": "12345"}'
        # The value can be one of:
        # None, str, List[str]
Esempio n. 9
0
    sa.Column('group_id', GUID, sa.ForeignKey('groups.id'), nullable=True),
    sa.Column('user_uuid',
              GUID,
              sa.ForeignKey('users.uuid'),
              index=True,
              nullable=False),
    sa.Column('type',
              EnumType(TemplateType),
              nullable=False,
              server_default='TASK',
              index=True),
    sa.Column('name', sa.String(length=128), nullable=True),
    sa.Column('template', pgsql.JSONB(), nullable=False))

task_template_v1 = t.Dict({
    tx.AliasedKey(['api_version', 'apiVersion']):
    t.String,
    t.Key('kind'):
    t.Enum('taskTemplate', 'task_template'),
    t.Key('metadata'):
    t.Dict({
        t.Key('name'): t.String,
        t.Key('tag', default=None): t.Null | t.String,
    }),
    t.Key('spec'):
    t.Dict({
        tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type':
        tx.Enum(SessionTypes),
        t.Key('kernel'):
        t.Dict({
            t.Key('image'):
Esempio n. 10
0
from ai.backend.common import validators as tx

from .stats import StatModes
from .types import AgentBackend

coredump_defaults = {
    'enabled': False,
    'path': './coredumps',
    'backup-count': 10,
    'size-limit': '64M',
}

agent_local_config_iv = t.Dict({
    t.Key('agent'):
    t.Dict({
        tx.AliasedKey(['backend', 'mode']):
        tx.Enum(AgentBackend),
        t.Key('rpc-listen-addr', default=('', 6001)):
        tx.HostPortPair(allow_blank_host=True),
        t.Key('agent-sock-port', default=6007):
        t.Int[1024:65535],
        t.Key('id', default=None):
        t.Null | t.String,
        t.Key('region', default=None):
        t.Null | t.String,
        t.Key('instance-type', default=None):
        t.Null | t.String,
        t.Key('scaling-group', default='default'):
        t.String,
        t.Key('pid-file', default=os.devnull):
        tx.Path(type='file', allow_nonexisting=True, allow_devnull=True),
Esempio n. 11
0
from .auth import auth_required
from .exceptions import InvalidAPIParameters, TaskTemplateNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params, get_access_key_scopes

from ..manager.models import (association_groups_users as agus, domains,
                              groups, session_templates, keypairs, users,
                              UserRole, query_accessible_session_templates,
                              TemplateType, verify_vfolder_name)

log = BraceStyleAdapter(
    logging.getLogger('ai.backend.gateway.session_template'))

task_template_v1 = t.Dict({
    tx.AliasedKey(['api_version', 'apiVersion']):
    t.String,
    t.Key('kind'):
    t.Enum('taskTemplate', 'task_template'),
    t.Key('metadata'):
    t.Dict({
        t.Key('name'): t.String,
        t.Key('tag', default=None): t.Null | t.String,
    }),
    t.Key('spec'):
    t.Dict({
        tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type':
        tx.Enum(SessionTypes),
        t.Key('kernel'):
        t.Dict({
            t.Key('image'):
Esempio n. 12
0
    t.Null | t.Int[1:],
    t.Key('instanceMemory', default=None):
    t.Null | tx.BinarySize,
    t.Key('instanceCores', default=None):
    t.Null | t.Int,
    t.Key('instanceGPUs', default=None):
    t.Null | t.Float,
    t.Key('instanceTPUs', default=None):
    t.Null | t.Int,
})
creation_config_v3 = t.Dict({
    t.Key('mounts', default=None):
    t.Null | t.List(t.String),
    t.Key('environ', default=None):
    t.Null | t.Mapping(t.String, t.String),
    tx.AliasedKey(['clusterSize', 'cluster_size'], default=None):
    t.Null | t.Int[1:],
    tx.AliasedKey(['scalingGroup', 'scaling_group'], default=None):
    t.Null | t.String,
    t.Key('resources', default=None):
    t.Null | t.Mapping(t.String, t.Any),
    t.Key('resource_opts', default=None):
    t.Null | t.Mapping(t.String, t.Any),
})


@server_status_required(ALL_ALLOWED)
@auth_required
@check_api_params(t.Dict({
    t.Key('clientSessionToken') >> 'sess_id':
    t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII),
from ..manager.models import (
    association_groups_users as agus, domains,
    groups, session_templates, keypairs, users, UserRole,
    query_accessible_session_templates, TemplateType
)
from ..manager.models.session_template import check_cluster_template

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.cluster_template'))


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(t.Dict(
    {
        tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String,
        tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String,
        t.Key('owner_access_key', default=None): t.Null | t.String,
        t.Key('payload'): t.String
    }
))
async def create(request: web.Request, params: Any) -> web.Response:
    if params['domain'] is None:
        params['domain'] = request['user']['domain_name']
    requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
    requester_uuid = request['user']['uuid']
    log.info('CREATE (ak:{0}/{1})',
             requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*')
    user_uuid = request['user']['uuid']

    dbpool = request.app['dbpool']