def assert_json_record_result(response, json, record): """Verify that the API result matches the given record object.""" assert response.status_code == 200 assert json['id'] == record.id assert json['doi'] == record.doi assert json['sid'] == record.sid assert json['collection_id'] == record.collection_id assert json['schema_id'] == record.schema_id assert json['metadata'] == record.metadata_ assert_new_timestamp(datetime.fromisoformat(json['timestamp'])) json_tags = json['tags'] db_tags = Session.execute( select(RecordTag).where(RecordTag.record_id == record.id)).scalars( ).all() + Session.execute( select(CollectionTag).where(CollectionTag.collection_id == record.collection_id)).scalars().all() assert len(json_tags) == len(db_tags) json_tags.sort(key=lambda t: t['tag_id']) db_tags.sort(key=lambda t: t.tag_id) for n, json_tag in enumerate(json_tags): assert json_tag['tag_id'] == db_tags[n].tag_id assert json_tag['user_id'] == db_tags[n].user_id assert json_tag['user_name'] == db_tags[n].user.name assert json_tag['data'] == db_tags[n].data assert_new_timestamp(datetime.fromisoformat(json_tag['timestamp']))
async def get_collection( collection_id: str, auth: Authorized = Depends(Authorize(ODPScope.COLLECTION_READ)), ): if auth.collection_ids != '*' and collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) stmt = (select(Collection, func.count(Record.id)).outerjoin(Record).where( Collection.id == collection_id).group_by(Collection)) if not (result := Session.execute(stmt).one_or_none()): raise HTTPException(HTTP_404_NOT_FOUND)
def assert_db_state(clients): """Verify that the DB client table contains the given client batch.""" Session.expire_all() result = Session.execute( select(Client).where(Client.id != 'odp.test')).scalars().all() assert set((row.id, scope_ids(row), row.collection_id) for row in result) \ == set((client.id, scope_ids(client), client.collection_id) for client in clients)
def test_create_role_with_scopes(): scopes = ScopeFactory.create_batch( 5, type='odp') + ScopeFactory.create_batch(5, type='client') role = RoleFactory(scopes=scopes) result = Session.execute(select(RoleScope)).scalars() assert [(row.role_id, row.scope_id, row.scope_type) for row in result] \ == [(role.id, scope.id, scope.type) for scope in scopes]
def select_scopes( scope_ids: list[str], scope_types: list[ScopeType] = None, ) -> list[Scope]: """Select Scope objects given a list of ids, optionally constrained to the given types.""" scopes = [] invalid_ids = [] for scope_id in scope_ids: stmt = select(Scope).where(Scope.id == scope_id) if scope_types is not None: stmt = stmt.where(Scope.type.in_(scope_types)) if scope := Session.execute(stmt).scalar_one_or_none(): scopes += [scope] else: invalid_ids += [scope_id]
def _select_records(self) -> list[tuple[str, datetime]]: """Select records to be evaluated for publication to, or retraction from, a catalog. A record is selected if: * there is no corresponding catalog_record entry; or * the record has any embargo tags; or * catalog_record.timestamp is less than any of the following: * catalog.schema.timestamp * collection.timestamp * record.timestamp :return: a list of (record_id, timestamp) tuples, where timestamp is that of the latest contributing change """ catalog = Session.get(Catalog, self.catalog_id) records_subq = ( select( Record.id.label('record_id'), func.greatest( catalog.schema.timestamp, Collection.timestamp, Record.timestamp, ).label('max_timestamp') ). join(Collection). subquery() ) catalog_records_subq = ( select( CatalogRecord.record_id, CatalogRecord.timestamp ). where(CatalogRecord.catalog_id == self.catalog_id). subquery() ) stmt = ( select( records_subq.c.record_id, records_subq.c.max_timestamp ). outerjoin_from(records_subq, catalog_records_subq). where(or_( catalog_records_subq.c.record_id == None, catalog_records_subq.c.timestamp < records_subq.c.max_timestamp, catalog_records_subq.c.record_id.in_( select(RecordTag.record_id). where(RecordTag.tag_id == ODPRecordTag.EMBARGO) ) )) ) return Session.execute(stmt).all()
async def __call__(self, request: Request, tag_instance_id: str) -> Authorized: if self.tag_type == TagType.record: stmt = ( select(Tag.scope_id). join(RecordTag). where(RecordTag.id == tag_instance_id) ) elif self.tag_type == TagType.collection: stmt = ( select(Tag.scope_id). join(CollectionTag). where(CollectionTag.id == tag_instance_id) ) else: assert False if not (tag_scope_id := Session.execute(stmt).scalar_one_or_none()): raise HTTPException(HTTP_404_NOT_FOUND)
def test_db_setup(): migrate.systemdata.init_system_scopes() Session.commit() result = Session.execute(select(Scope)).scalars() assert [row.id for row in result] == [s.value for s in ODPScope] # create a batch of arbitrary scopes, which should not be assigned to the sysadmin ScopeFactory.create_batch(5) migrate.systemdata.init_admin_role() Session.commit() result = Session.execute(select(Role)).scalar_one() assert (result.id, result.collection_id) == (migrate.systemdata.ODP_ADMIN_ROLE, None) result = Session.execute(select(RoleScope)).scalars() assert [(row.role_id, row.scope_id, row.scope_type) for row in result] \ == [(migrate.systemdata.ODP_ADMIN_ROLE, s.value, ScopeType.odp) for s in ODPScope]
def init_admin_user(): """Create an admin user if one is not found.""" if not Session.execute( select(UserRole).where(UserRole.role_id == ODP_ADMIN_ROLE) ).first(): print('Creating an admin user...') while not (name := input('Full name: ')): pass while not (email := input('Email: ')): pass
def main(): logger.info('PUBLISHING STARTED') try: for catalog_id in Session.execute(select(Catalog.id)).scalars(): publisher = publishers[catalog_id] publisher(catalog_id).run() logger.info('PUBLISHING FINISHED') except Exception as e: logger.critical(f'PUBLISHING ABORTED: {str(e)}')
def _update_token(self, hydra, token, refresh_token=None, access_token=None): if refresh_token: token_model = Session.execute( select(OAuth2Token). where(OAuth2Token.client_id == self.client_id). where(OAuth2Token.refresh_token == refresh_token) ).scalar_one() elif access_token: token_model = Session.execute( select(OAuth2Token). where(OAuth2Token.client_id == self.client_id). where(OAuth2Token.access_token == access_token) ).scalar_one() else: return token_model.access_token = token.get('access_token') token_model.refresh_token = token.get('refresh_token') token_model.expires_at = token.get('expires_at') token_model.save()
def paginate( self, query: Select, item_factory: Callable[[Row], ModelT], *, sort_model: Base = None, custom_sort: str = None, ) -> Page[ModelT]: total = Session.execute( select(func.count()). select_from(query.subquery()) ).scalar_one() try: if sort_model: sort_col = getattr(sort_model, self.sort) elif custom_sort: sort_col = text(custom_sort) else: sort_col = self.sort limit = self.size or total items = [ item_factory(row) for row in Session.execute( query. order_by(sort_col). offset(limit * (self.page - 1)). limit(limit) ) ] except (AttributeError, CompileError): raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, 'Invalid sort column') return Page( items=items, total=total, page=self.page, pages=ceil(total / limit) if limit else 0, )
def _create_record( record_in: RecordModelIn, metadata_schema: JSONSchema, auth: Authorized, ignore_collection_tags: bool = False, ) -> RecordModel: if auth.collection_ids != '*' and record_in.collection_id not in auth.collection_ids: raise HTTPException(HTTP_403_FORBIDDEN) if not ignore_collection_tags and Session.execute( select(CollectionTag).where( CollectionTag.collection_id == record_in.collection_id).where( CollectionTag.tag_id == ODPCollectionTag.FROZEN)).first() is not None: raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, 'A record cannot be added to a frozen collection') if record_in.doi and Session.execute( select(Record).where( Record.doi == record_in.doi)).first() is not None: raise HTTPException(HTTP_409_CONFLICT, 'DOI is already in use') if record_in.sid and Session.execute( select(Record).where( Record.sid == record_in.sid)).first() is not None: raise HTTPException(HTTP_409_CONFLICT, 'SID is already in use') record = Record( doi=record_in.doi, sid=record_in.sid, collection_id=record_in.collection_id, schema_id=record_in.schema_id, schema_type=SchemaType.metadata, metadata_=record_in.metadata, validity=get_validity(record_in.metadata, metadata_schema), timestamp=(timestamp := datetime.now(timezone.utc)), )
def assert_db_state(collections): """Verify that the DB collection table contains the given collection batch.""" Session.expire_all() result = Session.execute(select(Collection)).scalars().all() result.sort(key=lambda c: c.id) collections.sort(key=lambda c: c.id) assert len(result) == len(collections) for n, row in enumerate(result): assert row.id == collections[n].id assert row.name == collections[n].name assert row.doi_key == collections[n].doi_key assert row.provider_id == collections[n].provider_id assert_new_timestamp(row.timestamp) assert project_ids(row) == project_ids(collections[n]) assert client_ids(row) == client_ids(collections[n]) assert role_ids(row) == role_ids(collections[n])
def assert_db_state(records): """Verify that the DB record table contains the given record batch.""" Session.expire_all() result = Session.execute(select(Record)).scalars().all() result.sort(key=lambda r: r.id) records.sort(key=lambda r: r.id) assert len(result) == len(records) for n, row in enumerate(result): assert row.id == records[n].id assert row.doi == records[n].doi assert row.sid == records[n].sid assert row.metadata_ == records[n].metadata_ assert_new_timestamp(row.timestamp) assert row.collection_id == records[n].collection_id assert row.schema_id == records[n].schema_id assert row.schema_type == records[n].schema_type
async def get_tag(tag_id: str, ): tag = Session.execute( select(Tag).where(Tag.id == tag_id)).scalar_one_or_none() if not tag: raise HTTPException(HTTP_404_NOT_FOUND) return TagModel( id=tag.id, cardinality=tag.cardinality, public=tag.public, scope_id=tag.scope_id, schema_id=tag.schema_id, schema_uri=tag.schema.uri, schema_=schema_catalog.get_schema(URI(tag.schema.uri)).value, )
def assert_tag_audit_log(*entries): result = Session.execute(select(CollectionTagAudit)).scalars().all() assert len(result) == len(entries) for n, row in enumerate(result): assert row.client_id == 'odp.test' assert row.user_id is None assert row.command == entries[n]['command'] assert_new_timestamp(row.timestamp) assert row._collection_id == entries[n]['collection_id'] assert row._tag_id == entries[n]['collection_tag']['tag_id'] assert row._user_id == entries[n]['collection_tag'].get('user_id') if row.command in ('insert', 'update'): assert row._data == entries[n]['collection_tag']['data'] elif row.command == 'delete': assert row._data is None else: assert False
async def get_schema( schema_id: str, ): schema = Session.execute( select(Schema). where(Schema.id == schema_id) ).scalar_one_or_none() if not schema: raise HTTPException(HTTP_404_NOT_FOUND) return SchemaModel( id=schema.id, type=schema.type, uri=schema.uri, schema_=schema_catalog.get_schema(URI(schema.uri)).value, )
def assert_db_tag_state(collection_id, *collection_tags): """Verify that the collection_tag table contains the given collection tags.""" Session.expire_all() result = Session.execute(select(CollectionTag)).scalars().all() result.sort(key=lambda r: r.timestamp) assert len(result) == len(collection_tags) for n, row in enumerate(result): assert row.collection_id == collection_id assert_new_timestamp(row.timestamp) if isinstance(collection_tag := collection_tags[n], CollectionTag): assert row.tag_id == collection_tag.tag_id assert row.user_id == collection_tag.user_id assert row.data == collection_tag.data else: assert row.tag_id == collection_tag['tag_id'] assert row.user_id is None assert row.data == collection_tag['data']
def assert_audit_log(command, collection=None, collection_id=None): result = Session.execute(select(CollectionAudit)).scalar_one_or_none() assert result.client_id == 'odp.test' assert result.user_id is None assert result.command == command assert_new_timestamp(result.timestamp) if command in ('insert', 'update'): assert result._id == collection.id assert result._name == collection.name assert result._doi_key == collection.doi_key assert result._provider_id == collection.provider_id elif command == 'delete': assert result._id == collection_id assert result._name is None assert result._doi_key is None assert result._provider_id is None else: assert False
def assert_audit_log(command, record=None, record_id=None): result = Session.execute(select(RecordAudit)).scalar_one_or_none() assert result.client_id == 'odp.test' assert result.user_id is None assert result.command == command assert_new_timestamp(result.timestamp) if command in ('insert', 'update'): assert result._id == record.id assert result._doi == record.doi assert result._sid == record.sid assert result._metadata == record.metadata_ assert result._collection_id == record.collection_id assert result._schema_id == record.schema_id elif command == 'delete': assert result._id == record_id assert result._doi is None assert result._sid is None assert result._metadata is None assert result._collection_id is None assert result._schema_id is None else: assert False
def test_create_collection(): collection = CollectionFactory() result = Session.execute(select(Collection, Provider).join(Provider)).one() assert (result.Collection.id, result.Collection.name, result.Collection.doi_key, result.Collection.provider_id, result.Provider.name) \ == (collection.id, collection.name, collection.doi_key, collection.provider.id, collection.provider.name)
def assert_db_state(scopes): """Verify that the DB scope table contains the given scope batch.""" Session.expire_all() result = Session.execute(select(Scope)).scalars().all() assert set((row.id, row.type) for row in result) \ == set((scope.id, scope.type) for scope in scopes)
def get_user_by_email(email: str) -> Optional[User]: return Session.execute( select(User).where(User.email == email)).scalar_one_or_none()
def test_create_record(): record = RecordFactory() result = Session.execute(select(Record)).scalar_one() assert (result.id, result.doi, result.sid, result.metadata_, result.validity, result.collection_id, result.schema_id, result.schema_type) \ == (record.id, record.doi, record.sid, record.metadata_, record.validity, record.collection.id, record.schema.id, record.schema.type)
def test_create_provider(): provider = ProviderFactory() result = Session.execute(select(Provider)).scalar_one() assert (result.id, result.name) == (provider.id, provider.name)
def test_create_project_with_collections(): collections = CollectionFactory.create_batch(5) project = ProjectFactory(collections=collections) result = Session.execute(select(ProjectCollection)).scalars() assert [(row.project_id, row.collection_id) for row in result] \ == [(project.id, collection.id) for collection in collections]
def test_create_project(): project = ProjectFactory() result = Session.execute(select(Project)).scalar_one() assert (result.id, result.name) == (project.id, project.name)
def test_create_collection_tag(): collection_tag = CollectionTagFactory() result = Session.execute( select(CollectionTag).join(Collection).join(Tag)).scalar_one() assert (result.collection_id, result.tag_id, result.tag_type, result.user_id, result.data) \ == (collection_tag.collection.id, collection_tag.tag.id, 'collection', collection_tag.user.id, collection_tag.data)
def test_create_client_with_scopes(): scopes = ScopeFactory.create_batch(5) client = ClientFactory(scopes=scopes) result = Session.execute(select(ClientScope)).scalars() assert [(row.client_id, row.scope_id, row.scope_type) for row in result] \ == [(client.id, scope.id, scope.type) for scope in scopes]