示例#1
0
    async def operate_problem(self,
                              problem: "Problem",
                              operation: Operation,
                              position: Optional[int] = None) -> None:
        assert problem.domain_id == self.domain_id
        link = await ProblemProblemSetLink.one_or_none(problem_set_id=self.id,
                                                       problem_id=problem.id)
        if operation == Operation.Create:
            if link is not None:
                raise BizError(ErrorCode.IntegrityError,
                               "problem already added")
            link = ProblemProblemSetLink(problem_set_id=self.id,
                                         problem_id=problem.id)
        else:
            if link is None:
                raise BizError(ErrorCode.IntegrityError, "problem not added")

        if operation == Operation.Read:
            return
        if operation in (Operation.Update, Operation.Delete):
            self.problem_problem_set_links.remove(link)
        if operation in (Operation.Create, Operation.Update):
            if position is None:
                self.problem_problem_set_links.append(link)
            else:
                self.problem_problem_set_links.insert(position, link)
        if operation == Operation.Delete:
            await link.delete_model(commit=False)
        await self.save_model()
示例#2
0
async def claim_record_by_judge(
    judge_claim: schemas.JudgeClaim,
    record: models.Record = Depends(parse_record_judger),
    user: models.User = Depends(parse_user_from_auth),
) -> StandardResponse[schemas.JudgeCredentials]:
    # task_id can only be obtained by listening to the celery task queue
    # we give the worker with task_id the chance to claim the record
    # celery tasks can be retried, only one worker can hold the task_id at the same time
    # if a rejudge is scheduled, task_id changes, so previous task will be ineffective
    # TODO: we probably need a lock to handle race condition of rejudge and claim
    if record.task_id is None or record.task_id != judge_claim.task_id:
        raise BizError(ErrorCode.Error)
    # if record.state not in (schemas.RecordState.queueing, schemas.RecordState.retrying):
    #     raise BizError(ErrorCode.Error)
    # we can mark task failed if no problem config is available
    if record.problem_config is None or record.problem is None:
        raise BizError(ErrorCode.Error)

    # we always reset the state to "fetched", for both first attempt and retries
    record.judger_id = user.id
    record.state = schemas.RecordState.fetched
    await record.save_model()
    logger.info("judger claim record: {}", record)

    # initialize the permission of the judger to lakefs
    # the user have read access to all problems in the problem group,
    # actually only the access to one branch is necessary,
    # but it will create too many policies, so we grant all for simplicity
    # the user have read/write access to all records in the problem,
    # because the judger will write test result to the repo
    await record.fetch_related("problem")
    access_key = await models.UserAccessKey.get_lakefs_access_key(user)
    lakefs_problem_config = LakeFSProblemConfig(record.problem)
    lakefs_record = LakeFSRecord(record.problem, record)

    def sync_func() -> None:
        lakefs_problem_config.ensure_user_policy(user, "read")
        lakefs_record.ensure_user_policy(user, "all")

    await run_in_threadpool(sync_func)

    judge_credentials = schemas.JudgeCredentials(
        access_key_id=access_key.access_key_id,
        secret_access_key=access_key.secret_access_key,
        problem_config_repo_name=lakefs_problem_config.repo_name,
        problem_config_commit_id=record.problem_config.commit_id,
        record_repo_name=lakefs_record.repo_name,
        record_commit_id=record.commit_id,
    )
    return StandardResponse(judge_credentials)
示例#3
0
    async def submit(
        cls,
        *,
        background_tasks: BackgroundTasks,
        celery_app: Celery,
        problem_submit: ProblemSolutionSubmit,
        problem_set: Optional["ProblemSet"],
        problem: "Problem",
        user: "******",
    ) -> "Record":
        problem_config = await problem.get_latest_problem_config()
        if problem_config is None:
            raise BizError(ErrorCode.ProblemConfigNotFoundError)

        if (problem_submit.code_type == RecordCodeType.archive
                and problem_submit.file is None):
            raise BizError(ErrorCode.Error)

        problem_set_id = problem_set.id if problem_set else None
        record = cls(
            domain_id=problem.domain_id,
            problem_set_id=problem_set_id,
            problem_id=problem.id,
            problem_config_id=problem_config.id,
            committer_id=user.id,
        )

        await record.save_model(commit=False, refresh=False)
        problem.num_submit += 1
        await problem.save_model(commit=True, refresh=True)
        await record.refresh_model()

        key = cls.get_user_latest_record_key(problem_set_id, problem.id,
                                             user.id)
        value = RecordPreview(id=record.id,
                              state=record.state,
                              created_at=record.created_at)

        cache = get_redis_cache()
        await cache.set(key, value, namespace="user_latest_records")

        background_tasks.add_task(
            record.upload,
            celery_app=celery_app,
            problem_submit=problem_submit,
            problem=problem,
        )

        return record
示例#4
0
async def set_root_user(
    user: models.User = Depends(parse_user_from_auth),
    session: AsyncSession = Depends(db_session_dependency),
) -> StandardResponse[schemas.User]:
    root_user = await models.User.all(role=DefaultRole.ROOT)
    if root_user != []:
        raise BizError(ErrorCode.Error)
    current_user = await models.User.one_or_none(id=user.id)
    if current_user is None:
        raise BizError(ErrorCode.UserNotFoundError)
    current_user.role = DefaultRole.ROOT
    session.sync_session.add(current_user)
    await session.commit()
    await session.refresh(current_user)
    return StandardResponse(current_user)
示例#5
0
 def ensure_policy(self, permission: Literal["read",
                                             "all"]) -> models.Policy:
     if permission != "read" and permission != "all":
         raise BizError(ErrorCode.InternalServerError,
                        f"permission not defined: {permission}")
     client = get_lakefs_client()
     policy_id = f"{self.repo_name}-{permission}"
     try:
         policy = client.auth.get_policy(policy_id=policy_id)
     except LakeFSApiException:
         if permission == "read":
             action = ["fs:List*", "fs:Read*"]
         elif permission == "all":
             action = ["fs:*"]
         else:
             assert False
         policy = models.Policy(
             id=policy_id,
             statement=[
                 models.Statement(
                     effect="allow",
                     resource=
                     f"arn:lakefs:fs:::repository/{self.repo_name}/*",
                     action=action,
                 )
             ],
         )
         policy = client.auth.create_policy(policy=policy)
         logger.info(f"LakeFS create policy: {policy_id}")
     return policy
    async def get_lakefs_access_key(cls, user: "******") -> "UserAccessKey":
        access_key = await cls.one_or_none(service="lakefs", user_id=user.id)
        access_key_id = access_key.access_key_id if access_key else None

        def sync_func() -> Optional[CredentialsWithSecret]:
            ensure_user(user.id)
            return ensure_credentials(user.id, access_key_id)

        credentials = await run_in_threadpool(sync_func)

        if access_key is None and credentials is None:
            raise BizError(ErrorCode.Error)

        if access_key is not None and credentials is None:
            return access_key

        if access_key is None:
            access_key = cls(
                service="lakefs",
                access_key_id=credentials.access_key_id,
                secret_access_key=credentials.secret_access_key,
                user_id=user.id,
            )
        else:
            access_key.access_key_id = credentials.access_key_id
            access_key.secret_access_key = credentials.secret_access_key

        await access_key.save_model()
        return access_key
async def general_exception_handler(
    request: Request, exc: Exception
) -> JSONResponse:  # pragma: no cover
    logger.exception(f"Unexpected Error: {exc.__class__.__name__}")
    return business_exception_response(
        BizError(ErrorCode.InternalServerError, str(exc))
    )
示例#8
0
async def update_record_state_by_judge(
    record: models.Record = Depends(parse_record_judger),
    user: models.User = Depends(parse_user_from_auth),
) -> StandardResponse[schemas.Record]:
    if record.judger_id != user.id:
        raise BizError(ErrorCode.Error)
    if record.state not in (
            schemas.RecordState.fetched,
            schemas.RecordState.compiling,
            schemas.RecordState.running,
            schemas.RecordState.judging,
    ):
        raise BizError(ErrorCode.Error)
    record.state = schemas.RecordState.fetched
    await record.save_model()
    return StandardResponse(record)
示例#9
0
async def parse_problem_group(problem_group: str = Path(
    ...)) -> models.ProblemGroup:
    problem_group_model = await models.ProblemGroup.one_or_none(
        id=problem_group)
    if problem_group_model:
        return problem_group_model
    raise BizError(ErrorCode.ProblemGroupNotFoundError)
示例#10
0
async def get_domain(
    domain: str = Path(..., description="url or id of the domain"),
) -> Domain:
    domain_model = await Domain.find_by_url_or_id(domain)
    if domain_model is None:
        raise BizError(ErrorCode.DomainNotFoundError)
    return domain_model
示例#11
0
    async def create(
        cls,
        user_create: "UserCreate",
        jwt_access_token: Optional["JWTAccessToken"],
        register_ip: str,
    ) -> "User":
        oauth_account: Optional[UserOAuthAccount]
        if user_create.oauth_name:
            if (jwt_access_token is None
                    or jwt_access_token.category != "oauth"
                    or jwt_access_token.oauth_name != user_create.oauth_name
                    or jwt_access_token.id != user_create.oauth_account_id):
                raise BizError(ErrorCode.UserRegisterError,
                               "oauth account not matched")
            user, oauth_account = await cls._create_user_by_oauth(
                user_create, jwt_access_token, register_ip)
        else:
            user = cls._create_user(user_create, register_ip)
            oauth_account = None

        async with db_session() as session:
            session.sync_session.add(user)
            if oauth_account:  # pragma: no cover
                oauth_account.user_id = user.id
                session.sync_session.add(oauth_account)
            await session.commit()
            await session.refresh(user)
            return user
示例#12
0
 async def reset_password(self, current_password: str,
                          new_password: str) -> None:
     if self.hashed_password:
         if not self.verify_password(current_password):
             raise BizError(ErrorCode.UsernamePasswordError,
                            "incorrect password")
     self.hashed_password = self._generate_password_hash(new_password)
     await self.save_model()
示例#13
0
 def delete_directory(self, file_path: Path, recursive: bool) -> FileInfo:
     try:
         if recursive:
             return self.storage.delete_tree(file_path)
         else:
             return self.storage.delete_dir(file_path)
     except ElephantError as e:
         raise BizError(ErrorCode.ProblemConfigUpdateError, str(e))
示例#14
0
def parse_problem(
        problem: models.Problem = Depends(parse_problem_without_validation),
        auth: Authentication = Depends(),
) -> models.Problem:
    if not problem.hidden or auth.check(ScopeType.DOMAIN_PROBLEM,
                                        PermissionType.view_hidden):
        return problem
    raise BizError(ErrorCode.ProblemNotFoundError)
示例#15
0
async def parse_problem_without_validation(
    problem: str = Path(..., description="url or id of the problem"),
    domain: models.Domain = Depends(parse_domain_from_auth),
) -> models.Problem:
    problem_model = await models.Problem.find_by_domain_url_or_id(
        domain, problem)
    if problem_model:
        return problem_model
    raise BizError(ErrorCode.ProblemNotFoundError)
示例#16
0
async def parse_domain_role(
    role: NoneEmptyLongStr = Path(..., description="name of the domain role"),
    domain: models.Domain = Depends(parse_domain_from_auth),
) -> models.DomainRole:
    domain_role_model = await models.DomainRole.one_or_none(
        domain_id=domain.id, role=role)
    if domain_role_model is None:
        raise BizError(ErrorCode.DomainRoleNotFoundError)
    return domain_role_model
示例#17
0
async def parse_domain_from_auth(
        domain_auth: DomainAuthentication = Depends(), ) -> models.Domain:
    domain = domain_auth.auth.domain
    if domain is None or (
            domain.hidden and domain_auth.auth.domain_user is None
            and not domain_auth.auth.check(ScopeType.SITE_DOMAIN,
                                           PermissionType.view_hidden)):
        raise BizError(ErrorCode.DomainNotFoundError)
    return domain_auth.auth.domain
示例#18
0
async def submit_record_by_judge(
    record_result: schemas.RecordResult,
    record: models.Record = Depends(parse_record_judger),
    user: models.User = Depends(parse_user_from_auth),
) -> StandardResponse[Empty]:
    if record.state != schemas.RecordState.fetched:
        raise BizError(ErrorCode.Error)

    return StandardResponse()
示例#19
0
 def get_config(self, ref: str) -> Dict[str, Any]:
     try:
         result = self.download_file(Path("config.json"), ref)
         return orjson.loads(result.read())
     except ElephantError:
         raise BizError(
             ErrorCode.ProblemConfigValidationError,
             "config.json not found in problem config.",
         )
示例#20
0
async def parse_uid(
        uid: str = Query("me", description="'me' or id of the user"),
        auth: Authentication = Depends(),
) -> models.User:
    if uid == "me":
        return parse_user_from_auth(auth)
    user = await models.User.one_or_none(id=uid)
    if user:
        return user
    raise BizError(ErrorCode.UserNotFoundError)
示例#21
0
async def parse_domain_invitation(
    invitation: str = Path(...,
                           description="url or id of the domain invitation"),
    domain: models.Domain = Depends(parse_domain_from_auth),
) -> models.DomainInvitation:
    invitation_model = await models.DomainInvitation.find_by_domain_url_or_id(
        domain, invitation)
    if invitation_model:
        return invitation_model
    raise BizError(ErrorCode.DomainInvitationBadRequestError)
示例#22
0
 async def _create_user_by_oauth(
     cls,
     user_create: "UserCreate",
     jwt_access_token: "JWTAccessToken",
     register_ip: str,
 ) -> Tuple["User", "UserOAuthAccount"]:
     oauth_account = await UserOAuthAccount.one_or_none(
         oauth_name=jwt_access_token.oauth_name,
         account_id=jwt_access_token.id,
     )
     if oauth_account is None:
         raise BizError(ErrorCode.UserRegisterError,
                        "oauth account not matched")
     if not user_create.username:
         if not oauth_account.account_name:
             raise BizError(ErrorCode.UserRegisterError,
                            "username not provided")
         username = oauth_account.account_name
     else:
         username = user_create.username
     email = oauth_account.account_email
     if user_create.email and user_create.email != oauth_account.account_email:
         raise BizError(
             ErrorCode.UserRegisterError,
             "email must be same as the primary email of oauth account",
         )
     if user_create.password:
         hashed_password = cls._generate_password_hash(user_create.password)
     else:
         # register with oauth can omit password
         hashed_password = ""  # pragma: no cover
     user = User(
         username=username,
         email=email,
         student_id=jwt_access_token.student_id,
         real_name=jwt_access_token.real_name,
         is_active=True,
         hashed_password=hashed_password,
         register_ip=register_ip,
         login_ip=register_ip,
     )
     return user, oauth_account
示例#23
0
async def login(
    request: Request,
    response: Response,
    auth_parameters: AuthParams = Depends(auth_parameters_dependency),
    auth_jwt: AuthJWT = Depends(AuthJWT),
    credentials: OAuth2PasswordRequestForm = Depends(),
) -> schemas.StandardResponse[schemas.AuthTokens]:
    user = await models.User.one_or_none(username=credentials.username)
    if not user:
        raise BizError(ErrorCode.UsernamePasswordError, "user not found")
    if not user.verify_password(credentials.password):
        raise BizError(ErrorCode.UsernamePasswordError, "incorrect password")
    user.login_at = datetime.now(tz=timezone.utc)
    user.login_ip = request.client.host
    await user.save_model()
    logger.info(f"user login: {user}")
    access_token, refresh_token = auth_jwt_encode_user(auth_jwt, user=user)
    return await get_login_response(request, response, auth_jwt,
                                    auth_parameters, access_token,
                                    refresh_token)
示例#24
0
async def parse_problem_problem_set_link(
    problem_set: models.ProblemSet = Depends(parse_problem_set),
    problem: models.Problem = Depends(parse_problem_without_validation),
) -> models.ProblemProblemSetLink:
    link = await models.ProblemProblemSetLink.one_or_none(
        problem_set_id=problem_set.id, problem_id=problem.id)
    if link is not None:
        link.problem_set = problem_set
        link.problem = problem
        return link
    raise BizError(ErrorCode.ProblemNotFoundError)
示例#25
0
async def parse_record_judger(record: str = Path(...)) -> models.Record:
    statement = (models.Record.sql_select().where(
        models.Record.id == record).options(
            joinedload(models.Record.problem),
            joinedload(models.Record.problem_config),
        ))
    result = await models.Record.session_exec(statement)
    record_model = result.one_or_none()
    if record_model:
        return record_model
    raise BizError(ErrorCode.RecordNotFoundError)
示例#26
0
 def get_file_info(self,
                   file_path: Path,
                   ref: Optional[str] = None) -> FileInfo:
     try:
         if ref is None:
             storage = self.storage
         else:
             storage = self._get_storage(ref)
         return storage.getinfo(file_path)
     except ElephantError as e:
         raise BizError(ErrorCode.ProblemConfigDownloadError, str(e))
示例#27
0
def parse_user_from_auth(auth: Authentication = Depends()) -> models.User:
    if auth.jwt.category != "user":
        raise BizError(ErrorCode.UserNotFoundError)
    return models.User(
        id=auth.jwt.id,
        username=auth.jwt.username,
        email=auth.jwt.email,
        student_id=auth.jwt.student_id,
        real_name=auth.jwt.real_name,
        role=auth.jwt.role,
        is_active=auth.jwt.is_active,
    )
示例#28
0
async def parse_problem_config(
    config: str = Path(..., description="'latest' or id of the config"),
    problem: models.Problem = Depends(parse_problem),
) -> models.ProblemConfig:
    if config == "latest":
        config_model = await problem.get_latest_problem_config()
    else:
        config_model = await models.ProblemConfig.one_or_none(
            problem_id=problem.id, id=config)
    if config_model:
        return config_model
    raise BizError(ErrorCode.ProblemConfigNotFoundError)
示例#29
0
async def parse_record(
        record: UUID = Path(...),
        domain_auth: DomainAuthentication = Depends(),
        user: models.User = Depends(parse_user_from_auth),
) -> models.Record:
    record_model = await models.Record.one_or_none(id=record)

    # either is the user's own record, or it has the permission to view all
    if record_model and (record_model.committer_id == user.id
                         or domain_auth.auth.check(ScopeType.DOMAIN_RECORD,
                                                   PermissionType.view)):
        return record_model
    raise BizError(ErrorCode.RecordNotFoundError)
示例#30
0
 def _create_user(cls, user_create: "UserCreate",
                  register_ip: str) -> "User":
     if not user_create.password:
         raise BizError(ErrorCode.UserRegisterError,
                        "password not provided")
     if not user_create.username:
         raise BizError(ErrorCode.UserRegisterError,
                        "username not provided")
     if not user_create.email:
         raise BizError(ErrorCode.UserRegisterError, "email not provided")
     hashed_password = cls._generate_password_hash(user_create.password)
     user = User(
         username=user_create.username,
         email=user_create.email,
         student_id="",
         real_name="",
         is_active=False,
         hashed_password=hashed_password,
         register_ip=register_ip,
         login_ip=register_ip,
     )
     return user