Example #1
0
class Domain(graphene.ObjectType):
    name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    created_at = GQLDateTime()
    modified_at = GQLDateTime()
    total_resource_slots = graphene.JSONString()
    allowed_vfolder_hosts = graphene.List(lambda: graphene.String)
    allowed_docker_registries = graphene.List(lambda: graphene.String)
    integration_id = graphene.String()

    # Dynamic fields.
    scaling_groups = graphene.List(lambda: graphene.String)

    async def resolve_scaling_groups(self, info):
        sgroups = await ScalingGroup.load_by_domain(info.context, self.name)
        return [sg.name for sg in sgroups]

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        return cls(
            name=row['name'],
            description=row['description'],
            is_active=row['is_active'],
            created_at=row['created_at'],
            modified_at=row['modified_at'],
            total_resource_slots=row['total_resource_slots'].to_json(),
            allowed_vfolder_hosts=row['allowed_vfolder_hosts'],
            allowed_docker_registries=row['allowed_docker_registries'],
            integration_id=row['integration_id'],
        )

    @classmethod
    async def load_all(cls, context, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            query = sa.select([domains]).select_from(domains)
            if is_active is not None:
                query = query.where(domains.c.is_active == is_active)
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]

    @classmethod
    async def batch_load_by_name(cls, context, names=None, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([domains]).select_from(domains).where(
                domains.c.name.in_(names)))
            return await batch_result(
                context,
                conn,
                query,
                cls,
                names,
                lambda row: row['name'],
            )
Example #2
0
class VirtualFolder(graphene.ObjectType):
    id = graphene.UUID()
    host = graphene.String()
    name = graphene.String()
    max_files = graphene.Int()
    max_size = graphene.Int()
    created_at = GQLDateTime()
    last_used = GQLDateTime()

    num_files = graphene.Int()
    cur_size = graphene.Int()
    # num_attached = graphene.Int()

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            id=row.id,
            host=row.host,
            name=row.name,
            max_files=row.max_files,
            max_size=row.max_size,    # in KiB
            created_at=row.created_at,
            last_used=row.last_used,
            num_attached=row.num_attached,
        )

    async def resolve_num_files(self, info):
        # TODO: measure on-the-fly
        return 0

    async def resolve_cur_size(self, info):
        # TODO: measure on-the-fly
        return 0

    @staticmethod
    async def batch_load(dbpool, access_keys):
        async with dbpool.acquire() as conn:
            # TODO: num_attached count group-by
            query = (sa.select('*')
                       .select_from(vfolders)
                       .where(vfolders.c.belongs_to.in_(access_keys))
                       .order_by(sa.desc(vfolders.c.created_at)))
            objs_per_key = OrderedDict()
            for k in access_keys:
                objs_per_key[k] = list()
            async for row in conn.execute(query):
                o = VirtualFolder.from_row(row)
                objs_per_key[row.belongs_to].append(o)
        return tuple(objs_per_key.values())
Example #3
0
class Agent(graphene.ObjectType):

    class Meta:
        interfaces = (Item, )

    status = graphene.String()
    status_changed = GQLDateTime()
    region = graphene.String()
    scaling_group = graphene.String()
    schedulable = graphene.Boolean()
    available_slots = graphene.JSONString()
    occupied_slots = graphene.JSONString()
    addr = graphene.String()
    first_contact = GQLDateTime()
    lost_at = GQLDateTime()
    live_stat = graphene.JSONString()
    version = graphene.String()
    compute_plugins = graphene.JSONString()

    # Legacy fields
    mem_slots = graphene.Int()
    cpu_slots = graphene.Float()
    gpu_slots = graphene.Float()
    tpu_slots = graphene.Float()
    used_mem_slots = graphene.Int()
    used_cpu_slots = graphene.Float()
    used_gpu_slots = graphene.Float()
    used_tpu_slots = graphene.Float()
    cpu_cur_pct = graphene.Float()
    mem_cur_bytes = graphene.Float()

    compute_containers = graphene.List(
        'ai.backend.manager.models.ComputeContainer',
        status=graphene.String())

    @classmethod
    def from_row(
        cls,
        context: Mapping[str, Any],
        row: RowProxy,
    ) -> Agent:
        mega = 2 ** 20
        return cls(
            id=row['id'],
            status=row['status'].name,
            status_changed=row['status_changed'],
            region=row['region'],
            scaling_group=row['scaling_group'],
            schedulable=row['schedulable'],
            available_slots=row['available_slots'].to_json(),
            occupied_slots=row['occupied_slots'].to_json(),
            addr=row['addr'],
            first_contact=row['first_contact'],
            lost_at=row['lost_at'],
            version=row['version'],
            compute_plugins=row['compute_plugins'],
            # legacy fields
            mem_slots=BinarySize.from_str(row['available_slots']['mem']) // mega,
            cpu_slots=row['available_slots']['cpu'],
            gpu_slots=row['available_slots'].get('cuda.device', 0),
            tpu_slots=row['available_slots'].get('tpu.device', 0),
            used_mem_slots=BinarySize.from_str(
                row['occupied_slots'].get('mem', 0)) // mega,
            used_cpu_slots=float(row['occupied_slots'].get('cpu', 0)),
            used_gpu_slots=float(row['occupied_slots'].get('cuda.device', 0)),
            used_tpu_slots=float(row['occupied_slots'].get('tpu.device', 0)),
        )

    async def resolve_live_stat(self, info):
        rs = info.context['redis_stat']
        live_stat = await redis.execute_with_retries(
            lambda: rs.get(str(self.id), encoding=None))
        if live_stat is not None:
            live_stat = msgpack.unpackb(live_stat)
        return live_stat

    async def resolve_cpu_cur_pct(self, info):
        rs = info.context['redis_stat']
        live_stat = await redis.execute_with_retries(
            lambda: rs.get(str(self.id), encoding=None))
        if live_stat is not None:
            live_stat = msgpack.unpackb(live_stat)
            try:
                return float(live_stat['node']['cpu_util']['pct'])
            except (KeyError, TypeError, ValueError):
                return 0.0
        return 0.0

    async def resolve_mem_cur_bytes(self, info):
        rs = info.context['redis_stat']
        live_stat = await redis.execute_with_retries(
            lambda: rs.get(str(self.id), encoding=None))
        if live_stat is not None:
            live_stat = msgpack.unpackb(live_stat)
            try:
                return int(live_stat['node']['mem']['current'])
            except (KeyError, TypeError, ValueError):
                return 0
        return 0

    async def resolve_computations(self, info, status=None):
        '''
        Retrieves all children worker sessions run by this agent.
        '''
        manager = info.context['dlmgr']
        loader = manager.get_loader('Computation.by_agent_id', status=status)
        return await loader.load(self.id)

    @staticmethod
    async def load_count(
        context, *,
        scaling_group=None,
        status=None,
    ) -> int:
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([sa.func.count(agents.c.id)])
                .select_from(agents)
                .as_scalar()
            )
            if scaling_group is not None:
                query = query.where(agents.c.scaling_group == scaling_group)
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @classmethod
    async def load_slice(
        cls, context, limit, offset, *,
        scaling_group=None,
        status=None,
        order_key=None,
        order_asc=True,
    ) -> Sequence[Agent]:
        async with context['dbpool'].acquire() as conn:
            # TODO: optimization for pagination using subquery, join
            if order_key is None:
                _ordering = agents.c.id
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = _order_func(getattr(agents.c, order_key))
            query = (
                sa.select([agents])
                .select_from(agents)
                .order_by(_ordering)
                .limit(limit)
                .offset(offset)
            )
            if scaling_group is not None:
                query = query.where(agents.c.scaling_group == scaling_group)
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]

    @classmethod
    async def load_all(
        cls, context, *,
        scaling_group=None,
        status=None,
    ) -> Sequence[Agent]:
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([agents])
                .select_from(agents)
            )
            if scaling_group is not None:
                query = query.where(agents.c.scaling_group == scaling_group)
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]

    @classmethod
    async def batch_load(
        cls, context, agent_ids, *,
        status=None,
    ) -> Sequence[Optional[Agent]]:
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([agents])
                       .select_from(agents)
                       .where(agents.c.id.in_(agent_ids))
                       .order_by(agents.c.id))
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            return await batch_result(
                context, conn, query, cls,
                agent_ids, lambda row: row['id'],
            )
Example #4
0
class User(graphene.ObjectType):

    class Meta:
        interfaces = (Item, )

    uuid = graphene.UUID()  # legacy
    username = graphene.String()
    email = graphene.String()
    password = graphene.String()
    need_password_change = graphene.Boolean()
    full_name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    status = graphene.String()
    status_info = graphene.String()
    created_at = GQLDateTime()
    modified_at = GQLDateTime()
    domain_name = graphene.String()
    role = graphene.String()

    groups = graphene.List(lambda: UserGroup)

    async def resolve_groups(
        self,
        info: graphene.ResolveInfo,
    ) -> Iterable[UserGroup]:
        manager = info.context['dlmgr']
        loader = manager.get_loader('UserGroup.by_user_id')
        return await loader.load(self.id)

    @classmethod
    def from_row(
        cls,
        context: Mapping[str, Any],
        row: RowProxy,
    ) -> User:
        return cls(
            id=row['uuid'],
            uuid=row['uuid'],
            username=row['username'],
            email=row['email'],
            need_password_change=row['need_password_change'],
            full_name=row['full_name'],
            description=row['description'],
            is_active=True if row['status'] == UserStatus.ACTIVE else False,  # legacy
            status=row['status'],
            status_info=row['status_info'],
            created_at=row['created_at'],
            modified_at=row['modified_at'],
            domain_name=row['domain_name'],
            role=row['role'],
        )

    @classmethod
    async def load_all(
        cls, context, *,
        domain_name=None,
        group_id=None,
        is_active=None,
        status=None,
        limit=None,
    ) -> Sequence[User]:
        """
        Load user's information. Group names associated with the user are also returned.
        """
        async with context['dbpool'].acquire() as conn:
            if group_id is not None:
                from .group import association_groups_users as agus
                j = (users.join(agus, agus.c.user_id == users.c.uuid))
                query = (
                    sa.select([users])
                    .select_from(j)
                    .where(agus.c.group_id == group_id)
                )
            else:
                query = (
                    sa.select([users])
                    .select_from(users)
                )
            if context['user']['role'] != UserRole.SUPERADMIN:
                query = query.where(users.c.domain_name == context['user']['domain_name'])
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if status is not None:
                query = query.where(users.c.status == UserStatus(status))
            elif is_active is not None:  # consider is_active field only if status is empty
                _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES
                query = query.where(users.c.status.in_(_statuses))
            if limit is not None:
                query = query.limit(limit)
            return [cls.from_row(context, row) async for row in conn.execute(query)]

    @staticmethod
    async def load_count(
        context, *,
        domain_name=None,
        group_id=None,
        is_active=None,
        status=None,
    ) -> int:
        async with context['dbpool'].acquire() as conn:
            if group_id is not None:
                from .group import association_groups_users as agus
                j = (users.join(agus, agus.c.user_id == users.c.uuid))
                query = (
                    sa.select([sa.func.count(users.c.uuid)])
                    .select_from(j)
                    .where(agus.c.group_id == group_id)
                    .as_scalar()
                )
            else:
                query = (
                    sa.select([sa.func.count(users.c.uuid)])
                    .select_from(users)
                    .as_scalar()
                )
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if status is not None:
                query = query.where(users.c.status == UserStatus(status))
            elif is_active is not None:  # consider is_active field only if status is empty
                _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES
                query = query.where(users.c.status.in_(_statuses))
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @classmethod
    async def load_slice(
        cls, context, limit, offset, *,
        domain_name=None,
        group_id=None,
        is_active=None,
        status=None,
        order_key=None,
        order_asc=True,
    ) -> Sequence[User]:
        async with context['dbpool'].acquire() as conn:
            if order_key is None:
                _ordering = sa.desc(users.c.created_at)
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = _order_func(getattr(users.c, order_key))
            if group_id is not None:
                from .group import association_groups_users as agus
                j = (users.join(agus, agus.c.user_id == users.c.uuid))
                query = (
                    sa.select([users])
                    .select_from(j)
                    .where(agus.c.group_id == group_id)
                    .order_by(_ordering)
                    .limit(limit)
                    .offset(offset)
                )
            else:
                query = (
                    sa.select([users])
                    .select_from(users)
                    .order_by(_ordering)
                    .limit(limit)
                    .offset(offset)
                )
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if status is not None:
                query = query.where(users.c.status == UserStatus(status))
            elif is_active is not None:  # consider is_active field only if status is empty
                _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES
                query = query.where(users.c.status.in_(_statuses))
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]

    @classmethod
    async def batch_load_by_email(
        cls, context, emails=None, *,
        domain_name=None,
        is_active=None,
        status=None,
    ) -> Sequence[Optional[User]]:
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([users])
                .select_from(users)
                .where(users.c.email.in_(emails))
            )
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if status is not None:
                query = query.where(users.c.status == UserStatus(status))
            elif is_active is not None:  # consider is_active field only if status is empty
                _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES
                query = query.where(users.c.status.in_(_statuses))
            return await batch_result(
                context, conn, query, cls,
                emails, lambda row: row['email'],
            )

    @classmethod
    async def batch_load_by_uuid(
        cls, context, user_ids=None, *,
        domain_name=None,
        is_active=None,
        status=None,
    ) -> Sequence[Optional[User]]:
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([users])
                .select_from(users)
                .where(users.c.uuid.in_(user_ids))
            )
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if status is not None:
                query = query.where(users.c.status == UserStatus(status))
            elif is_active is not None:  # consider is_active field only if status is empty
                _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES
                query = query.where(users.c.status.in_(_statuses))
            return await batch_result(
                context, conn, query, cls,
                user_ids, lambda row: row['uuid'],
            )
Example #5
0
class Agent(graphene.ObjectType):
    class Meta:
        interfaces = (Item, )

    id = graphene.ID()
    status = graphene.String()
    region = graphene.String()
    scaling_group = graphene.String()
    available_slots = graphene.JSONString()
    occupied_slots = graphene.JSONString()
    addr = graphene.String()
    first_contact = GQLDateTime()
    lost_at = GQLDateTime()
    live_stat = graphene.JSONString()
    version = graphene.String()
    compute_plugins = graphene.JSONString()

    # Legacy fields
    mem_slots = graphene.Int()
    cpu_slots = graphene.Float()
    gpu_slots = graphene.Float()
    tpu_slots = graphene.Float()
    used_mem_slots = graphene.Int()
    used_cpu_slots = graphene.Float()
    used_gpu_slots = graphene.Float()
    used_tpu_slots = graphene.Float()
    cpu_cur_pct = graphene.Float()
    mem_cur_bytes = graphene.Float()

    computations = graphene.List('ai.backend.manager.models.Computation',
                                 status=graphene.String())

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        mega = 2**20
        return cls(
            id=row['id'],
            status=row['status'].name,
            region=row['region'],
            scaling_group=row['scaling_group'],
            available_slots=row['available_slots'].to_json(),
            occupied_slots=row['occupied_slots'].to_json(),
            addr=row['addr'],
            first_contact=row['first_contact'],
            lost_at=row['lost_at'],
            version=row['version'],
            compute_plugins=row['compute_plugins'],
            # legacy fields
            mem_slots=BinarySize.from_str(row['available_slots']['mem']) //
            mega,
            cpu_slots=row['available_slots']['cpu'],
            gpu_slots=row['available_slots'].get('cuda.device', 0),
            tpu_slots=row['available_slots'].get('tpu.device', 0),
            used_mem_slots=BinarySize.from_str(row['occupied_slots'].get(
                'mem', 0)) // mega,
            used_cpu_slots=float(row['occupied_slots'].get('cpu', 0)),
            used_gpu_slots=float(row['occupied_slots'].get('cuda.device', 0)),
            used_tpu_slots=float(row['occupied_slots'].get('tpu.device', 0)),
        )

    async def resolve_live_stat(self, info):
        rs = info.context['redis_stat']
        live_stat = await rs.get(str(self.id), encoding=None)
        if live_stat is not None:
            live_stat = msgpack.unpackb(live_stat)
        return live_stat

    async def resolve_cpu_cur_pct(self, info):
        rs = info.context['redis_stat']
        live_stat = await rs.get(str(self.id), encoding=None)
        if live_stat is not None:
            live_stat = msgpack.unpackb(live_stat)
        return float(live_stat['node']['cpu_util']['pct'])

    async def resolve_mem_cur_bytes(self, info):
        rs = info.context['redis_stat']
        live_stat = await rs.get(str(self.id), encoding=None)
        if live_stat is not None:
            live_stat = msgpack.unpackb(live_stat)
        return float(live_stat['node']['mem']['current'])

    async def resolve_computations(self, info, status=None):
        '''
        Retrieves all children worker sessions run by this agent.
        '''
        manager = info.context['dlmgr']
        loader = manager.get_loader('Computation.by_agent_id', status=status)
        return await loader.load(self.id)

    @staticmethod
    async def load_count(context, *, scaling_group=None, status=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([sa.func.count(agents.c.id)
                                ]).select_from(agents).as_scalar())
            if scaling_group is not None:
                query = query.where(agents.c.scaling_group == scaling_group)
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @staticmethod
    async def load_slice(context,
                         limit,
                         offset,
                         *,
                         scaling_group=None,
                         status=None,
                         order_key=None,
                         order_asc=True):
        async with context['dbpool'].acquire() as conn:
            # TODO: optimization for pagination using subquery, join
            if order_key is None:
                _ordering = agents.c.id
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = _order_func(getattr(agents.c, order_key))
            query = (sa.select([agents]).select_from(agents).order_by(
                _ordering).limit(limit).offset(offset))
            if scaling_group is not None:
                query = query.where(agents.c.scaling_group == scaling_group)
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            result = await conn.execute(query)
            rows = await result.fetchall()
            _agents = []
            for r in rows:
                _agent = Agent.from_row(context, r)
                _agents.append(_agent)
            return _agents

    @staticmethod
    async def load_all(context, *, scaling_group=None, status=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([agents]).select_from(agents))
            if scaling_group is not None:
                query = query.where(agents.c.scaling_group == scaling_group)
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            result = await conn.execute(query)
            rows = await result.fetchall()
            _agents = []
            for r in rows:
                _agent = Agent.from_row(context, r)
                _agents.append(_agent)
            return _agents

    @staticmethod
    async def batch_load(context, agent_ids, *, status=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([agents]).select_from(agents).where(
                agents.c.id.in_(agent_ids)).order_by(agents.c.id))
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            objs_per_key = OrderedDict()
            for k in agent_ids:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = Agent.from_row(context, row)
                objs_per_key[row.id] = o
        return tuple(objs_per_key.values())
Example #6
0
class ComputeSession(graphene.ObjectType):
    class Meta:
        interfaces = (Item, )

    # identity
    tag = graphene.String()
    name = graphene.String()
    type = graphene.String()

    # image
    image = graphene.String()     # image for the master
    registry = graphene.String()  # image registry for the master
    cluster_template = graphene.String()

    # ownership
    domain_name = graphene.String()
    group_name = graphene.String()
    group_id = graphene.UUID()
    user_email = graphene.String()
    user_id = graphene.UUID()
    access_key = graphene.String()
    created_user_email = graphene.String()
    created_user_id = graphene.UUID()

    # status
    status = graphene.String()
    status_changed = GQLDateTime()
    status_info = graphene.String()
    created_at = GQLDateTime()
    terminated_at = GQLDateTime()
    starts_at = GQLDateTime()
    startup_command = graphene.String()
    result = graphene.String()

    # resources
    resource_opts = graphene.JSONString()
    scaling_group = graphene.String()
    service_ports = graphene.JSONString()
    mounts = graphene.List(lambda: graphene.String)
    occupied_slots = graphene.JSONString()

    # statistics
    num_queries = BigInt()

    # owned containers (aka kernels)
    containers = graphene.List(lambda: ComputeContainer)

    # relations
    dependencies = graphene.List(lambda: ComputeSession)

    @classmethod
    def parse_row(cls, context, row):
        assert row is not None
        return {
            # identity
            'id': row['id'],
            'tag': row['tag'],
            'name': row['sess_id'],
            'type': row['sess_type'].name,

            # image
            'image': row['image'],
            'registry': row['registry'],
            'cluster_template': None,  # TODO: implement

            # ownership
            'domain_name': row['domain_name'],
            'group_name': row['name'],  # group.name (group is omitted since use_labels=True is not used)
            'group_id': row['group_id'],
            'user_email': row['email'],
            'user_id': row['user_uuid'],
            'access_key': row['access_key'],
            'created_user_email': None,  # TODO: implement
            'created_user_id': None,     # TODO: implement

            # status
            'status': row['status'].name,
            'status_changed': row['status_changed'],
            'status_info': row['status_info'],
            'created_at': row['created_at'],
            'terminated_at': row['terminated_at'],
            'starts_at': row['starts_at'],
            'startup_command': row['startup_command'],
            'result': row['result'].name,

            # resources
            'resource_opts': row['resource_opts'],
            'scaling_group': row['scaling_group'],
            'service_ports': row['service_ports'],
            'mounts': row['mounts'],
            'occupied_slots': row['occupied_slots'].to_json(),  # TODO: sum of owned containers

            # statistics
            'num_queries': row['num_queries'],
        }

    @classmethod
    def from_row(cls, context: Mapping[str, Any], row: RowProxy) -> Optional[ComputeSession]:
        if row is None:
            return None
        props = cls.parse_row(context, row)
        return cls(**props)

    async def resolve_containers(
        self,
        info: graphene.ResolveInfo,
    ) -> Iterable[ComputeContainer]:
        manager = info.context['dlmgr']
        loader = manager.get_loader('ComputeContainer.by_session')
        return await loader.load(self.id)

    async def resolve_dependencies(
        self,
        info: graphene.ResolveInfo,
    ) -> Iterable[ComputeSession]:
        manager = info.context['dlmgr']
        loader = manager.get_loader('ComputeSession.by_dependency')
        return await loader.load(self.id)

    @classmethod
    async def load_count(cls, context, *,
                         domain_name=None, group_id=None, access_key=None,
                         status=None):
        if isinstance(status, str):
            status_list = [KernelStatus[s] for s in status.split(',')]
        elif isinstance(status, KernelStatus):
            status_list = [status]
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([sa.func.count(kernels.c.id)])
                .select_from(kernels)
                .where(kernels.c.role == 'master')
                .as_scalar()
            )
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(kernels.c.group_id == group_id)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            if status is not None:
                query = query.where(kernels.c.status.in_(status_list))
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @classmethod
    async def load_slice(cls, context, limit, offset, *,
                         domain_name=None, group_id=None, access_key=None,
                         status=None,
                         order_key=None, order_asc=None):
        if isinstance(status, str):
            status_list = [KernelStatus[s] for s in status.split(',')]
        elif isinstance(status, KernelStatus):
            status_list = [status]
        async with context['dbpool'].acquire() as conn:
            if order_key is None:
                _ordering = DEFAULT_SESSION_ORDERING
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = [_order_func(getattr(kernels.c, order_key))]
            j = (
                kernels
                .join(groups, groups.c.id == kernels.c.group_id)
                .join(users, users.c.uuid == kernels.c.user_uuid)
            )
            query = (
                sa.select([kernels, groups.c.name, users.c.email])
                .select_from(j)
                .where(kernels.c.role == 'master')
                .order_by(*_ordering)
                .limit(limit)
                .offset(offset)
            )
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(kernels.c.group_id == group_id)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            if status is not None:
                query = query.where(kernels.c.status.in_(status_list))
            return [cls.from_row(context, r) async for r in conn.execute(query)]

    @classmethod
    async def batch_load_by_dependency(cls, context, session_ids):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                kernels, kernel_dependencies,
                kernels.c.id == kernel_dependencies.c.depends_on,
            )
            query = (
                sa.select([kernels])
                .select_from(j)
                .where(
                    (kernels.c.role == 'master') &
                    (kernel_dependencies.c.kernel_id.in_(session_ids))
                )
            )
            return await batch_multiresult(
                context, conn, query, cls,
                session_ids, lambda row: row['id'],
            )

    @classmethod
    async def batch_load_detail(cls, context, session_ids, *,
                                domain_name=None, access_key=None):
        async with context['dbpool'].acquire() as conn:
            j = (
                kernels
                .join(groups, groups.c.id == kernels.c.group_id)
                .join(users, users.c.uuid == kernels.c.user_uuid)
            )
            query = (
                sa.select([kernels, groups.c.name, users.c.email])
                .select_from(j)
                .where(
                    (kernels.c.role == 'master') &
                    (kernels.c.id.in_(session_ids))
                ))
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            return await batch_result(
                context, conn, query, cls,
                session_ids, lambda row: row['id'],
            )
Example #7
0
class KeyPair(graphene.ObjectType):
    user_id = graphene.String()
    access_key = graphene.String()
    secret_key = graphene.String()
    is_active = graphene.Boolean()
    is_admin = graphene.Boolean()
    resource_policy = graphene.String()
    created_at = GQLDateTime()
    last_used = GQLDateTime()
    concurrency_limit = graphene.Int()
    concurrency_used = graphene.Int()
    rate_limit = graphene.Int()
    num_queries = graphene.Int()

    vfolders = graphene.List('ai.backend.manager.models.VirtualFolder')
    compute_sessions = graphene.List(
        'ai.backend.manager.models.ComputeSession',
        status=graphene.String(),
    )

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            user_id=row['user_id'],
            access_key=row['access_key'],
            secret_key=row['secret_key'],
            is_active=row['is_active'],
            is_admin=row['is_admin'],
            resource_policy=row['resource_policy'],
            created_at=row['created_at'],
            last_used=row['last_used'],
            concurrency_limit=row['concurrency_limit'],
            concurrency_used=row['concurrency_used'],
            rate_limit=row['rate_limit'],
            num_queries=row['num_queries'],
        )

    async def resolve_vfolders(self, info):
        manager = info.context['dlmgr']
        loader = manager.get_loader('VirtualFolder')
        return await loader.load(self.access_key)

    async def resolve_compute_sessions(self, info, status=None):
        manager = info.context['dlmgr']
        from . import KernelStatus  # noqa: avoid circular imports
        if status is not None:
            status = KernelStatus[status]
        loader = manager.get_loader('ComputeSession', status=status)
        return await loader.load(self.access_key)

    @staticmethod
    async def batch_load_by_uid(dbpool, user_ids, *, is_active=None):
        async with dbpool.acquire() as conn:
            query = (sa.select('*').select_from(keypairs).where(
                keypairs.c.user_id.in_(user_ids)))
            if is_active is not None:
                query = query.where(keypairs.c.is_active == is_active)
            objs_per_key = OrderedDict()
            for k in user_ids:
                objs_per_key[k] = list()
            async for row in conn.execute(query):
                o = KeyPair.from_row(row)
                objs_per_key[row.user_id].append(o)
        return tuple(objs_per_key.values())

    @staticmethod
    async def batch_load_by_ak(dbpool, access_keys):
        async with dbpool.acquire() as conn:
            query = (sa.select('*').select_from(keypairs).where(
                keypairs.c.access_key.in_(access_keys)))
            objs_per_key = OrderedDict()
            # For each access key, there is only one keypair.
            # So we don't build lists in objs_per_key variable.
            for k in access_keys:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = KeyPair.from_row(row)
                objs_per_key[row.access_key] = o
        return tuple(objs_per_key.values())
Example #8
0
class SessionCommons:

    sess_id = graphene.String()
    id = graphene.UUID()
    role = graphene.String()
    lang = graphene.String()

    status = graphene.String()
    status_info = graphene.String()
    created_at = GQLDateTime()
    terminated_at = GQLDateTime()

    agent = graphene.String()
    container_id = graphene.String()

    mem_slot = graphene.Int()
    cpu_slot = graphene.Int()
    gpu_slot = graphene.Int()

    num_queries = graphene.Int()
    cpu_used = graphene.Int()
    mem_max_bytes = graphene.Int()
    mem_cur_bytes = graphene.Int()
    net_rx_bytes = graphene.Int()
    net_tx_bytes = graphene.Int()
    io_read_bytes = graphene.Int()
    io_write_bytes = graphene.Int()
    io_max_scratch_size = graphene.Int()
    io_cur_scratch_size = graphene.Int()

    async def resolve_cpu_used(self, info):
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.cpu_used)
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'cpu_used')
            return float(ret) if ret is not None else 0

    async def resolve_mem_max_bytes(self, info):
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.mem_max_bytes)
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'mem_max_bytes')
            return int(ret) if ret is not None else 0

    async def resolve_mem_cur_bytes(self, info):
        if self.status not in LIVE_STATUS:
            return 0
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'mem_cur_bytes')
            return int(ret) if ret is not None else 0

    async def resolve_net_rx_bytes(self, info):
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.net_rx_bytes)
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'net_rx_bytes')
            return int(ret) if ret is not None else 0

    async def resolve_net_tx_bytes(self, info):
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.net_tx_bytes)
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'net_tx_bytes')
            return int(ret) if ret is not None else 0

    async def resolve_io_read_bytes(self, info):
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.io_read_bytes)
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'io_read_bytes')
            return int(ret) if ret is not None else 0

    async def resolve_io_write_bytes(self, info):
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.io_write_bytes)
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'io_write_bytes')
            return int(ret) if ret is not None else 0

    async def resolve_io_max_scratch_size(self, info):
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.io_max_scratch_size)
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'io_max_scratch_size')
            return int(ret) if ret is not None else 0

    async def resolve_io_cur_scratch_size(self, info):
        if self.status not in LIVE_STATUS:
            return 0
        async with info.context['redis_stat_pool'].get() as rs:
            ret = await rs.hget(str(self.id), 'io_cur_scratch_size')
            return int(ret) if ret is not None else 0

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        props = {
            'sess_id': row.sess_id,
            'id': row.id,
            'role': row.role,
            'lang': row.lang,
            'status': row.status,
            'status_info': row.status_info,
            'created_at': row.created_at,
            'terminated_at': row.terminated_at,
            'agent': row.agent,
            'container_id': row.container_id,
            'mem_slot': row.mem_slot,
            'cpu_slot': row.cpu_slot,
            'gpu_slot': row.gpu_slot,
            'num_queries': row.num_queries,
            # live statistics
            # NOTE: currently graphene always uses resolve methods!
            'cpu_used': row.cpu_used,
            'mem_max_bytes': row.mem_max_bytes,
            'mem_cur_bytes': 0,
            'net_rx_bytes': row.net_rx_bytes,
            'net_tx_bytes': row.net_tx_bytes,
            'io_read_bytes': row.io_read_bytes,
            'io_write_bytes': row.io_write_bytes,
            'io_max_scratch_size': row.io_max_scratch_size,
            'io_cur_scratch_size': 0,
        }
        return cls(**props)
Example #9
0
class KeyPair(graphene.ObjectType):
    class Meta:
        interfaces = (Item, )

    user_id = graphene.String()
    access_key = graphene.String()
    secret_key = graphene.String()
    is_active = graphene.Boolean()
    is_admin = graphene.Boolean()
    resource_policy = graphene.String()
    created_at = GQLDateTime()
    last_used = GQLDateTime()
    concurrency_used = graphene.Int()
    rate_limit = graphene.Int()
    num_queries = graphene.Int()
    user = graphene.UUID()

    ssh_public_key = graphene.String()

    vfolders = graphene.List('ai.backend.manager.models.VirtualFolder')
    compute_sessions = graphene.List(
        'ai.backend.manager.models.ComputeSession',
        status=graphene.String(),
    )

    # Deprecated
    concurrency_limit = graphene.Int(
        deprecation_reason='Moved to KeyPairResourcePolicy object as '
        'max_concurrent_sessions field.')

    @classmethod
    def from_row(
        cls,
        context: Mapping[str, Any],
        row: RowProxy,
    ) -> KeyPair:
        return cls(
            id=row['access_key'],
            user_id=row['user_id'],
            access_key=row['access_key'],
            secret_key=row['secret_key'],
            is_active=row['is_active'],
            is_admin=row['is_admin'],
            resource_policy=row['resource_policy'],
            created_at=row['created_at'],
            last_used=row['last_used'],
            concurrency_limit=0,  # moved to resource policy
            concurrency_used=row['concurrency_used'],
            rate_limit=row['rate_limit'],
            num_queries=row['num_queries'],
            user=row['user'],
            ssh_public_key=row['ssh_public_key'],
        )

    async def resolve_vfolders(self, info):
        manager = info.context['dlmgr']
        loader = manager.get_loader('VirtualFolder')
        return await loader.load(self.access_key)

    async def resolve_compute_sessions(self, info, status=None):
        manager = info.context['dlmgr']
        from . import KernelStatus  # noqa: avoid circular imports
        if status is not None:
            status = KernelStatus[status]
        loader = manager.get_loader('ComputeSession', status=status)
        return await loader.load(self.access_key)

    @classmethod
    async def load_all(
        cls,
        context,
        *,
        domain_name=None,
        is_active=None,
        limit=None,
    ) -> Sequence[KeyPair]:
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = (sa.select([keypairs]).select_from(j))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if is_active is not None:
                query = query.where(keypairs.c.is_active == is_active)
            if limit is not None:
                query = query.limit(limit)
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]

    @staticmethod
    async def load_count(
        context,
        *,
        domain_name=None,
        email=None,
        is_active=None,
    ) -> int:
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = (sa.select([sa.func.count(keypairs.c.access_key)
                                ]).select_from(j).as_scalar())
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if email is not None:
                query = query.where(keypairs.c.user_id == email)
            if is_active is not None:
                query = query.where(keypairs.c.is_active == is_active)
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @classmethod
    async def load_slice(
        cls,
        context,
        limit,
        offset,
        *,
        domain_name=None,
        email=None,
        is_active=None,
        order_key=None,
        order_asc=True,
    ) -> Sequence[KeyPair]:
        from .user import users
        async with context['dbpool'].acquire() as conn:
            if order_key is None:
                _ordering = sa.desc(keypairs.c.created_at)
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = _order_func(getattr(keypairs.c, order_key))
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = (sa.select([
                keypairs
            ]).select_from(j).order_by(_ordering).limit(limit).offset(offset))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if email is not None:
                query = query.where(keypairs.c.user_id == email)
            if is_active is not None:
                query = query.where(keypairs.c.is_active == is_active)
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]

    @classmethod
    async def batch_load_by_email(
        cls,
        context,
        user_ids,
        *,
        domain_name=None,
        is_active=None,
    ) -> Sequence[Sequence[Optional[KeyPair]]]:
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = (sa.select([keypairs]).select_from(j).where(
                keypairs.c.user_id.in_(user_ids)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if is_active is not None:
                query = query.where(keypairs.c.is_active == is_active)
            return await batch_multiresult(
                context,
                conn,
                query,
                cls,
                user_ids,
                lambda row: row['user_id'],
            )

    @classmethod
    async def batch_load_by_ak(
        cls,
        context,
        access_keys,
        *,
        domain_name=None,
    ) -> Sequence[Optional[KeyPair]]:
        async with context['dbpool'].acquire() as conn:
            from .user import users
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = (sa.select([keypairs]).select_from(j).where(
                keypairs.c.access_key.in_(access_keys)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            return await batch_result(
                context,
                conn,
                query,
                cls,
                access_keys,
                lambda row: row['access_key'],
            )
Example #10
0
class User(graphene.ObjectType):
    uuid = graphene.UUID()
    username = graphene.String()
    email = graphene.String()
    password = graphene.String()
    need_password_change = graphene.Boolean()
    full_name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    created_at = GQLDateTime()
    domain_name = graphene.String()
    role = graphene.String()
    # Dynamic fields
    groups = graphene.List(lambda: UserGroup)

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        if 'id' in row and row.id is not None and 'name' in row and row.name is not None:
            groups = [UserGroup(id=row['id'], name=row['name'])]
        else:
            groups = None
        return cls(
            uuid=row['uuid'],
            username=row['username'],
            email=row['email'],
            need_password_change=row['need_password_change'],
            full_name=row['full_name'],
            description=row['description'],
            is_active=row['is_active'],
            created_at=row['created_at'],
            domain_name=row['domain_name'],
            role=row['role'],
            # Dynamic fields
            groups=groups,  # group information
        )

    @staticmethod
    async def load_all(context,
                       *,
                       domain_name=None,
                       group_id=None,
                       is_active=None):
        '''
        Load user's information. Group names associated with the user are also returned.
        '''
        async with context['dbpool'].acquire() as conn:
            from .group import groups, association_groups_users as agus
            j = (users.join(agus, agus.c.user_id == users.c.uuid,
                            isouter=True).join(groups,
                                               agus.c.group_id == groups.c.id,
                                               isouter=True))
            query = sa.select([users, groups.c.name,
                               groups.c.id]).select_from(j)
            if context['user']['role'] != UserRole.SUPERADMIN:
                query = query.where(
                    users.c.domain_name == context['user']['domain_name'])
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(groups.c.id == group_id)
            if is_active is not None:
                query = query.where(users.c.is_active == is_active)
            objs_per_key = OrderedDict()
            async for row in conn.execute(query):
                if row.email in objs_per_key:
                    # If same user is already saved, just append group information.
                    objs_per_key[row.email].groups.append(
                        UserGroup(id=row.id, name=row.name))
                    continue
                o = User.from_row(row)
                objs_per_key[row.email] = o
            objs = list(objs_per_key.values())
        return objs

    @staticmethod
    async def batch_load_by_email(context,
                                  emails=None,
                                  *,
                                  domain_name=None,
                                  is_active=None):
        async with context['dbpool'].acquire() as conn:
            from .group import groups, association_groups_users as agus
            j = (users.join(agus, agus.c.user_id == users.c.uuid,
                            isouter=True).join(groups,
                                               agus.c.group_id == groups.c.id,
                                               isouter=True))
            query = (sa.select([users, groups.c.name,
                                groups.c.id]).select_from(j).where(
                                    users.c.email.in_(emails)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            objs_per_key = OrderedDict()
            # For each email, there is only one user.
            # So we don't build lists in objs_per_key variable.
            for k in emails:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                key = row.email
                if objs_per_key[key] is not None:
                    objs_per_key[key].groups.append(
                        UserGroup(id=row.id, name=row.name))
                    continue
                o = User.from_row(row)
                objs_per_key[key] = o
        return tuple(objs_per_key.values())

    @staticmethod
    async def batch_load_by_uuid(context,
                                 user_ids=None,
                                 *,
                                 domain_name=None,
                                 is_active=None):
        async with context['dbpool'].acquire() as conn:
            from .group import groups, association_groups_users as agus
            j = (users.join(agus, agus.c.user_id == users.c.uuid,
                            isouter=True).join(groups,
                                               agus.c.group_id == groups.c.id,
                                               isouter=True))
            query = (sa.select([users, groups.c.name,
                                groups.c.id]).select_from(j).where(
                                    users.c.uuid.in_(user_ids)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            objs_per_key = OrderedDict()
            # For each uuid, there is only one user.
            # So we don't build lists in objs_per_key variable.
            for k in user_ids:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                key = str(row.uuid)
                if objs_per_key[key] is not None:
                    objs_per_key[key].groups.append(
                        UserGroup(id=row.id, name=row.name))
                    continue
                o = User.from_row(row)
                objs_per_key[key] = o
        return tuple(objs_per_key.values())
Example #11
0
class Group(graphene.ObjectType):
    id = graphene.UUID()
    name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    created_at = GQLDateTime()
    modified_at = GQLDateTime()
    domain_name = graphene.String()
    total_resource_slots = graphene.JSONString()
    allowed_vfolder_hosts = graphene.List(lambda: graphene.String)
    integration_id = graphene.String()

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            id=row['id'],
            name=row['name'],
            description=row['description'],
            is_active=row['is_active'],
            created_at=row['created_at'],
            modified_at=row['modified_at'],
            domain_name=row['domain_name'],
            total_resource_slots=row['total_resource_slots'].to_json(),
            allowed_vfolder_hosts=row['allowed_vfolder_hosts'],
            integration_id=row['integration_id'],
        )

    @staticmethod
    async def load_all(context, *, domain_name=None, is_active=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([groups]).select_from(groups))
            if domain_name is not None:
                query = query.where(groups.c.domain_name == domain_name)
            if is_active is not None:
                query = query.where(groups.c.is_active == is_active)
            objs = []
            async for row in conn.execute(query):
                o = Group.from_row(row)
                objs.append(o)
        return objs

    @staticmethod
    async def batch_load_by_id(context, ids, *, domain_name=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([groups]).select_from(groups).where(
                groups.c.id.in_(ids)))
            if domain_name is not None:
                query = query.where(groups.c.domain_name == domain_name)
            objs_per_key = OrderedDict()
            for k in ids:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = Group.from_row(row)
                objs_per_key[str(row.id)] = o
        return [*objs_per_key.values()]

    @staticmethod
    async def get_groups_for_user(context, user_id):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(groups, association_groups_users,
                        groups.c.id == association_groups_users.c.group_id)
            query = (sa.select([groups]).select_from(j).where(
                association_groups_users.c.user_id == user_id))
            objs = []
            async for row in conn.execute(query):
                o = Group.from_row(row)
                objs.append(o)
            return objs
Example #12
0
class Domain(graphene.ObjectType):
    name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    created_at = GQLDateTime()
    modified_at = GQLDateTime()
    total_resource_slots = graphene.JSONString()
    allowed_vfolder_hosts = graphene.List(lambda: graphene.String)
    allowed_docker_registries = graphene.List(lambda: graphene.String)
    integration_id = graphene.String()
    # Dynamic fields.
    scaling_groups = graphene.List(lambda: graphene.String)

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            name=row['name'],
            description=row['description'],
            is_active=row['is_active'],
            created_at=row['created_at'],
            modified_at=row['modified_at'],
            total_resource_slots=row['total_resource_slots'].to_json(),
            allowed_vfolder_hosts=row['allowed_vfolder_hosts'],
            allowed_docker_registries=row['allowed_docker_registries'],
            integration_id=row['integration_id'],
            # Dynamic fields.
            scaling_groups=[row.scaling_group] if 'scaling_group' in row else [],
        )

    @staticmethod
    async def load_all(context, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(domains, sgroups_for_domains,
                        domains.c.name == sgroups_for_domains.c.domain,
                        isouter=True)
            query = (sa.select([domains, sgroups_for_domains.c.scaling_group])
                       .select_from(j))
            if is_active is not None:
                query = query.where(domains.c.is_active == is_active)
            objs_per_key = OrderedDict()
            async for row in conn.execute(query):
                if row.name in objs_per_key:
                    # If same domain is already saved, just append sgroup information.
                    objs_per_key[row.name].scaling_groups.append(row.scaling_group)
                    continue
                o = Domain.from_row(row)
                objs_per_key[row.name] = o
            objs = list(objs_per_key.values())
        return objs

    @staticmethod
    async def batch_load_by_name(context, names=None, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(domains, sgroups_for_domains,
                        domains.c.name == sgroups_for_domains.c.domain,
                        isouter=True)
            query = (sa.select([domains, sgroups_for_domains.c.scaling_group])
                       .select_from(j)
                       .where(domains.c.name.in_(names)))
            objs_per_key = OrderedDict()
            # For each name, there is only one domain.
            # So we don't build lists in objs_per_key variable.
            for k in names:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                if objs_per_key[row.name] is not None:
                    objs_per_key[row.name].scaling_groups.append(row.scaling_group)
                    continue
                o = Domain.from_row(row)
                objs_per_key[row.name] = o
        return tuple(objs_per_key.values())
Example #13
0
class SessionCommons:
    sess_id = graphene.String()
    id = graphene.ID()
    role = graphene.String()
    image = graphene.String()
    registry = graphene.String()
    domain_name = graphene.String()
    group_name = graphene.String()
    group_id = graphene.UUID()
    scaling_group = graphene.String()
    user_uuid = graphene.UUID()
    access_key = graphene.String()

    status = graphene.String()
    status_info = graphene.String()
    created_at = GQLDateTime()
    terminated_at = GQLDateTime()

    # hidable fields by configuration
    agent = graphene.String()
    container_id = graphene.String()

    service_ports = graphene.JSONString()

    occupied_slots = graphene.JSONString()
    occupied_shares = graphene.JSONString()
    mounts = graphene.List(lambda: graphene.List(lambda: graphene.String))

    num_queries = BigInt()
    live_stat = graphene.JSONString()
    last_stat = graphene.JSONString()

    user_email = graphene.String()

    # Legacy fields
    lang = graphene.String()
    mem_slot = graphene.Int()
    cpu_slot = graphene.Float()
    gpu_slot = graphene.Float()
    tpu_slot = graphene.Float()
    cpu_used = BigInt()
    cpu_using = graphene.Float()
    mem_max_bytes = BigInt()
    mem_cur_bytes = BigInt()
    net_rx_bytes = BigInt()
    net_tx_bytes = BigInt()
    io_read_bytes = BigInt()
    io_write_bytes = BigInt()
    io_max_scratch_size = BigInt()
    io_cur_scratch_size = BigInt()

    @classmethod
    async def _resolve_live_stat(cls, redis_stat, kernel_id):
        cstat = await redis_stat.get(kernel_id, encoding=None)
        if cstat is not None:
            cstat = msgpack.unpackb(cstat)
        return cstat

    async def resolve_live_stat(self, info):
        rs = info.context['redis_stat']
        return await type(self)._resolve_live_stat(rs, str(self.id))

    async def _resolve_legacy_metric(self, info, metric_key, metric_field, convert_type):
        if not hasattr(self, 'status'):
            return None
        rs = info.context['redis_stat']
        if self.status not in LIVE_STATUS:
            if self.last_stat is None:
                return convert_type(0)
            metric = self.last_stat.get(metric_key)
            if metric is None:
                return convert_type(0)
            value = metric.get(metric_field)
            if value is None:
                return convert_type(0)
            return convert_type(value)
        else:
            kstat = await type(self)._resolve_live_stat(rs, str(self.id))
            if kstat is None:
                return convert_type(0)
            metric = kstat.get(metric_key)
            if metric is None:
                return convert_type(0)
            value = metric.get(metric_field)
            if value is None:
                return convert_type(0)
            return convert_type(value)

    async def resolve_cpu_used(self, info):
        return await self._resolve_legacy_metric(info, 'cpu_used', 'current', float)

    async def resolve_cpu_using(self, info):
        return await self._resolve_legacy_metric(info, 'cpu_util', 'pct', float)

    async def resolve_mem_max_bytes(self, info):
        return await self._resolve_legacy_metric(info, 'mem', 'stats.max', int)

    async def resolve_mem_cur_bytes(self, info):
        return await self._resolve_legacy_metric(info, 'mem', 'current', int)

    async def resolve_net_rx_bytes(self, info):
        return await self._resolve_legacy_metric(info, 'net_rx', 'stats.rate', int)

    async def resolve_net_tx_bytes(self, info):
        return await self._resolve_legacy_metric(info, 'net_tx', 'stats.rate', int)

    async def resolve_io_read_bytes(self, info):
        return await self._resolve_legacy_metric(info, 'io_read', 'current', int)

    async def resolve_io_write_bytes(self, info):
        return await self._resolve_legacy_metric(info, 'io_write', 'current', int)

    async def resolve_io_max_scratch_size(self, info):
        return await self._resolve_legacy_metric(info, 'io_scratch_size', 'stats.max', int)

    async def resolve_io_cur_scratch_size(self, info):
        return await self._resolve_legacy_metric(info, 'io_scratch_size', 'current', int)

    @classmethod
    def parse_row(cls, context, row):
        assert row is not None
        from .user import UserRole
        mega = 2 ** 20
        is_superadmin = (context['user']['role'] == UserRole.SUPERADMIN)
        if is_superadmin:
            hide_agents = False
        else:
            hide_agents = context['config']['manager']['hide-agents']
        return {
            'sess_id': row['sess_id'],
            'id': row['id'],
            'role': row['role'],
            'image': row['image'],
            'registry': row['registry'],
            'domain_name': row['domain_name'],
            'group_name': row['name'],  # group.name (group is omitted since use_labels=True is not used)
            'group_id': row['group_id'],
            'scaling_group': row['scaling_group'],
            'user_uuid': row['user_uuid'],
            'access_key': row['access_key'],
            'status': row['status'].name,
            'status_info': row['status_info'],
            'created_at': row['created_at'],
            'terminated_at': row['terminated_at'],
            'service_ports': row['service_ports'],
            'occupied_slots': row['occupied_slots'].to_json(),
            'occupied_shares': row['occupied_shares'],
            'mounts': row['mounts'],
            'num_queries': row['num_queries'],
            # optinally hidden
            'agent': row['agent'] if not hide_agents else None,
            'container_id': row['container_id'] if not hide_agents else None,
            # live_stat is resolved by Graphene
            'last_stat': row['last_stat'],
            'user_email': row['email'],
            # Legacy fields
            # NOTE: currently graphene always uses resolve methods!
            'cpu_used': 0,
            'mem_max_bytes': 0,
            'mem_cur_bytes': 0,
            'net_rx_bytes': 0,
            'net_tx_bytes': 0,
            'io_read_bytes': 0,
            'io_write_bytes': 0,
            'io_max_scratch_size': 0,
            'io_cur_scratch_size': 0,
            'lang': row['image'],
            'mem_slot': BinarySize.from_str(
                row['occupied_slots'].get('mem', 0)) // mega,
            'cpu_slot': float(row['occupied_slots'].get('cpu', 0)),
            'gpu_slot': float(row['occupied_slots'].get('cuda.device', 0)),
            'tpu_slot': float(row['occupied_slots'].get('tpu.device', 0)),
        }

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        props = cls.parse_row(context, row)
        return cls(**props)
class ScalingGroup(graphene.ObjectType):
    name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    created_at = GQLDateTime()
    driver = graphene.String()
    driver_opts = graphene.JSONString()
    scheduler = graphene.String()
    scheduler_opts = graphene.JSONString()

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        return cls(
            name=row['name'],
            description=row['description'],
            is_active=row['is_active'],
            created_at=row['created_at'],
            driver=row['driver'],
            driver_opts=row['driver_opts'],
            scheduler=row['scheduler'],
            scheduler_opts=row['scheduler_opts'],
        )

    @classmethod
    async def load_all(cls, context, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            query = sa.select([scaling_groups]).select_from(scaling_groups)
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            return [cls.from_row(context, row) async for row in conn.execute(query)]

    @classmethod
    async def load_by_domain(cls, context, domain, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                scaling_groups, sgroups_for_domains,
                scaling_groups.c.name == sgroups_for_domains.c.scaling_group)
            query = (
                sa.select([scaling_groups])
                .select_from(j)
                .where(sgroups_for_domains.c.domain == domain)
            )
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            return [cls.from_row(context, row) async for row in conn.execute(query)]

    @classmethod
    async def load_by_group(cls, context, group, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                scaling_groups, sgroups_for_groups,
                scaling_groups.c.name == sgroups_for_groups.c.scaling_group)
            query = (
                sa.select([scaling_groups])
                .select_from(j)
                .where(sgroups_for_groups.c.group == group)
            )
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            return [cls.from_row(context, row) async for row in conn.execute(query)]

    @classmethod
    async def load_by_keypair(cls, context, access_key, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                scaling_groups, sgroups_for_keypairs,
                scaling_groups.c.name == sgroups_for_keypairs.c.scaling_group)
            query = (
                sa.select([scaling_groups])
                .select_from(j)
                .where(sgroups_for_keypairs.c.access_key == access_key)
            )
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            return [cls.from_row(context, row) async for row in conn.execute(query)]

    @classmethod
    async def batch_load_by_name(cls, context, names):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([scaling_groups])
                       .select_from(scaling_groups)
                       .where(scaling_groups.c.name.in_(names)))
            return await batch_result(
                context, conn, query, cls,
                names, lambda row: row['name'],
            )
Example #15
0
class SessionCommons:

    sess_id = graphene.String()
    id = graphene.UUID()
    role = graphene.String()
    lang = graphene.String()

    status = graphene.String()
    status_info = graphene.String()
    created_at = GQLDateTime()
    terminated_at = GQLDateTime()

    agent = graphene.String()
    container_id = graphene.String()

    mem_slot = graphene.Int()
    cpu_slot = graphene.Float()
    gpu_slot = graphene.Float()

    num_queries = graphene.Int()
    cpu_used = graphene.Int()
    mem_max_bytes = graphene.Int()
    mem_cur_bytes = graphene.Int()
    net_rx_bytes = graphene.Int()
    net_tx_bytes = graphene.Int()
    io_read_bytes = graphene.Int()
    io_write_bytes = graphene.Int()
    io_max_scratch_size = graphene.Int()
    io_cur_scratch_size = graphene.Int()

    async def resolve_cpu_used(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.cpu_used)
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'cpu_used')
        return float(ret) if ret is not None else 0

    async def resolve_mem_max_bytes(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.mem_max_bytes)
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'mem_max_bytes')
        return int(ret) if ret is not None else 0

    async def resolve_mem_cur_bytes(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return 0
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'mem_cur_bytes')
        return int(ret) if ret is not None else 0

    async def resolve_net_rx_bytes(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.net_rx_bytes)
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'net_rx_bytes')
        return int(ret) if ret is not None else 0

    async def resolve_net_tx_bytes(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.net_tx_bytes)
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'net_tx_bytes')
        return int(ret) if ret is not None else 0

    async def resolve_io_read_bytes(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.io_read_bytes)
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'io_read_bytes')
        return int(ret) if ret is not None else 0

    async def resolve_io_write_bytes(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.io_write_bytes)
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'io_write_bytes')
        return int(ret) if ret is not None else 0

    async def resolve_io_max_scratch_size(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return zero_if_none(self.io_max_scratch_size)
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'io_max_scratch_size')
        return int(ret) if ret is not None else 0

    async def resolve_io_cur_scratch_size(self, info):
        if not hasattr(self, 'status'):
            return None
        if self.status not in LIVE_STATUS:
            return 0
        rs = info.context['redis_stat']
        ret = await rs.hget(str(self.id), 'io_cur_scratch_size')
        return int(ret) if ret is not None else 0

    @classmethod
    def parse_row(cls, row):
        assert row is not None
        return {
            'sess_id': row['sess_id'],
            'id': row['id'],
            'role': row['role'],
            'lang': row['lang'],
            'status': row['status'],
            'status_info': row['status_info'],
            'created_at': row['created_at'],
            'terminated_at': row['terminated_at'],
            'agent': row['agent'],
            'container_id': row['container_id'],
            'mem_slot': row['mem_slot'],
            'cpu_slot': row['cpu_slot'],
            'gpu_slot': row['gpu_slot'],
            'num_queries': row['num_queries'],
            # live statistics
            # NOTE: currently graphene always uses resolve methods!
            'cpu_used': row['cpu_used'],
            'mem_max_bytes': row['mem_max_bytes'],
            'mem_cur_bytes': 0,
            'net_rx_bytes': row['net_rx_bytes'],
            'net_tx_bytes': row['net_tx_bytes'],
            'io_read_bytes': row['io_read_bytes'],
            'io_write_bytes': row['io_write_bytes'],
            'io_max_scratch_size': row['io_max_scratch_size'],
            'io_cur_scratch_size': 0,
        }

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        props = cls.parse_row(row)
        return cls(**props)
Example #16
0
class KeyPairResourcePolicy(graphene.ObjectType):
    name = graphene.String()
    created_at = GQLDateTime()
    default_for_unspecified = graphene.String()
    total_resource_slots = graphene.JSONString()
    max_concurrent_sessions = graphene.Int()
    max_containers_per_session = graphene.Int()
    idle_timeout = BigInt()
    max_vfolder_count = graphene.Int()
    max_vfolder_size = BigInt()
    allowed_vfolder_hosts = graphene.List(lambda: graphene.String)

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        return cls(
            name=row['name'],
            created_at=row['created_at'],
            default_for_unspecified=row['default_for_unspecified'].name,
            total_resource_slots=row['total_resource_slots'].to_json(),
            max_concurrent_sessions=row['max_concurrent_sessions'],
            max_containers_per_session=row['max_containers_per_session'],
            idle_timeout=row['idle_timeout'],
            max_vfolder_count=row['max_vfolder_count'],
            max_vfolder_size=row['max_vfolder_size'],
            allowed_vfolder_hosts=row['allowed_vfolder_hosts'],
        )

    @classmethod
    async def load_all(cls, context):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([keypair_resource_policies
                                ]).select_from(keypair_resource_policies))
            result = await conn.execute(query)
            rows = await result.fetchall()
            return [cls.from_row(context, r) for r in rows]

    @classmethod
    async def load_all_user(cls, context, access_key):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([keypairs.c.user_id
                                ]).select_from(keypairs).where(
                                    keypairs.c.access_key == access_key))
            result = await conn.execute(query)
            row = await result.fetchone()
            user_id = row['user_id']
            j = sa.join(
                keypairs, keypair_resource_policies,
                keypairs.c.resource_policy == keypair_resource_policies.c.name)
            query = (sa.select([keypair_resource_policies
                                ]).select_from(j).where(
                                    (keypairs.c.user_id == user_id)))
            result = await conn.execute(query)
            rows = await result.fetchall()
            return [cls.from_row(context, r) for r in rows]

    @classmethod
    async def batch_load_by_name(cls, context, names):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([
                keypair_resource_policies
            ]).select_from(keypair_resource_policies).where(
                keypair_resource_policies.c.name.in_(names)).order_by(
                    keypair_resource_policies.c.name))
            objs_per_key = OrderedDict()
            for k in names:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = cls.from_row(context, row)
                objs_per_key[row.name] = o
        return tuple(objs_per_key.values())

    @classmethod
    async def batch_load_by_name_user(cls, context, names):
        async with context['dbpool'].acquire() as conn:
            access_key = context['access_key']
            j = sa.join(
                keypairs, keypair_resource_policies,
                keypairs.c.resource_policy == keypair_resource_policies.c.name)
            query = (sa.select([keypair_resource_policies]).select_from(
                j).where((keypair_resource_policies.c.name.in_(names))
                         & (keypairs.c.access_key == access_key)).order_by(
                             keypair_resource_policies.c.name))
            objs_per_key = OrderedDict()
            for k in names:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = cls.from_row(context, row)
                objs_per_key[row.name] = o
        return tuple(objs_per_key.values())

    @classmethod
    async def batch_load_by_ak(cls, context, access_keys):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                keypairs, keypair_resource_policies,
                keypairs.c.resource_policy == keypair_resource_policies.c.name)
            query = (sa.select(
                [keypair_resource_policies]).select_from(j).where(
                    (keypairs.c.access_key.in_(access_keys))).order_by(
                        keypair_resource_policies.c.name))
            objs_per_key = OrderedDict()
            async for row in conn.execute(query):
                o = cls.from_row(context, row)
                objs_per_key[row.name] = o
        return tuple(objs_per_key.values())
Example #17
0
class VirtualFolder(graphene.ObjectType):
    class Meta:
        interfaces = (Item, )

    host = graphene.String()
    name = graphene.String()
    user = graphene.UUID()       # User.id
    group = graphene.UUID()      # Group.id
    creator = graphene.String()  # User.email
    unmanaged_path = graphene.String()
    usage_mode = graphene.String()
    permission = graphene.String()
    ownership_type = graphene.String()
    max_files = graphene.Int()
    max_size = graphene.Int()
    created_at = GQLDateTime()
    last_used = GQLDateTime()

    num_files = graphene.Int()
    cur_size = BigInt()
    # num_attached = graphene.Int()
    cloneable = graphene.Boolean()

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        return cls(
            id=row['id'],
            host=row['host'],
            name=row['name'],
            user=row['user'],
            group=row['group'],
            creator=row['creator'],
            unmanaged_path=row['unmanaged_path'],
            usage_mode=row['usage_mode'],
            permission=row['permission'],
            ownership_type=row['ownership_type'],
            max_files=row['max_files'],
            max_size=row['max_size'],    # in KiB
            created_at=row['created_at'],
            last_used=row['last_used'],
            # num_attached=row['num_attached'],
            cloneable=row['cloneable'],
        )

    async def resolve_num_files(self, info):
        # TODO: measure on-the-fly
        return 0

    async def resolve_cur_size(self, info):
        # TODO: measure on-the-fly
        return 0

    @classmethod
    async def load_count(cls, context, *,
                         domain_name=None, group_id=None, user_id=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid)
            query = (
                sa.select([sa.func.count(vfolders.c.id)])
                .select_from(j)
                .as_scalar()
            )
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(vfolders.c.group == group_id)
            if user_id is not None:
                query = query.where(vfolders.c.user == user_id)
            result = await conn.execute(query)
            return await result.scalar()

    @classmethod
    async def load_slice(cls, context, limit, offset, *,
                         domain_name=None, group_id=None, user_id=None,
                         order_key=None, order_asc=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            if order_key is None:
                _ordering = vfolders.c.created_at
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = _order_func(getattr(vfolders.c, order_key))
            j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid)
            query = (
                sa.select([vfolders])
                .select_from(j)
                .order_by(_ordering)
                .limit(limit)
                .offset(offset)
            )
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(vfolders.c.group == group_id)
            if user_id is not None:
                query = query.where(vfolders.c.user == user_id)
            return [cls.from_row(context, r) async for r in conn.execute(query)]

    @classmethod
    async def batch_load_by_user(cls, context, user_uuids, *,
                                 domain_name=None, group_id=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            # TODO: num_attached count group-by
            j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid)
            query = (
                sa.select([vfolders])
                .select_from(j)
                .where(vfolders.c.user.in_(user_uuids))
                .order_by(sa.desc(vfolders.c.created_at))
            )
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(vfolders.c.group == group_id)
            return await batch_multiresult(
                context, conn, query, cls,
                user_uuids, lambda row: row['user']
            )
Example #18
0
class KeyPair(graphene.ObjectType):
    user_id = graphene.String()
    access_key = graphene.String()
    secret_key = graphene.String()
    is_active = graphene.Boolean()
    is_admin = graphene.Boolean()
    resource_policy = graphene.String()
    created_at = GQLDateTime()
    last_used = GQLDateTime()
    concurrency_used = graphene.Int()
    rate_limit = graphene.Int()
    num_queries = graphene.Int()
    user = graphene.UUID()

    vfolders = graphene.List('ai.backend.manager.models.VirtualFolder')
    compute_sessions = graphene.List(
        'ai.backend.manager.models.ComputeSession',
        status=graphene.String(),
    )

    # Deprecated
    concurrency_limit = graphene.Int(
        deprecation_reason='Moved to KeyPairResourcePolicy object as '
        'max_concurrent_sessions field.')

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            user_id=row['user_id'],
            access_key=row['access_key'],
            secret_key=row['secret_key'],
            is_active=row['is_active'],
            is_admin=row['is_admin'],
            resource_policy=row['resource_policy'],
            created_at=row['created_at'],
            last_used=row['last_used'],
            concurrency_limit=0,  # moved to resource policy
            concurrency_used=row['concurrency_used'],
            rate_limit=row['rate_limit'],
            num_queries=row['num_queries'],
            user=row['user'],
        )

    async def resolve_vfolders(self, info):
        manager = info.context['dlmgr']
        loader = manager.get_loader('VirtualFolder')
        return await loader.load(self.access_key)

    async def resolve_compute_sessions(self, info, status=None):
        manager = info.context['dlmgr']
        from . import KernelStatus  # noqa: avoid circular imports
        if status is not None:
            status = KernelStatus[status]
        loader = manager.get_loader('ComputeSession', status=status)
        return await loader.load(self.access_key)

    @staticmethod
    async def load_all(context, *, domain_name=None, is_active=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = sa.select([keypairs]).select_from(j)
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if is_active is not None:
                query = query.where(keypairs.c.is_active == is_active)
            objs = []
            async for row in conn.execute(query):
                o = KeyPair.from_row(row)
                objs.append(o)
        return objs

    @staticmethod
    async def batch_load_by_email(context,
                                  user_ids,
                                  *,
                                  domain_name=None,
                                  is_active=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = (sa.select([keypairs]).select_from(j).where(
                keypairs.c.user_id.in_(user_ids)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if is_active is not None:
                query = query.where(keypairs.c.is_active == is_active)
            objs_per_key = OrderedDict()
            for k in user_ids:
                objs_per_key[k] = list()
            async for row in conn.execute(query):
                o = KeyPair.from_row(row)
                objs_per_key[row.user_id].append(o)
        return tuple(objs_per_key.values())

    @staticmethod
    async def batch_load_by_ak(context, access_keys, *, domain_name=None):
        async with context['dbpool'].acquire() as conn:
            from .user import users
            j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid)
            query = (sa.select([keypairs]).select_from(j).where(
                keypairs.c.access_key.in_(access_keys)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            objs_per_key = OrderedDict()
            # For each access key, there is only one keypair.
            # So we don't build lists in objs_per_key variable.
            for k in access_keys:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = KeyPair.from_row(row)
                objs_per_key[row.access_key] = o
        return tuple(objs_per_key.values())
class ScalingGroup(graphene.ObjectType):
    name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    created_at = GQLDateTime()
    driver = graphene.String()
    driver_opts = graphene.JSONString()
    scheduler = graphene.String()
    scheduler_opts = graphene.JSONString()

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            name=row['name'],
            description=row['description'],
            is_active=row['is_active'],
            created_at=row['created_at'],
            driver=row['driver'],
            driver_opts=row['driver_opts'],
            scheduler=row['scheduler'],
            scheduler_opts=row['scheduler_opts'],
        )

    @staticmethod
    async def load_all(context, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            query = sa.select([scaling_groups]).select_from(scaling_groups)
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            objs = []
            async for row in conn.execute(query):
                o = ScalingGroup.from_row(row)
                objs.append(o)
        return objs

    @staticmethod
    async def load_by_domain(context, domain, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                scaling_groups, sgroups_for_domains,
                scaling_groups.c.name == sgroups_for_domains.c.scaling_group)
            query = (sa.select([
                scaling_groups
            ]).select_from(j).where(sgroups_for_domains.c.domain == domain))
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            objs = []
            async for row in conn.execute(query):
                o = ScalingGroup.from_row(row)
                objs.append(o)
        return objs

    @staticmethod
    async def load_by_group(context, group, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                scaling_groups, sgroups_for_groups,
                scaling_groups.c.name == sgroups_for_groups.c.scaling_group)
            query = (sa.select([
                scaling_groups
            ]).select_from(j).where(sgroups_for_groups.c.group == group))
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            objs = []
            async for row in conn.execute(query):
                o = ScalingGroup.from_row(row)
                objs.append(o)
        return objs

    @staticmethod
    async def load_by_keypair(context, access_key, *, is_active=None):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(
                scaling_groups, sgroups_for_keypairs,
                scaling_groups.c.name == sgroups_for_keypairs.c.scaling_group)
            query = (sa.select([scaling_groups]).select_from(j).where(
                sgroups_for_keypairs.c.access_key == access_key))
            if is_active is not None:
                query = query.where(scaling_groups.c.is_active == is_active)
            objs = []
            async for row in conn.execute(query):
                o = ScalingGroup.from_row(row)
                objs.append(o)
        return objs

    @staticmethod
    async def batch_load_by_name(context, names):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([scaling_groups
                                ]).select_from(scaling_groups).where(
                                    scaling_groups.c.name.in_(names)))
            objs_per_key = OrderedDict()
            for k in names:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = ScalingGroup.from_row(row)
                objs_per_key[row.name] = o
        return tuple(objs_per_key.values())
Example #20
0
class Group(graphene.ObjectType):
    id = graphene.UUID()
    name = graphene.String()
    description = graphene.String()
    is_active = graphene.Boolean()
    created_at = GQLDateTime()
    modified_at = GQLDateTime()
    domain_name = graphene.String()
    total_resource_slots = graphene.JSONString()
    allowed_vfolder_hosts = graphene.List(lambda: graphene.String)
    integration_id = graphene.String()

    scaling_groups = graphene.List(lambda: graphene.String)

    @classmethod
    def from_row(cls, context: Mapping[str, Any],
                 row: RowProxy) -> Optional[Group]:
        if row is None:
            return None
        return cls(
            id=row['id'],
            name=row['name'],
            description=row['description'],
            is_active=row['is_active'],
            created_at=row['created_at'],
            modified_at=row['modified_at'],
            domain_name=row['domain_name'],
            total_resource_slots=row['total_resource_slots'].to_json(),
            allowed_vfolder_hosts=row['allowed_vfolder_hosts'],
            integration_id=row['integration_id'],
        )

    async def resolve_scaling_groups(self, info):
        from .scaling_group import ScalingGroup
        sgroups = await ScalingGroup.load_by_group(info.context, self.id)
        return [sg.name for sg in sgroups]

    @classmethod
    async def load_all(cls, context, *, domain_name=None, is_active=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([groups]).select_from(groups))
            if domain_name is not None:
                query = query.where(groups.c.domain_name == domain_name)
            if is_active is not None:
                query = query.where(groups.c.is_active == is_active)
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]

    @classmethod
    async def batch_load_by_id(cls, context, group_ids, *, domain_name=None):
        async with context['dbpool'].acquire() as conn:
            query = (sa.select([groups]).select_from(groups).where(
                groups.c.id.in_(group_ids)))
            if domain_name is not None:
                query = query.where(groups.c.domain_name == domain_name)
            return await batch_result(
                context,
                conn,
                query,
                cls,
                group_ids,
                lambda row: row['id'],
            )

    @classmethod
    async def get_groups_for_user(cls, context, user_id):
        async with context['dbpool'].acquire() as conn:
            j = sa.join(groups, association_groups_users,
                        groups.c.id == association_groups_users.c.group_id)
            query = (sa.select([groups]).select_from(j).where(
                association_groups_users.c.user_id == user_id))
            return [
                cls.from_row(context, row) async for row in conn.execute(query)
            ]
Example #21
0
class ComputeContainer(graphene.ObjectType):
    class Meta:
        interfaces = (Item, )

    # identity
    role = graphene.String()
    hostname = graphene.String()
    session_id = graphene.UUID()  # owner session

    # image
    image = graphene.String()
    registry = graphene.String()

    # status
    status = graphene.String()
    status_changed = GQLDateTime()
    status_info = graphene.String()
    created_at = GQLDateTime()
    terminated_at = GQLDateTime()
    starts_at = GQLDateTime()

    # resources
    agent = graphene.String()
    container_id = graphene.String()
    resource_opts = graphene.JSONString()
    occupied_slots = graphene.JSONString()
    live_stat = graphene.JSONString()
    last_stat = graphene.JSONString()

    @classmethod
    def parse_row(cls, context, row):
        assert row is not None
        from .user import UserRole
        is_superadmin = (context['user']['role'] == UserRole.SUPERADMIN)
        if is_superadmin:
            hide_agents = False
        else:
            hide_agents = context['config']['manager']['hide-agents']
        return {
            # identity
            'id': row['id'],
            'role': row['role'],
            'hostname': None,         # TODO: implement
            'session_id': row['id'],  # master container's ID == session ID

            # image
            'image': row['image'],
            'registry': row['registry'],

            # status
            'status': row['status'].name,
            'status_changed': row['status_changed'],
            'status_info': row['status_info'],
            'created_at': row['created_at'],
            'terminated_at': row['terminated_at'],
            'starts_at': row['starts_at'],
            'occupied_slots': row['occupied_slots'].to_json(),

            # resources
            'agent': row['agent'] if not hide_agents else None,
            'container_id': row['container_id'] if not hide_agents else None,
            'resource_opts': row['resource_opts'],

            # statistics
            # live_stat is resolved by Graphene
            'last_stat': row['last_stat'],
        }

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        props = cls.parse_row(context, row)
        return cls(**props)

    async def resolve_live_stat(self, info: graphene.ResolveInfo):
        if not hasattr(self, 'status'):
            return None
        rs = info.context['redis_stat']
        if self.status in LIVE_STATUS:
            raw_live_stat = await redis.execute_with_retries(
                lambda: rs.get(str(self.id), encoding=None))
            if raw_live_stat is not None:
                live_stat = msgpack.unpackb(raw_live_stat)
                return live_stat
            return None
        else:
            return self.last_stat

    @classmethod
    async def load_count(cls, context, session_id, *,
                         role=None,
                         domain_name=None, group_id=None, access_key=None):
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([sa.func.count(kernels.c.id)])
                .select_from(kernels)
                # TODO: use "owner session ID" when we implement multi-container session
                .where(kernels.c.id == session_id)
                .as_scalar()
            )
            if role is not None:
                query = query.where(kernels.c.role == role)
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(kernels.c.group_id == group_id)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @classmethod
    async def load_slice(cls, context, limit, offset, session_id, *,
                         role=None,
                         domain_name=None, group_id=None, access_key=None,
                         order_key=None, order_asc=None):
        async with context['dbpool'].acquire() as conn:
            if order_key is None:
                _ordering = DEFAULT_SESSION_ORDERING
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = [_order_func(getattr(kernels.c, order_key))]
            j = (
                kernels
                .join(groups, groups.c.id == kernels.c.group_id)
                .join(users, users.c.uuid == kernels.c.user_uuid)
            )
            query = (
                sa.select([kernels, groups.c.name, users.c.email])
                .select_from(j)
                # TODO: use "owner session ID" when we implement multi-container session
                .where(kernels.c.id == session_id)
                .order_by(*_ordering)
                .limit(limit)
                .offset(offset)
            )
            if role is not None:
                query = query.where(kernels.c.role == role)
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(kernels.c.group_id == group_id)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            return [cls.from_row(context, r) async for r in conn.execute(query)]

    @classmethod
    async def batch_load_by_session(cls, context, session_ids):
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([kernels])
                .select_from(kernels)
                # TODO: use "owner session ID" when we implement multi-container session
                .where(kernels.c.id.in_(session_ids))
            )
            return await batch_multiresult(
                context, conn, query, cls,
                session_ids, lambda row: row['id'],
            )

    @classmethod
    async def batch_load_detail(cls, context, container_ids, *,
                                domain_name=None, access_key=None):
        async with context['dbpool'].acquire() as conn:
            j = (
                kernels
                .join(groups, groups.c.id == kernels.c.group_id)
                .join(users, users.c.uuid == kernels.c.user_uuid)
            )
            query = (
                sa.select([kernels, groups.c.name, users.c.email])
                .select_from(j)
                .where(
                    (kernels.c.id.in_(container_ids))
                ))
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            return await batch_result(
                context, conn, query, cls,
                container_ids, lambda row: row['id'],
            )
Example #22
0
class Agent(graphene.ObjectType):

    id = graphene.String()
    status = graphene.String()
    region = graphene.String()
    mem_slots = graphene.Int()
    cpu_slots = graphene.Int()
    gpu_slots = graphene.Int()
    used_mem_slots = graphene.Int()
    used_cpu_slots = graphene.Int()
    used_gpu_slots = graphene.Int()
    addr = graphene.String()
    first_contact = GQLDateTime()
    lost_at = GQLDateTime()

    computations = graphene.List('ai.backend.manager.models.Computation',
                                 status=graphene.String())

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            id=row.id,
            status=row.status,
            region=row.region,
            mem_slots=row.mem_slots,
            cpu_slots=row.cpu_slots,
            gpu_slots=row.gpu_slots,
            used_mem_slots=row.used_mem_slots,
            used_cpu_slots=row.used_cpu_slots,
            used_gpu_slots=row.used_gpu_slots,
            addr=row.addr,
            first_contact=row.first_contact,
            lost_at=row.lost_at,
        )

    async def resolve_computations(self, info, status=None):
        '''
        Retrieves all children worker sessions run by this agent.
        '''
        manager = info.context['dlmgr']
        loader = manager.get_loader('Computation.by_agent_id', status=status)
        return await loader.load(self.id)

    @staticmethod
    async def load_all(dbpool, status=None):
        async with dbpool.acquire() as conn:
            query = (sa.select('*').select_from(agents))
            if status is not None:
                status = AgentStatus[status]
                query = query.where(agents.c.status == status)
            result = await conn.execute(query)
            rows = await result.fetchall()
            return [Agent.from_row(r) for r in rows]

    @staticmethod
    async def batch_load(dbpool, agent_ids):
        async with dbpool.acquire() as conn:
            query = (sa.select('*').select_from(agents).where(
                agents.c.id.in_(agent_ids)).order_by(agents.c.id))
            objs_per_key = OrderedDict()
            for k in agent_ids:
                objs_per_key[k] = None
            async for row in conn.execute(query):
                o = Agent.from_row(row)
                objs_per_key[row.id] = o
        return tuple(objs_per_key.values())
Example #23
0
class LegacyComputeSession(graphene.ObjectType):
    """
    Represents a master session.
    """
    class Meta:
        interfaces = (Item, )

    tag = graphene.String()  # Only for ComputeSession
    sess_id = graphene.String()    # legacy
    sess_type = graphene.String()  # legacy
    session_name = graphene.String()
    session_type = graphene.String()
    role = graphene.String()
    image = graphene.String()
    registry = graphene.String()
    domain_name = graphene.String()
    group_name = graphene.String()
    group_id = graphene.UUID()
    scaling_group = graphene.String()
    user_uuid = graphene.UUID()
    access_key = graphene.String()

    status = graphene.String()
    status_changed = GQLDateTime()
    status_info = graphene.String()
    created_at = GQLDateTime()
    terminated_at = GQLDateTime()
    startup_command = graphene.String()
    result = graphene.String()

    # hidable fields by configuration
    agent = graphene.String()
    container_id = graphene.String()

    service_ports = graphene.JSONString()

    occupied_slots = graphene.JSONString()
    occupied_shares = graphene.JSONString()
    mounts = graphene.List(lambda: graphene.List(lambda: graphene.String))
    resource_opts = graphene.JSONString()

    num_queries = BigInt()
    live_stat = graphene.JSONString()
    last_stat = graphene.JSONString()

    user_email = graphene.String()

    # Legacy fields
    lang = graphene.String()
    mem_slot = graphene.Int()
    cpu_slot = graphene.Float()
    gpu_slot = graphene.Float()
    tpu_slot = graphene.Float()
    cpu_used = BigInt()
    cpu_using = graphene.Float()
    mem_max_bytes = BigInt()
    mem_cur_bytes = BigInt()
    net_rx_bytes = BigInt()
    net_tx_bytes = BigInt()
    io_read_bytes = BigInt()
    io_write_bytes = BigInt()
    io_max_scratch_size = BigInt()
    io_cur_scratch_size = BigInt()

    @classmethod
    async def _resolve_live_stat(cls, redis_stat, kernel_id):
        cstat = await redis.execute_with_retries(
            lambda: redis_stat.get(kernel_id, encoding=None))
        if cstat is not None:
            cstat = msgpack.unpackb(cstat)
        return cstat

    async def resolve_live_stat(self, info: graphene.ResolveInfo):
        if not hasattr(self, 'status'):
            return None
        rs = info.context['redis_stat']
        if self.status not in LIVE_STATUS:
            return self.last_stat
        else:
            return await type(self)._resolve_live_stat(rs, str(self.id))

    async def _resolve_legacy_metric(
        self, info: graphene.ResolveInfo,
        metric_key, metric_field, convert_type,
    ):
        if not hasattr(self, 'status'):
            return None
        rs = info.context['redis_stat']
        if self.status not in LIVE_STATUS:
            if self.last_stat is None:
                return convert_type(0)
            metric = self.last_stat.get(metric_key)
            if metric is None:
                return convert_type(0)
            value = metric.get(metric_field)
            if value is None:
                return convert_type(0)
            return convert_type(value)
        else:
            kstat = await type(self)._resolve_live_stat(rs, str(self.id))
            if kstat is None:
                return convert_type(0)
            metric = kstat.get(metric_key)
            if metric is None:
                return convert_type(0)
            value = metric.get(metric_field)
            if value is None:
                return convert_type(0)
            return convert_type(value)

    async def resolve_cpu_used(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'cpu_used', 'current', float)

    async def resolve_cpu_using(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'cpu_util', 'pct', float)

    async def resolve_mem_max_bytes(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'mem', 'stats.max', int)

    async def resolve_mem_cur_bytes(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'mem', 'current', int)

    async def resolve_net_rx_bytes(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'net_rx', 'stats.rate', int)

    async def resolve_net_tx_bytes(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'net_tx', 'stats.rate', int)

    async def resolve_io_read_bytes(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'io_read', 'current', int)

    async def resolve_io_write_bytes(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'io_write', 'current', int)

    async def resolve_io_max_scratch_size(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'io_scratch_size', 'stats.max', int)

    async def resolve_io_cur_scratch_size(self, info: graphene.ResolveInfo):
        return await self._resolve_legacy_metric(info, 'io_scratch_size', 'current', int)

    @classmethod
    def parse_row(cls, context, row):
        assert row is not None
        from .user import UserRole
        mega = 2 ** 20
        is_superadmin = (context['user']['role'] == UserRole.SUPERADMIN)
        if is_superadmin:
            hide_agents = False
        else:
            hide_agents = context['config']['manager']['hide-agents']
        return {
            'sess_id': row['sess_id'],           # legacy, will be deprecated
            'sess_type': row['sess_type'].name,  # legacy, will be deprecated
            'session_name': row['sess_id'],
            'session_type': row['sess_type'].name,
            'id': row['id'],                     # legacy, will be replaced with session UUID
            'role': row['role'],
            'tag': row['tag'],
            'image': row['image'],
            'registry': row['registry'],
            'domain_name': row['domain_name'],
            'group_name': row['name'],  # group.name (group is omitted since use_labels=True is not used)
            'group_id': row['group_id'],
            'scaling_group': row['scaling_group'],
            'user_uuid': row['user_uuid'],
            'access_key': row['access_key'],
            'status': row['status'].name,
            'status_changed': row['status_changed'],
            'status_info': row['status_info'],
            'created_at': row['created_at'],
            'terminated_at': row['terminated_at'],
            'startup_command': row['startup_command'],
            'result': row['result'].name,
            'service_ports': row['service_ports'],
            'occupied_slots': row['occupied_slots'].to_json(),
            'mounts': row['mounts'],
            'resource_opts': row['resource_opts'],
            'num_queries': row['num_queries'],
            # optionally hidden
            'agent': row['agent'] if not hide_agents else None,
            'container_id': row['container_id'] if not hide_agents else None,
            # live_stat is resolved by Graphene
            'last_stat': row['last_stat'],
            'user_email': row['email'],
            # Legacy fields
            # NOTE: currently graphene always uses resolve methods!
            'cpu_used': 0,
            'mem_max_bytes': 0,
            'mem_cur_bytes': 0,
            'net_rx_bytes': 0,
            'net_tx_bytes': 0,
            'io_read_bytes': 0,
            'io_write_bytes': 0,
            'io_max_scratch_size': 0,
            'io_cur_scratch_size': 0,
            'lang': row['image'],
            'occupied_shares': row['occupied_shares'],
            'mem_slot': BinarySize.from_str(
                row['occupied_slots'].get('mem', 0)) // mega,
            'cpu_slot': float(row['occupied_slots'].get('cpu', 0)),
            'gpu_slot': float(row['occupied_slots'].get('cuda.device', 0)),
            'tpu_slot': float(row['occupied_slots'].get('tpu.device', 0)),
        }

    @classmethod
    def from_row(cls, context, row):
        if row is None:
            return None
        props = cls.parse_row(context, row)
        return cls(**props)

    @classmethod
    async def load_count(cls, context, *,
                         domain_name=None, group_id=None, access_key=None,
                         status=None):
        if isinstance(status, str):
            status_list = [KernelStatus[s] for s in status.split(',')]
        elif isinstance(status, KernelStatus):
            status_list = [status]
        async with context['dbpool'].acquire() as conn:
            query = (
                sa.select([sa.func.count(kernels.c.sess_id)])
                .select_from(kernels)
                .where(kernels.c.role == 'master')
                .as_scalar()
            )
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(kernels.c.group_id == group_id)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            if status is not None:
                query = query.where(kernels.c.status.in_(status_list))
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @classmethod
    async def load_slice(cls, context, limit, offset, *,
                         domain_name=None, group_id=None, access_key=None,
                         status=None,
                         order_key=None, order_asc=None):
        if isinstance(status, str):
            status_list = [KernelStatus[s] for s in status.split(',')]
        elif isinstance(status, KernelStatus):
            status_list = [status]
        async with context['dbpool'].acquire() as conn:
            if order_key is None:
                _ordering = DEFAULT_SESSION_ORDERING
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = [_order_func(getattr(kernels.c, order_key))]
            j = (kernels.join(groups, groups.c.id == kernels.c.group_id)
                        .join(users, users.c.uuid == kernels.c.user_uuid))
            query = (
                sa.select([kernels, groups.c.name, users.c.email])
                .select_from(j)
                .where(kernels.c.role == 'master')
                .order_by(*_ordering)
                .limit(limit)
                .offset(offset)
            )
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(kernels.c.group_id == group_id)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            if status is not None:
                query = query.where(kernels.c.status.in_(status_list))
            return [cls.from_row(context, r) async for r in conn.execute(query)]

    @classmethod
    async def batch_load(cls, context, access_keys, *,
                         domain_name=None, group_id=None,
                         status=None):
        async with context['dbpool'].acquire() as conn:
            j = (kernels.join(groups, groups.c.id == kernels.c.group_id)
                        .join(users, users.c.uuid == kernels.c.user_uuid))
            query = (
                sa.select([kernels, groups.c.name, users.c.email])
                .select_from(j)
                .where(
                    (kernels.c.access_key.in_(access_keys)) &
                    (kernels.c.role == 'master')
                )
                .order_by(
                    sa.desc(sa.func.greatest(
                        kernels.c.created_at,
                        kernels.c.terminated_at,
                        kernels.c.status_changed,
                    ))
                )
                .limit(100))
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(kernels.c.group_id == group_id)
            if status is not None:
                query = query.where(kernels.c.status == status)
            return await batch_result(
                context, conn, query, cls,
                access_keys, lambda row: row['access_key'],
            )

    @classmethod
    async def batch_load_detail(cls, context, sess_ids, *,
                                domain_name=None, access_key=None,
                                status=None):
        async with context['dbpool'].acquire() as conn:
            status_list = []
            if isinstance(status, str):
                status_list = [KernelStatus[s] for s in status.split(',')]
            elif isinstance(status, KernelStatus):
                status_list = [status]
            elif status is None:
                status_list = [KernelStatus['RUNNING']]
            j = (kernels.join(groups, groups.c.id == kernels.c.group_id)
                        .join(users, users.c.uuid == kernels.c.user_uuid))
            query = (sa.select([kernels, groups.c.name, users.c.email])
                       .select_from(j)
                       .where((kernels.c.role == 'master') &
                              (kernels.c.sess_id.in_(sess_ids))))
            if domain_name is not None:
                query = query.where(kernels.c.domain_name == domain_name)
            if access_key is not None:
                query = query.where(kernels.c.access_key == access_key)
            if status_list:
                query = query.where(kernels.c.status.in_(status_list))
            return await batch_multiresult(
                context, conn, query, cls,
                sess_ids, lambda row: row['sess_id'],
            )
Example #24
0
class VirtualFolder(graphene.ObjectType):
    class Meta:
        interfaces = (Item, )

    host = graphene.String()
    name = graphene.String()
    max_files = graphene.Int()
    max_size = graphene.Int()
    created_at = GQLDateTime()
    last_used = GQLDateTime()

    num_files = graphene.Int()
    cur_size = graphene.Int()
    # num_attached = graphene.Int()

    @classmethod
    def from_row(cls, row):
        if row is None:
            return None
        return cls(
            id=row['id'],
            host=row['host'],
            name=row['name'],
            max_files=row['max_files'],
            max_size=row['max_size'],  # in KiB
            created_at=row['created_at'],
            last_used=row['last_used'],
            # num_attached=row['num_attached'],
        )

    async def resolve_num_files(self, info):
        # TODO: measure on-the-fly
        return 0

    async def resolve_cur_size(self, info):
        # TODO: measure on-the-fly
        return 0

    @staticmethod
    async def load_count(context,
                         *,
                         domain_name=None,
                         group_id=None,
                         user_id=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid)
            query = (sa.select([sa.func.count(vfolders.c.id)
                                ]).select_from(j).as_scalar())
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(vfolders.c.group == group_id)
            if user_id is not None:
                query = query.where(vfolders.c.user == user_id)
            result = await conn.execute(query)
            count = await result.fetchone()
            return count[0]

    @staticmethod
    async def load_slice(context,
                         limit,
                         offset,
                         *,
                         domain_name=None,
                         group_id=None,
                         user_id=None,
                         order_key=None,
                         order_asc=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            if order_key is None:
                _ordering = vfolders.c.created_at
            else:
                _order_func = sa.asc if order_asc else sa.desc
                _ordering = _order_func(getattr(vfolders.c, order_key))
            j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid)
            query = (sa.select([
                vfolders
            ]).select_from(j).order_by(_ordering).limit(limit).offset(offset))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(vfolders.c.group == group_id)
            if user_id is not None:
                query = query.where(vfolders.c.user == user_id)
            result = await conn.execute(query)
            rows = await result.fetchall()
            return [VirtualFolder.from_row(context, r) for r in rows]

    @staticmethod
    async def load_all(context,
                       *,
                       domain_name=None,
                       group_id=None,
                       user_id=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid)
            query = (sa.select([vfolders]).select_from(j).order_by(
                sa.desc(vfolders.c.created_at)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(vfolders.c.group == group_id)
            if user_id is not None:
                query = query.where(vfolders.c.user == user_id)
            objs = []
            async for row in conn.execute(query):
                o = VirtualFolder.from_row(row)
                objs.append(o)
        return objs

    @staticmethod
    async def batch_load_by_user(context,
                                 user_uuids,
                                 *,
                                 domain_name=None,
                                 group_id=None):
        from .user import users
        async with context['dbpool'].acquire() as conn:
            # TODO: num_attached count group-by
            j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid)
            query = (sa.select([vfolders]).select_from(j).where(
                vfolders.c.user.in_(user_uuids)).order_by(
                    sa.desc(vfolders.c.created_at)))
            if domain_name is not None:
                query = query.where(users.c.domain_name == domain_name)
            if group_id is not None:
                query = query.where(vfolders.c.group == group_id)
            objs_per_key = OrderedDict()
            for u in user_uuids:
                objs_per_key[u] = list()
            async for row in conn.execute(query):
                o = VirtualFolder.from_row(row)
                objs_per_key[row.user].append(o)
        return tuple(objs_per_key.values())