async def reenter_txn(pool: SAEngine, conn: SAConnection): if conn is None: async with pool.acquire() as conn, conn.begin(): yield conn else: async with conn.begin_nested(): yield conn
async def _list_agents_by_sgroup( db_conn: SAConnection, sgroup_name: str, ) -> Sequence[AgentContext]: query = ( sa.select([ agents.c.id, agents.c.addr, agents.c.scaling_group, agents.c.available_slots, agents.c.occupied_slots, ], for_update=True) .select_from(agents) .where( (agents.c.status == AgentStatus.ALIVE) & (agents.c.scaling_group == sgroup_name) & (agents.c.schedulable == true()) ) ) items = [] async for row in db_conn.execute(query): item = AgentContext( row['id'], row['addr'], row['scaling_group'], row['available_slots'], row['occupied_slots'], ) items.append(item) return items
async def group_vfolder_mounted_to_active_kernels( cls, conn: SAConnection, group_id: uuid.UUID, ) -> bool: """ Check if no active kernel is using the group's virtual folders. :param conn: DB connection :param group_id: group's ID :return: True if a virtual folder is mounted to active kernels. """ from . import kernels, vfolders, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES query = (sa.select([ vfolders.c.id ]).select_from(vfolders).where(vfolders.c.group == group_id)) result = await conn.execute(query) rows = await result.fetchall() group_vfolder_ids = [row.id for row in rows] query = (sa.select([ kernels.c.mounts ]).select_from(kernels).where((kernels.c.group_id == group_id) & ( kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)))) async for row in conn.execute(query): for _mount in row['mounts']: try: vfolder_id = uuid.UUID(_mount[2]) except Exception: pass if vfolder_id in group_vfolder_ids: return True return False
async def __load_user_groups(self, conn: SAConnection, user_id: str) -> List[str]: user_groups: List[str] = [] query = select([user_to_groups.c.gid ]).where(user_to_groups.c.uid == user_id) async for row in conn.execute(query): user_groups.append(row[user_to_groups.c.gid]) return user_groups
async def __load_user_groups(self, conn: SAConnection, user_id: int) -> List[RowProxy]: user_groups: List[RowProxy] = [] query = (select([groups ]).select_from(groups.join(user_to_groups)).where( user_to_groups.c.uid == user_id)) async for row in conn.execute(query): user_groups.append(row) return user_groups
async def __load_projects( self, conn: SAConnection, query: str, user_id: int, user_groups: List[RowProxy], filter_by_services: Optional[List[Dict]] = None, ) -> Tuple[List[Dict[str, Any]], List[ProjectType]]: api_projects: List[Dict] = [] # API model-compatible projects db_projects: List[Dict] = [] # DB model-compatible projects project_types: List[ProjectType] = [] async for row in conn.execute(query): try: _check_project_permissions(row, user_id, user_groups, "read") await asyncio.get_event_loop().run_in_executor( None, ProjectAtDB.from_orm, row) except ProjectInvalidRightsError: continue except ValidationError as exc: log.warning( "project %s failed validation, please check. error: %s", f"{row.id=}", exc, ) continue prj = dict(row.items()) if filter_by_services: if not await project_uses_available_services( prj, filter_by_services): log.warning( "Project %s will not be listed for user %s since it has no access rights" " for one or more of the services that includes.", f"{row.id=}", f"{user_id=}", ) continue db_projects.append(prj) # NOTE: DO NOT nest _get_tags_by_project in async loop above !!! # FIXME: temporary avoids inner async loops issue https://github.com/aio-libs/aiopg/issues/535 for db_prj in db_projects: db_prj["tags"] = await self._get_tags_by_project( conn, project_id=db_prj["id"]) user_email = await self._get_user_email(conn, db_prj["prj_owner"]) api_projects.append(_convert_to_schema_names(db_prj, user_email)) project_types.append(db_prj["type"]) return (api_projects, project_types)
async def _list_existing_sessions( db_conn: SAConnection, sgroup: str, ) -> List[ExistingSession]: query = ( sa.select([ kernels.c.id, kernels.c.status, kernels.c.image, kernels.c.registry, kernels.c.sess_type, kernels.c.sess_id, kernels.c.access_key, kernels.c.domain_name, kernels.c.group_id, kernels.c.scaling_group, kernels.c.occupied_slots, kernels.c.resource_opts, kernels.c.environ, kernels.c.mounts, kernels.c.mount_map, kernels.c.startup_command, kernels.c.internal_data, keypairs.c.resource_policy, ]) .select_from(sa.join( kernels, keypairs, keypairs.c.access_key == kernels.c.access_key )) .where( (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & (kernels.c.scaling_group == sgroup) ) .order_by(kernels.c.created_at) ) items = [] async for row in db_conn.execute(query): items.append(ExistingSession( kernel_id=row['id'], access_key=row['access_key'], session_type=row['sess_type'], session_name=row['sess_id'], domain_name=row['domain_name'], group_id=row['group_id'], scaling_group=row['scaling_group'], image_ref=ImageRef(row['image'], [row['registry']]), occupying_slots=row['occupied_slots'], )) return items
async def _clusters_from_cluster_ids( conn: connection.SAConnection, cluster_ids: Iterable[PositiveInt], offset: int = 0, limit: Optional[int] = None, ) -> List[Cluster]: cluster_id_to_cluster: Dict[PositiveInt, Cluster] = {} async for row in conn.execute( sa.select([ clusters, cluster_to_groups.c.gid, cluster_to_groups.c.read, cluster_to_groups.c.write, cluster_to_groups.c.delete, ]).select_from( clusters.join( cluster_to_groups, clusters.c.id == cluster_to_groups.c.cluster_id, )).where(clusters.c.id.in_(cluster_ids)).offset(offset).limit( limit)): cluster_access_rights = { row[cluster_to_groups.c.gid]: ClusterAccessRights( **{ "read": row[cluster_to_groups.c.read], "write": row[cluster_to_groups.c.write], "delete": row[cluster_to_groups.c.delete], }) } cluster_id = row[clusters.c.id] if cluster_id not in cluster_id_to_cluster: cluster_id_to_cluster[cluster_id] = Cluster( id=cluster_id, name=row[clusters.c.name], description=row[clusters.c.description], type=row[clusters.c.type], owner=row[clusters.c.owner], endpoint=row[clusters.c.endpoint], authentication=row[clusters.c.authentication], thumbnail=row[clusters.c.thumbnail], access_rights=cluster_access_rights, ) else: cluster_id_to_cluster[cluster_id].access_rights.update( cluster_access_rights) return list(cluster_id_to_cluster.values())
async def __load_projects(self, conn: SAConnection, query) -> List[Dict]: api_projects: List[Dict] = [] # API model-compatible projects db_projects: List[Dict] = [] # DB model-compatible projects async for row in conn.execute(query): prj = dict(row.items()) log.debug("found project: %s", prj) db_projects.append(prj) # NOTE: DO NOT nest _get_tags_by_project in async loop above !!! # FIXME: temporary avoids inner async loops issue https://github.com/aio-libs/aiopg/issues/535 for db_prj in db_projects: db_prj["tags"] = await self._get_tags_by_project( conn, project_id=db_prj["id"]) user_email = await self._get_user_email(conn, db_prj["prj_owner"]) api_projects.append(_convert_to_schema_names(db_prj, user_email)) return api_projects
async def batch_multiresult( context: Mapping[str, Any], conn: SAConnection, query: sa.sql.Select, obj_type: Type[_GenericSQLBasedGQLObject], key_list: Iterable[_Key], key_getter: Callable[[RowProxy], _Key], ) -> Sequence[Sequence[_GenericSQLBasedGQLObject]]: """ A batched query adaptor for (key -> [item]) resolving patterns. """ objs_per_key: Dict[_Key, List[_GenericSQLBasedGQLObject]] objs_per_key = collections.OrderedDict() for key in key_list: objs_per_key[key] = list() async for row in conn.execute(query): objs_per_key[key_getter(row)].append(obj_type.from_row(context, row)) return [*objs_per_key.values()]
async def delete_vfolders( cls, conn: SAConnection, user_uuid: uuid.UUID, config_server, ) -> int: """ Delete user's all virtual folders as well as their physical data. :param conn: DB connection :param user_uuid: user's UUID to delete virtual folders :return: number of deleted rows """ from . import vfolders mount_prefix = Path(await config_server.get('volumes/_mount')) fs_prefix = await config_server.get('volumes/_fsprefix') fs_prefix = Path(fs_prefix.lstrip('/')) query = ( sa.select([vfolders.c.id, vfolders.c.host, vfolders.c.unmanaged_path]) .select_from(vfolders) .where(vfolders.c.user == user_uuid) ) async for row in conn.execute(query): if row['unmanaged_path']: folder_path = Path(row['unmanaged_path']) else: folder_path = (mount_prefix / row['host'] / fs_prefix / row['id'].hex) log.info('deleting physical files: {0}', folder_path) try: loop = current_loop() await loop.run_in_executor(None, lambda: shutil.rmtree(folder_path)) # type: ignore except IOError: pass query = ( vfolders.delete() .where(vfolders.c.user == user_uuid) ) result = await conn.execute(query) if result.rowcount > 0: log.info('deleted {0} user\'s virtual folders ({1})', result.rowcount, user_uuid) return result.rowcount
async def list_projects_access_rights( conn: SAConnection, user_id: int) -> Dict[ProjectID, AccessRights]: """ Returns access-rights of user (user_id) over all OWNED or SHARED projects """ user_group_ids: List[int] = await _get_user_groups_ids(conn, user_id) smt = text(f"""\ SELECT uuid, access_rights FROM projects WHERE ( prj_owner = {user_id} OR jsonb_exists_any( access_rights, ( SELECT ARRAY( SELECT gid::TEXT FROM user_to_groups WHERE uid = {user_id} ) ) ) ) """) projects_access_rights = {} async for row in conn.execute(smt): assert isinstance(row.access_rights, dict) assert isinstance(row.uuid, ProjectID) if row.access_rights: # TODO: access_rights should be direclty filtered from result in stm instead calling again user_group_ids projects_access_rights[row.uuid] = _aggregate_access_rights( row.access_rights, user_group_ids) else: # backwards compatibility # - no access_rights defined BUT project is owned projects_access_rights[row.uuid] = AccessRights.all() return projects_access_rights
async def migrate_shared_vfolders( cls, conn: SAConnection, deleted_user_uuid: uuid.UUID, target_user_uuid: uuid.UUID, target_user_email: str, ) -> int: """ Migrate shared virtual folders' ownership to a target user. If migrating virtual folder's name collides with target user's already existing folder, append random string to the migrating one. :param conn: DB connection :param deleted_user_uuid: user's UUID who will be deleted :param target_user_uuid: user's UUID who will get the ownership of virtual folders :return: number of deleted rows """ from . import vfolders, vfolder_invitations, vfolder_permissions # Gather target user's virtual folders' names. query = ( sa.select([vfolders.c.name]) .select_from(vfolders) .where(vfolders.c.user == target_user_uuid) ) existing_vfolder_names = [row.name async for row in conn.execute(query)] # Migrate shared virtual folders. # If virtual folder's name collides with target user's folder, # append random string to the name of the migrating folder. j = vfolder_permissions.join( vfolders, vfolder_permissions.c.vfolder == vfolders.c.id ) query = ( sa.select([vfolders.c.id, vfolders.c.name]) .select_from(j) .where(vfolders.c.user == deleted_user_uuid) ) migrate_updates = [] async for row in conn.execute(query): name = row.name if name in existing_vfolder_names: name += f'-{uuid.uuid4().hex[:10]}' migrate_updates.append({'vid': row.id, 'vname': name}) if migrate_updates: # Remove invitations and vfolder_permissions from target user. # Target user will be the new owner, and it does not make sense to have # invitation and shared permission for its own folder. migrate_vfolder_ids = [item['vid'] for item in migrate_updates] query = ( vfolder_invitations.delete() .where((vfolder_invitations.c.invitee == target_user_email) & (vfolder_invitations.c.vfolder.in_(migrate_vfolder_ids))) ) await conn.execute(query) query = ( vfolder_permissions.delete() .where((vfolder_permissions.c.user == target_user_uuid) & (vfolder_permissions.c.vfolder.in_(migrate_vfolder_ids))) ) await conn.execute(query) rowcount = 0 for item in migrate_updates: query = ( vfolders.update() .values( user=target_user_uuid, name=item['vname'], ) .where(vfolders.c.id == item['vid']) ) result = await conn.execute(query) rowcount += result.rowcount if rowcount > 0: log.info('{0} shared folders detected. migrated to user {1}', rowcount, target_user_uuid) return rowcount else: return 0
async def test_basic_workflow(project: RowProxy, conn: SAConnection): # git init async with conn.begin(): # create repo repo_orm = ReposOrm(conn) repo_id = await repo_orm.insert(project_uuid=project.uuid) assert repo_id is not None assert isinstance(repo_id, int) repo_orm.set_filter(rowid=repo_id) repo = await repo_orm.fetch() assert repo assert repo.project_uuid == project.uuid assert repo.project_checksum is None assert repo.created == repo.modified # create main branch branches_orm = BranchesOrm(conn) branch_id = await branches_orm.insert(repo_id=repo.id) assert branch_id is not None assert isinstance(branch_id, int) branches_orm.set_filter(rowid=branch_id) main_branch: Optional[RowProxy] = await branches_orm.fetch() assert main_branch assert main_branch.name == "main", "Expected 'main' as default branch" assert main_branch.head_commit_id is None, "still not assigned" assert main_branch.created == main_branch.modified # assign head branch heads_orm = HeadsOrm(conn) await heads_orm.insert(repo_id=repo.id, head_branch_id=branch_id) heads_orm.set_filter(rowid=repo.id) head = await heads_orm.fetch() assert head # # create first commit -- TODO: separate tests # fetch a *full copy* of the project (WC) repo = await repo_orm.fetch("id project_uuid project_checksum") assert repo project_orm = ProjectsOrm(conn).set_filter(uuid=repo.project_uuid) project_wc = await project_orm.fetch() assert project_wc assert project == project_wc # call external function to compute checksum checksum = eval_checksum(project_wc.workbench) assert repo.project_checksum != checksum # take snapshot <=> git add & commit async with conn.begin(): snapshot_checksum = await add_snapshot(project_wc, checksum, repo, conn) # get HEAD = repo.branch_id -> .head_commit_id assert head.repo_id == repo.id branches_orm.set_filter(head.head_branch_id) branch = await branches_orm.fetch("head_commit_id name") assert branch assert branch.name == "main" assert branch.head_commit_id is None, "First commit" # create commit commits_orm = CommitsOrm(conn) commit_id = await commits_orm.insert( repo_id=repo.id, parent_commit_id=branch.head_commit_id, snapshot_checksum=snapshot_checksum, message="first commit", ) assert commit_id assert isinstance(commit_id, int) # update branch head await branches_orm.update(head_commit_id=commit_id) # update checksum cache await repo_orm.update(project_checksum=snapshot_checksum) # log history commits = await commits_orm.fetch_all() assert len(commits) == 1 assert commits[0].id == commit_id # tag tag_orm = TagsOrm(conn) tag_id = await tag_orm.insert( repo_id=repo.id, commit_id=commit_id, name="v1", ) assert tag_id is not None assert isinstance(tag_id, int) tag = await tag_orm.fetch(rowid=tag_id) assert tag assert tag.name == "v1" ############# NEW COMMIT ##################### # user add some changes repo = await repo_orm.fetch() assert repo project_orm.set_filter(uuid=repo.project_uuid) assert project_orm.is_filter_set() await project_orm.update(workbench={"node": { "input": 3, }}) project_wc = await project_orm.fetch("workbench ui") assert project_wc assert project.workbench != project_wc.workbench # get HEAD = repo.branch_id -> .head_commit_id head = await heads_orm.fetch() assert head branch = await branches_orm.fetch("head_commit_id", rowid=head.head_branch_id) assert branch # TODO: get subquery ... and compose head_commit = await commits_orm.fetch(rowid=branch.head_commit_id) assert head_commit # compare checksums between wc and HEAD checksum = eval_checksum(project_wc.workbench) assert head_commit.snapshot_checksum != checksum # updates wc checksum cache await repo_orm.update(project_checksum=checksum) # take snapshot = add & commit async with conn.begin(): snapshot_uuid: str = await add_snapshot(project_wc, checksum, repo, conn) commit_id = await commits_orm.insert( repo_id=head_commit.repo_id, parent_commit_id=head_commit.id, snapshot_checksum=checksum, message="second commit", ) assert commit_id assert isinstance(commit_id, int) # update branch head await branches_orm.update(head_commit_id=commit_id) # log history commits = await commits_orm.fetch_all() assert len(commits) == 2 assert commits[1].id == commit_id
async def _list_existing_sessions( db_conn: SAConnection, sgroup: str, ) -> List[ExistingSession]: query = ( sa.select([ kernels.c.id, kernels.c.status, kernels.c.image, kernels.c.cluster_mode, kernels.c.cluster_size, kernels.c.cluster_role, kernels.c.cluster_idx, kernels.c.cluster_hostname, kernels.c.registry, kernels.c.session_id, kernels.c.session_type, kernels.c.session_name, kernels.c.access_key, kernels.c.domain_name, kernels.c.group_id, kernels.c.scaling_group, kernels.c.occupied_slots, kernels.c.resource_opts, kernels.c.environ, kernels.c.mounts, kernels.c.mount_map, kernels.c.startup_command, kernels.c.internal_data, keypairs.c.resource_policy, ]) .select_from(sa.join( kernels, keypairs, keypairs.c.access_key == kernels.c.access_key )) .where( (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & (kernels.c.scaling_group == sgroup) ) .order_by(kernels.c.created_at) ) items: MutableMapping[str, ExistingSession] = {} async for row in db_conn.execute(query): if _session := items.get(row['session_id']): session = _session else: session = ExistingSession( kernels=[], access_key=row['access_key'], session_id=row['session_id'], session_type=row['session_type'], session_name=row['session_name'], cluster_mode=row['cluster_mode'], cluster_size=row['cluster_size'], domain_name=row['domain_name'], group_id=row['group_id'], scaling_group=row['scaling_group'], occupying_slots=ResourceSlot(), ) items[row['session_id']] = session # TODO: support multi-container sessions session.kernels.append(KernelInfo( # type: ignore kernel_id=row['id'], session_id=row['session_id'], access_key=row['access_key'], cluster_role=row['cluster_role'], cluster_idx=row['cluster_idx'], cluster_hostname=row['cluster_hostname'], image_ref=ImageRef(row['image'], [row['registry']]), bootstrap_script=None, startup_command=None, resource_opts=row['resource_opts'], requested_slots=row['occupied_slots'], )) session.occupying_slots += row['occupied_slots'] # type: ignore
async def _get_tags_by_project(self, conn: SAConnection, project_id: str) -> List: query = sa.select([study_tags.c.tag_id ]).where(study_tags.c.study_id == project_id) return [row.tag_id async for row in conn.execute(query)]
async def _schedule_in_sgroup(db_conn: SAConnection, sgroup_name: str) -> None: async with db_conn.begin(): scheduler = await self._load_scheduler(db_conn, sgroup_name) pending_sessions = await _list_pending_sessions(db_conn, sgroup_name) existing_sessions = await _list_existing_sessions(db_conn, sgroup_name) log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})', sgroup_name, len(pending_sessions), len(existing_sessions)) zero = ResourceSlot() while len(pending_sessions) > 0: async with db_conn.begin(): candidate_agents = await _list_agents_by_sgroup(db_conn, sgroup_name) total_capacity = sum((ag.available_slots for ag in candidate_agents), zero) picked_session_id = scheduler.pick_session( total_capacity, pending_sessions, existing_sessions, ) if picked_session_id is None: # no session is picked. # continue to next sgroup. return for picked_idx, sess_ctx in enumerate(pending_sessions): if sess_ctx.session_id == picked_session_id: break else: # no matching entry for picked session? raise RuntimeError('should not reach here') sess_ctx = pending_sessions.pop(picked_idx) log_args = ( sess_ctx.session_id, sess_ctx.session_type, sess_ctx.session_name, sess_ctx.access_key, sess_ctx.cluster_mode, ) log.debug(log_fmt + 'try-scheduling', *log_args) session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]] async with db_conn.begin(): predicates: Sequence[Awaitable[PredicateResult]] = [ check_reserved_batch_session(db_conn, sched_ctx, sess_ctx), check_concurrency(db_conn, sched_ctx, sess_ctx), check_dependencies(db_conn, sched_ctx, sess_ctx), check_keypair_resource_limit(db_conn, sched_ctx, sess_ctx), check_group_resource_limit(db_conn, sched_ctx, sess_ctx), check_domain_resource_limit(db_conn, sched_ctx, sess_ctx), check_scaling_group(db_conn, sched_ctx, sess_ctx), ] check_results: List[Union[Exception, PredicateResult]] = [] for check in predicates: try: check_results.append(await check) except Exception as e: log.exception(log_fmt + 'predicate-error', *log_args) check_results.append(e) has_failure = False for result in check_results: if isinstance(result, Exception): has_failure = True continue if not result.passed: has_failure = True if has_failure: log.debug(log_fmt + 'predicate-checks-failed', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) # Predicate failures are *NOT* permanent errors. # We need to retry the scheduling afterwards. continue if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE: # Assign agent resource per session. try: agent_id = scheduler.assign_agent_for_session(candidate_agents, sess_ctx) if agent_id is None: raise InstanceNotAvailable agent_alloc_ctx = await _reserve_agent( sched_ctx, db_conn, sgroup_name, agent_id, sess_ctx.requested_slots, ) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) raise except Exception: log.exception(log_fmt + 'unexpected-error, during agent allocation', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) raise query = kernels.update().values({ 'agent': agent_alloc_ctx.agent_id, 'agent_addr': agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_changed': datetime.now(tzutc()), }).where(kernels.c.session_id == sess_ctx.session_id) await db_conn.execute(query) session_agent_binding = ( sess_ctx, [ KernelAgentBinding(kernel, agent_alloc_ctx) for kernel in sess_ctx.kernels ], ) elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE: # Assign agent resource per kernel in the session. agent_query_extra_conds = None if len(sess_ctx.kernels) >= 2: # We should use agents that supports overlay networking. agent_query_extra_conds = (agents.c.clusterized) kernel_agent_bindings = [] for kernel in sess_ctx.kernels: try: agent_id = scheduler.assign_agent_for_kernel(candidate_agents, kernel) if agent_id is None: raise InstanceNotAvailable agent_alloc_ctx = await _reserve_agent( sched_ctx, db_conn, sgroup_name, agent_id, kernel.requested_slots, extra_conds=agent_query_extra_conds, ) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) # continue raise except Exception: log.exception(log_fmt + 'unexpected-error, during agent allocation', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) # continue raise # TODO: if error occurs for one kernel, should we cancel all others? query = kernels.update().values({ 'agent': agent_alloc_ctx.agent_id, 'agent_addr': agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_changed': datetime.now(tzutc()), }).where(kernels.c.id == kernel.kernel_id) await db_conn.execute(query) kernel_agent_bindings.append(KernelAgentBinding(kernel, agent_alloc_ctx)) session_agent_binding = (sess_ctx, kernel_agent_bindings) start_task_args.append( ( log_args, sched_ctx, session_agent_binding, check_results, ) )
async def _schedule_single_node_session( self, sched_ctx: SchedulingContext, scheduler: AbstractScheduler, agent_db_conn: SAConnection, kernel_db_conn: SAConnection, sgroup_name: str, candidate_agents: Sequence[AgentContext], sess_ctx: PendingSession, check_results: List[Tuple[str, Union[Exception, PredicateResult]]], ) -> Tuple[PendingSession, List[KernelAgentBinding]]: # Assign agent resource per session. log_fmt = _log_fmt.get() log_args = _log_args.get() try: agent_id = scheduler.assign_agent_for_session( candidate_agents, sess_ctx) if agent_id is None: raise InstanceNotAvailable async with agent_db_conn.begin(): agent_alloc_ctx = await _reserve_agent( sched_ctx, agent_db_conn, sgroup_name, agent_id, sess_ctx.requested_slots, ) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "no-available-instances", 'status_data': sql_json_increment(kernels.c.status_data, ('scheduler', 'retries'), parent_updates={ 'last_try': datetime.now(tzutc()).isoformat(), }), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) raise except Exception as e: log.exception( log_fmt + 'unexpected-error, during agent allocation', *log_args, ) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "scheduler-error", 'status_data': convert_to_status_data(e), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) raise async with kernel_db_conn.begin(): query = kernels.update().values({ 'agent': agent_alloc_ctx.agent_id, 'agent_addr': agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_data': {}, 'status_changed': datetime.now(tzutc()), }).where(kernels.c.session_id == sess_ctx.session_id) await kernel_db_conn.execute(query) return ( sess_ctx, [ KernelAgentBinding(kernel, agent_alloc_ctx) for kernel in sess_ctx.kernels ], )
async def _list_pending_sessions( db_conn: SAConnection, sgroup_name: str, ) -> List[PendingSession]: query = ( sa.select([ kernels.c.id, kernels.c.status, kernels.c.image, kernels.c.cluster_mode, kernels.c.cluster_size, kernels.c.cluster_role, kernels.c.cluster_idx, kernels.c.cluster_hostname, kernels.c.registry, kernels.c.session_id, kernels.c.session_type, kernels.c.session_name, kernels.c.access_key, kernels.c.domain_name, kernels.c.group_id, kernels.c.scaling_group, kernels.c.occupied_slots, kernels.c.resource_opts, kernels.c.environ, kernels.c.mounts, kernels.c.mount_map, kernels.c.bootstrap_script, kernels.c.startup_command, kernels.c.internal_data, kernels.c.preopen_ports, keypairs.c.resource_policy, ]) .select_from(sa.join( kernels, keypairs, keypairs.c.access_key == kernels.c.access_key )) .where( (kernels.c.status == KernelStatus.PENDING) & ( (kernels.c.scaling_group == sgroup_name) | (kernels.c.scaling_group.is_(None)) ) ) .order_by(kernels.c.created_at) ) # TODO: extend for multi-container sessions items: MutableMapping[str, PendingSession] = {} async for row in db_conn.execute(query): if _session := items.get(row['session_id']): session = _session else: session = PendingSession( kernels=[], access_key=row['access_key'], session_id=row['session_id'], session_type=row['session_type'], session_name=row['session_name'], cluster_mode=row['cluster_mode'], cluster_size=row['cluster_size'], domain_name=row['domain_name'], group_id=row['group_id'], scaling_group=row['scaling_group'], resource_policy=row['resource_policy'], resource_opts={}, requested_slots=ResourceSlot(), internal_data=row['internal_data'], target_sgroup_names=[], environ={ k: v for k, v in map(lambda s: s.split('=', maxsplit=1), row['environ']) }, mounts=row['mounts'], mount_map=row['mount_map'], bootstrap_script=row['bootstrap_script'], startup_command=row['startup_command'], preopen_ports=row['preopen_ports'], ) items[row['session_id']] = session session.kernels.append(KernelInfo( kernel_id=row['id'], session_id=row['session_id'], access_key=row['access_key'], cluster_role=row['cluster_role'], cluster_idx=row['cluster_idx'], cluster_hostname=row['cluster_hostname'], image_ref=ImageRef(row['image'], [row['registry']]), bootstrap_script=row['bootstrap_script'], startup_command=row['startup_command'], resource_opts=row['resource_opts'], requested_slots=row['occupied_slots'], )) session.requested_slots += row['occupied_slots'] # type: ignore merge_resource(session.resource_opts, row['resource_opts']) # type: ignore
async def _schedule_multi_node_session( self, sched_ctx: SchedulingContext, scheduler: AbstractScheduler, agent_db_conn: SAConnection, kernel_db_conn: SAConnection, sgroup_name: str, candidate_agents: Sequence[AgentContext], sess_ctx: PendingSession, check_results: List[Tuple[str, Union[Exception, PredicateResult]]], ) -> Tuple[PendingSession, List[KernelAgentBinding]]: # Assign agent resource per kernel in the session. log_fmt = _log_fmt.get() log_args = _log_args.get() agent_query_extra_conds = None kernel_agent_bindings: List[KernelAgentBinding] = [] async with agent_db_conn.begin(isolation_level="REPEATABLE READ"): # This outer transaction is rolled back when any exception occurs inside, # including scheduling failures of a kernel. # It ensures that occupied_slots are recovered when there are partial # scheduling failures. for kernel in sess_ctx.kernels: try: agent_id = scheduler.assign_agent_for_kernel( candidate_agents, kernel) if agent_id is None: raise InstanceNotAvailable async with agent_db_conn.begin_nested(): agent_alloc_ctx = await _reserve_agent( sched_ctx, agent_db_conn, sgroup_name, agent_id, kernel.requested_slots, extra_conds=agent_query_extra_conds, ) candidate_agents = await _list_agents_by_sgroup( agent_db_conn, sgroup_name) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "no-available-instances", 'status_data': sql_json_increment(kernels.c.status_data, ('scheduler', 'retries'), parent_updates={ 'last_try': datetime.now( tzutc()).isoformat(), }), }).where(kernels.c.id == kernel.kernel_id) await kernel_db_conn.execute(query) raise except Exception as e: log.exception( log_fmt + 'unexpected-error, during agent allocation', *log_args, ) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "scheduler-error", 'status_data': convert_to_status_data(e), }).where(kernels.c.id == kernel.kernel_id) await kernel_db_conn.execute(query) raise else: kernel_agent_bindings.append( KernelAgentBinding(kernel, agent_alloc_ctx)) if len(kernel_agent_bindings) == len(sess_ctx.kernels): # Proceed to PREPARING only when all kernels are successfully scheduled. async with kernel_db_conn.begin(): for binding in kernel_agent_bindings: query = kernels.update().values({ 'agent': binding.agent_alloc_ctx.agent_id, 'agent_addr': binding.agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_data': {}, 'status_changed': datetime.now(tzutc()), }).where(kernels.c.id == binding.kernel.kernel_id) await kernel_db_conn.execute(query) return (sess_ctx, kernel_agent_bindings)
async def _list_pending_sessions( db_conn: SAConnection, sgroup_name: str, ) -> List[PendingSession]: query = ( sa.select([ kernels.c.id, kernels.c.status, kernels.c.image, kernels.c.registry, kernels.c.sess_type, kernels.c.sess_id, kernels.c.access_key, kernels.c.domain_name, kernels.c.group_id, kernels.c.scaling_group, kernels.c.occupied_slots, kernels.c.resource_opts, kernels.c.environ, kernels.c.mounts, kernels.c.mount_map, kernels.c.bootstrap_script, kernels.c.startup_command, kernels.c.internal_data, kernels.c.preopen_ports, keypairs.c.resource_policy, ]) .select_from(sa.join( kernels, keypairs, keypairs.c.access_key == kernels.c.access_key )) .where( (kernels.c.status == KernelStatus.PENDING) & ( (kernels.c.scaling_group == sgroup_name) | (kernels.c.scaling_group.is_(None)) ) ) .order_by(kernels.c.created_at) ) items = [] async for row in db_conn.execute(query): items.append(PendingSession( kernel_id=row['id'], access_key=row['access_key'], session_type=row['sess_type'], session_name=row['sess_id'], domain_name=row['domain_name'], group_id=row['group_id'], scaling_group=row['scaling_group'], image_ref=ImageRef(row['image'], [row['registry']]), resource_policy=row['resource_policy'], resource_opts=row['resource_opts'], requested_slots=row['occupied_slots'], internal_data=row['internal_data'], target_sgroup_names=[], environ={ k: v for k, v in map(lambda s: s.split('=', maxsplit=1), row['environ']) }, mounts=row['mounts'], mount_map=row['mount_map'], bootstrap_script=row['bootstrap_script'], startup_command=row['startup_command'], preopen_ports=row['preopen_ports'], )) return items
async def _schedule_in_sgroup( self, sched_ctx: SchedulingContext, agent_db_conn: SAConnection, kernel_db_conn: SAConnection, sgroup_name: str, ) -> List[StartTaskArgs]: async with kernel_db_conn.begin(): scheduler = await self._load_scheduler(kernel_db_conn, sgroup_name) pending_sessions = await _list_pending_sessions( kernel_db_conn, sgroup_name) existing_sessions = await _list_existing_sessions( kernel_db_conn, sgroup_name) log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})', sgroup_name, len(pending_sessions), len(existing_sessions)) zero = ResourceSlot() args_list: List[StartTaskArgs] = [] while len(pending_sessions) > 0: async with agent_db_conn.begin(): candidate_agents = await _list_agents_by_sgroup( agent_db_conn, sgroup_name) total_capacity = sum( (ag.available_slots for ag in candidate_agents), zero) picked_session_id = scheduler.pick_session( total_capacity, pending_sessions, existing_sessions, ) if picked_session_id is None: # no session is picked. # continue to next sgroup. return [] for picked_idx, sess_ctx in enumerate(pending_sessions): if sess_ctx.session_id == picked_session_id: break else: # no matching entry for picked session? raise RuntimeError('should not reach here') sess_ctx = pending_sessions.pop(picked_idx) log_fmt = 'schedule(s:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): ' log_args = ( sess_ctx.session_id, sess_ctx.session_type, sess_ctx.session_name, sess_ctx.access_key, sess_ctx.cluster_mode, ) _log_fmt.set(log_fmt) _log_args.set(log_args) log.debug(log_fmt + 'try-scheduling', *log_args) session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]] async with kernel_db_conn.begin(): predicates: Sequence[Tuple[ str, Awaitable[PredicateResult]]] = [ ( 'reserved_time', check_reserved_batch_session( kernel_db_conn, sched_ctx, sess_ctx), ), ('concurrency', check_concurrency(kernel_db_conn, sched_ctx, sess_ctx)), ('dependencies', check_dependencies(kernel_db_conn, sched_ctx, sess_ctx)), ( 'keypair_resource_limit', check_keypair_resource_limit( kernel_db_conn, sched_ctx, sess_ctx), ), ( 'user_group_resource_limit', check_group_resource_limit(kernel_db_conn, sched_ctx, sess_ctx), ), ( 'domain_resource_limit', check_domain_resource_limit( kernel_db_conn, sched_ctx, sess_ctx), ), ( 'scaling_group_resource_limit', check_scaling_group(kernel_db_conn, sched_ctx, sess_ctx), ), ] check_results: List[Tuple[str, Union[Exception, PredicateResult]]] = [] for predicate_name, check_coro in predicates: try: check_results.append((predicate_name, await check_coro)) except Exception as e: log.exception(log_fmt + 'predicate-error', *log_args) check_results.append((predicate_name, e)) has_failure = False # has_permanent_failure = False failed_predicates = [] passed_predicates = [] for predicate_name, result in check_results: if isinstance(result, Exception): has_failure = True failed_predicates.append({ 'name': predicate_name, 'msg': repr(result), }) continue if result.passed: passed_predicates.append({ 'name': predicate_name, }) else: failed_predicates.append({ 'name': predicate_name, 'msg': result.message or "", }) has_failure = True # if result.permanent: # has_permanent_failure = True if has_failure: log.debug(log_fmt + 'predicate-checks-failed (temporary)', *log_args) # TODO: handle has_permanent_failure as cancellation # - An early implementation of it has caused DB query blocking due to # the inclusion of the kernels.status field. :( # Let's fix it. async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "predicate-checks-failed", 'status_data': sql_json_increment( kernels.c.status_data, ('scheduler', 'retries'), parent_updates={ 'last_try': datetime.now(tzutc()).isoformat(), 'failed_predicates': failed_predicates, 'passed_predicates': passed_predicates, }), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) # Predicate failures are *NOT* permanent errors. # We need to retry the scheduling afterwards. continue else: async with kernel_db_conn.begin(): query = kernels.update().values({ 'status_data': sql_json_merge( kernels.c.status_data, ('scheduler', ), { 'last_try': datetime.now(tzutc()).isoformat(), 'failed_predicates': failed_predicates, 'passed_predicates': passed_predicates, }), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE: session_agent_binding = await self._schedule_single_node_session( sched_ctx, scheduler, agent_db_conn, kernel_db_conn, sgroup_name, candidate_agents, sess_ctx, check_results, ) elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE: session_agent_binding = await self._schedule_multi_node_session( sched_ctx, scheduler, agent_db_conn, kernel_db_conn, sgroup_name, candidate_agents, sess_ctx, check_results, ) else: raise RuntimeError( f"should not reach here; unknown cluster_mode: {sess_ctx.cluster_mode}" ) args_list.append(( log_args, sched_ctx, session_agent_binding, check_results, )) return args_list