コード例 #1
0
async def batch_multiresult(
    context: Mapping[str, Any],
    conn: SAConnection,
    query: sa.sql.Select,
    obj_type: Type[_GenericSQLBasedGQLObject],
    key_list: Iterable[_Key],
    key_getter: Callable[[RowProxy], _Key],
) -> Sequence[Sequence[_GenericSQLBasedGQLObject]]:
    """
    A batched query adaptor for (key -> [item]) resolving patterns.
    """
    objs_per_key: Dict[_Key, List[_GenericSQLBasedGQLObject]]
    objs_per_key = collections.OrderedDict()
    for key in key_list:
        objs_per_key[key] = list()
    async for row in conn.execute(query):
        objs_per_key[key_getter(row)].append(obj_type.from_row(context, row))
    return [*objs_per_key.values()]
コード例 #2
0
    async def delete_vfolders(
        cls,
        conn: SAConnection,
        user_uuid: uuid.UUID,
        config_server,
    ) -> int:
        """
        Delete user's all virtual folders as well as their physical data.

        :param conn: DB connection
        :param user_uuid: user's UUID to delete virtual folders

        :return: number of deleted rows
        """
        from . import vfolders
        mount_prefix = Path(await config_server.get('volumes/_mount'))
        fs_prefix = await config_server.get('volumes/_fsprefix')
        fs_prefix = Path(fs_prefix.lstrip('/'))
        query = (
            sa.select([vfolders.c.id, vfolders.c.host, vfolders.c.unmanaged_path])
            .select_from(vfolders)
            .where(vfolders.c.user == user_uuid)
        )
        async for row in conn.execute(query):
            if row['unmanaged_path']:
                folder_path = Path(row['unmanaged_path'])
            else:
                folder_path = (mount_prefix / row['host'] / fs_prefix / row['id'].hex)
            log.info('deleting physical files: {0}', folder_path)
            try:
                loop = current_loop()
                await loop.run_in_executor(None, lambda: shutil.rmtree(folder_path))  # type: ignore
            except IOError:
                pass
        query = (
            vfolders.delete()
            .where(vfolders.c.user == user_uuid)
        )
        result = await conn.execute(query)
        if result.rowcount > 0:
            log.info('deleted {0} user\'s virtual folders ({1})', result.rowcount, user_uuid)
        return result.rowcount
コード例 #3
0
    async def __load_projects(self, conn: SAConnection, query: str,
                              user_id: int,
                              user_groups: List[RowProxy]) -> List[Dict]:
        api_projects: List[Dict] = []  # API model-compatible projects
        db_projects: List[Dict] = []  # DB model-compatible projects
        async for row in conn.execute(query):
            try:
                _check_project_permissions(row, user_id, user_groups, "read")
            except ProjectInvalidRightsError:
                continue
            prj = dict(row.items())
            log.debug("found project: %s", prj)
            db_projects.append(prj)

        # NOTE: DO NOT nest _get_tags_by_project in async loop above !!!
        # FIXME: temporary avoids inner async loops issue https://github.com/aio-libs/aiopg/issues/535
        for db_prj in db_projects:
            db_prj["tags"] = await self._get_tags_by_project(
                conn, project_id=db_prj["id"])
            user_email = await self._get_user_email(conn, db_prj["prj_owner"])
            api_projects.append(_convert_to_schema_names(db_prj, user_email))

        return api_projects
コード例 #4
0
ファイル: group.py プロジェクト: inureyes/backend.ai-manager
    async def group_vfolder_mounted_to_active_kernels(
        cls,
        conn: SAConnection,
        group_id: uuid.UUID,
    ) -> bool:
        """
        Check if no active kernel is using the group's virtual folders.

        :param conn: DB connection
        :param group_id: group's ID

        :return: True if a virtual folder is mounted to active kernels.
        """
        from . import kernels, vfolders, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES
        query = (
            sa.select([vfolders.c.id])
            .select_from(vfolders)
            .where(vfolders.c.group == group_id)
        )
        result = await conn.execute(query)
        rows = await result.fetchall()
        group_vfolder_ids = [row.id for row in rows]
        query = (
            sa.select([kernels.c.mounts])
            .select_from(kernels)
            .where((kernels.c.group_id == group_id) &
                   (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)))
        )
        async for row in conn.execute(query):
            for _mount in row['mounts']:
                try:
                    vfolder_id = uuid.UUID(_mount[2])
                    if vfolder_id in group_vfolder_ids:
                        return True
                except Exception:
                    pass
        return False
コード例 #5
0
async def list_projects_access_rights(
        conn: SAConnection, user_id: int) -> Dict[ProjectID, AccessRights]:
    """
    Returns access-rights of user (user_id) over all OWNED or SHARED projects
    """

    user_group_ids: List[int] = await _get_user_groups_ids(conn, user_id)

    smt = text(f"""\
    SELECT uuid, access_rights
    FROM projects
    WHERE (
        prj_owner = {user_id}
        OR jsonb_exists_any( access_rights, (
               SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} )
            )
        )
    )
    """)
    projects_access_rights = {}

    async for row in conn.execute(smt):
        assert isinstance(row.access_rights, dict)
        assert isinstance(row.uuid, ProjectID)

        if row.access_rights:
            # TODO: access_rights should be direclty filtered from result in stm instead calling again user_group_ids
            projects_access_rights[row.uuid] = _aggregate_access_rights(
                row.access_rights, user_group_ids)

        else:
            # backwards compatibility
            # - no access_rights defined BUT project is owned
            projects_access_rights[row.uuid] = AccessRights.all()

    return projects_access_rights
コード例 #6
0
async def _list_agents_by_sgroup(
    db_conn: SAConnection,
    sgroup_name: str,
) -> Sequence[AgentContext]:
    query = (sa.select([
        agents.c.id,
        agents.c.addr,
        agents.c.scaling_group,
        agents.c.available_slots,
        agents.c.occupied_slots,
    ]).select_from(agents).where((agents.c.status == AgentStatus.ALIVE)
                                 & (agents.c.scaling_group == sgroup_name)
                                 & (agents.c.schedulable == true())))
    items = []
    async for row in db_conn.execute(query):
        item = AgentContext(
            row['id'],
            row['addr'],
            row['scaling_group'],
            row['available_slots'],
            row['occupied_slots'],
        )
        items.append(item)
    return items
コード例 #7
0
async def _list_existing_sessions(
    db_conn: SAConnection,
    sgroup: str,
) -> List[ExistingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.startup_command,
            kernels.c.internal_data,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
            (kernels.c.scaling_group == sgroup)
        )
        .order_by(kernels.c.created_at)
    )
    items: MutableMapping[str, ExistingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = ExistingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                occupying_slots=ResourceSlot(),
            )
            items[row['session_id']] = session
        # TODO: support multi-container sessions
        session.kernels.append(KernelInfo(  # type: ignore
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=None,
            startup_command=None,
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.occupying_slots += row['occupied_slots']  # type: ignore
コード例 #8
0
async def _list_pending_sessions(
    db_conn: SAConnection,
    sgroup_name: str,
) -> List[PendingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.cluster_mode,
            kernels.c.cluster_size,
            kernels.c.cluster_role,
            kernels.c.cluster_idx,
            kernels.c.cluster_hostname,
            kernels.c.registry,
            kernels.c.session_id,
            kernels.c.session_type,
            kernels.c.session_name,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.bootstrap_script,
            kernels.c.startup_command,
            kernels.c.internal_data,
            kernels.c.preopen_ports,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status == KernelStatus.PENDING) &
            (
                (kernels.c.scaling_group == sgroup_name) |
                (kernels.c.scaling_group.is_(None))
            )
        )
        .order_by(kernels.c.created_at)
    )
    # TODO: extend for multi-container sessions
    items: MutableMapping[str, PendingSession] = {}
    async for row in db_conn.execute(query):
        if _session := items.get(row['session_id']):
            session = _session
        else:
            session = PendingSession(
                kernels=[],
                access_key=row['access_key'],
                session_id=row['session_id'],
                session_type=row['session_type'],
                session_name=row['session_name'],
                cluster_mode=row['cluster_mode'],
                cluster_size=row['cluster_size'],
                domain_name=row['domain_name'],
                group_id=row['group_id'],
                scaling_group=row['scaling_group'],
                resource_policy=row['resource_policy'],
                resource_opts={},
                requested_slots=ResourceSlot(),
                internal_data=row['internal_data'],
                target_sgroup_names=[],
                environ={
                    k: v for k, v
                    in map(lambda s: s.split('=', maxsplit=1), row['environ'])
                },
                mounts=row['mounts'],
                mount_map=row['mount_map'],
                bootstrap_script=row['bootstrap_script'],
                startup_command=row['startup_command'],
                preopen_ports=row['preopen_ports'],
            )
            items[row['session_id']] = session
        session.kernels.append(KernelInfo(
            kernel_id=row['id'],
            session_id=row['session_id'],
            access_key=row['access_key'],
            cluster_role=row['cluster_role'],
            cluster_idx=row['cluster_idx'],
            cluster_hostname=row['cluster_hostname'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            bootstrap_script=row['bootstrap_script'],
            startup_command=row['startup_command'],
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
        ))
        session.requested_slots += row['occupied_slots']  # type: ignore
        merge_resource(session.resource_opts, row['resource_opts'])  # type: ignore
コード例 #9
0
    async def migrate_shared_vfolders(
        cls,
        conn: SAConnection,
        deleted_user_uuid: uuid.UUID,
        target_user_uuid: uuid.UUID,
        target_user_email: str,
    ) -> int:
        """
        Migrate shared virtual folders' ownership to a target user.

        If migrating virtual folder's name collides with target user's already
        existing folder, append random string to the migrating one.

        :param conn: DB connection
        :param deleted_user_uuid: user's UUID who will be deleted
        :param target_user_uuid: user's UUID who will get the ownership of virtual folders

        :return: number of deleted rows
        """
        from . import vfolders, vfolder_invitations, vfolder_permissions
        # Gather target user's virtual folders' names.
        query = (
            sa.select([vfolders.c.name])
            .select_from(vfolders)
            .where(vfolders.c.user == target_user_uuid)
        )
        existing_vfolder_names = [row.name async for row in conn.execute(query)]

        # Migrate shared virtual folders.
        # If virtual folder's name collides with target user's folder,
        # append random string to the name of the migrating folder.
        j = vfolder_permissions.join(
            vfolders,
            vfolder_permissions.c.vfolder == vfolders.c.id
        )
        query = (
            sa.select([vfolders.c.id, vfolders.c.name])
            .select_from(j)
            .where(vfolders.c.user == deleted_user_uuid)
        )
        migrate_updates = []
        async for row in conn.execute(query):
            name = row.name
            if name in existing_vfolder_names:
                name += f'-{uuid.uuid4().hex[:10]}'
            migrate_updates.append({'vid': row.id, 'vname': name})
        if migrate_updates:
            # Remove invitations and vfolder_permissions from target user.
            # Target user will be the new owner, and it does not make sense to have
            # invitation and shared permission for its own folder.
            migrate_vfolder_ids = [item['vid'] for item in migrate_updates]
            query = (
                vfolder_invitations.delete()
                .where((vfolder_invitations.c.invitee == target_user_email) &
                       (vfolder_invitations.c.vfolder.in_(migrate_vfolder_ids)))
            )
            await conn.execute(query)
            query = (
                vfolder_permissions.delete()
                .where((vfolder_permissions.c.user == target_user_uuid) &
                       (vfolder_permissions.c.vfolder.in_(migrate_vfolder_ids)))
            )
            await conn.execute(query)

            rowcount = 0
            for item in migrate_updates:
                query = (
                    vfolders.update()
                    .values(
                        user=target_user_uuid,
                        name=item['vname'],
                    )
                    .where(vfolders.c.id == item['vid'])
                )
                result = await conn.execute(query)
                rowcount += result.rowcount
            if rowcount > 0:
                log.info('{0} shared folders detected. migrated to user {1}',
                         rowcount, target_user_uuid)
            return rowcount
        else:
            return 0
コード例 #10
0
async def _list_pending_sessions(
    db_conn: SAConnection,
    sgroup_name: str,
) -> List[PendingSession]:
    query = (
        sa.select([
            kernels.c.id,
            kernels.c.status,
            kernels.c.image,
            kernels.c.registry,
            kernels.c.sess_type,
            kernels.c.sess_id,
            kernels.c.access_key,
            kernels.c.domain_name,
            kernels.c.group_id,
            kernels.c.scaling_group,
            kernels.c.occupied_slots,
            kernels.c.resource_opts,
            kernels.c.environ,
            kernels.c.mounts,
            kernels.c.mount_map,
            kernels.c.bootstrap_script,
            kernels.c.startup_command,
            kernels.c.internal_data,
            kernels.c.preopen_ports,
            keypairs.c.resource_policy,
        ])
        .select_from(sa.join(
            kernels, keypairs,
            keypairs.c.access_key == kernels.c.access_key
        ))
        .where(
            (kernels.c.status == KernelStatus.PENDING) &
            (
                (kernels.c.scaling_group == sgroup_name) |
                (kernels.c.scaling_group.is_(None))
            )
        )
        .order_by(kernels.c.created_at)
    )
    items = []
    async for row in db_conn.execute(query):
        items.append(PendingSession(
            kernel_id=row['id'],
            access_key=row['access_key'],
            session_type=row['sess_type'],
            session_name=row['sess_id'],
            domain_name=row['domain_name'],
            group_id=row['group_id'],
            scaling_group=row['scaling_group'],
            image_ref=ImageRef(row['image'], [row['registry']]),
            resource_policy=row['resource_policy'],
            resource_opts=row['resource_opts'],
            requested_slots=row['occupied_slots'],
            internal_data=row['internal_data'],
            target_sgroup_names=[],
            environ={
                k: v for k, v
                in map(lambda s: s.split('=', maxsplit=1), row['environ'])
            },
            mounts=row['mounts'],
            mount_map=row['mount_map'],
            bootstrap_script=row['bootstrap_script'],
            startup_command=row['startup_command'],
            preopen_ports=row['preopen_ports'],
        ))
    return items
コード例 #11
0
 async def _get_tags_by_project(self, conn: SAConnection,
                                project_id: str) -> List:
     query = sa.select([study_tags.c.tag_id
                        ]).where(study_tags.c.study_id == project_id)
     return [row.tag_id async for row in conn.execute(query)]