async def operate_problem(self, problem: "Problem", operation: Operation, position: Optional[int] = None) -> None: assert problem.domain_id == self.domain_id link = await ProblemProblemSetLink.one_or_none(problem_set_id=self.id, problem_id=problem.id) if operation == Operation.Create: if link is not None: raise BizError(ErrorCode.IntegrityError, "problem already added") link = ProblemProblemSetLink(problem_set_id=self.id, problem_id=problem.id) else: if link is None: raise BizError(ErrorCode.IntegrityError, "problem not added") if operation == Operation.Read: return if operation in (Operation.Update, Operation.Delete): self.problem_problem_set_links.remove(link) if operation in (Operation.Create, Operation.Update): if position is None: self.problem_problem_set_links.append(link) else: self.problem_problem_set_links.insert(position, link) if operation == Operation.Delete: await link.delete_model(commit=False) await self.save_model()
async def claim_record_by_judge( judge_claim: schemas.JudgeClaim, record: models.Record = Depends(parse_record_judger), user: models.User = Depends(parse_user_from_auth), ) -> StandardResponse[schemas.JudgeCredentials]: # task_id can only be obtained by listening to the celery task queue # we give the worker with task_id the chance to claim the record # celery tasks can be retried, only one worker can hold the task_id at the same time # if a rejudge is scheduled, task_id changes, so previous task will be ineffective # TODO: we probably need a lock to handle race condition of rejudge and claim if record.task_id is None or record.task_id != judge_claim.task_id: raise BizError(ErrorCode.Error) # if record.state not in (schemas.RecordState.queueing, schemas.RecordState.retrying): # raise BizError(ErrorCode.Error) # we can mark task failed if no problem config is available if record.problem_config is None or record.problem is None: raise BizError(ErrorCode.Error) # we always reset the state to "fetched", for both first attempt and retries record.judger_id = user.id record.state = schemas.RecordState.fetched await record.save_model() logger.info("judger claim record: {}", record) # initialize the permission of the judger to lakefs # the user have read access to all problems in the problem group, # actually only the access to one branch is necessary, # but it will create too many policies, so we grant all for simplicity # the user have read/write access to all records in the problem, # because the judger will write test result to the repo await record.fetch_related("problem") access_key = await models.UserAccessKey.get_lakefs_access_key(user) lakefs_problem_config = LakeFSProblemConfig(record.problem) lakefs_record = LakeFSRecord(record.problem, record) def sync_func() -> None: lakefs_problem_config.ensure_user_policy(user, "read") lakefs_record.ensure_user_policy(user, "all") await run_in_threadpool(sync_func) judge_credentials = schemas.JudgeCredentials( access_key_id=access_key.access_key_id, secret_access_key=access_key.secret_access_key, problem_config_repo_name=lakefs_problem_config.repo_name, problem_config_commit_id=record.problem_config.commit_id, record_repo_name=lakefs_record.repo_name, record_commit_id=record.commit_id, ) return StandardResponse(judge_credentials)
async def submit( cls, *, background_tasks: BackgroundTasks, celery_app: Celery, problem_submit: ProblemSolutionSubmit, problem_set: Optional["ProblemSet"], problem: "Problem", user: "******", ) -> "Record": problem_config = await problem.get_latest_problem_config() if problem_config is None: raise BizError(ErrorCode.ProblemConfigNotFoundError) if (problem_submit.code_type == RecordCodeType.archive and problem_submit.file is None): raise BizError(ErrorCode.Error) problem_set_id = problem_set.id if problem_set else None record = cls( domain_id=problem.domain_id, problem_set_id=problem_set_id, problem_id=problem.id, problem_config_id=problem_config.id, committer_id=user.id, ) await record.save_model(commit=False, refresh=False) problem.num_submit += 1 await problem.save_model(commit=True, refresh=True) await record.refresh_model() key = cls.get_user_latest_record_key(problem_set_id, problem.id, user.id) value = RecordPreview(id=record.id, state=record.state, created_at=record.created_at) cache = get_redis_cache() await cache.set(key, value, namespace="user_latest_records") background_tasks.add_task( record.upload, celery_app=celery_app, problem_submit=problem_submit, problem=problem, ) return record
async def set_root_user( user: models.User = Depends(parse_user_from_auth), session: AsyncSession = Depends(db_session_dependency), ) -> StandardResponse[schemas.User]: root_user = await models.User.all(role=DefaultRole.ROOT) if root_user != []: raise BizError(ErrorCode.Error) current_user = await models.User.one_or_none(id=user.id) if current_user is None: raise BizError(ErrorCode.UserNotFoundError) current_user.role = DefaultRole.ROOT session.sync_session.add(current_user) await session.commit() await session.refresh(current_user) return StandardResponse(current_user)
def ensure_policy(self, permission: Literal["read", "all"]) -> models.Policy: if permission != "read" and permission != "all": raise BizError(ErrorCode.InternalServerError, f"permission not defined: {permission}") client = get_lakefs_client() policy_id = f"{self.repo_name}-{permission}" try: policy = client.auth.get_policy(policy_id=policy_id) except LakeFSApiException: if permission == "read": action = ["fs:List*", "fs:Read*"] elif permission == "all": action = ["fs:*"] else: assert False policy = models.Policy( id=policy_id, statement=[ models.Statement( effect="allow", resource= f"arn:lakefs:fs:::repository/{self.repo_name}/*", action=action, ) ], ) policy = client.auth.create_policy(policy=policy) logger.info(f"LakeFS create policy: {policy_id}") return policy
async def get_lakefs_access_key(cls, user: "******") -> "UserAccessKey": access_key = await cls.one_or_none(service="lakefs", user_id=user.id) access_key_id = access_key.access_key_id if access_key else None def sync_func() -> Optional[CredentialsWithSecret]: ensure_user(user.id) return ensure_credentials(user.id, access_key_id) credentials = await run_in_threadpool(sync_func) if access_key is None and credentials is None: raise BizError(ErrorCode.Error) if access_key is not None and credentials is None: return access_key if access_key is None: access_key = cls( service="lakefs", access_key_id=credentials.access_key_id, secret_access_key=credentials.secret_access_key, user_id=user.id, ) else: access_key.access_key_id = credentials.access_key_id access_key.secret_access_key = credentials.secret_access_key await access_key.save_model() return access_key
async def general_exception_handler( request: Request, exc: Exception ) -> JSONResponse: # pragma: no cover logger.exception(f"Unexpected Error: {exc.__class__.__name__}") return business_exception_response( BizError(ErrorCode.InternalServerError, str(exc)) )
async def update_record_state_by_judge( record: models.Record = Depends(parse_record_judger), user: models.User = Depends(parse_user_from_auth), ) -> StandardResponse[schemas.Record]: if record.judger_id != user.id: raise BizError(ErrorCode.Error) if record.state not in ( schemas.RecordState.fetched, schemas.RecordState.compiling, schemas.RecordState.running, schemas.RecordState.judging, ): raise BizError(ErrorCode.Error) record.state = schemas.RecordState.fetched await record.save_model() return StandardResponse(record)
async def parse_problem_group(problem_group: str = Path( ...)) -> models.ProblemGroup: problem_group_model = await models.ProblemGroup.one_or_none( id=problem_group) if problem_group_model: return problem_group_model raise BizError(ErrorCode.ProblemGroupNotFoundError)
async def get_domain( domain: str = Path(..., description="url or id of the domain"), ) -> Domain: domain_model = await Domain.find_by_url_or_id(domain) if domain_model is None: raise BizError(ErrorCode.DomainNotFoundError) return domain_model
async def create( cls, user_create: "UserCreate", jwt_access_token: Optional["JWTAccessToken"], register_ip: str, ) -> "User": oauth_account: Optional[UserOAuthAccount] if user_create.oauth_name: if (jwt_access_token is None or jwt_access_token.category != "oauth" or jwt_access_token.oauth_name != user_create.oauth_name or jwt_access_token.id != user_create.oauth_account_id): raise BizError(ErrorCode.UserRegisterError, "oauth account not matched") user, oauth_account = await cls._create_user_by_oauth( user_create, jwt_access_token, register_ip) else: user = cls._create_user(user_create, register_ip) oauth_account = None async with db_session() as session: session.sync_session.add(user) if oauth_account: # pragma: no cover oauth_account.user_id = user.id session.sync_session.add(oauth_account) await session.commit() await session.refresh(user) return user
async def reset_password(self, current_password: str, new_password: str) -> None: if self.hashed_password: if not self.verify_password(current_password): raise BizError(ErrorCode.UsernamePasswordError, "incorrect password") self.hashed_password = self._generate_password_hash(new_password) await self.save_model()
def delete_directory(self, file_path: Path, recursive: bool) -> FileInfo: try: if recursive: return self.storage.delete_tree(file_path) else: return self.storage.delete_dir(file_path) except ElephantError as e: raise BizError(ErrorCode.ProblemConfigUpdateError, str(e))
def parse_problem( problem: models.Problem = Depends(parse_problem_without_validation), auth: Authentication = Depends(), ) -> models.Problem: if not problem.hidden or auth.check(ScopeType.DOMAIN_PROBLEM, PermissionType.view_hidden): return problem raise BizError(ErrorCode.ProblemNotFoundError)
async def parse_problem_without_validation( problem: str = Path(..., description="url or id of the problem"), domain: models.Domain = Depends(parse_domain_from_auth), ) -> models.Problem: problem_model = await models.Problem.find_by_domain_url_or_id( domain, problem) if problem_model: return problem_model raise BizError(ErrorCode.ProblemNotFoundError)
async def parse_domain_role( role: NoneEmptyLongStr = Path(..., description="name of the domain role"), domain: models.Domain = Depends(parse_domain_from_auth), ) -> models.DomainRole: domain_role_model = await models.DomainRole.one_or_none( domain_id=domain.id, role=role) if domain_role_model is None: raise BizError(ErrorCode.DomainRoleNotFoundError) return domain_role_model
async def parse_domain_from_auth( domain_auth: DomainAuthentication = Depends(), ) -> models.Domain: domain = domain_auth.auth.domain if domain is None or ( domain.hidden and domain_auth.auth.domain_user is None and not domain_auth.auth.check(ScopeType.SITE_DOMAIN, PermissionType.view_hidden)): raise BizError(ErrorCode.DomainNotFoundError) return domain_auth.auth.domain
async def submit_record_by_judge( record_result: schemas.RecordResult, record: models.Record = Depends(parse_record_judger), user: models.User = Depends(parse_user_from_auth), ) -> StandardResponse[Empty]: if record.state != schemas.RecordState.fetched: raise BizError(ErrorCode.Error) return StandardResponse()
def get_config(self, ref: str) -> Dict[str, Any]: try: result = self.download_file(Path("config.json"), ref) return orjson.loads(result.read()) except ElephantError: raise BizError( ErrorCode.ProblemConfigValidationError, "config.json not found in problem config.", )
async def parse_uid( uid: str = Query("me", description="'me' or id of the user"), auth: Authentication = Depends(), ) -> models.User: if uid == "me": return parse_user_from_auth(auth) user = await models.User.one_or_none(id=uid) if user: return user raise BizError(ErrorCode.UserNotFoundError)
async def parse_domain_invitation( invitation: str = Path(..., description="url or id of the domain invitation"), domain: models.Domain = Depends(parse_domain_from_auth), ) -> models.DomainInvitation: invitation_model = await models.DomainInvitation.find_by_domain_url_or_id( domain, invitation) if invitation_model: return invitation_model raise BizError(ErrorCode.DomainInvitationBadRequestError)
async def _create_user_by_oauth( cls, user_create: "UserCreate", jwt_access_token: "JWTAccessToken", register_ip: str, ) -> Tuple["User", "UserOAuthAccount"]: oauth_account = await UserOAuthAccount.one_or_none( oauth_name=jwt_access_token.oauth_name, account_id=jwt_access_token.id, ) if oauth_account is None: raise BizError(ErrorCode.UserRegisterError, "oauth account not matched") if not user_create.username: if not oauth_account.account_name: raise BizError(ErrorCode.UserRegisterError, "username not provided") username = oauth_account.account_name else: username = user_create.username email = oauth_account.account_email if user_create.email and user_create.email != oauth_account.account_email: raise BizError( ErrorCode.UserRegisterError, "email must be same as the primary email of oauth account", ) if user_create.password: hashed_password = cls._generate_password_hash(user_create.password) else: # register with oauth can omit password hashed_password = "" # pragma: no cover user = User( username=username, email=email, student_id=jwt_access_token.student_id, real_name=jwt_access_token.real_name, is_active=True, hashed_password=hashed_password, register_ip=register_ip, login_ip=register_ip, ) return user, oauth_account
async def login( request: Request, response: Response, auth_parameters: AuthParams = Depends(auth_parameters_dependency), auth_jwt: AuthJWT = Depends(AuthJWT), credentials: OAuth2PasswordRequestForm = Depends(), ) -> schemas.StandardResponse[schemas.AuthTokens]: user = await models.User.one_or_none(username=credentials.username) if not user: raise BizError(ErrorCode.UsernamePasswordError, "user not found") if not user.verify_password(credentials.password): raise BizError(ErrorCode.UsernamePasswordError, "incorrect password") user.login_at = datetime.now(tz=timezone.utc) user.login_ip = request.client.host await user.save_model() logger.info(f"user login: {user}") access_token, refresh_token = auth_jwt_encode_user(auth_jwt, user=user) return await get_login_response(request, response, auth_jwt, auth_parameters, access_token, refresh_token)
async def parse_problem_problem_set_link( problem_set: models.ProblemSet = Depends(parse_problem_set), problem: models.Problem = Depends(parse_problem_without_validation), ) -> models.ProblemProblemSetLink: link = await models.ProblemProblemSetLink.one_or_none( problem_set_id=problem_set.id, problem_id=problem.id) if link is not None: link.problem_set = problem_set link.problem = problem return link raise BizError(ErrorCode.ProblemNotFoundError)
async def parse_record_judger(record: str = Path(...)) -> models.Record: statement = (models.Record.sql_select().where( models.Record.id == record).options( joinedload(models.Record.problem), joinedload(models.Record.problem_config), )) result = await models.Record.session_exec(statement) record_model = result.one_or_none() if record_model: return record_model raise BizError(ErrorCode.RecordNotFoundError)
def get_file_info(self, file_path: Path, ref: Optional[str] = None) -> FileInfo: try: if ref is None: storage = self.storage else: storage = self._get_storage(ref) return storage.getinfo(file_path) except ElephantError as e: raise BizError(ErrorCode.ProblemConfigDownloadError, str(e))
def parse_user_from_auth(auth: Authentication = Depends()) -> models.User: if auth.jwt.category != "user": raise BizError(ErrorCode.UserNotFoundError) return models.User( id=auth.jwt.id, username=auth.jwt.username, email=auth.jwt.email, student_id=auth.jwt.student_id, real_name=auth.jwt.real_name, role=auth.jwt.role, is_active=auth.jwt.is_active, )
async def parse_problem_config( config: str = Path(..., description="'latest' or id of the config"), problem: models.Problem = Depends(parse_problem), ) -> models.ProblemConfig: if config == "latest": config_model = await problem.get_latest_problem_config() else: config_model = await models.ProblemConfig.one_or_none( problem_id=problem.id, id=config) if config_model: return config_model raise BizError(ErrorCode.ProblemConfigNotFoundError)
async def parse_record( record: UUID = Path(...), domain_auth: DomainAuthentication = Depends(), user: models.User = Depends(parse_user_from_auth), ) -> models.Record: record_model = await models.Record.one_or_none(id=record) # either is the user's own record, or it has the permission to view all if record_model and (record_model.committer_id == user.id or domain_auth.auth.check(ScopeType.DOMAIN_RECORD, PermissionType.view)): return record_model raise BizError(ErrorCode.RecordNotFoundError)
def _create_user(cls, user_create: "UserCreate", register_ip: str) -> "User": if not user_create.password: raise BizError(ErrorCode.UserRegisterError, "password not provided") if not user_create.username: raise BizError(ErrorCode.UserRegisterError, "username not provided") if not user_create.email: raise BizError(ErrorCode.UserRegisterError, "email not provided") hashed_password = cls._generate_password_hash(user_create.password) user = User( username=user_create.username, email=user_create.email, student_id="", real_name="", is_active=False, hashed_password=hashed_password, register_ip=register_ip, login_ip=register_ip, ) return user