async def reenter_txn(pool: SAEngine, conn: SAConnection): if conn is None: async with pool.acquire() as conn, conn.begin(): yield conn else: async with conn.begin_nested(): yield conn
async def _schedule_in_sgroup(db_conn: SAConnection, sgroup_name: str) -> None: async with db_conn.begin(): scheduler = await self._load_scheduler(db_conn, sgroup_name) pending_sessions = await _list_pending_sessions(db_conn, sgroup_name) existing_sessions = await _list_existing_sessions(db_conn, sgroup_name) log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})', sgroup_name, len(pending_sessions), len(existing_sessions)) zero = ResourceSlot() while len(pending_sessions) > 0: async with db_conn.begin(): candidate_agents = await _list_agents_by_sgroup(db_conn, sgroup_name) total_capacity = sum((ag.available_slots for ag in candidate_agents), zero) picked_session_id = scheduler.pick_session( total_capacity, pending_sessions, existing_sessions, ) if picked_session_id is None: # no session is picked. # continue to next sgroup. return for picked_idx, sess_ctx in enumerate(pending_sessions): if sess_ctx.session_id == picked_session_id: break else: # no matching entry for picked session? raise RuntimeError('should not reach here') sess_ctx = pending_sessions.pop(picked_idx) log_args = ( sess_ctx.session_id, sess_ctx.session_type, sess_ctx.session_name, sess_ctx.access_key, sess_ctx.cluster_mode, ) log.debug(log_fmt + 'try-scheduling', *log_args) session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]] async with db_conn.begin(): predicates: Sequence[Awaitable[PredicateResult]] = [ check_reserved_batch_session(db_conn, sched_ctx, sess_ctx), check_concurrency(db_conn, sched_ctx, sess_ctx), check_dependencies(db_conn, sched_ctx, sess_ctx), check_keypair_resource_limit(db_conn, sched_ctx, sess_ctx), check_group_resource_limit(db_conn, sched_ctx, sess_ctx), check_domain_resource_limit(db_conn, sched_ctx, sess_ctx), check_scaling_group(db_conn, sched_ctx, sess_ctx), ] check_results: List[Union[Exception, PredicateResult]] = [] for check in predicates: try: check_results.append(await check) except Exception as e: log.exception(log_fmt + 'predicate-error', *log_args) check_results.append(e) has_failure = False for result in check_results: if isinstance(result, Exception): has_failure = True continue if not result.passed: has_failure = True if has_failure: log.debug(log_fmt + 'predicate-checks-failed', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) # Predicate failures are *NOT* permanent errors. # We need to retry the scheduling afterwards. continue if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE: # Assign agent resource per session. try: agent_id = scheduler.assign_agent_for_session(candidate_agents, sess_ctx) if agent_id is None: raise InstanceNotAvailable agent_alloc_ctx = await _reserve_agent( sched_ctx, db_conn, sgroup_name, agent_id, sess_ctx.requested_slots, ) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) raise except Exception: log.exception(log_fmt + 'unexpected-error, during agent allocation', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) raise query = kernels.update().values({ 'agent': agent_alloc_ctx.agent_id, 'agent_addr': agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_changed': datetime.now(tzutc()), }).where(kernels.c.session_id == sess_ctx.session_id) await db_conn.execute(query) session_agent_binding = ( sess_ctx, [ KernelAgentBinding(kernel, agent_alloc_ctx) for kernel in sess_ctx.kernels ], ) elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE: # Assign agent resource per kernel in the session. agent_query_extra_conds = None if len(sess_ctx.kernels) >= 2: # We should use agents that supports overlay networking. agent_query_extra_conds = (agents.c.clusterized) kernel_agent_bindings = [] for kernel in sess_ctx.kernels: try: agent_id = scheduler.assign_agent_for_kernel(candidate_agents, kernel) if agent_id is None: raise InstanceNotAvailable agent_alloc_ctx = await _reserve_agent( sched_ctx, db_conn, sgroup_name, agent_id, kernel.requested_slots, extra_conds=agent_query_extra_conds, ) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) # continue raise except Exception: log.exception(log_fmt + 'unexpected-error, during agent allocation', *log_args) await _invoke_failure_callbacks( db_conn, sched_ctx, sess_ctx, check_results, ) # continue raise # TODO: if error occurs for one kernel, should we cancel all others? query = kernels.update().values({ 'agent': agent_alloc_ctx.agent_id, 'agent_addr': agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_changed': datetime.now(tzutc()), }).where(kernels.c.id == kernel.kernel_id) await db_conn.execute(query) kernel_agent_bindings.append(KernelAgentBinding(kernel, agent_alloc_ctx)) session_agent_binding = (sess_ctx, kernel_agent_bindings) start_task_args.append( ( log_args, sched_ctx, session_agent_binding, check_results, ) )
async def test_basic_workflow(project: RowProxy, conn: SAConnection): # git init async with conn.begin(): # create repo repo_orm = ReposOrm(conn) repo_id = await repo_orm.insert(project_uuid=project.uuid) assert repo_id is not None assert isinstance(repo_id, int) repo_orm.set_filter(rowid=repo_id) repo = await repo_orm.fetch() assert repo assert repo.project_uuid == project.uuid assert repo.project_checksum is None assert repo.created == repo.modified # create main branch branches_orm = BranchesOrm(conn) branch_id = await branches_orm.insert(repo_id=repo.id) assert branch_id is not None assert isinstance(branch_id, int) branches_orm.set_filter(rowid=branch_id) main_branch: Optional[RowProxy] = await branches_orm.fetch() assert main_branch assert main_branch.name == "main", "Expected 'main' as default branch" assert main_branch.head_commit_id is None, "still not assigned" assert main_branch.created == main_branch.modified # assign head branch heads_orm = HeadsOrm(conn) await heads_orm.insert(repo_id=repo.id, head_branch_id=branch_id) heads_orm.set_filter(rowid=repo.id) head = await heads_orm.fetch() assert head # # create first commit -- TODO: separate tests # fetch a *full copy* of the project (WC) repo = await repo_orm.fetch("id project_uuid project_checksum") assert repo project_orm = ProjectsOrm(conn).set_filter(uuid=repo.project_uuid) project_wc = await project_orm.fetch() assert project_wc assert project == project_wc # call external function to compute checksum checksum = eval_checksum(project_wc.workbench) assert repo.project_checksum != checksum # take snapshot <=> git add & commit async with conn.begin(): snapshot_checksum = await add_snapshot(project_wc, checksum, repo, conn) # get HEAD = repo.branch_id -> .head_commit_id assert head.repo_id == repo.id branches_orm.set_filter(head.head_branch_id) branch = await branches_orm.fetch("head_commit_id name") assert branch assert branch.name == "main" assert branch.head_commit_id is None, "First commit" # create commit commits_orm = CommitsOrm(conn) commit_id = await commits_orm.insert( repo_id=repo.id, parent_commit_id=branch.head_commit_id, snapshot_checksum=snapshot_checksum, message="first commit", ) assert commit_id assert isinstance(commit_id, int) # update branch head await branches_orm.update(head_commit_id=commit_id) # update checksum cache await repo_orm.update(project_checksum=snapshot_checksum) # log history commits = await commits_orm.fetch_all() assert len(commits) == 1 assert commits[0].id == commit_id # tag tag_orm = TagsOrm(conn) tag_id = await tag_orm.insert( repo_id=repo.id, commit_id=commit_id, name="v1", ) assert tag_id is not None assert isinstance(tag_id, int) tag = await tag_orm.fetch(rowid=tag_id) assert tag assert tag.name == "v1" ############# NEW COMMIT ##################### # user add some changes repo = await repo_orm.fetch() assert repo project_orm.set_filter(uuid=repo.project_uuid) assert project_orm.is_filter_set() await project_orm.update(workbench={"node": { "input": 3, }}) project_wc = await project_orm.fetch("workbench ui") assert project_wc assert project.workbench != project_wc.workbench # get HEAD = repo.branch_id -> .head_commit_id head = await heads_orm.fetch() assert head branch = await branches_orm.fetch("head_commit_id", rowid=head.head_branch_id) assert branch # TODO: get subquery ... and compose head_commit = await commits_orm.fetch(rowid=branch.head_commit_id) assert head_commit # compare checksums between wc and HEAD checksum = eval_checksum(project_wc.workbench) assert head_commit.snapshot_checksum != checksum # updates wc checksum cache await repo_orm.update(project_checksum=checksum) # take snapshot = add & commit async with conn.begin(): snapshot_uuid: str = await add_snapshot(project_wc, checksum, repo, conn) commit_id = await commits_orm.insert( repo_id=head_commit.repo_id, parent_commit_id=head_commit.id, snapshot_checksum=checksum, message="second commit", ) assert commit_id assert isinstance(commit_id, int) # update branch head await branches_orm.update(head_commit_id=commit_id) # log history commits = await commits_orm.fetch_all() assert len(commits) == 2 assert commits[1].id == commit_id
async def _schedule_multi_node_session( self, sched_ctx: SchedulingContext, scheduler: AbstractScheduler, agent_db_conn: SAConnection, kernel_db_conn: SAConnection, sgroup_name: str, candidate_agents: Sequence[AgentContext], sess_ctx: PendingSession, check_results: List[Tuple[str, Union[Exception, PredicateResult]]], ) -> Tuple[PendingSession, List[KernelAgentBinding]]: # Assign agent resource per kernel in the session. log_fmt = _log_fmt.get() log_args = _log_args.get() agent_query_extra_conds = None kernel_agent_bindings: List[KernelAgentBinding] = [] async with agent_db_conn.begin(isolation_level="REPEATABLE READ"): # This outer transaction is rolled back when any exception occurs inside, # including scheduling failures of a kernel. # It ensures that occupied_slots are recovered when there are partial # scheduling failures. for kernel in sess_ctx.kernels: try: agent_id = scheduler.assign_agent_for_kernel( candidate_agents, kernel) if agent_id is None: raise InstanceNotAvailable async with agent_db_conn.begin_nested(): agent_alloc_ctx = await _reserve_agent( sched_ctx, agent_db_conn, sgroup_name, agent_id, kernel.requested_slots, extra_conds=agent_query_extra_conds, ) candidate_agents = await _list_agents_by_sgroup( agent_db_conn, sgroup_name) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "no-available-instances", 'status_data': sql_json_increment(kernels.c.status_data, ('scheduler', 'retries'), parent_updates={ 'last_try': datetime.now( tzutc()).isoformat(), }), }).where(kernels.c.id == kernel.kernel_id) await kernel_db_conn.execute(query) raise except Exception as e: log.exception( log_fmt + 'unexpected-error, during agent allocation', *log_args, ) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "scheduler-error", 'status_data': convert_to_status_data(e), }).where(kernels.c.id == kernel.kernel_id) await kernel_db_conn.execute(query) raise else: kernel_agent_bindings.append( KernelAgentBinding(kernel, agent_alloc_ctx)) if len(kernel_agent_bindings) == len(sess_ctx.kernels): # Proceed to PREPARING only when all kernels are successfully scheduled. async with kernel_db_conn.begin(): for binding in kernel_agent_bindings: query = kernels.update().values({ 'agent': binding.agent_alloc_ctx.agent_id, 'agent_addr': binding.agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_data': {}, 'status_changed': datetime.now(tzutc()), }).where(kernels.c.id == binding.kernel.kernel_id) await kernel_db_conn.execute(query) return (sess_ctx, kernel_agent_bindings)
async def _schedule_single_node_session( self, sched_ctx: SchedulingContext, scheduler: AbstractScheduler, agent_db_conn: SAConnection, kernel_db_conn: SAConnection, sgroup_name: str, candidate_agents: Sequence[AgentContext], sess_ctx: PendingSession, check_results: List[Tuple[str, Union[Exception, PredicateResult]]], ) -> Tuple[PendingSession, List[KernelAgentBinding]]: # Assign agent resource per session. log_fmt = _log_fmt.get() log_args = _log_args.get() try: agent_id = scheduler.assign_agent_for_session( candidate_agents, sess_ctx) if agent_id is None: raise InstanceNotAvailable async with agent_db_conn.begin(): agent_alloc_ctx = await _reserve_agent( sched_ctx, agent_db_conn, sgroup_name, agent_id, sess_ctx.requested_slots, ) except InstanceNotAvailable: log.debug(log_fmt + 'no-available-instances', *log_args) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "no-available-instances", 'status_data': sql_json_increment(kernels.c.status_data, ('scheduler', 'retries'), parent_updates={ 'last_try': datetime.now(tzutc()).isoformat(), }), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) raise except Exception as e: log.exception( log_fmt + 'unexpected-error, during agent allocation', *log_args, ) async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "scheduler-error", 'status_data': convert_to_status_data(e), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) raise async with kernel_db_conn.begin(): query = kernels.update().values({ 'agent': agent_alloc_ctx.agent_id, 'agent_addr': agent_alloc_ctx.agent_addr, 'scaling_group': sgroup_name, 'status': KernelStatus.PREPARING, 'status_info': 'scheduled', 'status_data': {}, 'status_changed': datetime.now(tzutc()), }).where(kernels.c.session_id == sess_ctx.session_id) await kernel_db_conn.execute(query) return ( sess_ctx, [ KernelAgentBinding(kernel, agent_alloc_ctx) for kernel in sess_ctx.kernels ], )
async def _schedule_in_sgroup( self, sched_ctx: SchedulingContext, agent_db_conn: SAConnection, kernel_db_conn: SAConnection, sgroup_name: str, ) -> List[StartTaskArgs]: async with kernel_db_conn.begin(): scheduler = await self._load_scheduler(kernel_db_conn, sgroup_name) pending_sessions = await _list_pending_sessions( kernel_db_conn, sgroup_name) existing_sessions = await _list_existing_sessions( kernel_db_conn, sgroup_name) log.debug('running scheduler (sgroup:{}, pending:{}, existing:{})', sgroup_name, len(pending_sessions), len(existing_sessions)) zero = ResourceSlot() args_list: List[StartTaskArgs] = [] while len(pending_sessions) > 0: async with agent_db_conn.begin(): candidate_agents = await _list_agents_by_sgroup( agent_db_conn, sgroup_name) total_capacity = sum( (ag.available_slots for ag in candidate_agents), zero) picked_session_id = scheduler.pick_session( total_capacity, pending_sessions, existing_sessions, ) if picked_session_id is None: # no session is picked. # continue to next sgroup. return [] for picked_idx, sess_ctx in enumerate(pending_sessions): if sess_ctx.session_id == picked_session_id: break else: # no matching entry for picked session? raise RuntimeError('should not reach here') sess_ctx = pending_sessions.pop(picked_idx) log_fmt = 'schedule(s:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): ' log_args = ( sess_ctx.session_id, sess_ctx.session_type, sess_ctx.session_name, sess_ctx.access_key, sess_ctx.cluster_mode, ) _log_fmt.set(log_fmt) _log_args.set(log_args) log.debug(log_fmt + 'try-scheduling', *log_args) session_agent_binding: Tuple[PendingSession, List[KernelAgentBinding]] async with kernel_db_conn.begin(): predicates: Sequence[Tuple[ str, Awaitable[PredicateResult]]] = [ ( 'reserved_time', check_reserved_batch_session( kernel_db_conn, sched_ctx, sess_ctx), ), ('concurrency', check_concurrency(kernel_db_conn, sched_ctx, sess_ctx)), ('dependencies', check_dependencies(kernel_db_conn, sched_ctx, sess_ctx)), ( 'keypair_resource_limit', check_keypair_resource_limit( kernel_db_conn, sched_ctx, sess_ctx), ), ( 'user_group_resource_limit', check_group_resource_limit(kernel_db_conn, sched_ctx, sess_ctx), ), ( 'domain_resource_limit', check_domain_resource_limit( kernel_db_conn, sched_ctx, sess_ctx), ), ( 'scaling_group_resource_limit', check_scaling_group(kernel_db_conn, sched_ctx, sess_ctx), ), ] check_results: List[Tuple[str, Union[Exception, PredicateResult]]] = [] for predicate_name, check_coro in predicates: try: check_results.append((predicate_name, await check_coro)) except Exception as e: log.exception(log_fmt + 'predicate-error', *log_args) check_results.append((predicate_name, e)) has_failure = False # has_permanent_failure = False failed_predicates = [] passed_predicates = [] for predicate_name, result in check_results: if isinstance(result, Exception): has_failure = True failed_predicates.append({ 'name': predicate_name, 'msg': repr(result), }) continue if result.passed: passed_predicates.append({ 'name': predicate_name, }) else: failed_predicates.append({ 'name': predicate_name, 'msg': result.message or "", }) has_failure = True # if result.permanent: # has_permanent_failure = True if has_failure: log.debug(log_fmt + 'predicate-checks-failed (temporary)', *log_args) # TODO: handle has_permanent_failure as cancellation # - An early implementation of it has caused DB query blocking due to # the inclusion of the kernels.status field. :( # Let's fix it. async with kernel_db_conn.begin(): await _invoke_failure_callbacks( kernel_db_conn, sched_ctx, sess_ctx, check_results, ) query = kernels.update().values({ 'status_info': "predicate-checks-failed", 'status_data': sql_json_increment( kernels.c.status_data, ('scheduler', 'retries'), parent_updates={ 'last_try': datetime.now(tzutc()).isoformat(), 'failed_predicates': failed_predicates, 'passed_predicates': passed_predicates, }), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) # Predicate failures are *NOT* permanent errors. # We need to retry the scheduling afterwards. continue else: async with kernel_db_conn.begin(): query = kernels.update().values({ 'status_data': sql_json_merge( kernels.c.status_data, ('scheduler', ), { 'last_try': datetime.now(tzutc()).isoformat(), 'failed_predicates': failed_predicates, 'passed_predicates': passed_predicates, }), }).where(kernels.c.id == sess_ctx.session_id) await kernel_db_conn.execute(query) if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE: session_agent_binding = await self._schedule_single_node_session( sched_ctx, scheduler, agent_db_conn, kernel_db_conn, sgroup_name, candidate_agents, sess_ctx, check_results, ) elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE: session_agent_binding = await self._schedule_multi_node_session( sched_ctx, scheduler, agent_db_conn, kernel_db_conn, sgroup_name, candidate_agents, sess_ctx, check_results, ) else: raise RuntimeError( f"should not reach here; unknown cluster_mode: {sess_ctx.cluster_mode}" ) args_list.append(( log_args, sched_ctx, session_agent_binding, check_results, )) return args_list