Beispiel #1
0
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.plugin import AbstractPlugin, BasePluginContext
from .exception import (
    InsufficientResource,
    InvalidResourceArgument,
    InvalidResourceCombination,
    NotMultipleOfQuantum,
)
from .stats import StatContext, NodeMeasurement, ContainerMeasurement
from .types import Container as SessionContainer

if TYPE_CHECKING:
    from aiofiles.threadpool.text import AsyncTextIOWrapper
    from io import TextIOWrapper

log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.resources'))

known_slot_types: Mapping[SlotName, SlotTypes] = {}


class FractionAllocationStrategy(enum.Enum):
    FILL = 0
    EVENLY = 1


@attr.s(auto_attribs=True, slots=True)
class KernelResourceSpec:
    """
    This struct-like object stores the kernel resource allocation information
    with serialization and deserialization.
from yarl import URL
import zmq

from ai.backend.common import msgpack
from ai.backend.common.logging import BraceStyleAdapter
from ..gateway.exceptions import (
    BackendError, InvalidAPIParameters, InstanceNotAvailable, InstanceNotFound,
    KernelNotFound, KernelAlreadyExists, KernelCreationFailed,
    KernelDestructionFailed, KernelExecutionFailed, KernelRestartFailed,
    AgentError)
from .models import (agents, kernels, keypairs, ResourceSlot, AgentStatus,
                     KernelStatus)

__all__ = ['AgentRegistry', 'InstanceNotFound']

log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.registry'))

agent_peers = {}


@aiotools.actxmgr
async def RPCContext(addr, timeout=None):
    preserved_exceptions = (
        NotFoundError,
        ParametersError,
        asyncio.TimeoutError,
        asyncio.CancelledError,
        asyncio.InvalidStateError,
    )
    global agent_peers
    peer = agent_peers.get(addr, None)
Beispiel #3
0
    Sequence,
    TYPE_CHECKING,
)

import graphene
import sqlalchemy as sa
from graphql.execution.executors.asyncio import AsyncioExecutor

from ai.backend.common.logging import BraceStyleAdapter
from .user import UserRole
from .base import BigInt, KVPair, ResourceLimit

if TYPE_CHECKING:
    from ..background import ProgressReporter

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.admin'))

__all__ = (
    'Image',
    'PreloadImage',
    'RescanImages',
    'ForgetImage',
    'AliasImage',
    'DealiasImage',
)


class Image(graphene.ObjectType):
    name = graphene.String()
    humanized_name = graphene.String()
    tag = graphene.String()
Beispiel #4
0
import logging
import os
from setproctitle import setproctitle
import subprocess
import sys
from typing import Any, Mapping
from pathlib import Path

import attr
import click

from ai.backend.common.cli import LazyGroup
from ai.backend.common.logging import Logger, BraceStyleAdapter
from ai.backend.gateway.config import load as load_config

log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.cli'))


@attr.s(auto_attribs=True, frozen=True)
class CLIContext:
    logger: Logger
    config: Mapping[str, Any]


@click.group(invoke_without_command=True,
             context_settings={'help_option_names': ['-h', '--help']})
@click.option(
    '-f',
    '--config-path',
    '--config',
    type=Path,
Beispiel #5
0
    InvalidAuthParameters,
    InvalidAPIParameters,
    RejectedByHook,
)
from ..manager.models import (
    keypairs,
    keypair_resource_policies,
    users,
)
from ..manager.models.user import UserRole, UserStatus, INACTIVE_USER_STATUSES, check_credential
from ..manager.models.keypair import generate_keypair as _gen_keypair, generate_ssh_keypair
from ..manager.models.group import association_groups_users, groups
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params, set_handler_attr, get_handler_attr

log: Final = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.auth'))

whois_timezone_info: Final = {
    "A": 1 * 3600,
    "ACDT": 10.5 * 3600,
    "ACST": 9.5 * 3600,
    "ACT": -5 * 3600,
    "ACWST": 8.75 * 3600,
    "ADT": 4 * 3600,
    "AEDT": 11 * 3600,
    "AEST": 10 * 3600,
    "AET": 10 * 3600,
    "AFT": 4.5 * 3600,
    "AKDT": -8 * 3600,
    "AKST": -9 * 3600,
    "ALMT": 6 * 3600,
Beispiel #6
0
import logging
import os
from pathlib import Path
from typing import Dict

from ai.backend.common.logging import BraceStyleAdapter

log = BraceStyleAdapter(logging.getLogger(__name__))

# the names of following AWS variables follow boto3 convention.
s3_access_key = os.environ.get('AWS_ACCESS_KEY_ID', 'dummy-access-key')
s3_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', 'dummy-secret-key')
s3_region = os.environ.get('AWS_REGION', 'ap-northeast-1')
s3_bucket = os.environ.get('AWS_S3_BUCKET', 'codeonweb')
s3_bucket_path = os.environ.get('AWS_S3_BUCKET_PATH', 'bucket')

if s3_access_key == 'dummy-access-key':
    log.info('Automatic ~/.output file S3 uploads is disabled.')


def relpath(path, base):
    return Path(path).resolve().relative_to(Path(base).resolve())


def scandir(root: Path, allowed_max_size: int):
    '''
    Scans a directory recursively and returns a dictionary of all files and
    their last modified time.
    '''
    file_stats: Dict[Path, float] = dict()
    if not isinstance(root, Path):
Beispiel #7
0
    resource_presets,
    domains,
    groups,
    kernels,
    users,
    AgentStatus,
    association_groups_users,
    query_allowed_sgroups,
    AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
    RESOURCE_USAGE_KERNEL_STATUSES,
    LIVE_STATUS,
)
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.kernel'))

_json_loads = functools.partial(json.loads, parse_float=Decimal)


@auth_required
@atomic
async def list_presets(request) -> web.Response:
    '''
    Returns the list of all resource presets.
    '''
    log.info('LIST_PRESETS (ak:{})', request['keypair']['access_key'])
    await request.app['registry'].config_server.get_resource_slots()
    async with request.app['dbpool'].acquire() as conn, conn.begin():
        query = (sa.select([resource_presets]).select_from(resource_presets))
        # TODO: uncomment when we implement scaling group.
Beispiel #8
0
    InternalServerError,
    InvalidAPIParameters,
    SessionNotFound,
    TooManySessionsMatched,
)
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params, call_non_bursty
from .wsproxy import TCPProxy
from ..manager.defs import DEFAULT_ROLE
from ..manager.models import kernels
if TYPE_CHECKING:
    from .config import LocalConfig
    from ..manager.registry import AgentRegistry

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.stream'))


@server_status_required(READ_ALLOWED)
@auth_required
@adefer
async def stream_pty(defer, request: web.Request) -> web.StreamResponse:
    app = request.app
    local_config = app['local_config']
    registry = app['registry']
    session_name = request.match_info['session_name']
    access_key = request['keypair']['access_key']
    extra_fields = (kernels.c.stdin_port, kernels.c.stdout_port)
    api_version = request['api_version']
    try:
        compute_session = await asyncio.shield(
Beispiel #9
0
from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import SessionTypes

from .auth import auth_required
from .exceptions import InvalidAPIParameters, TaskTemplateNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params, get_access_key_scopes

from ..manager.models import (association_groups_users as agus, domains,
                              groups, session_templates, keypairs, users,
                              UserRole, query_accessible_session_templates,
                              TemplateType, verify_vfolder_name)

log = BraceStyleAdapter(
    logging.getLogger('ai.backend.gateway.session_template'))

task_template_v1 = t.Dict({
    tx.AliasedKey(['api_version', 'apiVersion']):
    t.String,
    t.Key('kind'):
    t.Enum('taskTemplate', 'task_template'),
    t.Key('metadata'):
    t.Dict({
        t.Key('name'): t.String,
        t.Key('tag', default=None): t.Null | t.String,
    }),
    t.Key('spec'):
    t.Dict({
        tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type':
        tx.Enum(SessionTypes),
Beispiel #10
0
    FrozenSet,
    Sequence,
    Tuple,
)

from aiodocker.docker import Docker, DockerVolume
from aiodocker.exceptions import DockerError
from aiotools import TaskGroup

from ai.backend.common.docker import ImageRef
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.utils import current_loop
from ..resources import KernelResourceSpec
from ..kernel import AbstractKernel, AbstractCodeRunner

log = BraceStyleAdapter(logging.getLogger(__name__))


class DockerKernel(AbstractKernel):

    # FIXME: apply TypedDict to data in Python 3.8

    def __init__(
            self,
            kernel_id: str,
            image: ImageRef,
            version: int,
            *,
            agent_config: Mapping[str, Any],
            resource_spec: KernelResourceSpec,
            service_ports: Any,  # TODO: type-annotation
Beispiel #11
0
)
from .utils import (
    remove_exponent, )
if TYPE_CHECKING:
    from .agent import AbstractAgent

__all__ = (
    'StatContext',
    'StatModes',
    'MetricTypes',
    'NodeMeasurement',
    'ContainerMeasurement',
    'Measurement',
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.stats'))


def check_cgroup_available():
    """
    Check if the host OS provides cgroups.
    """
    return (not is_containerized() and sys.platform.startswith('linux'))


class StatModes(enum.Enum):
    CGROUP = 'cgroup'
    DOCKER = 'docker'

    @staticmethod
    def get_preferred_mode():
Beispiel #12
0
import aiofiles
import aiojobs.aiohttp
from aiohttp import web
import aiotools
import click
from setproctitle import setproctitle
import trafaret as t

from ai.backend.common import config, utils, validators as tx
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
from ai.backend.common.logging import Logger, BraceStyleAdapter
from ai.backend.common.utils import Fstab
from . import __version__ as VERSION

log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.watcher'))

shutdown_enabled = False


@web.middleware
async def auth_middleware(request, handler):
    token = request.headers.get('X-BackendAI-Watcher-Token', None)
    if token == request.app['token']:
        try:
            return (await handler(request))
        except FileNotFoundError as e:
            log.info(repr(e))
            message = 'Agent is not loaded with systemctl.'
            return web.json_response({'message': message}, status=200)
        except Exception as e:
Beispiel #13
0
)
from ai.backend.common.exception import UnknownImageReference
from ai.backend.common.etcd import (
    quote as etcd_quote,
    unquote as etcd_unquote,
    ConfigScopes,
)

from .exceptions import ServerMisconfiguredError
from .manager import ManagerStatus
if TYPE_CHECKING:
    from ..manager.background import ProgressReporter
from ..manager.container_registry import get_container_registry
from ..manager.defs import INTRINSIC_SLOTS

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.config'))

_max_cpu_count = os.cpu_count()
_file_perm = (Path(__file__).parent / 'server.py').stat()

DEFAULT_CHUNK_SIZE: Final = 256 * 1024  # 256 KiB
DEFAULT_INFLIGHT_CHUNKS: Final = 8

shared_config_defaults = {
    'volumes/_mount': '/mnt',
    'volumes/_default_host': 'local',
    'volumes/_fsprefix': '/',
    'config/api/allow-origins': '*',
    'config/docker/image/auto_pull': 'digest',
}
Beispiel #14
0
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import (
    aobject,
    AgentId,
)
from ai.backend.common.utils import current_loop
from .auth import auth_required
from .defs import REDIS_STREAM_DB
from .exceptions import GenericNotFound, GenericForbidden, GroupNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params
from ..manager.models import kernels, groups, UserRole
from ..manager.types import BackgroundTaskEventArgs, Sentinel

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.events'))

sentinel: Final = Sentinel.token


class EventCallback(Protocol):
    async def __call__(self, context: Any, agent_id: AgentId, event_name: str,
                       *args) -> None:
        ...


@attr.s(auto_attribs=True, slots=True, frozen=True, eq=False, order=False)
class EventHandler:
    context: Any
    callback: EventCallback
Beispiel #15
0
    Awaitable, Callable, Final, Optional,
    Literal, Union,
    Set,
)
import uuid

from aiojobs import Scheduler
import attr

from ai.backend.common import redis
from ai.backend.common.logging import BraceStyleAdapter

from ..gateway.events import EventDispatcher
from .types import BackgroundTaskEventArgs

log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.background'))

MAX_BGTASK_ARCHIVE_PERIOD = 86400  # 24  hours

TaskResult = Literal['task_done', 'task_cancelled', 'task_failed']


class ProgressReporter:
    event_dispatcher: Final[EventDispatcher]
    task_id: Final[uuid.UUID]
    total_progress: Union[int, float]
    current_progress: Union[int, float]

    def __init__(
        self,
        event_dispatcher: EventDispatcher,
Beispiel #16
0
from .predicates import (
    check_reserved_batch_session,
    check_concurrency,
    check_dependencies,
    check_keypair_resource_limit,
    check_group_resource_limit,
    check_domain_resource_limit,
    check_scaling_group,
)

__all__ = (
    'load_scheduler',
    'SchedulerDispatcher',
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.scheduler'))


def load_scheduler(name: str, scheduler_configs: Mapping[str, Any]) -> AbstractScheduler:
    entry_prefix = 'backendai_scheduler_v10'
    for entrypoint in pkg_resources.iter_entry_points(entry_prefix):
        if entrypoint.name == name:
            log.debug('loading scheduler plugin "{}" from {}', name, entrypoint.module_name)
            scheduler_cls = entrypoint.load()
            scheduler_config = scheduler_configs.get(name, {})
            return scheduler_cls(scheduler_config)
    raise ImportError('Cannot load the scheduler plugin', name)


def merge_resource(src: MutableMapping[str, Any], val: MutableMapping[str, Any]) -> None:
    for k in val.keys():
from typing import Any

from aiohttp import web
import aiohttp_cors
import trafaret as t

from ai.backend.common import validators as tx
from ai.backend.common.logging import BraceStyleAdapter

from .auth import auth_required
from .manager import (READ_ALLOWED, server_status_required)
from .utils import check_api_params
from ..manager.models import (
    query_allowed_sgroups, )

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.scaling_group'))


@auth_required
@server_status_required(READ_ALLOWED)
@check_api_params(
    t.Dict({
        tx.AliasedKey(['group', 'group_id', 'group_name']):
        tx.UUID | t.String,
    }), )
async def list_available_sgroups(request: web.Request,
                                 params: Any) -> web.Response:
    dbpool = request.app['dbpool']
    access_key = request['keypair']['access_key']
    domain_name = request['user']['domain_name']
    group_id_or_name = params['group']
Beispiel #18
0
import asyncio
import logging
import os
from pathlib import Path

from ai.backend.common.logging import BraceStyleAdapter
import botocore, aiobotocore

log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.files'))

# the names of following AWS variables follow boto3 convention.
s3_access_key = os.environ.get('AWS_ACCESS_KEY_ID', 'dummy-access-key')
s3_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', 'dummy-secret-key')
s3_region = os.environ.get('AWS_REGION', 'ap-northeast-1')
s3_bucket = os.environ.get('AWS_S3_BUCKET', 'codeonweb')
s3_bucket_path = os.environ.get('AWS_S3_BUCKET_PATH', 'bucket')

if s3_access_key == 'dummy-access-key':
    log.info('Automatic ~/.output file S3 uploads is disabled.')


def relpath(path, base):
    return Path(path).resolve().relative_to(Path(base).resolve())


async def upload_output_files_to_s3(initial_file_stats,
                                    final_file_stats,
                                    base_dir, prefix):
    loop = asyncio.get_event_loop()
    output_files = []
    diff_files = diff_file_stats(initial_file_stats, final_file_stats)
Beispiel #19
0
import logging
import secrets
from urllib.parse import urlparse
import weakref

import aiohttp
from aiohttp import web

from ai.backend.common.logging import BraceStyleAdapter

from .auth import auth_required
from .exceptions import KernelNotFound
from .utils import not_impl_stub
from ..manager.models import kernels

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.wsproxy'))


class WebSocketProxy():
    __slots__ = [
        'path', 'conn', 'down_conn', 'upstream_buffer', 'upstream_buffer_task'
    ]

    def __init__(self, path, ws: web.WebSocketResponse):
        super(WebSocketProxy, self).__init__()
        self.path = path
        self.upstream_buffer = asyncio.PriorityQueue()
        self.down_conn = ws
        self.conn = None
        self.upstream_buffer_task = None
Beispiel #20
0
import aiotools
from callosum.rpc import Peer, RPCMessage
import click
import trafaret as t

from ai.backend.common import config, msgpack
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
from ai.backend.common.logging import Logger, BraceStyleAdapter
from ai.backend.common.types import aobject
from ai.backend.common import validators as tx

from . import __version__ as VERSION
from .exception import ExecutionError

log = BraceStyleAdapter(logging.getLogger('ai.backend.storage.server'))


async def run(cmd: str) -> str:
    log.debug('Executing [{}]', cmd)
    proc = await create_subprocess_shell(cmd,
                                         stdout=subprocess.PIPE,
                                         stderr=subprocess.PIPE)
    out, err = await proc.communicate()
    if err:
        raise ExecutionError(err.decode())
    return out.decode()


class RPCFunctionRegistry:
    BinarySize, MetricKey,
    DeviceName, DeviceId, DeviceModelInfo,
    SlotName, SlotTypes,
)
from . import __version__
from .nvidia import libcudart, libnvml, LibraryError

__all__ = (
    'PREFIX',
    'CUDADevice',
    'CUDAPlugin',
)

PREFIX = 'cuda'

log = BraceStyleAdapter(logging.getLogger('ai.backend.accelerator.cuda'))


@attr.s(auto_attribs=True)
class CUDADevice(AbstractComputeDevice):
    model_name: str
    uuid: str


class CUDAPlugin(AbstractComputePlugin):

    config_watch_enabled = False

    key = DeviceName('cuda')
    slot_types: Sequence[Tuple[SlotName, SlotTypes]] = (
        (SlotName('cuda.device'), SlotTypes('count')),
Beispiel #22
0
from .auth import auth_required, admin_required
from .exceptions import (InvalidAPIParameters, DotfileCreationFailed,
                         DotfileNotFound, DotfileAlreadyExists,
                         GenericForbidden, DomainNotFound)
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params

from ..manager.models import (
    domains,
    query_domain_dotfiles,
    verify_dotfile_name,
    MAXIMUM_DOTFILE_SIZE,
)

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.dotfile'))


@server_status_required(READ_ALLOWED)
@admin_required
@check_api_params(
    t.Dict({
        t.Key('domain'): t.String,
        t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE),
        t.Key('path'): t.String,
        t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII),
    }))
async def create(request: web.Request, params: Any) -> web.Response:
    log.info('CREATE DOTFILE (domain: {0})', params['domain'])
    if not request['is_superadmin'] and request['user'][
            'domain_name'] != params['domain']:
Beispiel #23
0
from . import ManagerStatus
from .auth import superadmin_required
from .exceptions import (
    InstanceNotFound,
    InvalidAPIParameters,
    GenericBadRequest,
    ServerFrozen,
    ServiceUnavailable,
)
from .types import CORSOptions, WebMiddleware
from .utils import check_api_params
from ..manager.defs import DEFAULT_ROLE
from ..manager.models import agents, kernels, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.manager'))


class SchedulerOps(enum.Enum):
    INCLUDE_AGENTS = 'include-agents'
    EXCLUDE_AGENTS = 'exclude-agents'


def server_status_required(allowed_status: FrozenSet[ManagerStatus]):
    def decorator(handler):
        @functools.wraps(handler)
        async def wrapped(request, *args, **kwargs):
            status = await request.app['shared_config'].get_manager_status()
            if status not in allowed_status:
                if status == ManagerStatus.FROZEN:
                    raise ServerFrozen
Beispiel #24
0
from decimal import Decimal
import functools
import logging
import time

from aiohttp import web
import aioredis

from ai.backend.common.logging import BraceStyleAdapter

from .defs import REDIS_RLIM_DB
from .exceptions import RateLimitExceeded

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.ratelimit'))

_time_prec = Decimal('1e-3')  # msec
_rlim_window = 60 * 15

# We implement rate limiting using a rolling counter, which prevents
# last-minute and first-minute bursts between the intervals.

_rlim_script = '''
local access_key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local request_id = tonumber(redis.call('INCR', '__request_id'))
if request_id >= 1e12 then
    redis.call('SET', '__request_id', 1)
end
if redis.call('EXISTS', access_key) == 1 then
    redis.call('ZREMRANGEBYSCORE', access_key, 0, now - window)
Beispiel #25
0
    'v5.20191215',

    # rewrote vfolder upload/download APIs to migrate to external storage proxies
    'v6.20200815',
])
LATEST_REV_DATES: Final = {
    1: '20160915',
    2: '20170915',
    3: '20181215',
    4: '20190615',
    5: '20191215',
    6: '20200815',
}
LATEST_API_VERSION: Final = 'v6.20200815'

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.server'))

PUBLIC_INTERFACES: Final = [
    'pidx',
    'background_task_manager',
    'local_config',
    'shared_config',
    'dbpool',
    'registry',
    'redis_live',
    'redis_stat',
    'redis_image',
    'redis_stream',
    'event_dispatcher',
    'idle_checkers',
    'storage_manager',
Beispiel #26
0
import logging

from aioconsole.events import run_console
from aiopg.sa import create_engine

from ai.backend.common.logging import BraceStyleAdapter

from . import register_command

log = BraceStyleAdapter(logging.getLogger(__name__))
_args = None


@register_command
def shell(args):
    '''Launch an interactive Python prompt running under an async event loop.'''
    global _args
    _args = args
    run_console()


async def create_dbpool():
    p = await create_engine(host=_args.db_addr[0],
                            port=_args.db_addr[1],
                            user=_args.db_user,
                            password=_args.db_password,
                            dbname=_args.db_name,
                            minsize=1,
                            maxsize=4)
    return p
from .auth import auth_required
from .exceptions import InvalidAPIParameters, TaskTemplateNotFound
from .manager import READ_ALLOWED, server_status_required
from .types import CORSOptions, Iterable, WebMiddleware
from .utils import check_api_params, get_access_key_scopes


from ..manager.models import (
    association_groups_users as agus, domains,
    groups, session_templates, keypairs, users, UserRole,
    query_accessible_session_templates, TemplateType
)
from ..manager.models.session_template import check_cluster_template

log = BraceStyleAdapter(logging.getLogger('ai.backend.gateway.cluster_template'))


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(t.Dict(
    {
        tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String,
        tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String,
        t.Key('owner_access_key', default=None): t.Null | t.String,
        t.Key('payload'): t.String
    }
))
async def create(request: web.Request, params: Any) -> web.Response:
    if params['domain'] is None:
        params['domain'] = request['user']['domain_name']