from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.plugin import AbstractPlugin, BasePluginContext from .exception import ( InsufficientResource, InvalidResourceArgument, InvalidResourceCombination, NotMultipleOfQuantum, ) from .stats import StatContext, NodeMeasurement, ContainerMeasurement from .types import Container as SessionContainer if TYPE_CHECKING: from aiofiles.threadpool.text import AsyncTextIOWrapper from io import TextIOWrapper log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.resources')) known_slot_types: Mapping[SlotName, SlotTypes] = {} class FractionAllocationStrategy(enum.Enum): FILL = 0 EVENLY = 1 @attr.s(auto_attribs=True, slots=True) class KernelResourceSpec: """ This struct-like object stores the kernel resource allocation information with serialization and deserialization.
from yarl import URL import zmq from ai.backend.common import msgpack from ai.backend.common.logging import BraceStyleAdapter from ..gateway.exceptions import ( BackendError, InvalidAPIParameters, InstanceNotAvailable, InstanceNotFound, KernelNotFound, KernelAlreadyExists, KernelCreationFailed, KernelDestructionFailed, KernelExecutionFailed, KernelRestartFailed, AgentError) from .models import (agents, kernels, keypairs, ResourceSlot, AgentStatus, KernelStatus) __all__ = ['AgentRegistry', 'InstanceNotFound'] log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.registry')) agent_peers = {} @aiotools.actxmgr async def RPCContext(addr, timeout=None): preserved_exceptions = ( NotFoundError, ParametersError, asyncio.TimeoutError, asyncio.CancelledError, asyncio.InvalidStateError, ) global agent_peers peer = agent_peers.get(addr, None)
Sequence, TYPE_CHECKING, ) import graphene import sqlalchemy as sa from graphql.execution.executors.asyncio import AsyncioExecutor from ai.backend.common.logging import BraceStyleAdapter from .user import UserRole from .base import BigInt, KVPair, ResourceLimit if TYPE_CHECKING: from ..background import ProgressReporter log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.admin')) __all__ = ( 'Image', 'PreloadImage', 'RescanImages', 'ForgetImage', 'AliasImage', 'DealiasImage', ) class Image(graphene.ObjectType): name = graphene.String() humanized_name = graphene.String() tag = graphene.String()
import logging import os from setproctitle import setproctitle import subprocess import sys from typing import Any, Mapping from pathlib import Path import attr import click from ai.backend.common.cli import LazyGroup from ai.backend.common.logging import Logger, BraceStyleAdapter from ai.backend.gateway.config import load as load_config log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.cli')) @attr.s(auto_attribs=True, frozen=True) class CLIContext: logger: Logger config: Mapping[str, Any] @click.group(invoke_without_command=True, context_settings={'help_option_names': ['-h', '--help']}) @click.option( '-f', '--config-path', '--config', type=Path,
InvalidAuthParameters, InvalidAPIParameters, RejectedByHook, ) from ..manager.models import ( keypairs, keypair_resource_policies, users, ) from ..manager.models.user import UserRole, UserStatus, INACTIVE_USER_STATUSES, check_credential from ..manager.models.keypair import generate_keypair as _gen_keypair, generate_ssh_keypair from ..manager.models.group import association_groups_users, groups from .types import CORSOptions, WebMiddleware from .utils import check_api_params, set_handler_attr, get_handler_attr log: Final = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.auth')) whois_timezone_info: Final = { "A": 1 * 3600, "ACDT": 10.5 * 3600, "ACST": 9.5 * 3600, "ACT": -5 * 3600, "ACWST": 8.75 * 3600, "ADT": 4 * 3600, "AEDT": 11 * 3600, "AEST": 10 * 3600, "AET": 10 * 3600, "AFT": 4.5 * 3600, "AKDT": -8 * 3600, "AKST": -9 * 3600, "ALMT": 6 * 3600,
import logging import os from pathlib import Path from typing import Dict from ai.backend.common.logging import BraceStyleAdapter log = BraceStyleAdapter(logging.getLogger(__name__)) # the names of following AWS variables follow boto3 convention. s3_access_key = os.environ.get('AWS_ACCESS_KEY_ID', 'dummy-access-key') s3_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', 'dummy-secret-key') s3_region = os.environ.get('AWS_REGION', 'ap-northeast-1') s3_bucket = os.environ.get('AWS_S3_BUCKET', 'codeonweb') s3_bucket_path = os.environ.get('AWS_S3_BUCKET_PATH', 'bucket') if s3_access_key == 'dummy-access-key': log.info('Automatic ~/.output file S3 uploads is disabled.') def relpath(path, base): return Path(path).resolve().relative_to(Path(base).resolve()) def scandir(root: Path, allowed_max_size: int): ''' Scans a directory recursively and returns a dictionary of all files and their last modified time. ''' file_stats: Dict[Path, float] = dict() if not isinstance(root, Path):
resource_presets, domains, groups, kernels, users, AgentStatus, association_groups_users, query_allowed_sgroups, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, RESOURCE_USAGE_KERNEL_STATUSES, LIVE_STATUS, ) from .types import CORSOptions, WebMiddleware from .utils import check_api_params log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.kernel')) _json_loads = functools.partial(json.loads, parse_float=Decimal) @auth_required @atomic async def list_presets(request) -> web.Response: ''' Returns the list of all resource presets. ''' log.info('LIST_PRESETS (ak:{})', request['keypair']['access_key']) await request.app['registry'].config_server.get_resource_slots() async with request.app['dbpool'].acquire() as conn, conn.begin(): query = (sa.select([resource_presets]).select_from(resource_presets)) # TODO: uncomment when we implement scaling group.
InternalServerError, InvalidAPIParameters, SessionNotFound, TooManySessionsMatched, ) from .manager import READ_ALLOWED, server_status_required from .types import CORSOptions, WebMiddleware from .utils import check_api_params, call_non_bursty from .wsproxy import TCPProxy from ..manager.defs import DEFAULT_ROLE from ..manager.models import kernels if TYPE_CHECKING: from .config import LocalConfig from ..manager.registry import AgentRegistry log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.stream')) @server_status_required(READ_ALLOWED) @auth_required @adefer async def stream_pty(defer, request: web.Request) -> web.StreamResponse: app = request.app local_config = app['local_config'] registry = app['registry'] session_name = request.match_info['session_name'] access_key = request['keypair']['access_key'] extra_fields = (kernels.c.stdin_port, kernels.c.stdout_port) api_version = request['api_version'] try: compute_session = await asyncio.shield(
from ai.backend.common import validators as tx from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.types import SessionTypes from .auth import auth_required from .exceptions import InvalidAPIParameters, TaskTemplateNotFound from .manager import READ_ALLOWED, server_status_required from .types import CORSOptions, Iterable, WebMiddleware from .utils import check_api_params, get_access_key_scopes from ..manager.models import (association_groups_users as agus, domains, groups, session_templates, keypairs, users, UserRole, query_accessible_session_templates, TemplateType, verify_vfolder_name) log = BraceStyleAdapter( logging.getLogger('ai.backend.gateway.session_template')) task_template_v1 = t.Dict({ tx.AliasedKey(['api_version', 'apiVersion']): t.String, t.Key('kind'): t.Enum('taskTemplate', 'task_template'), t.Key('metadata'): t.Dict({ t.Key('name'): t.String, t.Key('tag', default=None): t.Null | t.String, }), t.Key('spec'): t.Dict({ tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type': tx.Enum(SessionTypes),
FrozenSet, Sequence, Tuple, ) from aiodocker.docker import Docker, DockerVolume from aiodocker.exceptions import DockerError from aiotools import TaskGroup from ai.backend.common.docker import ImageRef from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.utils import current_loop from ..resources import KernelResourceSpec from ..kernel import AbstractKernel, AbstractCodeRunner log = BraceStyleAdapter(logging.getLogger(__name__)) class DockerKernel(AbstractKernel): # FIXME: apply TypedDict to data in Python 3.8 def __init__( self, kernel_id: str, image: ImageRef, version: int, *, agent_config: Mapping[str, Any], resource_spec: KernelResourceSpec, service_ports: Any, # TODO: type-annotation
) from .utils import ( remove_exponent, ) if TYPE_CHECKING: from .agent import AbstractAgent __all__ = ( 'StatContext', 'StatModes', 'MetricTypes', 'NodeMeasurement', 'ContainerMeasurement', 'Measurement', ) log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.stats')) def check_cgroup_available(): """ Check if the host OS provides cgroups. """ return (not is_containerized() and sys.platform.startswith('linux')) class StatModes(enum.Enum): CGROUP = 'cgroup' DOCKER = 'docker' @staticmethod def get_preferred_mode():
import aiofiles import aiojobs.aiohttp from aiohttp import web import aiotools import click from setproctitle import setproctitle import trafaret as t from ai.backend.common import config, utils, validators as tx from ai.backend.common.etcd import AsyncEtcd, ConfigScopes from ai.backend.common.logging import Logger, BraceStyleAdapter from ai.backend.common.utils import Fstab from . import __version__ as VERSION log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.watcher')) shutdown_enabled = False @web.middleware async def auth_middleware(request, handler): token = request.headers.get('X-BackendAI-Watcher-Token', None) if token == request.app['token']: try: return (await handler(request)) except FileNotFoundError as e: log.info(repr(e)) message = 'Agent is not loaded with systemctl.' return web.json_response({'message': message}, status=200) except Exception as e:
) from ai.backend.common.exception import UnknownImageReference from ai.backend.common.etcd import ( quote as etcd_quote, unquote as etcd_unquote, ConfigScopes, ) from .exceptions import ServerMisconfiguredError from .manager import ManagerStatus if TYPE_CHECKING: from ..manager.background import ProgressReporter from ..manager.container_registry import get_container_registry from ..manager.defs import INTRINSIC_SLOTS log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.config')) _max_cpu_count = os.cpu_count() _file_perm = (Path(__file__).parent / 'server.py').stat() DEFAULT_CHUNK_SIZE: Final = 256 * 1024 # 256 KiB DEFAULT_INFLIGHT_CHUNKS: Final = 8 shared_config_defaults = { 'volumes/_mount': '/mnt', 'volumes/_default_host': 'local', 'volumes/_fsprefix': '/', 'config/api/allow-origins': '*', 'config/docker/image/auto_pull': 'digest', }
from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.types import ( aobject, AgentId, ) from ai.backend.common.utils import current_loop from .auth import auth_required from .defs import REDIS_STREAM_DB from .exceptions import GenericNotFound, GenericForbidden, GroupNotFound from .manager import READ_ALLOWED, server_status_required from .types import CORSOptions, WebMiddleware from .utils import check_api_params from ..manager.models import kernels, groups, UserRole from ..manager.types import BackgroundTaskEventArgs, Sentinel log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.events')) sentinel: Final = Sentinel.token class EventCallback(Protocol): async def __call__(self, context: Any, agent_id: AgentId, event_name: str, *args) -> None: ... @attr.s(auto_attribs=True, slots=True, frozen=True, eq=False, order=False) class EventHandler: context: Any callback: EventCallback
Awaitable, Callable, Final, Optional, Literal, Union, Set, ) import uuid from aiojobs import Scheduler import attr from ai.backend.common import redis from ai.backend.common.logging import BraceStyleAdapter from ..gateway.events import EventDispatcher from .types import BackgroundTaskEventArgs log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.background')) MAX_BGTASK_ARCHIVE_PERIOD = 86400 # 24 hours TaskResult = Literal['task_done', 'task_cancelled', 'task_failed'] class ProgressReporter: event_dispatcher: Final[EventDispatcher] task_id: Final[uuid.UUID] total_progress: Union[int, float] current_progress: Union[int, float] def __init__( self, event_dispatcher: EventDispatcher,
from .predicates import ( check_reserved_batch_session, check_concurrency, check_dependencies, check_keypair_resource_limit, check_group_resource_limit, check_domain_resource_limit, check_scaling_group, ) __all__ = ( 'load_scheduler', 'SchedulerDispatcher', ) log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.scheduler')) def load_scheduler(name: str, scheduler_configs: Mapping[str, Any]) -> AbstractScheduler: entry_prefix = 'backendai_scheduler_v10' for entrypoint in pkg_resources.iter_entry_points(entry_prefix): if entrypoint.name == name: log.debug('loading scheduler plugin "{}" from {}', name, entrypoint.module_name) scheduler_cls = entrypoint.load() scheduler_config = scheduler_configs.get(name, {}) return scheduler_cls(scheduler_config) raise ImportError('Cannot load the scheduler plugin', name) def merge_resource(src: MutableMapping[str, Any], val: MutableMapping[str, Any]) -> None: for k in val.keys():
from typing import Any from aiohttp import web import aiohttp_cors import trafaret as t from ai.backend.common import validators as tx from ai.backend.common.logging import BraceStyleAdapter from .auth import auth_required from .manager import (READ_ALLOWED, server_status_required) from .utils import check_api_params from ..manager.models import ( query_allowed_sgroups, ) log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.scaling_group')) @auth_required @server_status_required(READ_ALLOWED) @check_api_params( t.Dict({ tx.AliasedKey(['group', 'group_id', 'group_name']): tx.UUID | t.String, }), ) async def list_available_sgroups(request: web.Request, params: Any) -> web.Response: dbpool = request.app['dbpool'] access_key = request['keypair']['access_key'] domain_name = request['user']['domain_name'] group_id_or_name = params['group']
import asyncio import logging import os from pathlib import Path from ai.backend.common.logging import BraceStyleAdapter import botocore, aiobotocore log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.files')) # the names of following AWS variables follow boto3 convention. s3_access_key = os.environ.get('AWS_ACCESS_KEY_ID', 'dummy-access-key') s3_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', 'dummy-secret-key') s3_region = os.environ.get('AWS_REGION', 'ap-northeast-1') s3_bucket = os.environ.get('AWS_S3_BUCKET', 'codeonweb') s3_bucket_path = os.environ.get('AWS_S3_BUCKET_PATH', 'bucket') if s3_access_key == 'dummy-access-key': log.info('Automatic ~/.output file S3 uploads is disabled.') def relpath(path, base): return Path(path).resolve().relative_to(Path(base).resolve()) async def upload_output_files_to_s3(initial_file_stats, final_file_stats, base_dir, prefix): loop = asyncio.get_event_loop() output_files = [] diff_files = diff_file_stats(initial_file_stats, final_file_stats)
import logging import secrets from urllib.parse import urlparse import weakref import aiohttp from aiohttp import web from ai.backend.common.logging import BraceStyleAdapter from .auth import auth_required from .exceptions import KernelNotFound from .utils import not_impl_stub from ..manager.models import kernels log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.wsproxy')) class WebSocketProxy(): __slots__ = [ 'path', 'conn', 'down_conn', 'upstream_buffer', 'upstream_buffer_task' ] def __init__(self, path, ws: web.WebSocketResponse): super(WebSocketProxy, self).__init__() self.path = path self.upstream_buffer = asyncio.PriorityQueue() self.down_conn = ws self.conn = None self.upstream_buffer_task = None
import aiotools from callosum.rpc import Peer, RPCMessage import click import trafaret as t from ai.backend.common import config, msgpack from ai.backend.common.etcd import AsyncEtcd, ConfigScopes from ai.backend.common.logging import Logger, BraceStyleAdapter from ai.backend.common.types import aobject from ai.backend.common import validators as tx from . import __version__ as VERSION from .exception import ExecutionError log = BraceStyleAdapter(logging.getLogger('ai.backend.storage.server')) async def run(cmd: str) -> str: log.debug('Executing [{}]', cmd) proc = await create_subprocess_shell(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = await proc.communicate() if err: raise ExecutionError(err.decode()) return out.decode() class RPCFunctionRegistry:
BinarySize, MetricKey, DeviceName, DeviceId, DeviceModelInfo, SlotName, SlotTypes, ) from . import __version__ from .nvidia import libcudart, libnvml, LibraryError __all__ = ( 'PREFIX', 'CUDADevice', 'CUDAPlugin', ) PREFIX = 'cuda' log = BraceStyleAdapter(logging.getLogger('ai.backend.accelerator.cuda')) @attr.s(auto_attribs=True) class CUDADevice(AbstractComputeDevice): model_name: str uuid: str class CUDAPlugin(AbstractComputePlugin): config_watch_enabled = False key = DeviceName('cuda') slot_types: Sequence[Tuple[SlotName, SlotTypes]] = ( (SlotName('cuda.device'), SlotTypes('count')),
from .auth import auth_required, admin_required from .exceptions import (InvalidAPIParameters, DotfileCreationFailed, DotfileNotFound, DotfileAlreadyExists, GenericForbidden, DomainNotFound) from .manager import READ_ALLOWED, server_status_required from .types import CORSOptions, Iterable, WebMiddleware from .utils import check_api_params from ..manager.models import ( domains, query_domain_dotfiles, verify_dotfile_name, MAXIMUM_DOTFILE_SIZE, ) log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.dotfile')) @server_status_required(READ_ALLOWED) @admin_required @check_api_params( t.Dict({ t.Key('domain'): t.String, t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), t.Key('path'): t.String, t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), })) async def create(request: web.Request, params: Any) -> web.Response: log.info('CREATE DOTFILE (domain: {0})', params['domain']) if not request['is_superadmin'] and request['user'][ 'domain_name'] != params['domain']:
from . import ManagerStatus from .auth import superadmin_required from .exceptions import ( InstanceNotFound, InvalidAPIParameters, GenericBadRequest, ServerFrozen, ServiceUnavailable, ) from .types import CORSOptions, WebMiddleware from .utils import check_api_params from ..manager.defs import DEFAULT_ROLE from ..manager.models import agents, kernels, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.manager')) class SchedulerOps(enum.Enum): INCLUDE_AGENTS = 'include-agents' EXCLUDE_AGENTS = 'exclude-agents' def server_status_required(allowed_status: FrozenSet[ManagerStatus]): def decorator(handler): @functools.wraps(handler) async def wrapped(request, *args, **kwargs): status = await request.app['shared_config'].get_manager_status() if status not in allowed_status: if status == ManagerStatus.FROZEN: raise ServerFrozen
from decimal import Decimal import functools import logging import time from aiohttp import web import aioredis from ai.backend.common.logging import BraceStyleAdapter from .defs import REDIS_RLIM_DB from .exceptions import RateLimitExceeded log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.ratelimit')) _time_prec = Decimal('1e-3') # msec _rlim_window = 60 * 15 # We implement rate limiting using a rolling counter, which prevents # last-minute and first-minute bursts between the intervals. _rlim_script = ''' local access_key = KEYS[1] local now = tonumber(ARGV[1]) local window = tonumber(ARGV[2]) local request_id = tonumber(redis.call('INCR', '__request_id')) if request_id >= 1e12 then redis.call('SET', '__request_id', 1) end if redis.call('EXISTS', access_key) == 1 then redis.call('ZREMRANGEBYSCORE', access_key, 0, now - window)
'v5.20191215', # rewrote vfolder upload/download APIs to migrate to external storage proxies 'v6.20200815', ]) LATEST_REV_DATES: Final = { 1: '20160915', 2: '20170915', 3: '20181215', 4: '20190615', 5: '20191215', 6: '20200815', } LATEST_API_VERSION: Final = 'v6.20200815' log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.server')) PUBLIC_INTERFACES: Final = [ 'pidx', 'background_task_manager', 'local_config', 'shared_config', 'dbpool', 'registry', 'redis_live', 'redis_stat', 'redis_image', 'redis_stream', 'event_dispatcher', 'idle_checkers', 'storage_manager',
import logging from aioconsole.events import run_console from aiopg.sa import create_engine from ai.backend.common.logging import BraceStyleAdapter from . import register_command log = BraceStyleAdapter(logging.getLogger(__name__)) _args = None @register_command def shell(args): '''Launch an interactive Python prompt running under an async event loop.''' global _args _args = args run_console() async def create_dbpool(): p = await create_engine(host=_args.db_addr[0], port=_args.db_addr[1], user=_args.db_user, password=_args.db_password, dbname=_args.db_name, minsize=1, maxsize=4) return p
from .auth import auth_required from .exceptions import InvalidAPIParameters, TaskTemplateNotFound from .manager import READ_ALLOWED, server_status_required from .types import CORSOptions, Iterable, WebMiddleware from .utils import check_api_params, get_access_key_scopes from ..manager.models import ( association_groups_users as agus, domains, groups, session_templates, keypairs, users, UserRole, query_accessible_session_templates, TemplateType ) from ..manager.models.session_template import check_cluster_template log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.cluster_template')) @server_status_required(READ_ALLOWED) @auth_required @check_api_params(t.Dict( { tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String, tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String, t.Key('owner_access_key', default=None): t.Null | t.String, t.Key('payload'): t.String } )) async def create(request: web.Request, params: Any) -> web.Response: if params['domain'] is None: params['domain'] = request['user']['domain_name']