async def batch_multiresult( context: Mapping[str, Any], conn: SAConnection, query: sa.sql.Select, obj_type: Type[_GenericSQLBasedGQLObject], key_list: Iterable[_Key], key_getter: Callable[[RowProxy], _Key], ) -> Sequence[Sequence[_GenericSQLBasedGQLObject]]: """ A batched query adaptor for (key -> [item]) resolving patterns. """ objs_per_key: Dict[_Key, List[_GenericSQLBasedGQLObject]] objs_per_key = collections.OrderedDict() for key in key_list: objs_per_key[key] = list() async for row in conn.execute(query): objs_per_key[key_getter(row)].append(obj_type.from_row(context, row)) return [*objs_per_key.values()]
async def delete_vfolders( cls, conn: SAConnection, user_uuid: uuid.UUID, config_server, ) -> int: """ Delete user's all virtual folders as well as their physical data. :param conn: DB connection :param user_uuid: user's UUID to delete virtual folders :return: number of deleted rows """ from . import vfolders mount_prefix = Path(await config_server.get('volumes/_mount')) fs_prefix = await config_server.get('volumes/_fsprefix') fs_prefix = Path(fs_prefix.lstrip('/')) query = ( sa.select([vfolders.c.id, vfolders.c.host, vfolders.c.unmanaged_path]) .select_from(vfolders) .where(vfolders.c.user == user_uuid) ) async for row in conn.execute(query): if row['unmanaged_path']: folder_path = Path(row['unmanaged_path']) else: folder_path = (mount_prefix / row['host'] / fs_prefix / row['id'].hex) log.info('deleting physical files: {0}', folder_path) try: loop = current_loop() await loop.run_in_executor(None, lambda: shutil.rmtree(folder_path)) # type: ignore except IOError: pass query = ( vfolders.delete() .where(vfolders.c.user == user_uuid) ) result = await conn.execute(query) if result.rowcount > 0: log.info('deleted {0} user\'s virtual folders ({1})', result.rowcount, user_uuid) return result.rowcount
async def __load_projects(self, conn: SAConnection, query: str, user_id: int, user_groups: List[RowProxy]) -> List[Dict]: api_projects: List[Dict] = [] # API model-compatible projects db_projects: List[Dict] = [] # DB model-compatible projects async for row in conn.execute(query): try: _check_project_permissions(row, user_id, user_groups, "read") except ProjectInvalidRightsError: continue prj = dict(row.items()) log.debug("found project: %s", prj) db_projects.append(prj) # NOTE: DO NOT nest _get_tags_by_project in async loop above !!! # FIXME: temporary avoids inner async loops issue https://github.com/aio-libs/aiopg/issues/535 for db_prj in db_projects: db_prj["tags"] = await self._get_tags_by_project( conn, project_id=db_prj["id"]) user_email = await self._get_user_email(conn, db_prj["prj_owner"]) api_projects.append(_convert_to_schema_names(db_prj, user_email)) return api_projects
async def group_vfolder_mounted_to_active_kernels( cls, conn: SAConnection, group_id: uuid.UUID, ) -> bool: """ Check if no active kernel is using the group's virtual folders. :param conn: DB connection :param group_id: group's ID :return: True if a virtual folder is mounted to active kernels. """ from . import kernels, vfolders, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES query = ( sa.select([vfolders.c.id]) .select_from(vfolders) .where(vfolders.c.group == group_id) ) result = await conn.execute(query) rows = await result.fetchall() group_vfolder_ids = [row.id for row in rows] query = ( sa.select([kernels.c.mounts]) .select_from(kernels) .where((kernels.c.group_id == group_id) & (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))) ) async for row in conn.execute(query): for _mount in row['mounts']: try: vfolder_id = uuid.UUID(_mount[2]) if vfolder_id in group_vfolder_ids: return True except Exception: pass return False
async def list_projects_access_rights( conn: SAConnection, user_id: int) -> Dict[ProjectID, AccessRights]: """ Returns access-rights of user (user_id) over all OWNED or SHARED projects """ user_group_ids: List[int] = await _get_user_groups_ids(conn, user_id) smt = text(f"""\ SELECT uuid, access_rights FROM projects WHERE ( prj_owner = {user_id} OR jsonb_exists_any( access_rights, ( SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} ) ) ) ) """) projects_access_rights = {} async for row in conn.execute(smt): assert isinstance(row.access_rights, dict) assert isinstance(row.uuid, ProjectID) if row.access_rights: # TODO: access_rights should be direclty filtered from result in stm instead calling again user_group_ids projects_access_rights[row.uuid] = _aggregate_access_rights( row.access_rights, user_group_ids) else: # backwards compatibility # - no access_rights defined BUT project is owned projects_access_rights[row.uuid] = AccessRights.all() return projects_access_rights
async def _list_agents_by_sgroup( db_conn: SAConnection, sgroup_name: str, ) -> Sequence[AgentContext]: query = (sa.select([ agents.c.id, agents.c.addr, agents.c.scaling_group, agents.c.available_slots, agents.c.occupied_slots, ]).select_from(agents).where((agents.c.status == AgentStatus.ALIVE) & (agents.c.scaling_group == sgroup_name) & (agents.c.schedulable == true()))) items = [] async for row in db_conn.execute(query): item = AgentContext( row['id'], row['addr'], row['scaling_group'], row['available_slots'], row['occupied_slots'], ) items.append(item) return items
async def _list_existing_sessions( db_conn: SAConnection, sgroup: str, ) -> List[ExistingSession]: query = ( sa.select([ kernels.c.id, kernels.c.status, kernels.c.image, kernels.c.cluster_mode, kernels.c.cluster_size, kernels.c.cluster_role, kernels.c.cluster_idx, kernels.c.cluster_hostname, kernels.c.registry, kernels.c.session_id, kernels.c.session_type, kernels.c.session_name, kernels.c.access_key, kernels.c.domain_name, kernels.c.group_id, kernels.c.scaling_group, kernels.c.occupied_slots, kernels.c.resource_opts, kernels.c.environ, kernels.c.mounts, kernels.c.mount_map, kernels.c.startup_command, kernels.c.internal_data, keypairs.c.resource_policy, ]) .select_from(sa.join( kernels, keypairs, keypairs.c.access_key == kernels.c.access_key )) .where( (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & (kernels.c.scaling_group == sgroup) ) .order_by(kernels.c.created_at) ) items: MutableMapping[str, ExistingSession] = {} async for row in db_conn.execute(query): if _session := items.get(row['session_id']): session = _session else: session = ExistingSession( kernels=[], access_key=row['access_key'], session_id=row['session_id'], session_type=row['session_type'], session_name=row['session_name'], cluster_mode=row['cluster_mode'], cluster_size=row['cluster_size'], domain_name=row['domain_name'], group_id=row['group_id'], scaling_group=row['scaling_group'], occupying_slots=ResourceSlot(), ) items[row['session_id']] = session # TODO: support multi-container sessions session.kernels.append(KernelInfo( # type: ignore kernel_id=row['id'], session_id=row['session_id'], access_key=row['access_key'], cluster_role=row['cluster_role'], cluster_idx=row['cluster_idx'], cluster_hostname=row['cluster_hostname'], image_ref=ImageRef(row['image'], [row['registry']]), bootstrap_script=None, startup_command=None, resource_opts=row['resource_opts'], requested_slots=row['occupied_slots'], )) session.occupying_slots += row['occupied_slots'] # type: ignore
async def _list_pending_sessions( db_conn: SAConnection, sgroup_name: str, ) -> List[PendingSession]: query = ( sa.select([ kernels.c.id, kernels.c.status, kernels.c.image, kernels.c.cluster_mode, kernels.c.cluster_size, kernels.c.cluster_role, kernels.c.cluster_idx, kernels.c.cluster_hostname, kernels.c.registry, kernels.c.session_id, kernels.c.session_type, kernels.c.session_name, kernels.c.access_key, kernels.c.domain_name, kernels.c.group_id, kernels.c.scaling_group, kernels.c.occupied_slots, kernels.c.resource_opts, kernels.c.environ, kernels.c.mounts, kernels.c.mount_map, kernels.c.bootstrap_script, kernels.c.startup_command, kernels.c.internal_data, kernels.c.preopen_ports, keypairs.c.resource_policy, ]) .select_from(sa.join( kernels, keypairs, keypairs.c.access_key == kernels.c.access_key )) .where( (kernels.c.status == KernelStatus.PENDING) & ( (kernels.c.scaling_group == sgroup_name) | (kernels.c.scaling_group.is_(None)) ) ) .order_by(kernels.c.created_at) ) # TODO: extend for multi-container sessions items: MutableMapping[str, PendingSession] = {} async for row in db_conn.execute(query): if _session := items.get(row['session_id']): session = _session else: session = PendingSession( kernels=[], access_key=row['access_key'], session_id=row['session_id'], session_type=row['session_type'], session_name=row['session_name'], cluster_mode=row['cluster_mode'], cluster_size=row['cluster_size'], domain_name=row['domain_name'], group_id=row['group_id'], scaling_group=row['scaling_group'], resource_policy=row['resource_policy'], resource_opts={}, requested_slots=ResourceSlot(), internal_data=row['internal_data'], target_sgroup_names=[], environ={ k: v for k, v in map(lambda s: s.split('=', maxsplit=1), row['environ']) }, mounts=row['mounts'], mount_map=row['mount_map'], bootstrap_script=row['bootstrap_script'], startup_command=row['startup_command'], preopen_ports=row['preopen_ports'], ) items[row['session_id']] = session session.kernels.append(KernelInfo( kernel_id=row['id'], session_id=row['session_id'], access_key=row['access_key'], cluster_role=row['cluster_role'], cluster_idx=row['cluster_idx'], cluster_hostname=row['cluster_hostname'], image_ref=ImageRef(row['image'], [row['registry']]), bootstrap_script=row['bootstrap_script'], startup_command=row['startup_command'], resource_opts=row['resource_opts'], requested_slots=row['occupied_slots'], )) session.requested_slots += row['occupied_slots'] # type: ignore merge_resource(session.resource_opts, row['resource_opts']) # type: ignore
async def migrate_shared_vfolders( cls, conn: SAConnection, deleted_user_uuid: uuid.UUID, target_user_uuid: uuid.UUID, target_user_email: str, ) -> int: """ Migrate shared virtual folders' ownership to a target user. If migrating virtual folder's name collides with target user's already existing folder, append random string to the migrating one. :param conn: DB connection :param deleted_user_uuid: user's UUID who will be deleted :param target_user_uuid: user's UUID who will get the ownership of virtual folders :return: number of deleted rows """ from . import vfolders, vfolder_invitations, vfolder_permissions # Gather target user's virtual folders' names. query = ( sa.select([vfolders.c.name]) .select_from(vfolders) .where(vfolders.c.user == target_user_uuid) ) existing_vfolder_names = [row.name async for row in conn.execute(query)] # Migrate shared virtual folders. # If virtual folder's name collides with target user's folder, # append random string to the name of the migrating folder. j = vfolder_permissions.join( vfolders, vfolder_permissions.c.vfolder == vfolders.c.id ) query = ( sa.select([vfolders.c.id, vfolders.c.name]) .select_from(j) .where(vfolders.c.user == deleted_user_uuid) ) migrate_updates = [] async for row in conn.execute(query): name = row.name if name in existing_vfolder_names: name += f'-{uuid.uuid4().hex[:10]}' migrate_updates.append({'vid': row.id, 'vname': name}) if migrate_updates: # Remove invitations and vfolder_permissions from target user. # Target user will be the new owner, and it does not make sense to have # invitation and shared permission for its own folder. migrate_vfolder_ids = [item['vid'] for item in migrate_updates] query = ( vfolder_invitations.delete() .where((vfolder_invitations.c.invitee == target_user_email) & (vfolder_invitations.c.vfolder.in_(migrate_vfolder_ids))) ) await conn.execute(query) query = ( vfolder_permissions.delete() .where((vfolder_permissions.c.user == target_user_uuid) & (vfolder_permissions.c.vfolder.in_(migrate_vfolder_ids))) ) await conn.execute(query) rowcount = 0 for item in migrate_updates: query = ( vfolders.update() .values( user=target_user_uuid, name=item['vname'], ) .where(vfolders.c.id == item['vid']) ) result = await conn.execute(query) rowcount += result.rowcount if rowcount > 0: log.info('{0} shared folders detected. migrated to user {1}', rowcount, target_user_uuid) return rowcount else: return 0
async def _list_pending_sessions( db_conn: SAConnection, sgroup_name: str, ) -> List[PendingSession]: query = ( sa.select([ kernels.c.id, kernels.c.status, kernels.c.image, kernels.c.registry, kernels.c.sess_type, kernels.c.sess_id, kernels.c.access_key, kernels.c.domain_name, kernels.c.group_id, kernels.c.scaling_group, kernels.c.occupied_slots, kernels.c.resource_opts, kernels.c.environ, kernels.c.mounts, kernels.c.mount_map, kernels.c.bootstrap_script, kernels.c.startup_command, kernels.c.internal_data, kernels.c.preopen_ports, keypairs.c.resource_policy, ]) .select_from(sa.join( kernels, keypairs, keypairs.c.access_key == kernels.c.access_key )) .where( (kernels.c.status == KernelStatus.PENDING) & ( (kernels.c.scaling_group == sgroup_name) | (kernels.c.scaling_group.is_(None)) ) ) .order_by(kernels.c.created_at) ) items = [] async for row in db_conn.execute(query): items.append(PendingSession( kernel_id=row['id'], access_key=row['access_key'], session_type=row['sess_type'], session_name=row['sess_id'], domain_name=row['domain_name'], group_id=row['group_id'], scaling_group=row['scaling_group'], image_ref=ImageRef(row['image'], [row['registry']]), resource_policy=row['resource_policy'], resource_opts=row['resource_opts'], requested_slots=row['occupied_slots'], internal_data=row['internal_data'], target_sgroup_names=[], environ={ k: v for k, v in map(lambda s: s.split('=', maxsplit=1), row['environ']) }, mounts=row['mounts'], mount_map=row['mount_map'], bootstrap_script=row['bootstrap_script'], startup_command=row['startup_command'], preopen_ports=row['preopen_ports'], )) return items
async def _get_tags_by_project(self, conn: SAConnection, project_id: str) -> List: query = sa.select([study_tags.c.tag_id ]).where(study_tags.c.study_id == project_id) return [row.tag_id async for row in conn.execute(query)]