def assert_json_record_result(response, json, record): """Verify that the API result matches the given record object.""" assert response.status_code == 200 assert json['id'] == record.id assert json['doi'] == record.doi assert json['sid'] == record.sid assert json['collection_id'] == record.collection_id assert json['schema_id'] == record.schema_id assert json['metadata'] == record.metadata_ assert_new_timestamp(datetime.fromisoformat(json['timestamp'])) json_tags = json['tags'] db_tags = Session.execute( select(RecordTag).where(RecordTag.record_id == record.id)).scalars( ).all() + Session.execute( select(CollectionTag).where(CollectionTag.collection_id == record.collection_id)).scalars().all() assert len(json_tags) == len(db_tags) json_tags.sort(key=lambda t: t['tag_id']) db_tags.sort(key=lambda t: t.tag_id) for n, json_tag in enumerate(json_tags): assert json_tag['tag_id'] == db_tags[n].tag_id assert json_tag['user_id'] == db_tags[n].user_id assert json_tag['user_name'] == db_tags[n].user.name assert json_tag['data'] == db_tags[n].data assert_new_timestamp(datetime.fromisoformat(json_tag['timestamp']))
async def get_catalog_record( record_id: str, catalog_id: str, auth: Authorized = Depends(Authorize(ODPScope.RECORD_READ)), ): if not (catalog_record := Session.get(CatalogRecord, (catalog_id, record_id))): raise HTTPException(HTTP_404_NOT_FOUND)
def _untag_record( record_id: str, tag_instance_id: str, auth: Authorized, ignore_user_id: bool = False, ) -> None: if not (record := Session.get(Record, record_id)): raise HTTPException(HTTP_404_NOT_FOUND)
async def tag_record( record_id: str, tag_instance_in: TagInstanceModelIn, tag_schema: JSONSchema = Depends(get_tag_schema), auth: Authorized = Depends(TagAuthorize()), ): if not (record := Session.get(Record, record_id)): raise HTTPException(HTTP_404_NOT_FOUND)
def commit_transaction(response): """Commit any open transaction if the request was successful.""" if 200 <= response.status_code < 400: Session.commit() else: Session.rollback() return response
def login_callback(self): """Save the token and log the user in.""" token = self.oauth.hydra.authorize_access_token() userinfo = self.oauth.hydra.userinfo() user_id = userinfo['sub'] if not (token_model := Session.get(OAuth2Token, (self.client_id, user_id))): token_model = OAuth2Token(client_id=self.client_id, user_id=user_id)
def _select_records(self) -> list[tuple[str, datetime]]: """Select records to be evaluated for publication to, or retraction from, a catalog. A record is selected if: * there is no corresponding catalog_record entry; or * the record has any embargo tags; or * catalog_record.timestamp is less than any of the following: * catalog.schema.timestamp * collection.timestamp * record.timestamp :return: a list of (record_id, timestamp) tuples, where timestamp is that of the latest contributing change """ catalog = Session.get(Catalog, self.catalog_id) records_subq = ( select( Record.id.label('record_id'), func.greatest( catalog.schema.timestamp, Collection.timestamp, Record.timestamp, ).label('max_timestamp') ). join(Collection). subquery() ) catalog_records_subq = ( select( CatalogRecord.record_id, CatalogRecord.timestamp ). where(CatalogRecord.catalog_id == self.catalog_id). subquery() ) stmt = ( select( records_subq.c.record_id, records_subq.c.max_timestamp ). outerjoin_from(records_subq, catalog_records_subq). where(or_( catalog_records_subq.c.record_id == None, catalog_records_subq.c.timestamp < records_subq.c.max_timestamp, catalog_records_subq.c.record_id.in_( select(RecordTag.record_id). where(RecordTag.tag_id == ODPRecordTag.EMBARGO) ) )) ) return Session.execute(stmt).all()
async def update_collection( collection_in: CollectionModelIn, auth: Authorized = Depends(Authorize(ODPScope.COLLECTION_ADMIN)), ): if auth.collection_ids != '*' and collection_in.id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not (collection := Session.get(Collection, collection_in.id)): raise HTTPException(HTTP_404_NOT_FOUND)
async def get_new_doi( collection_id: str, auth: Authorized = Depends(Authorize(ODPScope.COLLECTION_READ)), ): if auth.collection_ids != '*' and collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not (collection := Session.get(Collection, collection_id)): raise HTTPException(HTTP_404_NOT_FOUND)
def init_roles(): """Create or update role definitions.""" with open(datadir / 'roles.yml') as f: role_data = yaml.safe_load(f) for role_id, role_spec in role_data.items(): role = Session.get(Role, role_id) or Role(id=role_id) role.scopes = [Session.get(Scope, (scope_id, ScopeType.odp)) for scope_id in role_spec['scopes']] role.save()
async def update_role( role_in: RoleModelIn, auth: Authorized = Depends(Authorize(ODPScope.ROLE_ADMIN)), ): if auth.collection_ids != '*' and role_in.collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not (role := Session.get(Role, role_in.id)): raise HTTPException(HTTP_404_NOT_FOUND)
async def update_client( client_in: ClientModelIn, auth: Authorized = Depends(Authorize(ODPScope.CLIENT_ADMIN)), ): if auth.collection_ids != '*' and client_in.collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not (client := Session.get(Client, client_in.id)): raise HTTPException(HTTP_404_NOT_FOUND)
async def tag_collection( collection_id: str, tag_instance_in: TagInstanceModelIn, tag_schema: JSONSchema = Depends(get_tag_schema), auth: Authorized = Depends(TagAuthorize()), ): if auth.collection_ids != '*' and collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not (collection := Session.get(Collection, collection_id)): raise HTTPException(HTTP_404_NOT_FOUND)
async def update_record( record_id: str, record_in: RecordModelIn, metadata_schema: JSONSchema = Depends(get_metadata_schema), auth: Authorized = Depends(Authorize(ODPScope.RECORD_WRITE)), ): if auth.collection_ids != '*' and record_in.collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not (record := Session.get(Record, record_id)): raise HTTPException(HTTP_404_NOT_FOUND)
def _untag_collection( collection_id: str, tag_instance_id: str, auth: Authorized, ignore_user_id: bool = False, ) -> None: if auth.collection_ids != '*' and collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not (collection := Session.get(Collection, collection_id)): raise HTTPException(HTTP_404_NOT_FOUND)
async def get_collection( collection_id: str, auth: Authorized = Depends(Authorize(ODPScope.COLLECTION_READ)), ): if auth.collection_ids != '*' and collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) stmt = (select(Collection, func.count(Record.id)).outerjoin(Record).where( Collection.id == collection_id).group_by(Collection)) if not (result := Session.execute(stmt).one_or_none()): raise HTTPException(HTTP_404_NOT_FOUND)
def init_cli_client(): """Create or update the Swagger UI client.""" client = Session.get(Client, ODP_CLI_CLIENT_ID) or Client(id=ODP_CLI_CLIENT_ID) client.scopes = [Session.get(Scope, (s.value, ScopeType.odp)) for s in ODPScope] client.save() hydra_admin_api.create_or_update_client( id=ODP_CLI_CLIENT_ID, name=ODP_CLI_CLIENT_NAME, secret=ODP_CLI_CLIENT_SECRET, scope_ids=[s.value for s in ODPScope], grant_types=[GrantType.CLIENT_CREDENTIALS], )
async def create_project( project_in: ProjectModelIn, ): if Session.get(Project, project_in.id): raise HTTPException(HTTP_409_CONFLICT, 'Project id is already in use') project = Project( id=project_in.id, name=project_in.name, collections=[ Session.get(Collection, collection_id) for collection_id in project_in.collection_ids ], ) project.save()
def assert_db_state(records): """Verify that the DB record table contains the given record batch.""" Session.expire_all() result = Session.execute(select(Record)).scalars().all() result.sort(key=lambda r: r.id) records.sort(key=lambda r: r.id) assert len(result) == len(records) for n, row in enumerate(result): assert row.id == records[n].id assert row.doi == records[n].doi assert row.sid == records[n].sid assert row.metadata_ == records[n].metadata_ assert_new_timestamp(row.timestamp) assert row.collection_id == records[n].collection_id assert row.schema_id == records[n].schema_id assert row.schema_type == records[n].schema_type
def assert_db_state(collections): """Verify that the DB collection table contains the given collection batch.""" Session.expire_all() result = Session.execute(select(Collection)).scalars().all() result.sort(key=lambda c: c.id) collections.sort(key=lambda c: c.id) assert len(result) == len(collections) for n, row in enumerate(result): assert row.id == collections[n].id assert row.name == collections[n].name assert row.doi_key == collections[n].doi_key assert row.provider_id == collections[n].provider_id assert_new_timestamp(row.timestamp) assert project_ids(row) == project_ids(collections[n]) assert client_ids(row) == client_ids(collections[n]) assert role_ids(row) == role_ids(collections[n])
def test_create_role_with_scopes(): scopes = ScopeFactory.create_batch( 5, type='odp') + ScopeFactory.create_batch(5, type='client') role = RoleFactory(scopes=scopes) result = Session.execute(select(RoleScope)).scalars() assert [(row.role_id, row.scope_id, row.scope_type) for row in result] \ == [(role.id, scope.id, scope.type) for scope in scopes]
def validate_auto_login(user_id): """ Validate a login request for which Hydra has indicated that the user is already authenticated, returning the user object on success. An ``ODPIdentityError`` is raised if the login cannot be permitted for any reason. :param user_id: the user id :raises ODPUserNotFound: if the user account associated with this id has been deleted :raises ODPAccountLocked: if the user account has been temporarily locked :raises ODPAccountDisabled: if the user account has been deactivated :raises ODPEmailNotVerified: if the user changed their email address since their last login, but have not yet verified it """ user = Session.get(User, user_id) if not user: raise x.ODPUserNotFound if is_account_locked(user_id): raise x.ODPAccountLocked if not user.active: raise x.ODPAccountDisabled if not user.verified: raise x.ODPEmailNotVerified
def init_dap_client(): """Create or update the Data Access Portal client.""" client = Session.get(Client, ODP_UI_DAP_CLIENT_ID) or Client(id=ODP_UI_DAP_CLIENT_ID) client.scopes = [Session.get(Scope, (HydraScope.OPENID, ScopeType.oauth))] + \ [Session.get(Scope, (HydraScope.OFFLINE_ACCESS, ScopeType.oauth))] client.save() hydra_admin_api.create_or_update_client( id=ODP_UI_DAP_CLIENT_ID, name=ODP_UI_DAP_CLIENT_NAME, secret=ODP_UI_DAP_CLIENT_SECRET, scope_ids=[HydraScope.OPENID, HydraScope.OFFLINE_ACCESS], grant_types=[GrantType.AUTHORIZATION_CODE, GrantType.REFRESH_TOKEN], response_types=[ResponseType.CODE], redirect_uris=[ODP_UI_DAP_LOGGED_IN_URL], post_logout_redirect_uris=[ODP_UI_DAP_LOGGED_OUT_URL], )
def select_scopes( scope_ids: list[str], scope_types: list[ScopeType] = None, ) -> list[Scope]: """Select Scope objects given a list of ids, optionally constrained to the given types.""" scopes = [] invalid_ids = [] for scope_id in scope_ids: stmt = select(Scope).where(Scope.id == scope_id) if scope_types is not None: stmt = stmt.where(Scope.type.in_(scope_types)) if scope := Session.execute(stmt).scalar_one_or_none(): scopes += [scope] else: invalid_ids += [scope_id]
async def create_provider(provider_in: ProviderModelIn, ): if Session.get(Provider, provider_in.id): raise HTTPException(HTTP_409_CONFLICT, 'Provider id is already in use') provider = Provider( id=provider_in.id, name=provider_in.name, ) provider.save()
def assert_db_tag_state(collection_id, *collection_tags): """Verify that the collection_tag table contains the given collection tags.""" Session.expire_all() result = Session.execute(select(CollectionTag)).scalars().all() result.sort(key=lambda r: r.timestamp) assert len(result) == len(collection_tags) for n, row in enumerate(result): assert row.collection_id == collection_id assert_new_timestamp(row.timestamp) if isinstance(collection_tag := collection_tags[n], CollectionTag): assert row.tag_id == collection_tag.tag_id assert row.user_id == collection_tag.user_id assert row.data == collection_tag.data else: assert row.tag_id == collection_tag['tag_id'] assert row.user_id is None assert row.data == collection_tag['data']
def get_user_profile(user_id): """ Return a dict of user profile info. """ user = Session.get(User, user_id) info = {} for attr in 'name', 'picture': info[attr] = getattr(user, attr) return info
async def __call__(self, request: Request, tag_instance_id: str) -> Authorized: if self.tag_type == TagType.record: stmt = ( select(Tag.scope_id). join(RecordTag). where(RecordTag.id == tag_instance_id) ) elif self.tag_type == TagType.collection: stmt = ( select(Tag.scope_id). join(CollectionTag). where(CollectionTag.id == tag_instance_id) ) else: assert False if not (tag_scope_id := Session.execute(stmt).scalar_one_or_none()): raise HTTPException(HTTP_404_NOT_FOUND)
def update_user_password(user_id, password): """ Update a user's password. :param user_id: the user id :param password: the input plain-text password """ user = Session.get(User, user_id) user.password = ph.hash(password) user.save()
def update_user_verified(user_id, verified): """ Update the verified status of a user. :param user_id: the user id :param verified: True/False """ user = Session.get(User, user_id) user.verified = verified user.save()