from ai.backend.common.types import (
    aobject,
    AgentId,
)
from .auth import auth_required
from .defs import REDIS_STREAM_DB
from .exceptions import GenericNotFound, GenericForbidden, GroupNotFound
from .manager import READ_ALLOWED, server_status_required
from .utils import check_api_params
from ..manager.models import kernels, groups, UserRole
from ..manager.types import BackgroundTaskEventArgs, Sentinel
if TYPE_CHECKING:
    from .types import CORSOptions, WebMiddleware
    from ..gateway.config import LocalConfig, SharedConfig

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.events'))

sentinel: Final = Sentinel.token


class EventCallback(Protocol):
    async def __call__(self, context: Any, agent_id: AgentId, event_name: str,
                       *args) -> None:
        ...


@attr.s(auto_attribs=True, slots=True, frozen=True, eq=False, order=False)
class EventHandler:
    context: Any
    callback: EventCallback
Exemple #2
0
import aiofiles
import aiojobs.aiohttp
from aiohttp import web
import aiotools
import click
from setproctitle import setproctitle
import trafaret as t

from ai.backend.common import config, utils, validators as tx
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
from ai.backend.common.logging import Logger, BraceStyleAdapter
from ai.backend.common.utils import Fstab
from . import __version__ as VERSION

log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.watcher'))

shutdown_enabled = False


@web.middleware
async def auth_middleware(request, handler):
    token = request.headers.get('X-BackendAI-Watcher-Token', None)
    if token == request.app['token']:
        try:
            return (await handler(request))
        except FileNotFoundError as e:
            log.info(repr(e))
            message = 'Agent is not loaded with systemctl.'
            return web.json_response({'message': message}, status=200)
        except Exception as e:
Exemple #3
0
from aiohttp import web
import aiohttp_cors
import trafaret as t

from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter

from .auth import auth_required
from .manager import (READ_ALLOWED, server_status_required)
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params
from ..manager.models import (
    query_allowed_sgroups, )

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.scaling_group'))


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
    t.Dict({
        tx.AliasedKey(['group', 'group_id', 'group_name']):
        tx.UUID | t.String,
    }), )
async def list_available_sgroups(request: web.Request,
                                 params: Any) -> web.Response:
    dbpool = request.app['dbpool']
    access_key = request['keypair']['access_key']
    domain_name = request['user']['domain_name']
    group_id_or_name = params['group']
from ..models import (
    domains,
    groups,
    kernels,
    keypairs,
    keypair_resource_policies,
    query_allowed_sgroups,
    DefaultForUnspecified,
)
from . import (
    SchedulingContext,
    PendingSession,
    PredicateResult,
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.scheduler'))


async def check_reserved_batch_session(
    db_conn: SAConnection,
    sched_ctx: SchedulingContext,
    sess_ctx: PendingSession,
) -> PredicateResult:
    """
    Check if a batch-type session should not be started for a certain amount of time.
    """
    if sess_ctx.session_type == SessionTypes.BATCH:
        query = (sa.select([
            kernels.c.starts_at
        ]).select_from(kernels).where(kernels.c.id == sess_ctx.session_id))
        starts_at = await db_conn.scalar(query)
from .auth import auth_required
from .exceptions import InvalidAPIParameters, TaskTemplateNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params, get_access_key_scopes


from ..manager.models import (
    association_groups_users as agus, domains,
    groups, session_templates, keypairs, users, UserRole,
    query_accessible_session_templates, TemplateType
)
from ..manager.models.session_template import check_cluster_template

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.cluster_template'))


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(t.Dict(
    {
        tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String,
        tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String,
        t.Key('owner_access_key', default=None): t.Null | t.String,
        t.Key('payload'): t.String
    }
))
async def create(request: web.Request, params: Any) -> web.Response:
    if params['domain'] is None:
        params['domain'] = request['user']['domain_name']
Exemple #6
0
import logging
from typing import FrozenSet
import sqlalchemy as sa

from aiohttp import web
import aiohttp_cors
from aiojobs.aiohttp import atomic

from ai.backend.common.logging import BraceStyleAdapter

from . import ManagerStatus
from .auth import admin_required
from .exceptions import InvalidAPIParameters, ServerFrozen, ServiceUnavailable
from ..manager.models import kernels, KernelStatus

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.manager'))


def server_status_required(allowed_status: FrozenSet[ManagerStatus]):
    def decorator(handler):
        @functools.wraps(handler)
        async def wrapped(request, *args, **kwargs):
            status = await request.app['config_server'].get_manager_status()
            if status not in allowed_status:
                if status == ManagerStatus.FROZEN:
                    raise ServerFrozen
                msg = f'Server is not in the required status: {allowed_status}'
                raise ServiceUnavailable(msg)
            return (await handler(request, *args, **kwargs))

        return wrapped
from ai.backend.common.logging import BraceStyleAdapter

from .auth import auth_required
from .exceptions import (
    BackendError,
    AppNotFound,
    KernelNotFound,
    InvalidAPIParameters,
    InternalServerError,
)
from .manager import READ_ALLOWED, server_status_required
from .utils import not_impl_stub, call_non_bursty
from .wsproxy import TCPProxy
from ..manager.models import kernels

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.stream'))


@server_status_required(READ_ALLOWED)
@auth_required
async def stream_pty(request) -> web.StreamResponse:
    app = request.app
    registry = app['registry']
    sess_id = request.match_info['sess_id']
    access_key = request['keypair']['access_key']
    stream_key = (sess_id, access_key)
    extra_fields = (kernels.c.stdin_port, kernels.c.stdout_port)
    api_version = request['api_version']
    try:
        kernel = await asyncio.shield(
            registry.get_session(sess_id, access_key, field=extra_fields))
Exemple #8
0
from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import SessionTypes

from .auth import auth_required
from .exceptions import InvalidAPIParameters, TaskTemplateNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params, get_access_key_scopes

from ..manager.models import (association_groups_users as agus, domains,
                              groups, session_templates, keypairs, users,
                              UserRole, query_accessible_session_templates,
                              TemplateType, verify_vfolder_name)

log = BraceStyleAdapter(
    logging.getLogger('ai.backend.gateway.session_template'))

task_template_v1 = t.Dict({
    tx.AliasedKey(['api_version', 'apiVersion']):
    t.String,
    t.Key('kind'):
    t.Enum('taskTemplate', 'task_template'),
    t.Key('metadata'):
    t.Dict({
        t.Key('name'): t.String,
        t.Key('tag', default=None): t.Null | t.String,
    }),
    t.Key('spec'):
    t.Dict({
        tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type':
        tx.Enum(SessionTypes),
Exemple #9
0
import graphene
import sqlalchemy as sa

from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import ResourceSlot
from .base import (
    metadata,
    ResourceSlotColumn,
    privileged_mutation,
    simple_db_mutate,
    simple_db_mutate_returning_item,
    set_if_set,
)
from .user import UserRole

log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.models'))

__all__: Sequence[str] = (
    'resource_presets',
    'ResourcePreset',
    'CreateResourcePreset',
    'ModifyResourcePreset',
    'DeleteResourcePreset',
)

resource_presets = sa.Table(
    'resource_presets',
    metadata,
    sa.Column('name', sa.String(length=256), primary_key=True),
    sa.Column('resource_slots', ResourceSlotColumn(), nullable=False),
)
Exemple #10
0
)
from ai.backend.common.exception import UnknownImageReference
from ai.backend.common.etcd import (
    quote as etcd_quote,
    unquote as etcd_unquote,
    ConfigScopes,
)

from .exceptions import ServerMisconfiguredError
from .manager import ManagerStatus
if TYPE_CHECKING:
    from ..manager.background import ProgressReporter
from ..manager.container_registry import get_container_registry
from ..manager.defs import INTRINSIC_SLOTS

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.config'))

_max_cpu_count = os.cpu_count()
_file_perm = (Path(__file__).parent / 'server.py').stat()

DEFAULT_CHUNK_SIZE: Final = 256 * 1024  # 256 KiB
DEFAULT_INFLIGHT_CHUNKS: Final = 8

shared_config_defaults = {
    'volumes/_mount': '/mnt',
    'volumes/_default_host': 'local',
    'volumes/_fsprefix': '/',
    'config/api/allow-origins': '*',
    'config/docker/image/auto_pull': 'digest',
}
Exemple #11
0
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.plugin import AbstractPlugin, BasePluginContext
from .exception import (
    InsufficientResource,
    InvalidResourceArgument,
    InvalidResourceCombination,
    NotMultipleOfQuantum,
)
from .stats import StatContext, NodeMeasurement, ContainerMeasurement
from .types import Container as SessionContainer

if TYPE_CHECKING:
    from aiofiles.threadpool.text import AsyncTextIOWrapper
    from io import TextIOWrapper

log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.resources'))

known_slot_types: Mapping[SlotName, SlotTypes] = {}


class FractionAllocationStrategy(enum.Enum):
    FILL = 0
    EVENLY = 1


@attr.s(auto_attribs=True, slots=True)
class KernelResourceSpec:
    """
    This struct-like object stores the kernel resource allocation information
    with serialization and deserialization.
    get_access_key_scopes,
)
from .manager import ALL_ALLOWED, READ_ALLOWED, server_status_required
from ..manager.models import (
    domains,
    association_groups_users as agus,
    groups,
    keypairs,
    kernels,
    vfolders,
    AgentStatus,
    KernelStatus,
    query_accessible_vfolders,
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.kernel'))

_json_loads = functools.partial(json.loads, parse_float=Decimal)

creation_config_v1 = t.Dict({
    t.Key('mounts', default=None):
    t.Null | t.List(t.String),
    t.Key('environ', default=None):
    t.Null | t.Mapping(t.String, t.String),
    t.Key('clusterSize', default=None):
    t.Null | t.Int[1:],
})
creation_config_v2 = t.Dict({
    t.Key('mounts', default=None):
    t.Null | t.List(t.String),
    t.Key('environ', default=None):
Exemple #13
0
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import aobject, AccessKey
import ai.backend.common.validators as tx
if TYPE_CHECKING:
    from ai.backend.common.types import AgentId, SessionId

from .distributed import GlobalTimer
from .models import kernels, keypairs, keypair_resource_policies
from .models.kernel import LIVE_STATUS
from ..gateway.defs import REDIS_LIVE_DB
if TYPE_CHECKING:
    from ..gateway.config import SharedConfig
    from ..gateway.events import EventDispatcher

log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.idle'))


class AppStreamingStatus(enum.Enum):
    NO_ACTIVE_CONNECTIONS = 0
    HAS_ACTIVE_CONNECTIONS = 1


class BaseIdleChecker(aobject, metaclass=ABCMeta):

    name: ClassVar[str] = "base"

    def __init__(
        self,
        dbpool: SAPool,
        shared_config: SharedConfig,
Exemple #14
0
)
from .utils import (
    remove_exponent, )
if TYPE_CHECKING:
    from .agent import AbstractAgent

__all__ = (
    'StatContext',
    'StatModes',
    'MetricTypes',
    'NodeMeasurement',
    'ContainerMeasurement',
    'Measurement',
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.stats'))


def check_cgroup_available():
    """
    Check if the host OS provides cgroups.
    """
    return (not is_containerized() and sys.platform.startswith('linux'))


class StatModes(enum.Enum):
    CGROUP = 'cgroup'
    DOCKER = 'docker'

    @staticmethod
    def get_preferred_mode():
Exemple #15
0
    Iterable,
    Final,
    Tuple,
)

from aiohttp import web
from aiotools import apartial

from ai.backend.common import redis
from ai.backend.common.logging import BraceStyleAdapter

from .defs import REDIS_RLIM_DB
from .exceptions import RateLimitExceeded
from .types import CORSOptions, WebRequestHandler, WebMiddleware

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.ratelimit'))

_time_prec: Final = Decimal('1e-3')  # msec
_rlim_window: Final = 60 * 15

# We implement rate limiting using a rolling counter, which prevents
# last-minute and first-minute bursts between the intervals.

_rlim_script = '''
local access_key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local request_id = tonumber(redis.call('INCR', '__request_id'))
if request_id >= 1e12 then
    redis.call('SET', '__request_id', 1)
end
from abc import ABCMeta, abstractmethod
import asyncio
import logging
from typing import (
    Optional,
    Union,
    Awaitable,
)

import aiohttp
from aiohttp import web

from ai.backend.common.logging import BraceStyleAdapter

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.wsproxy'))


class ServiceProxy(metaclass=ABCMeta):
    '''
    The abstract base class to implement service proxy handlers.
    '''

    __slots__ = (
        'ws',
        'host',
        'port',
        'downstream_cb',
        'upstream_cb',
        'ping_cb',
    )
Exemple #17
0
import logging

import graphene
import sqlalchemy as sa

from ai.backend.common.logging import BraceStyleAdapter
from .base import BigInt, KVPair, ResourceLimit

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.admin'))

__all__ = (
    'Image',
    'PreloadImage',
    'RescanImages',
    'AliasImage',
    'DealiasImage',
)


class Image(graphene.ObjectType):
    name = graphene.String()
    humanized_name = graphene.String()
    tag = graphene.String()
    registry = graphene.String()
    digest = graphene.String()
    labels = graphene.List(KVPair)
    aliases = graphene.List(graphene.String)
    size_bytes = BigInt()
    resource_limits = graphene.List(ResourceLimit)
    supported_accelerators = graphene.List(graphene.String)
    installed = graphene.Boolean()
import enum
import json
import logging
from contextlib import asynccontextmanager as actxmgr
from datetime import datetime
from datetime import timezone as tz
from typing import Any, Optional, Union

import trafaret as t
from aiohttp import web

from ai.backend.common.logging import BraceStyleAdapter

log = BraceStyleAdapter(logging.getLogger(__name__))


class CheckParamSource(enum.Enum):
    BODY = 0
    QUERY = 1


def fstime2datetime(t: Union[float, int]) -> datetime:
    return datetime.utcfromtimestamp(t).replace(tzinfo=tz.utc)


@actxmgr
async def check_params(
    request: web.Request,
    checker: Optional[t.Trafaret],
    *,
    read_from: CheckParamSource = CheckParamSource.BODY,
)
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params, get_access_key_scopes

from ..manager.models import (
    keypairs,
    vfolders,
    query_accessible_vfolders,
    query_bootstrap_script,
    query_owned_dotfiles,
    verify_dotfile_name,
    MAXIMUM_DOTFILE_SIZE,
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.dotfile'))


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(t.Dict(
    {
        t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE),
        t.Key('path'): t.String,
        t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII),
        t.Key('owner_access_key', default=None): t.Null | t.String,
    }
))
async def create(request: web.Request, params: Any) -> web.Response:
    requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
    log.info('CREATE (ak:{0}/{1})',