class Domain(graphene.ObjectType): name = graphene.String() description = graphene.String() is_active = graphene.Boolean() created_at = GQLDateTime() modified_at = GQLDateTime() total_resource_slots = graphene.JSONString() allowed_vfolder_hosts = graphene.List(lambda: graphene.String) allowed_docker_registries = graphene.List(lambda: graphene.String) integration_id = graphene.String() # Dynamic fields. scaling_groups = graphene.List(lambda: graphene.String) async def resolve_scaling_groups(self, info): sgroups = await ScalingGroup.load_by_domain(info.context, self.name) return [sg.name for sg in sgroups] @classmethod def from_row(cls, context, row): if row is None: return None return cls( name=row['name'], description=row['description'], is_active=row['is_active'], created_at=row['created_at'], modified_at=row['modified_at'], total_resource_slots=row['total_resource_slots'].to_json(), allowed_vfolder_hosts=row['allowed_vfolder_hosts'], allowed_docker_registries=row['allowed_docker_registries'], integration_id=row['integration_id'], ) @classmethod async def load_all(cls, context, *, is_active=None): async with context['dbpool'].acquire() as conn: query = sa.select([domains]).select_from(domains) if is_active is not None: query = query.where(domains.c.is_active == is_active) return [ cls.from_row(context, row) async for row in conn.execute(query) ] @classmethod async def batch_load_by_name(cls, context, names=None, *, is_active=None): async with context['dbpool'].acquire() as conn: query = (sa.select([domains]).select_from(domains).where( domains.c.name.in_(names))) return await batch_result( context, conn, query, cls, names, lambda row: row['name'], )
class VirtualFolder(graphene.ObjectType): id = graphene.UUID() host = graphene.String() name = graphene.String() max_files = graphene.Int() max_size = graphene.Int() created_at = GQLDateTime() last_used = GQLDateTime() num_files = graphene.Int() cur_size = graphene.Int() # num_attached = graphene.Int() @classmethod def from_row(cls, row): if row is None: return None return cls( id=row.id, host=row.host, name=row.name, max_files=row.max_files, max_size=row.max_size, # in KiB created_at=row.created_at, last_used=row.last_used, num_attached=row.num_attached, ) async def resolve_num_files(self, info): # TODO: measure on-the-fly return 0 async def resolve_cur_size(self, info): # TODO: measure on-the-fly return 0 @staticmethod async def batch_load(dbpool, access_keys): async with dbpool.acquire() as conn: # TODO: num_attached count group-by query = (sa.select('*') .select_from(vfolders) .where(vfolders.c.belongs_to.in_(access_keys)) .order_by(sa.desc(vfolders.c.created_at))) objs_per_key = OrderedDict() for k in access_keys: objs_per_key[k] = list() async for row in conn.execute(query): o = VirtualFolder.from_row(row) objs_per_key[row.belongs_to].append(o) return tuple(objs_per_key.values())
class Agent(graphene.ObjectType): class Meta: interfaces = (Item, ) status = graphene.String() status_changed = GQLDateTime() region = graphene.String() scaling_group = graphene.String() schedulable = graphene.Boolean() available_slots = graphene.JSONString() occupied_slots = graphene.JSONString() addr = graphene.String() first_contact = GQLDateTime() lost_at = GQLDateTime() live_stat = graphene.JSONString() version = graphene.String() compute_plugins = graphene.JSONString() # Legacy fields mem_slots = graphene.Int() cpu_slots = graphene.Float() gpu_slots = graphene.Float() tpu_slots = graphene.Float() used_mem_slots = graphene.Int() used_cpu_slots = graphene.Float() used_gpu_slots = graphene.Float() used_tpu_slots = graphene.Float() cpu_cur_pct = graphene.Float() mem_cur_bytes = graphene.Float() compute_containers = graphene.List( 'ai.backend.manager.models.ComputeContainer', status=graphene.String()) @classmethod def from_row( cls, context: Mapping[str, Any], row: RowProxy, ) -> Agent: mega = 2 ** 20 return cls( id=row['id'], status=row['status'].name, status_changed=row['status_changed'], region=row['region'], scaling_group=row['scaling_group'], schedulable=row['schedulable'], available_slots=row['available_slots'].to_json(), occupied_slots=row['occupied_slots'].to_json(), addr=row['addr'], first_contact=row['first_contact'], lost_at=row['lost_at'], version=row['version'], compute_plugins=row['compute_plugins'], # legacy fields mem_slots=BinarySize.from_str(row['available_slots']['mem']) // mega, cpu_slots=row['available_slots']['cpu'], gpu_slots=row['available_slots'].get('cuda.device', 0), tpu_slots=row['available_slots'].get('tpu.device', 0), used_mem_slots=BinarySize.from_str( row['occupied_slots'].get('mem', 0)) // mega, used_cpu_slots=float(row['occupied_slots'].get('cpu', 0)), used_gpu_slots=float(row['occupied_slots'].get('cuda.device', 0)), used_tpu_slots=float(row['occupied_slots'].get('tpu.device', 0)), ) async def resolve_live_stat(self, info): rs = info.context['redis_stat'] live_stat = await redis.execute_with_retries( lambda: rs.get(str(self.id), encoding=None)) if live_stat is not None: live_stat = msgpack.unpackb(live_stat) return live_stat async def resolve_cpu_cur_pct(self, info): rs = info.context['redis_stat'] live_stat = await redis.execute_with_retries( lambda: rs.get(str(self.id), encoding=None)) if live_stat is not None: live_stat = msgpack.unpackb(live_stat) try: return float(live_stat['node']['cpu_util']['pct']) except (KeyError, TypeError, ValueError): return 0.0 return 0.0 async def resolve_mem_cur_bytes(self, info): rs = info.context['redis_stat'] live_stat = await redis.execute_with_retries( lambda: rs.get(str(self.id), encoding=None)) if live_stat is not None: live_stat = msgpack.unpackb(live_stat) try: return int(live_stat['node']['mem']['current']) except (KeyError, TypeError, ValueError): return 0 return 0 async def resolve_computations(self, info, status=None): ''' Retrieves all children worker sessions run by this agent. ''' manager = info.context['dlmgr'] loader = manager.get_loader('Computation.by_agent_id', status=status) return await loader.load(self.id) @staticmethod async def load_count( context, *, scaling_group=None, status=None, ) -> int: async with context['dbpool'].acquire() as conn: query = ( sa.select([sa.func.count(agents.c.id)]) .select_from(agents) .as_scalar() ) if scaling_group is not None: query = query.where(agents.c.scaling_group == scaling_group) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) result = await conn.execute(query) count = await result.fetchone() return count[0] @classmethod async def load_slice( cls, context, limit, offset, *, scaling_group=None, status=None, order_key=None, order_asc=True, ) -> Sequence[Agent]: async with context['dbpool'].acquire() as conn: # TODO: optimization for pagination using subquery, join if order_key is None: _ordering = agents.c.id else: _order_func = sa.asc if order_asc else sa.desc _ordering = _order_func(getattr(agents.c, order_key)) query = ( sa.select([agents]) .select_from(agents) .order_by(_ordering) .limit(limit) .offset(offset) ) if scaling_group is not None: query = query.where(agents.c.scaling_group == scaling_group) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) return [ cls.from_row(context, row) async for row in conn.execute(query) ] @classmethod async def load_all( cls, context, *, scaling_group=None, status=None, ) -> Sequence[Agent]: async with context['dbpool'].acquire() as conn: query = ( sa.select([agents]) .select_from(agents) ) if scaling_group is not None: query = query.where(agents.c.scaling_group == scaling_group) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) return [ cls.from_row(context, row) async for row in conn.execute(query) ] @classmethod async def batch_load( cls, context, agent_ids, *, status=None, ) -> Sequence[Optional[Agent]]: async with context['dbpool'].acquire() as conn: query = (sa.select([agents]) .select_from(agents) .where(agents.c.id.in_(agent_ids)) .order_by(agents.c.id)) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) return await batch_result( context, conn, query, cls, agent_ids, lambda row: row['id'], )
class User(graphene.ObjectType): class Meta: interfaces = (Item, ) uuid = graphene.UUID() # legacy username = graphene.String() email = graphene.String() password = graphene.String() need_password_change = graphene.Boolean() full_name = graphene.String() description = graphene.String() is_active = graphene.Boolean() status = graphene.String() status_info = graphene.String() created_at = GQLDateTime() modified_at = GQLDateTime() domain_name = graphene.String() role = graphene.String() groups = graphene.List(lambda: UserGroup) async def resolve_groups( self, info: graphene.ResolveInfo, ) -> Iterable[UserGroup]: manager = info.context['dlmgr'] loader = manager.get_loader('UserGroup.by_user_id') return await loader.load(self.id) @classmethod def from_row( cls, context: Mapping[str, Any], row: RowProxy, ) -> User: return cls( id=row['uuid'], uuid=row['uuid'], username=row['username'], email=row['email'], need_password_change=row['need_password_change'], full_name=row['full_name'], description=row['description'], is_active=True if row['status'] == UserStatus.ACTIVE else False, # legacy status=row['status'], status_info=row['status_info'], created_at=row['created_at'], modified_at=row['modified_at'], domain_name=row['domain_name'], role=row['role'], ) @classmethod async def load_all( cls, context, *, domain_name=None, group_id=None, is_active=None, status=None, limit=None, ) -> Sequence[User]: """ Load user's information. Group names associated with the user are also returned. """ async with context['dbpool'].acquire() as conn: if group_id is not None: from .group import association_groups_users as agus j = (users.join(agus, agus.c.user_id == users.c.uuid)) query = ( sa.select([users]) .select_from(j) .where(agus.c.group_id == group_id) ) else: query = ( sa.select([users]) .select_from(users) ) if context['user']['role'] != UserRole.SUPERADMIN: query = query.where(users.c.domain_name == context['user']['domain_name']) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if status is not None: query = query.where(users.c.status == UserStatus(status)) elif is_active is not None: # consider is_active field only if status is empty _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES query = query.where(users.c.status.in_(_statuses)) if limit is not None: query = query.limit(limit) return [cls.from_row(context, row) async for row in conn.execute(query)] @staticmethod async def load_count( context, *, domain_name=None, group_id=None, is_active=None, status=None, ) -> int: async with context['dbpool'].acquire() as conn: if group_id is not None: from .group import association_groups_users as agus j = (users.join(agus, agus.c.user_id == users.c.uuid)) query = ( sa.select([sa.func.count(users.c.uuid)]) .select_from(j) .where(agus.c.group_id == group_id) .as_scalar() ) else: query = ( sa.select([sa.func.count(users.c.uuid)]) .select_from(users) .as_scalar() ) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if status is not None: query = query.where(users.c.status == UserStatus(status)) elif is_active is not None: # consider is_active field only if status is empty _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES query = query.where(users.c.status.in_(_statuses)) result = await conn.execute(query) count = await result.fetchone() return count[0] @classmethod async def load_slice( cls, context, limit, offset, *, domain_name=None, group_id=None, is_active=None, status=None, order_key=None, order_asc=True, ) -> Sequence[User]: async with context['dbpool'].acquire() as conn: if order_key is None: _ordering = sa.desc(users.c.created_at) else: _order_func = sa.asc if order_asc else sa.desc _ordering = _order_func(getattr(users.c, order_key)) if group_id is not None: from .group import association_groups_users as agus j = (users.join(agus, agus.c.user_id == users.c.uuid)) query = ( sa.select([users]) .select_from(j) .where(agus.c.group_id == group_id) .order_by(_ordering) .limit(limit) .offset(offset) ) else: query = ( sa.select([users]) .select_from(users) .order_by(_ordering) .limit(limit) .offset(offset) ) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if status is not None: query = query.where(users.c.status == UserStatus(status)) elif is_active is not None: # consider is_active field only if status is empty _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES query = query.where(users.c.status.in_(_statuses)) return [ cls.from_row(context, row) async for row in conn.execute(query) ] @classmethod async def batch_load_by_email( cls, context, emails=None, *, domain_name=None, is_active=None, status=None, ) -> Sequence[Optional[User]]: async with context['dbpool'].acquire() as conn: query = ( sa.select([users]) .select_from(users) .where(users.c.email.in_(emails)) ) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if status is not None: query = query.where(users.c.status == UserStatus(status)) elif is_active is not None: # consider is_active field only if status is empty _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES query = query.where(users.c.status.in_(_statuses)) return await batch_result( context, conn, query, cls, emails, lambda row: row['email'], ) @classmethod async def batch_load_by_uuid( cls, context, user_ids=None, *, domain_name=None, is_active=None, status=None, ) -> Sequence[Optional[User]]: async with context['dbpool'].acquire() as conn: query = ( sa.select([users]) .select_from(users) .where(users.c.uuid.in_(user_ids)) ) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if status is not None: query = query.where(users.c.status == UserStatus(status)) elif is_active is not None: # consider is_active field only if status is empty _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES query = query.where(users.c.status.in_(_statuses)) return await batch_result( context, conn, query, cls, user_ids, lambda row: row['uuid'], )
class Agent(graphene.ObjectType): class Meta: interfaces = (Item, ) id = graphene.ID() status = graphene.String() region = graphene.String() scaling_group = graphene.String() available_slots = graphene.JSONString() occupied_slots = graphene.JSONString() addr = graphene.String() first_contact = GQLDateTime() lost_at = GQLDateTime() live_stat = graphene.JSONString() version = graphene.String() compute_plugins = graphene.JSONString() # Legacy fields mem_slots = graphene.Int() cpu_slots = graphene.Float() gpu_slots = graphene.Float() tpu_slots = graphene.Float() used_mem_slots = graphene.Int() used_cpu_slots = graphene.Float() used_gpu_slots = graphene.Float() used_tpu_slots = graphene.Float() cpu_cur_pct = graphene.Float() mem_cur_bytes = graphene.Float() computations = graphene.List('ai.backend.manager.models.Computation', status=graphene.String()) @classmethod def from_row(cls, context, row): if row is None: return None mega = 2**20 return cls( id=row['id'], status=row['status'].name, region=row['region'], scaling_group=row['scaling_group'], available_slots=row['available_slots'].to_json(), occupied_slots=row['occupied_slots'].to_json(), addr=row['addr'], first_contact=row['first_contact'], lost_at=row['lost_at'], version=row['version'], compute_plugins=row['compute_plugins'], # legacy fields mem_slots=BinarySize.from_str(row['available_slots']['mem']) // mega, cpu_slots=row['available_slots']['cpu'], gpu_slots=row['available_slots'].get('cuda.device', 0), tpu_slots=row['available_slots'].get('tpu.device', 0), used_mem_slots=BinarySize.from_str(row['occupied_slots'].get( 'mem', 0)) // mega, used_cpu_slots=float(row['occupied_slots'].get('cpu', 0)), used_gpu_slots=float(row['occupied_slots'].get('cuda.device', 0)), used_tpu_slots=float(row['occupied_slots'].get('tpu.device', 0)), ) async def resolve_live_stat(self, info): rs = info.context['redis_stat'] live_stat = await rs.get(str(self.id), encoding=None) if live_stat is not None: live_stat = msgpack.unpackb(live_stat) return live_stat async def resolve_cpu_cur_pct(self, info): rs = info.context['redis_stat'] live_stat = await rs.get(str(self.id), encoding=None) if live_stat is not None: live_stat = msgpack.unpackb(live_stat) return float(live_stat['node']['cpu_util']['pct']) async def resolve_mem_cur_bytes(self, info): rs = info.context['redis_stat'] live_stat = await rs.get(str(self.id), encoding=None) if live_stat is not None: live_stat = msgpack.unpackb(live_stat) return float(live_stat['node']['mem']['current']) async def resolve_computations(self, info, status=None): ''' Retrieves all children worker sessions run by this agent. ''' manager = info.context['dlmgr'] loader = manager.get_loader('Computation.by_agent_id', status=status) return await loader.load(self.id) @staticmethod async def load_count(context, *, scaling_group=None, status=None): async with context['dbpool'].acquire() as conn: query = (sa.select([sa.func.count(agents.c.id) ]).select_from(agents).as_scalar()) if scaling_group is not None: query = query.where(agents.c.scaling_group == scaling_group) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) result = await conn.execute(query) count = await result.fetchone() return count[0] @staticmethod async def load_slice(context, limit, offset, *, scaling_group=None, status=None, order_key=None, order_asc=True): async with context['dbpool'].acquire() as conn: # TODO: optimization for pagination using subquery, join if order_key is None: _ordering = agents.c.id else: _order_func = sa.asc if order_asc else sa.desc _ordering = _order_func(getattr(agents.c, order_key)) query = (sa.select([agents]).select_from(agents).order_by( _ordering).limit(limit).offset(offset)) if scaling_group is not None: query = query.where(agents.c.scaling_group == scaling_group) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) result = await conn.execute(query) rows = await result.fetchall() _agents = [] for r in rows: _agent = Agent.from_row(context, r) _agents.append(_agent) return _agents @staticmethod async def load_all(context, *, scaling_group=None, status=None): async with context['dbpool'].acquire() as conn: query = (sa.select([agents]).select_from(agents)) if scaling_group is not None: query = query.where(agents.c.scaling_group == scaling_group) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) result = await conn.execute(query) rows = await result.fetchall() _agents = [] for r in rows: _agent = Agent.from_row(context, r) _agents.append(_agent) return _agents @staticmethod async def batch_load(context, agent_ids, *, status=None): async with context['dbpool'].acquire() as conn: query = (sa.select([agents]).select_from(agents).where( agents.c.id.in_(agent_ids)).order_by(agents.c.id)) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) objs_per_key = OrderedDict() for k in agent_ids: objs_per_key[k] = None async for row in conn.execute(query): o = Agent.from_row(context, row) objs_per_key[row.id] = o return tuple(objs_per_key.values())
class ComputeSession(graphene.ObjectType): class Meta: interfaces = (Item, ) # identity tag = graphene.String() name = graphene.String() type = graphene.String() # image image = graphene.String() # image for the master registry = graphene.String() # image registry for the master cluster_template = graphene.String() # ownership domain_name = graphene.String() group_name = graphene.String() group_id = graphene.UUID() user_email = graphene.String() user_id = graphene.UUID() access_key = graphene.String() created_user_email = graphene.String() created_user_id = graphene.UUID() # status status = graphene.String() status_changed = GQLDateTime() status_info = graphene.String() created_at = GQLDateTime() terminated_at = GQLDateTime() starts_at = GQLDateTime() startup_command = graphene.String() result = graphene.String() # resources resource_opts = graphene.JSONString() scaling_group = graphene.String() service_ports = graphene.JSONString() mounts = graphene.List(lambda: graphene.String) occupied_slots = graphene.JSONString() # statistics num_queries = BigInt() # owned containers (aka kernels) containers = graphene.List(lambda: ComputeContainer) # relations dependencies = graphene.List(lambda: ComputeSession) @classmethod def parse_row(cls, context, row): assert row is not None return { # identity 'id': row['id'], 'tag': row['tag'], 'name': row['sess_id'], 'type': row['sess_type'].name, # image 'image': row['image'], 'registry': row['registry'], 'cluster_template': None, # TODO: implement # ownership 'domain_name': row['domain_name'], 'group_name': row['name'], # group.name (group is omitted since use_labels=True is not used) 'group_id': row['group_id'], 'user_email': row['email'], 'user_id': row['user_uuid'], 'access_key': row['access_key'], 'created_user_email': None, # TODO: implement 'created_user_id': None, # TODO: implement # status 'status': row['status'].name, 'status_changed': row['status_changed'], 'status_info': row['status_info'], 'created_at': row['created_at'], 'terminated_at': row['terminated_at'], 'starts_at': row['starts_at'], 'startup_command': row['startup_command'], 'result': row['result'].name, # resources 'resource_opts': row['resource_opts'], 'scaling_group': row['scaling_group'], 'service_ports': row['service_ports'], 'mounts': row['mounts'], 'occupied_slots': row['occupied_slots'].to_json(), # TODO: sum of owned containers # statistics 'num_queries': row['num_queries'], } @classmethod def from_row(cls, context: Mapping[str, Any], row: RowProxy) -> Optional[ComputeSession]: if row is None: return None props = cls.parse_row(context, row) return cls(**props) async def resolve_containers( self, info: graphene.ResolveInfo, ) -> Iterable[ComputeContainer]: manager = info.context['dlmgr'] loader = manager.get_loader('ComputeContainer.by_session') return await loader.load(self.id) async def resolve_dependencies( self, info: graphene.ResolveInfo, ) -> Iterable[ComputeSession]: manager = info.context['dlmgr'] loader = manager.get_loader('ComputeSession.by_dependency') return await loader.load(self.id) @classmethod async def load_count(cls, context, *, domain_name=None, group_id=None, access_key=None, status=None): if isinstance(status, str): status_list = [KernelStatus[s] for s in status.split(',')] elif isinstance(status, KernelStatus): status_list = [status] async with context['dbpool'].acquire() as conn: query = ( sa.select([sa.func.count(kernels.c.id)]) .select_from(kernels) .where(kernels.c.role == 'master') .as_scalar() ) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if group_id is not None: query = query.where(kernels.c.group_id == group_id) if access_key is not None: query = query.where(kernels.c.access_key == access_key) if status is not None: query = query.where(kernels.c.status.in_(status_list)) result = await conn.execute(query) count = await result.fetchone() return count[0] @classmethod async def load_slice(cls, context, limit, offset, *, domain_name=None, group_id=None, access_key=None, status=None, order_key=None, order_asc=None): if isinstance(status, str): status_list = [KernelStatus[s] for s in status.split(',')] elif isinstance(status, KernelStatus): status_list = [status] async with context['dbpool'].acquire() as conn: if order_key is None: _ordering = DEFAULT_SESSION_ORDERING else: _order_func = sa.asc if order_asc else sa.desc _ordering = [_order_func(getattr(kernels.c, order_key))] j = ( kernels .join(groups, groups.c.id == kernels.c.group_id) .join(users, users.c.uuid == kernels.c.user_uuid) ) query = ( sa.select([kernels, groups.c.name, users.c.email]) .select_from(j) .where(kernels.c.role == 'master') .order_by(*_ordering) .limit(limit) .offset(offset) ) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if group_id is not None: query = query.where(kernels.c.group_id == group_id) if access_key is not None: query = query.where(kernels.c.access_key == access_key) if status is not None: query = query.where(kernels.c.status.in_(status_list)) return [cls.from_row(context, r) async for r in conn.execute(query)] @classmethod async def batch_load_by_dependency(cls, context, session_ids): async with context['dbpool'].acquire() as conn: j = sa.join( kernels, kernel_dependencies, kernels.c.id == kernel_dependencies.c.depends_on, ) query = ( sa.select([kernels]) .select_from(j) .where( (kernels.c.role == 'master') & (kernel_dependencies.c.kernel_id.in_(session_ids)) ) ) return await batch_multiresult( context, conn, query, cls, session_ids, lambda row: row['id'], ) @classmethod async def batch_load_detail(cls, context, session_ids, *, domain_name=None, access_key=None): async with context['dbpool'].acquire() as conn: j = ( kernels .join(groups, groups.c.id == kernels.c.group_id) .join(users, users.c.uuid == kernels.c.user_uuid) ) query = ( sa.select([kernels, groups.c.name, users.c.email]) .select_from(j) .where( (kernels.c.role == 'master') & (kernels.c.id.in_(session_ids)) )) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if access_key is not None: query = query.where(kernels.c.access_key == access_key) return await batch_result( context, conn, query, cls, session_ids, lambda row: row['id'], )
class KeyPair(graphene.ObjectType): user_id = graphene.String() access_key = graphene.String() secret_key = graphene.String() is_active = graphene.Boolean() is_admin = graphene.Boolean() resource_policy = graphene.String() created_at = GQLDateTime() last_used = GQLDateTime() concurrency_limit = graphene.Int() concurrency_used = graphene.Int() rate_limit = graphene.Int() num_queries = graphene.Int() vfolders = graphene.List('ai.backend.manager.models.VirtualFolder') compute_sessions = graphene.List( 'ai.backend.manager.models.ComputeSession', status=graphene.String(), ) @classmethod def from_row(cls, row): if row is None: return None return cls( user_id=row['user_id'], access_key=row['access_key'], secret_key=row['secret_key'], is_active=row['is_active'], is_admin=row['is_admin'], resource_policy=row['resource_policy'], created_at=row['created_at'], last_used=row['last_used'], concurrency_limit=row['concurrency_limit'], concurrency_used=row['concurrency_used'], rate_limit=row['rate_limit'], num_queries=row['num_queries'], ) async def resolve_vfolders(self, info): manager = info.context['dlmgr'] loader = manager.get_loader('VirtualFolder') return await loader.load(self.access_key) async def resolve_compute_sessions(self, info, status=None): manager = info.context['dlmgr'] from . import KernelStatus # noqa: avoid circular imports if status is not None: status = KernelStatus[status] loader = manager.get_loader('ComputeSession', status=status) return await loader.load(self.access_key) @staticmethod async def batch_load_by_uid(dbpool, user_ids, *, is_active=None): async with dbpool.acquire() as conn: query = (sa.select('*').select_from(keypairs).where( keypairs.c.user_id.in_(user_ids))) if is_active is not None: query = query.where(keypairs.c.is_active == is_active) objs_per_key = OrderedDict() for k in user_ids: objs_per_key[k] = list() async for row in conn.execute(query): o = KeyPair.from_row(row) objs_per_key[row.user_id].append(o) return tuple(objs_per_key.values()) @staticmethod async def batch_load_by_ak(dbpool, access_keys): async with dbpool.acquire() as conn: query = (sa.select('*').select_from(keypairs).where( keypairs.c.access_key.in_(access_keys))) objs_per_key = OrderedDict() # For each access key, there is only one keypair. # So we don't build lists in objs_per_key variable. for k in access_keys: objs_per_key[k] = None async for row in conn.execute(query): o = KeyPair.from_row(row) objs_per_key[row.access_key] = o return tuple(objs_per_key.values())
class SessionCommons: sess_id = graphene.String() id = graphene.UUID() role = graphene.String() lang = graphene.String() status = graphene.String() status_info = graphene.String() created_at = GQLDateTime() terminated_at = GQLDateTime() agent = graphene.String() container_id = graphene.String() mem_slot = graphene.Int() cpu_slot = graphene.Int() gpu_slot = graphene.Int() num_queries = graphene.Int() cpu_used = graphene.Int() mem_max_bytes = graphene.Int() mem_cur_bytes = graphene.Int() net_rx_bytes = graphene.Int() net_tx_bytes = graphene.Int() io_read_bytes = graphene.Int() io_write_bytes = graphene.Int() io_max_scratch_size = graphene.Int() io_cur_scratch_size = graphene.Int() async def resolve_cpu_used(self, info): if self.status not in LIVE_STATUS: return zero_if_none(self.cpu_used) async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'cpu_used') return float(ret) if ret is not None else 0 async def resolve_mem_max_bytes(self, info): if self.status not in LIVE_STATUS: return zero_if_none(self.mem_max_bytes) async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'mem_max_bytes') return int(ret) if ret is not None else 0 async def resolve_mem_cur_bytes(self, info): if self.status not in LIVE_STATUS: return 0 async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'mem_cur_bytes') return int(ret) if ret is not None else 0 async def resolve_net_rx_bytes(self, info): if self.status not in LIVE_STATUS: return zero_if_none(self.net_rx_bytes) async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'net_rx_bytes') return int(ret) if ret is not None else 0 async def resolve_net_tx_bytes(self, info): if self.status not in LIVE_STATUS: return zero_if_none(self.net_tx_bytes) async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'net_tx_bytes') return int(ret) if ret is not None else 0 async def resolve_io_read_bytes(self, info): if self.status not in LIVE_STATUS: return zero_if_none(self.io_read_bytes) async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'io_read_bytes') return int(ret) if ret is not None else 0 async def resolve_io_write_bytes(self, info): if self.status not in LIVE_STATUS: return zero_if_none(self.io_write_bytes) async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'io_write_bytes') return int(ret) if ret is not None else 0 async def resolve_io_max_scratch_size(self, info): if self.status not in LIVE_STATUS: return zero_if_none(self.io_max_scratch_size) async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'io_max_scratch_size') return int(ret) if ret is not None else 0 async def resolve_io_cur_scratch_size(self, info): if self.status not in LIVE_STATUS: return 0 async with info.context['redis_stat_pool'].get() as rs: ret = await rs.hget(str(self.id), 'io_cur_scratch_size') return int(ret) if ret is not None else 0 @classmethod def from_row(cls, row): if row is None: return None props = { 'sess_id': row.sess_id, 'id': row.id, 'role': row.role, 'lang': row.lang, 'status': row.status, 'status_info': row.status_info, 'created_at': row.created_at, 'terminated_at': row.terminated_at, 'agent': row.agent, 'container_id': row.container_id, 'mem_slot': row.mem_slot, 'cpu_slot': row.cpu_slot, 'gpu_slot': row.gpu_slot, 'num_queries': row.num_queries, # live statistics # NOTE: currently graphene always uses resolve methods! 'cpu_used': row.cpu_used, 'mem_max_bytes': row.mem_max_bytes, 'mem_cur_bytes': 0, 'net_rx_bytes': row.net_rx_bytes, 'net_tx_bytes': row.net_tx_bytes, 'io_read_bytes': row.io_read_bytes, 'io_write_bytes': row.io_write_bytes, 'io_max_scratch_size': row.io_max_scratch_size, 'io_cur_scratch_size': 0, } return cls(**props)
class KeyPair(graphene.ObjectType): class Meta: interfaces = (Item, ) user_id = graphene.String() access_key = graphene.String() secret_key = graphene.String() is_active = graphene.Boolean() is_admin = graphene.Boolean() resource_policy = graphene.String() created_at = GQLDateTime() last_used = GQLDateTime() concurrency_used = graphene.Int() rate_limit = graphene.Int() num_queries = graphene.Int() user = graphene.UUID() ssh_public_key = graphene.String() vfolders = graphene.List('ai.backend.manager.models.VirtualFolder') compute_sessions = graphene.List( 'ai.backend.manager.models.ComputeSession', status=graphene.String(), ) # Deprecated concurrency_limit = graphene.Int( deprecation_reason='Moved to KeyPairResourcePolicy object as ' 'max_concurrent_sessions field.') @classmethod def from_row( cls, context: Mapping[str, Any], row: RowProxy, ) -> KeyPair: return cls( id=row['access_key'], user_id=row['user_id'], access_key=row['access_key'], secret_key=row['secret_key'], is_active=row['is_active'], is_admin=row['is_admin'], resource_policy=row['resource_policy'], created_at=row['created_at'], last_used=row['last_used'], concurrency_limit=0, # moved to resource policy concurrency_used=row['concurrency_used'], rate_limit=row['rate_limit'], num_queries=row['num_queries'], user=row['user'], ssh_public_key=row['ssh_public_key'], ) async def resolve_vfolders(self, info): manager = info.context['dlmgr'] loader = manager.get_loader('VirtualFolder') return await loader.load(self.access_key) async def resolve_compute_sessions(self, info, status=None): manager = info.context['dlmgr'] from . import KernelStatus # noqa: avoid circular imports if status is not None: status = KernelStatus[status] loader = manager.get_loader('ComputeSession', status=status) return await loader.load(self.access_key) @classmethod async def load_all( cls, context, *, domain_name=None, is_active=None, limit=None, ) -> Sequence[KeyPair]: from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = (sa.select([keypairs]).select_from(j)) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if is_active is not None: query = query.where(keypairs.c.is_active == is_active) if limit is not None: query = query.limit(limit) return [ cls.from_row(context, row) async for row in conn.execute(query) ] @staticmethod async def load_count( context, *, domain_name=None, email=None, is_active=None, ) -> int: from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = (sa.select([sa.func.count(keypairs.c.access_key) ]).select_from(j).as_scalar()) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if email is not None: query = query.where(keypairs.c.user_id == email) if is_active is not None: query = query.where(keypairs.c.is_active == is_active) result = await conn.execute(query) count = await result.fetchone() return count[0] @classmethod async def load_slice( cls, context, limit, offset, *, domain_name=None, email=None, is_active=None, order_key=None, order_asc=True, ) -> Sequence[KeyPair]: from .user import users async with context['dbpool'].acquire() as conn: if order_key is None: _ordering = sa.desc(keypairs.c.created_at) else: _order_func = sa.asc if order_asc else sa.desc _ordering = _order_func(getattr(keypairs.c, order_key)) j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = (sa.select([ keypairs ]).select_from(j).order_by(_ordering).limit(limit).offset(offset)) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if email is not None: query = query.where(keypairs.c.user_id == email) if is_active is not None: query = query.where(keypairs.c.is_active == is_active) return [ cls.from_row(context, row) async for row in conn.execute(query) ] @classmethod async def batch_load_by_email( cls, context, user_ids, *, domain_name=None, is_active=None, ) -> Sequence[Sequence[Optional[KeyPair]]]: from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = (sa.select([keypairs]).select_from(j).where( keypairs.c.user_id.in_(user_ids))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if is_active is not None: query = query.where(keypairs.c.is_active == is_active) return await batch_multiresult( context, conn, query, cls, user_ids, lambda row: row['user_id'], ) @classmethod async def batch_load_by_ak( cls, context, access_keys, *, domain_name=None, ) -> Sequence[Optional[KeyPair]]: async with context['dbpool'].acquire() as conn: from .user import users j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = (sa.select([keypairs]).select_from(j).where( keypairs.c.access_key.in_(access_keys))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) return await batch_result( context, conn, query, cls, access_keys, lambda row: row['access_key'], )
class User(graphene.ObjectType): uuid = graphene.UUID() username = graphene.String() email = graphene.String() password = graphene.String() need_password_change = graphene.Boolean() full_name = graphene.String() description = graphene.String() is_active = graphene.Boolean() created_at = GQLDateTime() domain_name = graphene.String() role = graphene.String() # Dynamic fields groups = graphene.List(lambda: UserGroup) @classmethod def from_row(cls, row): if row is None: return None if 'id' in row and row.id is not None and 'name' in row and row.name is not None: groups = [UserGroup(id=row['id'], name=row['name'])] else: groups = None return cls( uuid=row['uuid'], username=row['username'], email=row['email'], need_password_change=row['need_password_change'], full_name=row['full_name'], description=row['description'], is_active=row['is_active'], created_at=row['created_at'], domain_name=row['domain_name'], role=row['role'], # Dynamic fields groups=groups, # group information ) @staticmethod async def load_all(context, *, domain_name=None, group_id=None, is_active=None): ''' Load user's information. Group names associated with the user are also returned. ''' async with context['dbpool'].acquire() as conn: from .group import groups, association_groups_users as agus j = (users.join(agus, agus.c.user_id == users.c.uuid, isouter=True).join(groups, agus.c.group_id == groups.c.id, isouter=True)) query = sa.select([users, groups.c.name, groups.c.id]).select_from(j) if context['user']['role'] != UserRole.SUPERADMIN: query = query.where( users.c.domain_name == context['user']['domain_name']) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(groups.c.id == group_id) if is_active is not None: query = query.where(users.c.is_active == is_active) objs_per_key = OrderedDict() async for row in conn.execute(query): if row.email in objs_per_key: # If same user is already saved, just append group information. objs_per_key[row.email].groups.append( UserGroup(id=row.id, name=row.name)) continue o = User.from_row(row) objs_per_key[row.email] = o objs = list(objs_per_key.values()) return objs @staticmethod async def batch_load_by_email(context, emails=None, *, domain_name=None, is_active=None): async with context['dbpool'].acquire() as conn: from .group import groups, association_groups_users as agus j = (users.join(agus, agus.c.user_id == users.c.uuid, isouter=True).join(groups, agus.c.group_id == groups.c.id, isouter=True)) query = (sa.select([users, groups.c.name, groups.c.id]).select_from(j).where( users.c.email.in_(emails))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) objs_per_key = OrderedDict() # For each email, there is only one user. # So we don't build lists in objs_per_key variable. for k in emails: objs_per_key[k] = None async for row in conn.execute(query): key = row.email if objs_per_key[key] is not None: objs_per_key[key].groups.append( UserGroup(id=row.id, name=row.name)) continue o = User.from_row(row) objs_per_key[key] = o return tuple(objs_per_key.values()) @staticmethod async def batch_load_by_uuid(context, user_ids=None, *, domain_name=None, is_active=None): async with context['dbpool'].acquire() as conn: from .group import groups, association_groups_users as agus j = (users.join(agus, agus.c.user_id == users.c.uuid, isouter=True).join(groups, agus.c.group_id == groups.c.id, isouter=True)) query = (sa.select([users, groups.c.name, groups.c.id]).select_from(j).where( users.c.uuid.in_(user_ids))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) objs_per_key = OrderedDict() # For each uuid, there is only one user. # So we don't build lists in objs_per_key variable. for k in user_ids: objs_per_key[k] = None async for row in conn.execute(query): key = str(row.uuid) if objs_per_key[key] is not None: objs_per_key[key].groups.append( UserGroup(id=row.id, name=row.name)) continue o = User.from_row(row) objs_per_key[key] = o return tuple(objs_per_key.values())
class Group(graphene.ObjectType): id = graphene.UUID() name = graphene.String() description = graphene.String() is_active = graphene.Boolean() created_at = GQLDateTime() modified_at = GQLDateTime() domain_name = graphene.String() total_resource_slots = graphene.JSONString() allowed_vfolder_hosts = graphene.List(lambda: graphene.String) integration_id = graphene.String() @classmethod def from_row(cls, row): if row is None: return None return cls( id=row['id'], name=row['name'], description=row['description'], is_active=row['is_active'], created_at=row['created_at'], modified_at=row['modified_at'], domain_name=row['domain_name'], total_resource_slots=row['total_resource_slots'].to_json(), allowed_vfolder_hosts=row['allowed_vfolder_hosts'], integration_id=row['integration_id'], ) @staticmethod async def load_all(context, *, domain_name=None, is_active=None): async with context['dbpool'].acquire() as conn: query = (sa.select([groups]).select_from(groups)) if domain_name is not None: query = query.where(groups.c.domain_name == domain_name) if is_active is not None: query = query.where(groups.c.is_active == is_active) objs = [] async for row in conn.execute(query): o = Group.from_row(row) objs.append(o) return objs @staticmethod async def batch_load_by_id(context, ids, *, domain_name=None): async with context['dbpool'].acquire() as conn: query = (sa.select([groups]).select_from(groups).where( groups.c.id.in_(ids))) if domain_name is not None: query = query.where(groups.c.domain_name == domain_name) objs_per_key = OrderedDict() for k in ids: objs_per_key[k] = None async for row in conn.execute(query): o = Group.from_row(row) objs_per_key[str(row.id)] = o return [*objs_per_key.values()] @staticmethod async def get_groups_for_user(context, user_id): async with context['dbpool'].acquire() as conn: j = sa.join(groups, association_groups_users, groups.c.id == association_groups_users.c.group_id) query = (sa.select([groups]).select_from(j).where( association_groups_users.c.user_id == user_id)) objs = [] async for row in conn.execute(query): o = Group.from_row(row) objs.append(o) return objs
class Domain(graphene.ObjectType): name = graphene.String() description = graphene.String() is_active = graphene.Boolean() created_at = GQLDateTime() modified_at = GQLDateTime() total_resource_slots = graphene.JSONString() allowed_vfolder_hosts = graphene.List(lambda: graphene.String) allowed_docker_registries = graphene.List(lambda: graphene.String) integration_id = graphene.String() # Dynamic fields. scaling_groups = graphene.List(lambda: graphene.String) @classmethod def from_row(cls, row): if row is None: return None return cls( name=row['name'], description=row['description'], is_active=row['is_active'], created_at=row['created_at'], modified_at=row['modified_at'], total_resource_slots=row['total_resource_slots'].to_json(), allowed_vfolder_hosts=row['allowed_vfolder_hosts'], allowed_docker_registries=row['allowed_docker_registries'], integration_id=row['integration_id'], # Dynamic fields. scaling_groups=[row.scaling_group] if 'scaling_group' in row else [], ) @staticmethod async def load_all(context, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join(domains, sgroups_for_domains, domains.c.name == sgroups_for_domains.c.domain, isouter=True) query = (sa.select([domains, sgroups_for_domains.c.scaling_group]) .select_from(j)) if is_active is not None: query = query.where(domains.c.is_active == is_active) objs_per_key = OrderedDict() async for row in conn.execute(query): if row.name in objs_per_key: # If same domain is already saved, just append sgroup information. objs_per_key[row.name].scaling_groups.append(row.scaling_group) continue o = Domain.from_row(row) objs_per_key[row.name] = o objs = list(objs_per_key.values()) return objs @staticmethod async def batch_load_by_name(context, names=None, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join(domains, sgroups_for_domains, domains.c.name == sgroups_for_domains.c.domain, isouter=True) query = (sa.select([domains, sgroups_for_domains.c.scaling_group]) .select_from(j) .where(domains.c.name.in_(names))) objs_per_key = OrderedDict() # For each name, there is only one domain. # So we don't build lists in objs_per_key variable. for k in names: objs_per_key[k] = None async for row in conn.execute(query): if objs_per_key[row.name] is not None: objs_per_key[row.name].scaling_groups.append(row.scaling_group) continue o = Domain.from_row(row) objs_per_key[row.name] = o return tuple(objs_per_key.values())
class SessionCommons: sess_id = graphene.String() id = graphene.ID() role = graphene.String() image = graphene.String() registry = graphene.String() domain_name = graphene.String() group_name = graphene.String() group_id = graphene.UUID() scaling_group = graphene.String() user_uuid = graphene.UUID() access_key = graphene.String() status = graphene.String() status_info = graphene.String() created_at = GQLDateTime() terminated_at = GQLDateTime() # hidable fields by configuration agent = graphene.String() container_id = graphene.String() service_ports = graphene.JSONString() occupied_slots = graphene.JSONString() occupied_shares = graphene.JSONString() mounts = graphene.List(lambda: graphene.List(lambda: graphene.String)) num_queries = BigInt() live_stat = graphene.JSONString() last_stat = graphene.JSONString() user_email = graphene.String() # Legacy fields lang = graphene.String() mem_slot = graphene.Int() cpu_slot = graphene.Float() gpu_slot = graphene.Float() tpu_slot = graphene.Float() cpu_used = BigInt() cpu_using = graphene.Float() mem_max_bytes = BigInt() mem_cur_bytes = BigInt() net_rx_bytes = BigInt() net_tx_bytes = BigInt() io_read_bytes = BigInt() io_write_bytes = BigInt() io_max_scratch_size = BigInt() io_cur_scratch_size = BigInt() @classmethod async def _resolve_live_stat(cls, redis_stat, kernel_id): cstat = await redis_stat.get(kernel_id, encoding=None) if cstat is not None: cstat = msgpack.unpackb(cstat) return cstat async def resolve_live_stat(self, info): rs = info.context['redis_stat'] return await type(self)._resolve_live_stat(rs, str(self.id)) async def _resolve_legacy_metric(self, info, metric_key, metric_field, convert_type): if not hasattr(self, 'status'): return None rs = info.context['redis_stat'] if self.status not in LIVE_STATUS: if self.last_stat is None: return convert_type(0) metric = self.last_stat.get(metric_key) if metric is None: return convert_type(0) value = metric.get(metric_field) if value is None: return convert_type(0) return convert_type(value) else: kstat = await type(self)._resolve_live_stat(rs, str(self.id)) if kstat is None: return convert_type(0) metric = kstat.get(metric_key) if metric is None: return convert_type(0) value = metric.get(metric_field) if value is None: return convert_type(0) return convert_type(value) async def resolve_cpu_used(self, info): return await self._resolve_legacy_metric(info, 'cpu_used', 'current', float) async def resolve_cpu_using(self, info): return await self._resolve_legacy_metric(info, 'cpu_util', 'pct', float) async def resolve_mem_max_bytes(self, info): return await self._resolve_legacy_metric(info, 'mem', 'stats.max', int) async def resolve_mem_cur_bytes(self, info): return await self._resolve_legacy_metric(info, 'mem', 'current', int) async def resolve_net_rx_bytes(self, info): return await self._resolve_legacy_metric(info, 'net_rx', 'stats.rate', int) async def resolve_net_tx_bytes(self, info): return await self._resolve_legacy_metric(info, 'net_tx', 'stats.rate', int) async def resolve_io_read_bytes(self, info): return await self._resolve_legacy_metric(info, 'io_read', 'current', int) async def resolve_io_write_bytes(self, info): return await self._resolve_legacy_metric(info, 'io_write', 'current', int) async def resolve_io_max_scratch_size(self, info): return await self._resolve_legacy_metric(info, 'io_scratch_size', 'stats.max', int) async def resolve_io_cur_scratch_size(self, info): return await self._resolve_legacy_metric(info, 'io_scratch_size', 'current', int) @classmethod def parse_row(cls, context, row): assert row is not None from .user import UserRole mega = 2 ** 20 is_superadmin = (context['user']['role'] == UserRole.SUPERADMIN) if is_superadmin: hide_agents = False else: hide_agents = context['config']['manager']['hide-agents'] return { 'sess_id': row['sess_id'], 'id': row['id'], 'role': row['role'], 'image': row['image'], 'registry': row['registry'], 'domain_name': row['domain_name'], 'group_name': row['name'], # group.name (group is omitted since use_labels=True is not used) 'group_id': row['group_id'], 'scaling_group': row['scaling_group'], 'user_uuid': row['user_uuid'], 'access_key': row['access_key'], 'status': row['status'].name, 'status_info': row['status_info'], 'created_at': row['created_at'], 'terminated_at': row['terminated_at'], 'service_ports': row['service_ports'], 'occupied_slots': row['occupied_slots'].to_json(), 'occupied_shares': row['occupied_shares'], 'mounts': row['mounts'], 'num_queries': row['num_queries'], # optinally hidden 'agent': row['agent'] if not hide_agents else None, 'container_id': row['container_id'] if not hide_agents else None, # live_stat is resolved by Graphene 'last_stat': row['last_stat'], 'user_email': row['email'], # Legacy fields # NOTE: currently graphene always uses resolve methods! 'cpu_used': 0, 'mem_max_bytes': 0, 'mem_cur_bytes': 0, 'net_rx_bytes': 0, 'net_tx_bytes': 0, 'io_read_bytes': 0, 'io_write_bytes': 0, 'io_max_scratch_size': 0, 'io_cur_scratch_size': 0, 'lang': row['image'], 'mem_slot': BinarySize.from_str( row['occupied_slots'].get('mem', 0)) // mega, 'cpu_slot': float(row['occupied_slots'].get('cpu', 0)), 'gpu_slot': float(row['occupied_slots'].get('cuda.device', 0)), 'tpu_slot': float(row['occupied_slots'].get('tpu.device', 0)), } @classmethod def from_row(cls, context, row): if row is None: return None props = cls.parse_row(context, row) return cls(**props)
class ScalingGroup(graphene.ObjectType): name = graphene.String() description = graphene.String() is_active = graphene.Boolean() created_at = GQLDateTime() driver = graphene.String() driver_opts = graphene.JSONString() scheduler = graphene.String() scheduler_opts = graphene.JSONString() @classmethod def from_row(cls, context, row): if row is None: return None return cls( name=row['name'], description=row['description'], is_active=row['is_active'], created_at=row['created_at'], driver=row['driver'], driver_opts=row['driver_opts'], scheduler=row['scheduler'], scheduler_opts=row['scheduler_opts'], ) @classmethod async def load_all(cls, context, *, is_active=None): async with context['dbpool'].acquire() as conn: query = sa.select([scaling_groups]).select_from(scaling_groups) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) return [cls.from_row(context, row) async for row in conn.execute(query)] @classmethod async def load_by_domain(cls, context, domain, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join( scaling_groups, sgroups_for_domains, scaling_groups.c.name == sgroups_for_domains.c.scaling_group) query = ( sa.select([scaling_groups]) .select_from(j) .where(sgroups_for_domains.c.domain == domain) ) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) return [cls.from_row(context, row) async for row in conn.execute(query)] @classmethod async def load_by_group(cls, context, group, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join( scaling_groups, sgroups_for_groups, scaling_groups.c.name == sgroups_for_groups.c.scaling_group) query = ( sa.select([scaling_groups]) .select_from(j) .where(sgroups_for_groups.c.group == group) ) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) return [cls.from_row(context, row) async for row in conn.execute(query)] @classmethod async def load_by_keypair(cls, context, access_key, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join( scaling_groups, sgroups_for_keypairs, scaling_groups.c.name == sgroups_for_keypairs.c.scaling_group) query = ( sa.select([scaling_groups]) .select_from(j) .where(sgroups_for_keypairs.c.access_key == access_key) ) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) return [cls.from_row(context, row) async for row in conn.execute(query)] @classmethod async def batch_load_by_name(cls, context, names): async with context['dbpool'].acquire() as conn: query = (sa.select([scaling_groups]) .select_from(scaling_groups) .where(scaling_groups.c.name.in_(names))) return await batch_result( context, conn, query, cls, names, lambda row: row['name'], )
class SessionCommons: sess_id = graphene.String() id = graphene.UUID() role = graphene.String() lang = graphene.String() status = graphene.String() status_info = graphene.String() created_at = GQLDateTime() terminated_at = GQLDateTime() agent = graphene.String() container_id = graphene.String() mem_slot = graphene.Int() cpu_slot = graphene.Float() gpu_slot = graphene.Float() num_queries = graphene.Int() cpu_used = graphene.Int() mem_max_bytes = graphene.Int() mem_cur_bytes = graphene.Int() net_rx_bytes = graphene.Int() net_tx_bytes = graphene.Int() io_read_bytes = graphene.Int() io_write_bytes = graphene.Int() io_max_scratch_size = graphene.Int() io_cur_scratch_size = graphene.Int() async def resolve_cpu_used(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return zero_if_none(self.cpu_used) rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'cpu_used') return float(ret) if ret is not None else 0 async def resolve_mem_max_bytes(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return zero_if_none(self.mem_max_bytes) rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'mem_max_bytes') return int(ret) if ret is not None else 0 async def resolve_mem_cur_bytes(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return 0 rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'mem_cur_bytes') return int(ret) if ret is not None else 0 async def resolve_net_rx_bytes(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return zero_if_none(self.net_rx_bytes) rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'net_rx_bytes') return int(ret) if ret is not None else 0 async def resolve_net_tx_bytes(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return zero_if_none(self.net_tx_bytes) rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'net_tx_bytes') return int(ret) if ret is not None else 0 async def resolve_io_read_bytes(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return zero_if_none(self.io_read_bytes) rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'io_read_bytes') return int(ret) if ret is not None else 0 async def resolve_io_write_bytes(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return zero_if_none(self.io_write_bytes) rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'io_write_bytes') return int(ret) if ret is not None else 0 async def resolve_io_max_scratch_size(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return zero_if_none(self.io_max_scratch_size) rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'io_max_scratch_size') return int(ret) if ret is not None else 0 async def resolve_io_cur_scratch_size(self, info): if not hasattr(self, 'status'): return None if self.status not in LIVE_STATUS: return 0 rs = info.context['redis_stat'] ret = await rs.hget(str(self.id), 'io_cur_scratch_size') return int(ret) if ret is not None else 0 @classmethod def parse_row(cls, row): assert row is not None return { 'sess_id': row['sess_id'], 'id': row['id'], 'role': row['role'], 'lang': row['lang'], 'status': row['status'], 'status_info': row['status_info'], 'created_at': row['created_at'], 'terminated_at': row['terminated_at'], 'agent': row['agent'], 'container_id': row['container_id'], 'mem_slot': row['mem_slot'], 'cpu_slot': row['cpu_slot'], 'gpu_slot': row['gpu_slot'], 'num_queries': row['num_queries'], # live statistics # NOTE: currently graphene always uses resolve methods! 'cpu_used': row['cpu_used'], 'mem_max_bytes': row['mem_max_bytes'], 'mem_cur_bytes': 0, 'net_rx_bytes': row['net_rx_bytes'], 'net_tx_bytes': row['net_tx_bytes'], 'io_read_bytes': row['io_read_bytes'], 'io_write_bytes': row['io_write_bytes'], 'io_max_scratch_size': row['io_max_scratch_size'], 'io_cur_scratch_size': 0, } @classmethod def from_row(cls, row): if row is None: return None props = cls.parse_row(row) return cls(**props)
class KeyPairResourcePolicy(graphene.ObjectType): name = graphene.String() created_at = GQLDateTime() default_for_unspecified = graphene.String() total_resource_slots = graphene.JSONString() max_concurrent_sessions = graphene.Int() max_containers_per_session = graphene.Int() idle_timeout = BigInt() max_vfolder_count = graphene.Int() max_vfolder_size = BigInt() allowed_vfolder_hosts = graphene.List(lambda: graphene.String) @classmethod def from_row(cls, context, row): if row is None: return None return cls( name=row['name'], created_at=row['created_at'], default_for_unspecified=row['default_for_unspecified'].name, total_resource_slots=row['total_resource_slots'].to_json(), max_concurrent_sessions=row['max_concurrent_sessions'], max_containers_per_session=row['max_containers_per_session'], idle_timeout=row['idle_timeout'], max_vfolder_count=row['max_vfolder_count'], max_vfolder_size=row['max_vfolder_size'], allowed_vfolder_hosts=row['allowed_vfolder_hosts'], ) @classmethod async def load_all(cls, context): async with context['dbpool'].acquire() as conn: query = (sa.select([keypair_resource_policies ]).select_from(keypair_resource_policies)) result = await conn.execute(query) rows = await result.fetchall() return [cls.from_row(context, r) for r in rows] @classmethod async def load_all_user(cls, context, access_key): async with context['dbpool'].acquire() as conn: query = (sa.select([keypairs.c.user_id ]).select_from(keypairs).where( keypairs.c.access_key == access_key)) result = await conn.execute(query) row = await result.fetchone() user_id = row['user_id'] j = sa.join( keypairs, keypair_resource_policies, keypairs.c.resource_policy == keypair_resource_policies.c.name) query = (sa.select([keypair_resource_policies ]).select_from(j).where( (keypairs.c.user_id == user_id))) result = await conn.execute(query) rows = await result.fetchall() return [cls.from_row(context, r) for r in rows] @classmethod async def batch_load_by_name(cls, context, names): async with context['dbpool'].acquire() as conn: query = (sa.select([ keypair_resource_policies ]).select_from(keypair_resource_policies).where( keypair_resource_policies.c.name.in_(names)).order_by( keypair_resource_policies.c.name)) objs_per_key = OrderedDict() for k in names: objs_per_key[k] = None async for row in conn.execute(query): o = cls.from_row(context, row) objs_per_key[row.name] = o return tuple(objs_per_key.values()) @classmethod async def batch_load_by_name_user(cls, context, names): async with context['dbpool'].acquire() as conn: access_key = context['access_key'] j = sa.join( keypairs, keypair_resource_policies, keypairs.c.resource_policy == keypair_resource_policies.c.name) query = (sa.select([keypair_resource_policies]).select_from( j).where((keypair_resource_policies.c.name.in_(names)) & (keypairs.c.access_key == access_key)).order_by( keypair_resource_policies.c.name)) objs_per_key = OrderedDict() for k in names: objs_per_key[k] = None async for row in conn.execute(query): o = cls.from_row(context, row) objs_per_key[row.name] = o return tuple(objs_per_key.values()) @classmethod async def batch_load_by_ak(cls, context, access_keys): async with context['dbpool'].acquire() as conn: j = sa.join( keypairs, keypair_resource_policies, keypairs.c.resource_policy == keypair_resource_policies.c.name) query = (sa.select( [keypair_resource_policies]).select_from(j).where( (keypairs.c.access_key.in_(access_keys))).order_by( keypair_resource_policies.c.name)) objs_per_key = OrderedDict() async for row in conn.execute(query): o = cls.from_row(context, row) objs_per_key[row.name] = o return tuple(objs_per_key.values())
class VirtualFolder(graphene.ObjectType): class Meta: interfaces = (Item, ) host = graphene.String() name = graphene.String() user = graphene.UUID() # User.id group = graphene.UUID() # Group.id creator = graphene.String() # User.email unmanaged_path = graphene.String() usage_mode = graphene.String() permission = graphene.String() ownership_type = graphene.String() max_files = graphene.Int() max_size = graphene.Int() created_at = GQLDateTime() last_used = GQLDateTime() num_files = graphene.Int() cur_size = BigInt() # num_attached = graphene.Int() cloneable = graphene.Boolean() @classmethod def from_row(cls, context, row): if row is None: return None return cls( id=row['id'], host=row['host'], name=row['name'], user=row['user'], group=row['group'], creator=row['creator'], unmanaged_path=row['unmanaged_path'], usage_mode=row['usage_mode'], permission=row['permission'], ownership_type=row['ownership_type'], max_files=row['max_files'], max_size=row['max_size'], # in KiB created_at=row['created_at'], last_used=row['last_used'], # num_attached=row['num_attached'], cloneable=row['cloneable'], ) async def resolve_num_files(self, info): # TODO: measure on-the-fly return 0 async def resolve_cur_size(self, info): # TODO: measure on-the-fly return 0 @classmethod async def load_count(cls, context, *, domain_name=None, group_id=None, user_id=None): from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) query = ( sa.select([sa.func.count(vfolders.c.id)]) .select_from(j) .as_scalar() ) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(vfolders.c.group == group_id) if user_id is not None: query = query.where(vfolders.c.user == user_id) result = await conn.execute(query) return await result.scalar() @classmethod async def load_slice(cls, context, limit, offset, *, domain_name=None, group_id=None, user_id=None, order_key=None, order_asc=None): from .user import users async with context['dbpool'].acquire() as conn: if order_key is None: _ordering = vfolders.c.created_at else: _order_func = sa.asc if order_asc else sa.desc _ordering = _order_func(getattr(vfolders.c, order_key)) j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) query = ( sa.select([vfolders]) .select_from(j) .order_by(_ordering) .limit(limit) .offset(offset) ) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(vfolders.c.group == group_id) if user_id is not None: query = query.where(vfolders.c.user == user_id) return [cls.from_row(context, r) async for r in conn.execute(query)] @classmethod async def batch_load_by_user(cls, context, user_uuids, *, domain_name=None, group_id=None): from .user import users async with context['dbpool'].acquire() as conn: # TODO: num_attached count group-by j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) query = ( sa.select([vfolders]) .select_from(j) .where(vfolders.c.user.in_(user_uuids)) .order_by(sa.desc(vfolders.c.created_at)) ) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(vfolders.c.group == group_id) return await batch_multiresult( context, conn, query, cls, user_uuids, lambda row: row['user'] )
class KeyPair(graphene.ObjectType): user_id = graphene.String() access_key = graphene.String() secret_key = graphene.String() is_active = graphene.Boolean() is_admin = graphene.Boolean() resource_policy = graphene.String() created_at = GQLDateTime() last_used = GQLDateTime() concurrency_used = graphene.Int() rate_limit = graphene.Int() num_queries = graphene.Int() user = graphene.UUID() vfolders = graphene.List('ai.backend.manager.models.VirtualFolder') compute_sessions = graphene.List( 'ai.backend.manager.models.ComputeSession', status=graphene.String(), ) # Deprecated concurrency_limit = graphene.Int( deprecation_reason='Moved to KeyPairResourcePolicy object as ' 'max_concurrent_sessions field.') @classmethod def from_row(cls, row): if row is None: return None return cls( user_id=row['user_id'], access_key=row['access_key'], secret_key=row['secret_key'], is_active=row['is_active'], is_admin=row['is_admin'], resource_policy=row['resource_policy'], created_at=row['created_at'], last_used=row['last_used'], concurrency_limit=0, # moved to resource policy concurrency_used=row['concurrency_used'], rate_limit=row['rate_limit'], num_queries=row['num_queries'], user=row['user'], ) async def resolve_vfolders(self, info): manager = info.context['dlmgr'] loader = manager.get_loader('VirtualFolder') return await loader.load(self.access_key) async def resolve_compute_sessions(self, info, status=None): manager = info.context['dlmgr'] from . import KernelStatus # noqa: avoid circular imports if status is not None: status = KernelStatus[status] loader = manager.get_loader('ComputeSession', status=status) return await loader.load(self.access_key) @staticmethod async def load_all(context, *, domain_name=None, is_active=None): from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = sa.select([keypairs]).select_from(j) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if is_active is not None: query = query.where(keypairs.c.is_active == is_active) objs = [] async for row in conn.execute(query): o = KeyPair.from_row(row) objs.append(o) return objs @staticmethod async def batch_load_by_email(context, user_ids, *, domain_name=None, is_active=None): from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = (sa.select([keypairs]).select_from(j).where( keypairs.c.user_id.in_(user_ids))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if is_active is not None: query = query.where(keypairs.c.is_active == is_active) objs_per_key = OrderedDict() for k in user_ids: objs_per_key[k] = list() async for row in conn.execute(query): o = KeyPair.from_row(row) objs_per_key[row.user_id].append(o) return tuple(objs_per_key.values()) @staticmethod async def batch_load_by_ak(context, access_keys, *, domain_name=None): async with context['dbpool'].acquire() as conn: from .user import users j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) query = (sa.select([keypairs]).select_from(j).where( keypairs.c.access_key.in_(access_keys))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) objs_per_key = OrderedDict() # For each access key, there is only one keypair. # So we don't build lists in objs_per_key variable. for k in access_keys: objs_per_key[k] = None async for row in conn.execute(query): o = KeyPair.from_row(row) objs_per_key[row.access_key] = o return tuple(objs_per_key.values())
class ScalingGroup(graphene.ObjectType): name = graphene.String() description = graphene.String() is_active = graphene.Boolean() created_at = GQLDateTime() driver = graphene.String() driver_opts = graphene.JSONString() scheduler = graphene.String() scheduler_opts = graphene.JSONString() @classmethod def from_row(cls, row): if row is None: return None return cls( name=row['name'], description=row['description'], is_active=row['is_active'], created_at=row['created_at'], driver=row['driver'], driver_opts=row['driver_opts'], scheduler=row['scheduler'], scheduler_opts=row['scheduler_opts'], ) @staticmethod async def load_all(context, *, is_active=None): async with context['dbpool'].acquire() as conn: query = sa.select([scaling_groups]).select_from(scaling_groups) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) objs = [] async for row in conn.execute(query): o = ScalingGroup.from_row(row) objs.append(o) return objs @staticmethod async def load_by_domain(context, domain, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join( scaling_groups, sgroups_for_domains, scaling_groups.c.name == sgroups_for_domains.c.scaling_group) query = (sa.select([ scaling_groups ]).select_from(j).where(sgroups_for_domains.c.domain == domain)) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) objs = [] async for row in conn.execute(query): o = ScalingGroup.from_row(row) objs.append(o) return objs @staticmethod async def load_by_group(context, group, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join( scaling_groups, sgroups_for_groups, scaling_groups.c.name == sgroups_for_groups.c.scaling_group) query = (sa.select([ scaling_groups ]).select_from(j).where(sgroups_for_groups.c.group == group)) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) objs = [] async for row in conn.execute(query): o = ScalingGroup.from_row(row) objs.append(o) return objs @staticmethod async def load_by_keypair(context, access_key, *, is_active=None): async with context['dbpool'].acquire() as conn: j = sa.join( scaling_groups, sgroups_for_keypairs, scaling_groups.c.name == sgroups_for_keypairs.c.scaling_group) query = (sa.select([scaling_groups]).select_from(j).where( sgroups_for_keypairs.c.access_key == access_key)) if is_active is not None: query = query.where(scaling_groups.c.is_active == is_active) objs = [] async for row in conn.execute(query): o = ScalingGroup.from_row(row) objs.append(o) return objs @staticmethod async def batch_load_by_name(context, names): async with context['dbpool'].acquire() as conn: query = (sa.select([scaling_groups ]).select_from(scaling_groups).where( scaling_groups.c.name.in_(names))) objs_per_key = OrderedDict() for k in names: objs_per_key[k] = None async for row in conn.execute(query): o = ScalingGroup.from_row(row) objs_per_key[row.name] = o return tuple(objs_per_key.values())
class Group(graphene.ObjectType): id = graphene.UUID() name = graphene.String() description = graphene.String() is_active = graphene.Boolean() created_at = GQLDateTime() modified_at = GQLDateTime() domain_name = graphene.String() total_resource_slots = graphene.JSONString() allowed_vfolder_hosts = graphene.List(lambda: graphene.String) integration_id = graphene.String() scaling_groups = graphene.List(lambda: graphene.String) @classmethod def from_row(cls, context: Mapping[str, Any], row: RowProxy) -> Optional[Group]: if row is None: return None return cls( id=row['id'], name=row['name'], description=row['description'], is_active=row['is_active'], created_at=row['created_at'], modified_at=row['modified_at'], domain_name=row['domain_name'], total_resource_slots=row['total_resource_slots'].to_json(), allowed_vfolder_hosts=row['allowed_vfolder_hosts'], integration_id=row['integration_id'], ) async def resolve_scaling_groups(self, info): from .scaling_group import ScalingGroup sgroups = await ScalingGroup.load_by_group(info.context, self.id) return [sg.name for sg in sgroups] @classmethod async def load_all(cls, context, *, domain_name=None, is_active=None): async with context['dbpool'].acquire() as conn: query = (sa.select([groups]).select_from(groups)) if domain_name is not None: query = query.where(groups.c.domain_name == domain_name) if is_active is not None: query = query.where(groups.c.is_active == is_active) return [ cls.from_row(context, row) async for row in conn.execute(query) ] @classmethod async def batch_load_by_id(cls, context, group_ids, *, domain_name=None): async with context['dbpool'].acquire() as conn: query = (sa.select([groups]).select_from(groups).where( groups.c.id.in_(group_ids))) if domain_name is not None: query = query.where(groups.c.domain_name == domain_name) return await batch_result( context, conn, query, cls, group_ids, lambda row: row['id'], ) @classmethod async def get_groups_for_user(cls, context, user_id): async with context['dbpool'].acquire() as conn: j = sa.join(groups, association_groups_users, groups.c.id == association_groups_users.c.group_id) query = (sa.select([groups]).select_from(j).where( association_groups_users.c.user_id == user_id)) return [ cls.from_row(context, row) async for row in conn.execute(query) ]
class ComputeContainer(graphene.ObjectType): class Meta: interfaces = (Item, ) # identity role = graphene.String() hostname = graphene.String() session_id = graphene.UUID() # owner session # image image = graphene.String() registry = graphene.String() # status status = graphene.String() status_changed = GQLDateTime() status_info = graphene.String() created_at = GQLDateTime() terminated_at = GQLDateTime() starts_at = GQLDateTime() # resources agent = graphene.String() container_id = graphene.String() resource_opts = graphene.JSONString() occupied_slots = graphene.JSONString() live_stat = graphene.JSONString() last_stat = graphene.JSONString() @classmethod def parse_row(cls, context, row): assert row is not None from .user import UserRole is_superadmin = (context['user']['role'] == UserRole.SUPERADMIN) if is_superadmin: hide_agents = False else: hide_agents = context['config']['manager']['hide-agents'] return { # identity 'id': row['id'], 'role': row['role'], 'hostname': None, # TODO: implement 'session_id': row['id'], # master container's ID == session ID # image 'image': row['image'], 'registry': row['registry'], # status 'status': row['status'].name, 'status_changed': row['status_changed'], 'status_info': row['status_info'], 'created_at': row['created_at'], 'terminated_at': row['terminated_at'], 'starts_at': row['starts_at'], 'occupied_slots': row['occupied_slots'].to_json(), # resources 'agent': row['agent'] if not hide_agents else None, 'container_id': row['container_id'] if not hide_agents else None, 'resource_opts': row['resource_opts'], # statistics # live_stat is resolved by Graphene 'last_stat': row['last_stat'], } @classmethod def from_row(cls, context, row): if row is None: return None props = cls.parse_row(context, row) return cls(**props) async def resolve_live_stat(self, info: graphene.ResolveInfo): if not hasattr(self, 'status'): return None rs = info.context['redis_stat'] if self.status in LIVE_STATUS: raw_live_stat = await redis.execute_with_retries( lambda: rs.get(str(self.id), encoding=None)) if raw_live_stat is not None: live_stat = msgpack.unpackb(raw_live_stat) return live_stat return None else: return self.last_stat @classmethod async def load_count(cls, context, session_id, *, role=None, domain_name=None, group_id=None, access_key=None): async with context['dbpool'].acquire() as conn: query = ( sa.select([sa.func.count(kernels.c.id)]) .select_from(kernels) # TODO: use "owner session ID" when we implement multi-container session .where(kernels.c.id == session_id) .as_scalar() ) if role is not None: query = query.where(kernels.c.role == role) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if group_id is not None: query = query.where(kernels.c.group_id == group_id) if access_key is not None: query = query.where(kernels.c.access_key == access_key) result = await conn.execute(query) count = await result.fetchone() return count[0] @classmethod async def load_slice(cls, context, limit, offset, session_id, *, role=None, domain_name=None, group_id=None, access_key=None, order_key=None, order_asc=None): async with context['dbpool'].acquire() as conn: if order_key is None: _ordering = DEFAULT_SESSION_ORDERING else: _order_func = sa.asc if order_asc else sa.desc _ordering = [_order_func(getattr(kernels.c, order_key))] j = ( kernels .join(groups, groups.c.id == kernels.c.group_id) .join(users, users.c.uuid == kernels.c.user_uuid) ) query = ( sa.select([kernels, groups.c.name, users.c.email]) .select_from(j) # TODO: use "owner session ID" when we implement multi-container session .where(kernels.c.id == session_id) .order_by(*_ordering) .limit(limit) .offset(offset) ) if role is not None: query = query.where(kernels.c.role == role) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if group_id is not None: query = query.where(kernels.c.group_id == group_id) if access_key is not None: query = query.where(kernels.c.access_key == access_key) return [cls.from_row(context, r) async for r in conn.execute(query)] @classmethod async def batch_load_by_session(cls, context, session_ids): async with context['dbpool'].acquire() as conn: query = ( sa.select([kernels]) .select_from(kernels) # TODO: use "owner session ID" when we implement multi-container session .where(kernels.c.id.in_(session_ids)) ) return await batch_multiresult( context, conn, query, cls, session_ids, lambda row: row['id'], ) @classmethod async def batch_load_detail(cls, context, container_ids, *, domain_name=None, access_key=None): async with context['dbpool'].acquire() as conn: j = ( kernels .join(groups, groups.c.id == kernels.c.group_id) .join(users, users.c.uuid == kernels.c.user_uuid) ) query = ( sa.select([kernels, groups.c.name, users.c.email]) .select_from(j) .where( (kernels.c.id.in_(container_ids)) )) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if access_key is not None: query = query.where(kernels.c.access_key == access_key) return await batch_result( context, conn, query, cls, container_ids, lambda row: row['id'], )
class Agent(graphene.ObjectType): id = graphene.String() status = graphene.String() region = graphene.String() mem_slots = graphene.Int() cpu_slots = graphene.Int() gpu_slots = graphene.Int() used_mem_slots = graphene.Int() used_cpu_slots = graphene.Int() used_gpu_slots = graphene.Int() addr = graphene.String() first_contact = GQLDateTime() lost_at = GQLDateTime() computations = graphene.List('ai.backend.manager.models.Computation', status=graphene.String()) @classmethod def from_row(cls, row): if row is None: return None return cls( id=row.id, status=row.status, region=row.region, mem_slots=row.mem_slots, cpu_slots=row.cpu_slots, gpu_slots=row.gpu_slots, used_mem_slots=row.used_mem_slots, used_cpu_slots=row.used_cpu_slots, used_gpu_slots=row.used_gpu_slots, addr=row.addr, first_contact=row.first_contact, lost_at=row.lost_at, ) async def resolve_computations(self, info, status=None): ''' Retrieves all children worker sessions run by this agent. ''' manager = info.context['dlmgr'] loader = manager.get_loader('Computation.by_agent_id', status=status) return await loader.load(self.id) @staticmethod async def load_all(dbpool, status=None): async with dbpool.acquire() as conn: query = (sa.select('*').select_from(agents)) if status is not None: status = AgentStatus[status] query = query.where(agents.c.status == status) result = await conn.execute(query) rows = await result.fetchall() return [Agent.from_row(r) for r in rows] @staticmethod async def batch_load(dbpool, agent_ids): async with dbpool.acquire() as conn: query = (sa.select('*').select_from(agents).where( agents.c.id.in_(agent_ids)).order_by(agents.c.id)) objs_per_key = OrderedDict() for k in agent_ids: objs_per_key[k] = None async for row in conn.execute(query): o = Agent.from_row(row) objs_per_key[row.id] = o return tuple(objs_per_key.values())
class LegacyComputeSession(graphene.ObjectType): """ Represents a master session. """ class Meta: interfaces = (Item, ) tag = graphene.String() # Only for ComputeSession sess_id = graphene.String() # legacy sess_type = graphene.String() # legacy session_name = graphene.String() session_type = graphene.String() role = graphene.String() image = graphene.String() registry = graphene.String() domain_name = graphene.String() group_name = graphene.String() group_id = graphene.UUID() scaling_group = graphene.String() user_uuid = graphene.UUID() access_key = graphene.String() status = graphene.String() status_changed = GQLDateTime() status_info = graphene.String() created_at = GQLDateTime() terminated_at = GQLDateTime() startup_command = graphene.String() result = graphene.String() # hidable fields by configuration agent = graphene.String() container_id = graphene.String() service_ports = graphene.JSONString() occupied_slots = graphene.JSONString() occupied_shares = graphene.JSONString() mounts = graphene.List(lambda: graphene.List(lambda: graphene.String)) resource_opts = graphene.JSONString() num_queries = BigInt() live_stat = graphene.JSONString() last_stat = graphene.JSONString() user_email = graphene.String() # Legacy fields lang = graphene.String() mem_slot = graphene.Int() cpu_slot = graphene.Float() gpu_slot = graphene.Float() tpu_slot = graphene.Float() cpu_used = BigInt() cpu_using = graphene.Float() mem_max_bytes = BigInt() mem_cur_bytes = BigInt() net_rx_bytes = BigInt() net_tx_bytes = BigInt() io_read_bytes = BigInt() io_write_bytes = BigInt() io_max_scratch_size = BigInt() io_cur_scratch_size = BigInt() @classmethod async def _resolve_live_stat(cls, redis_stat, kernel_id): cstat = await redis.execute_with_retries( lambda: redis_stat.get(kernel_id, encoding=None)) if cstat is not None: cstat = msgpack.unpackb(cstat) return cstat async def resolve_live_stat(self, info: graphene.ResolveInfo): if not hasattr(self, 'status'): return None rs = info.context['redis_stat'] if self.status not in LIVE_STATUS: return self.last_stat else: return await type(self)._resolve_live_stat(rs, str(self.id)) async def _resolve_legacy_metric( self, info: graphene.ResolveInfo, metric_key, metric_field, convert_type, ): if not hasattr(self, 'status'): return None rs = info.context['redis_stat'] if self.status not in LIVE_STATUS: if self.last_stat is None: return convert_type(0) metric = self.last_stat.get(metric_key) if metric is None: return convert_type(0) value = metric.get(metric_field) if value is None: return convert_type(0) return convert_type(value) else: kstat = await type(self)._resolve_live_stat(rs, str(self.id)) if kstat is None: return convert_type(0) metric = kstat.get(metric_key) if metric is None: return convert_type(0) value = metric.get(metric_field) if value is None: return convert_type(0) return convert_type(value) async def resolve_cpu_used(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'cpu_used', 'current', float) async def resolve_cpu_using(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'cpu_util', 'pct', float) async def resolve_mem_max_bytes(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'mem', 'stats.max', int) async def resolve_mem_cur_bytes(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'mem', 'current', int) async def resolve_net_rx_bytes(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'net_rx', 'stats.rate', int) async def resolve_net_tx_bytes(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'net_tx', 'stats.rate', int) async def resolve_io_read_bytes(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'io_read', 'current', int) async def resolve_io_write_bytes(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'io_write', 'current', int) async def resolve_io_max_scratch_size(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'io_scratch_size', 'stats.max', int) async def resolve_io_cur_scratch_size(self, info: graphene.ResolveInfo): return await self._resolve_legacy_metric(info, 'io_scratch_size', 'current', int) @classmethod def parse_row(cls, context, row): assert row is not None from .user import UserRole mega = 2 ** 20 is_superadmin = (context['user']['role'] == UserRole.SUPERADMIN) if is_superadmin: hide_agents = False else: hide_agents = context['config']['manager']['hide-agents'] return { 'sess_id': row['sess_id'], # legacy, will be deprecated 'sess_type': row['sess_type'].name, # legacy, will be deprecated 'session_name': row['sess_id'], 'session_type': row['sess_type'].name, 'id': row['id'], # legacy, will be replaced with session UUID 'role': row['role'], 'tag': row['tag'], 'image': row['image'], 'registry': row['registry'], 'domain_name': row['domain_name'], 'group_name': row['name'], # group.name (group is omitted since use_labels=True is not used) 'group_id': row['group_id'], 'scaling_group': row['scaling_group'], 'user_uuid': row['user_uuid'], 'access_key': row['access_key'], 'status': row['status'].name, 'status_changed': row['status_changed'], 'status_info': row['status_info'], 'created_at': row['created_at'], 'terminated_at': row['terminated_at'], 'startup_command': row['startup_command'], 'result': row['result'].name, 'service_ports': row['service_ports'], 'occupied_slots': row['occupied_slots'].to_json(), 'mounts': row['mounts'], 'resource_opts': row['resource_opts'], 'num_queries': row['num_queries'], # optionally hidden 'agent': row['agent'] if not hide_agents else None, 'container_id': row['container_id'] if not hide_agents else None, # live_stat is resolved by Graphene 'last_stat': row['last_stat'], 'user_email': row['email'], # Legacy fields # NOTE: currently graphene always uses resolve methods! 'cpu_used': 0, 'mem_max_bytes': 0, 'mem_cur_bytes': 0, 'net_rx_bytes': 0, 'net_tx_bytes': 0, 'io_read_bytes': 0, 'io_write_bytes': 0, 'io_max_scratch_size': 0, 'io_cur_scratch_size': 0, 'lang': row['image'], 'occupied_shares': row['occupied_shares'], 'mem_slot': BinarySize.from_str( row['occupied_slots'].get('mem', 0)) // mega, 'cpu_slot': float(row['occupied_slots'].get('cpu', 0)), 'gpu_slot': float(row['occupied_slots'].get('cuda.device', 0)), 'tpu_slot': float(row['occupied_slots'].get('tpu.device', 0)), } @classmethod def from_row(cls, context, row): if row is None: return None props = cls.parse_row(context, row) return cls(**props) @classmethod async def load_count(cls, context, *, domain_name=None, group_id=None, access_key=None, status=None): if isinstance(status, str): status_list = [KernelStatus[s] for s in status.split(',')] elif isinstance(status, KernelStatus): status_list = [status] async with context['dbpool'].acquire() as conn: query = ( sa.select([sa.func.count(kernels.c.sess_id)]) .select_from(kernels) .where(kernels.c.role == 'master') .as_scalar() ) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if group_id is not None: query = query.where(kernels.c.group_id == group_id) if access_key is not None: query = query.where(kernels.c.access_key == access_key) if status is not None: query = query.where(kernels.c.status.in_(status_list)) result = await conn.execute(query) count = await result.fetchone() return count[0] @classmethod async def load_slice(cls, context, limit, offset, *, domain_name=None, group_id=None, access_key=None, status=None, order_key=None, order_asc=None): if isinstance(status, str): status_list = [KernelStatus[s] for s in status.split(',')] elif isinstance(status, KernelStatus): status_list = [status] async with context['dbpool'].acquire() as conn: if order_key is None: _ordering = DEFAULT_SESSION_ORDERING else: _order_func = sa.asc if order_asc else sa.desc _ordering = [_order_func(getattr(kernels.c, order_key))] j = (kernels.join(groups, groups.c.id == kernels.c.group_id) .join(users, users.c.uuid == kernels.c.user_uuid)) query = ( sa.select([kernels, groups.c.name, users.c.email]) .select_from(j) .where(kernels.c.role == 'master') .order_by(*_ordering) .limit(limit) .offset(offset) ) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if group_id is not None: query = query.where(kernels.c.group_id == group_id) if access_key is not None: query = query.where(kernels.c.access_key == access_key) if status is not None: query = query.where(kernels.c.status.in_(status_list)) return [cls.from_row(context, r) async for r in conn.execute(query)] @classmethod async def batch_load(cls, context, access_keys, *, domain_name=None, group_id=None, status=None): async with context['dbpool'].acquire() as conn: j = (kernels.join(groups, groups.c.id == kernels.c.group_id) .join(users, users.c.uuid == kernels.c.user_uuid)) query = ( sa.select([kernels, groups.c.name, users.c.email]) .select_from(j) .where( (kernels.c.access_key.in_(access_keys)) & (kernels.c.role == 'master') ) .order_by( sa.desc(sa.func.greatest( kernels.c.created_at, kernels.c.terminated_at, kernels.c.status_changed, )) ) .limit(100)) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if group_id is not None: query = query.where(kernels.c.group_id == group_id) if status is not None: query = query.where(kernels.c.status == status) return await batch_result( context, conn, query, cls, access_keys, lambda row: row['access_key'], ) @classmethod async def batch_load_detail(cls, context, sess_ids, *, domain_name=None, access_key=None, status=None): async with context['dbpool'].acquire() as conn: status_list = [] if isinstance(status, str): status_list = [KernelStatus[s] for s in status.split(',')] elif isinstance(status, KernelStatus): status_list = [status] elif status is None: status_list = [KernelStatus['RUNNING']] j = (kernels.join(groups, groups.c.id == kernels.c.group_id) .join(users, users.c.uuid == kernels.c.user_uuid)) query = (sa.select([kernels, groups.c.name, users.c.email]) .select_from(j) .where((kernels.c.role == 'master') & (kernels.c.sess_id.in_(sess_ids)))) if domain_name is not None: query = query.where(kernels.c.domain_name == domain_name) if access_key is not None: query = query.where(kernels.c.access_key == access_key) if status_list: query = query.where(kernels.c.status.in_(status_list)) return await batch_multiresult( context, conn, query, cls, sess_ids, lambda row: row['sess_id'], )
class VirtualFolder(graphene.ObjectType): class Meta: interfaces = (Item, ) host = graphene.String() name = graphene.String() max_files = graphene.Int() max_size = graphene.Int() created_at = GQLDateTime() last_used = GQLDateTime() num_files = graphene.Int() cur_size = graphene.Int() # num_attached = graphene.Int() @classmethod def from_row(cls, row): if row is None: return None return cls( id=row['id'], host=row['host'], name=row['name'], max_files=row['max_files'], max_size=row['max_size'], # in KiB created_at=row['created_at'], last_used=row['last_used'], # num_attached=row['num_attached'], ) async def resolve_num_files(self, info): # TODO: measure on-the-fly return 0 async def resolve_cur_size(self, info): # TODO: measure on-the-fly return 0 @staticmethod async def load_count(context, *, domain_name=None, group_id=None, user_id=None): from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) query = (sa.select([sa.func.count(vfolders.c.id) ]).select_from(j).as_scalar()) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(vfolders.c.group == group_id) if user_id is not None: query = query.where(vfolders.c.user == user_id) result = await conn.execute(query) count = await result.fetchone() return count[0] @staticmethod async def load_slice(context, limit, offset, *, domain_name=None, group_id=None, user_id=None, order_key=None, order_asc=None): from .user import users async with context['dbpool'].acquire() as conn: if order_key is None: _ordering = vfolders.c.created_at else: _order_func = sa.asc if order_asc else sa.desc _ordering = _order_func(getattr(vfolders.c, order_key)) j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) query = (sa.select([ vfolders ]).select_from(j).order_by(_ordering).limit(limit).offset(offset)) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(vfolders.c.group == group_id) if user_id is not None: query = query.where(vfolders.c.user == user_id) result = await conn.execute(query) rows = await result.fetchall() return [VirtualFolder.from_row(context, r) for r in rows] @staticmethod async def load_all(context, *, domain_name=None, group_id=None, user_id=None): from .user import users async with context['dbpool'].acquire() as conn: j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) query = (sa.select([vfolders]).select_from(j).order_by( sa.desc(vfolders.c.created_at))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(vfolders.c.group == group_id) if user_id is not None: query = query.where(vfolders.c.user == user_id) objs = [] async for row in conn.execute(query): o = VirtualFolder.from_row(row) objs.append(o) return objs @staticmethod async def batch_load_by_user(context, user_uuids, *, domain_name=None, group_id=None): from .user import users async with context['dbpool'].acquire() as conn: # TODO: num_attached count group-by j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) query = (sa.select([vfolders]).select_from(j).where( vfolders.c.user.in_(user_uuids)).order_by( sa.desc(vfolders.c.created_at))) if domain_name is not None: query = query.where(users.c.domain_name == domain_name) if group_id is not None: query = query.where(vfolders.c.group == group_id) objs_per_key = OrderedDict() for u in user_uuids: objs_per_key[u] = list() async for row in conn.execute(query): o = VirtualFolder.from_row(row) objs_per_key[row.user].append(o) return tuple(objs_per_key.values())