def test_mongo_special_id(self, mapper): from optimade.filtertransformers.mongo import MongoTransformer from bson import ObjectId class MyMapper(mapper("StructureMapper")): ALIASES = (("immutable_id", "_id"), ) transformer = MongoTransformer(mapper=MyMapper()) parser = LarkParser(version=self.version, variant=self.variant) assert transformer.transform( parser.parse('immutable_id = "5cfb441f053b174410700d02"')) == { "_id": { "$eq": ObjectId("5cfb441f053b174410700d02") } } assert transformer.transform( parser.parse('immutable_id != "5cfb441f053b174410700d02"')) == { "_id": { "$ne": ObjectId("5cfb441f053b174410700d02") } } for op in ("CONTAINS", "STARTS WITH", "ENDS WITH", "HAS"): with pytest.raises( BadRequest, match= r".*not supported for query on field 'immutable_id', can only test for equality.*", ): transformer.transform( parser.parse(f'immutable_id {op} "abcdef"'))
class TestParserV0_9_5: @pytest.fixture(autouse=True) def set_up(self): self.test_filters = [] for fn in sorted(glob(os.path.join(testfile_dir, "*.inp"))): with open(fn) as f: self.test_filters.append(f.read().strip()) self.parser = LarkParser(version=(0, 9, 5)) def test_inputs(self): for tf in self.test_filters: if tf == "filter=number=0.0.1": with pytest.raises(ParserError): self.parser.parse(tf) else: tree = self.parser.parse(tf) assert isinstance(tree, Tree) def test_parser_version(self): v = (0, 9, 5) p = LarkParser(version=v) assert isinstance(p.parse(self.test_filters[0]), Tree) assert p.version == v def test_repr(self): assert repr(self.parser) is not None self.parser.parse(self.test_filters[0]) assert repr(self.parser) is not None
class ParserTestV0_9_5(unittest.TestCase): @classmethod def setUpClass(cls): cls.test_filters = [] for fn in sorted(glob(os.path.join(testfile_dir, "*.inp"))): with open(fn) as f: cls.test_filters.append(f.read().strip()) def setUp(self): self.parser = LarkParser(version=(0, 9, 5)) def test_inputs(self): for tf in self.test_filters: if tf == "filter=number=0.0.1": self.assertRaises(ParserError, self.parser.parse, tf) else: tree = self.parser.parse(tf) self.assertTrue(tree, Tree) def test_parser_version(self): v = (0, 9, 5) p = LarkParser(version=v) self.assertIsInstance(p.parse(self.test_filters[0]), Tree) self.assertEqual(p.version, v) def test_repr(self): self.assertIsNotNone(repr(self.parser)) self.parser.parse(self.test_filters[0]) self.assertIsNotNone(repr(self.parser))
def test_aliased_length_operator(self, mapper): from optimade.filtertransformers.mongo import MongoTransformer class MyMapper(mapper("StructureMapper")): ALIASES = (("elements", "my_elements"), ("nelements", "nelem")) LENGTH_ALIASES = ( ("chemsys", "nelements"), ("cartesian_site_positions", "nsites"), ("elements", "nelements"), ) PROVIDER_FIELDS = ("chemsys",) transformer = MongoTransformer(mapper=MyMapper()) parser = LarkParser(version=self.version, variant=self.variant) assert transformer.transform( parser.parse("cartesian_site_positions LENGTH <= 3") ) == {"nsites": {"$lte": 3}} assert transformer.transform( parser.parse("cartesian_site_positions LENGTH < 3") ) == {"nsites": {"$lt": 3}} assert transformer.transform( parser.parse("cartesian_site_positions LENGTH 3") ) == {"nsites": 3} assert transformer.transform( parser.parse("cartesian_site_positions LENGTH 3") ) == {"nsites": 3} assert transformer.transform( parser.parse("cartesian_site_positions LENGTH >= 10") ) == {"nsites": {"$gte": 10}} assert transformer.transform( parser.parse("structure_features LENGTH > 10") ) == {"structure_features.11": {"$exists": True}} assert transformer.transform(parser.parse("nsites LENGTH > 10")) == { "nsites.11": {"$exists": True} } assert transformer.transform(parser.parse("elements LENGTH 3")) == {"nelem": 3} assert transformer.transform(parser.parse('elements HAS "Ag"')) == { "my_elements": {"$in": ["Ag"]} } assert transformer.transform(parser.parse("chemsys LENGTH 3")) == {"nelem": 3}
def test_list_length_aliases(self, mapper): from optimade.filtertransformers.mongo import MongoTransformer transformer = MongoTransformer(mapper=mapper("StructureMapper")()) parser = LarkParser(version=self.version, variant=self.variant) assert transformer.transform(parser.parse("elements LENGTH 3")) == { "nelements": 3 } assert transformer.transform( parser.parse('elements HAS "Li" AND elements LENGTH = 3') ) == {"$and": [{"elements": {"$in": ["Li"]}}, {"nelements": 3}]} assert transformer.transform(parser.parse("elements LENGTH > 3")) == { "nelements": {"$gt": 3} } assert transformer.transform(parser.parse("elements LENGTH < 3")) == { "nelements": {"$lt": 3} } assert transformer.transform(parser.parse("elements LENGTH = 3")) == { "nelements": 3 } assert transformer.transform( parser.parse("cartesian_site_positions LENGTH <= 3") ) == {"nsites": {"$lte": 3}} assert transformer.transform( parser.parse("cartesian_site_positions LENGTH >= 3") ) == {"nsites": {"$gte": 3}}
def test_other_provider_fields(self, mapper): """Test that fields from other providers generate queries that treat the value of the field as `null`. """ from optimade.filtertransformers.mongo import MongoTransformer t = MongoTransformer(mapper=mapper("StructureMapper")) p = LarkParser(version=self.version, variant=self.variant) assert t.transform(p.parse("_other_provider_field > 1")) == { "_other_provider_field": { "$gt": 1 } }
def __init__(self): p = LarkParser(version=(1, 0, 0), variant="default") t = MongoTransformer() self.transform = lambda inp: t.transform(p.parse(inp)) client = MongoClient('mongodb://{}:{}@{}:{}/?authSource={}'.format( "admin", "admin", "localhost", "27017", "admin")) db = client["MaterialsDB"] self.cl = db["Data.Calculation.StaticCalculation"] self.lu = Lower2Upper() self.data = info1 self.info = info2
def test_list_length_aliases(self): from optimade.server.mappers import StructureMapper transformer = MongoTransformer(mapper=StructureMapper()) parser = LarkParser(version=self.version, variant=self.variant) self.assertEqual( transformer.transform(parser.parse("elements LENGTH 3")), {"nelements": 3}) self.assertEqual( transformer.transform( parser.parse('elements HAS "Li" AND elements LENGTH = 3')), {"$and": [{ "elements": { "$in": ["Li"] } }, { "nelements": 3 }]}, ) self.assertEqual( transformer.transform(parser.parse("elements LENGTH > 3")), {"nelements": { "$gt": 3 }}, ) self.assertEqual( transformer.transform(parser.parse("elements LENGTH < 3")), {"nelements": { "$lt": 3 }}, ) self.assertEqual( transformer.transform(parser.parse("elements LENGTH = 3")), {"nelements": 3}) self.assertEqual( transformer.transform( parser.parse("cartesian_site_positions LENGTH <= 3")), {"nsites": { "$lte": 3 }}, ) self.assertEqual( transformer.transform( parser.parse("cartesian_site_positions LENGTH >= 3")), {"nsites": { "$gte": 3 }}, )
def test_list_length_aliases(): """Check LENGTH aliases for lists""" from optimade.server.mappers import StructureMapper transformer = AiidaTransformer(mapper=StructureMapper()) parser = LarkParser(version=VERSION, variant=VARIANT) assert transformer.transform(parser.parse("elements LENGTH 3")) == { "nelements": 3 } assert transformer.transform( parser.parse('elements HAS "Li" AND elements LENGTH = 3')) == { "and": [{ "elements": { "contains": ["Li"] } }, { "nelements": 3 }] } assert transformer.transform(parser.parse("elements LENGTH > 3")) == ({ "nelements": { ">": 3 } }) assert transformer.transform(parser.parse("elements LENGTH < 3")) == ({ "nelements": { "<": 3 } }) assert transformer.transform(parser.parse("elements LENGTH = 3")) == { "nelements": 3 } assert transformer.transform( parser.parse("cartesian_site_positions LENGTH <= 3")) == { "nsites": { "<=": 3 } } assert transformer.transform( parser.parse("cartesian_site_positions LENGTH >= 3")) == { "nsites": { ">=": 3 } }
class BaseTestFilterParser(abc.ABC): """Base class for parsing different versions of the grammar using `LarkParser`.""" version: Tuple[int, int, int] variant: str = "default" @pytest.fixture(autouse=True) def set_up(self): self.parser = LarkParser(version=self.version, variant=self.variant) def test_repr(self): assert repr(self.parser) is not None self.parse("band_gap = 1") assert repr(self.parser) is not None def parse(self, inp): return self.parser.parse(inp) def test_parser_version(self): assert self.parser.version == self.version assert self.parser.variant == self.variant
def test_filtering_on_relationships(self, mapper): """Test the nested properties with special names like "structures", "references" etc. are applied to the relationships field. """ from optimade.filtertransformers.mongo import MongoTransformer t = MongoTransformer(mapper=mapper("StructureMapper")) p = LarkParser(version=self.version, variant=self.variant) assert t.transform(p.parse('references.id HAS "dummy/2019"')) == { "relationships.references.data.id": { "$in": ["dummy/2019"] } } assert t.transform( p.parse('structures.id HAS ANY "dummy/2019", "dijkstra1968"')) == { "relationships.structures.data.id": { "$in": ["dummy/2019", "dijkstra1968"] } } assert t.transform( p.parse('structures.id HAS ALL "dummy/2019", "dijkstra1968"')) == { "relationships.structures.data.id": { "$all": ["dummy/2019", "dijkstra1968"] } } assert t.transform(p.parse('structures.id HAS ONLY "dummy/2019"')) == { "$and": [ { "relationships.structures.data": { "$size": 1 } }, { "relationships.structures.data.id": { "$all": ["dummy/2019"] } }, ] } assert t.transform( p.parse( 'structures.id HAS ONLY "dummy/2019" AND structures.id HAS "dummy/2019"' )) == { "$and": [ { "$and": [ { "relationships.structures.data": { "$size": 1, } }, { "relationships.structures.data.id": { "$all": ["dummy/2019"] } }, ] }, { "relationships.structures.data.id": { "$in": ["dummy/2019"] } }, ], }
class MongoCollection(EntryCollection): def __init__( self, collection: Union[pymongo.collection.Collection, mongomock.collection.Collection], resource_cls: EntryResource, resource_mapper: ResourceMapper, ): super().__init__(collection, resource_cls, resource_mapper) self.transformer = MongoTransformer() self.provider = CONFIG.provider["prefix"] self.provider_fields = CONFIG.provider_fields.get( resource_mapper.ENDPOINT, []) self.parser = LarkParser( version=(0, 10, 1), variant="default" ) # The MongoTransformer only supports v0.10.1 as the latest grammar def __len__(self): return self.collection.estimated_document_count() def __contains__(self, entry): return self.collection.count_documents(entry.dict()) > 0 def count(self, **kwargs): for k in list(kwargs.keys()): if k not in ("filter", "skip", "limit", "hint", "maxTimeMS"): del kwargs[k] if "filter" not in kwargs: # "filter" is needed for count_documents() kwargs["filter"] = {} return self.collection.count_documents(**kwargs) def find( self, params: Union[EntryListingQueryParams, SingleEntryQueryParams] ) -> Tuple[List[EntryResource], NonnegativeInt, bool, set]: criteria = self._parse_params(params) all_fields = criteria.pop("fields") if getattr(params, "response_fields", False): fields = set(params.response_fields.split(",")) else: fields = all_fields.copy() results = [] for doc in self.collection.find(**criteria): results.append( self.resource_cls(**self.resource_mapper.map_back(doc))) nresults_now = len(results) if isinstance(params, EntryListingQueryParams): criteria_nolimit = criteria.copy() criteria_nolimit.pop("limit", None) data_returned = self.count(**criteria_nolimit) more_data_available = nresults_now < data_returned else: # SingleEntryQueryParams, e.g., /structures/{entry_id} data_returned = nresults_now more_data_available = False if nresults_now > 1: raise HTTPException( status_code=404, detail= f"Instead of a single entry, {nresults_now} entries were found", ) results = results[0] if results else None return results, data_returned, more_data_available, all_fields - fields def _alias_filter(self, filter_: dict) -> dict: res = {} for key, value in filter_.items(): if key in ["$and", "$or"]: res[key] = [self._alias_filter(item) for item in value] else: new_value = value if isinstance(value, dict): new_value = self._alias_filter(value) res[self.resource_mapper.alias_for(key)] = new_value return res def _parse_params( self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]) -> dict: cursor_kwargs = {} if getattr(params, "filter", False): tree = self.parser.parse(params.filter) mongo_filter = self.transformer.transform(tree) cursor_kwargs["filter"] = self._alias_filter(mongo_filter) else: cursor_kwargs["filter"] = {} if (getattr(params, "response_format", False) and params.response_format != "json"): raise HTTPException(status_code=400, detail="Only 'json' response_format supported") if getattr(params, "page_limit", False): limit = params.page_limit if limit > CONFIG.page_limit_max: raise HTTPException( status_code=403, # Forbidden detail= f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}", ) cursor_kwargs["limit"] = limit else: cursor_kwargs["limit"] = CONFIG.page_limit # All OPTiMaDe fields fields = self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy() fields |= self.get_attribute_fields() # All provider-specific fields fields |= {self.provider + _ for _ in self.provider_fields} cursor_kwargs["fields"] = fields cursor_kwargs["projection"] = [ self.resource_mapper.alias_for(f) for f in fields ] if getattr(params, "sort", False): sort_spec = [] for elt in params.sort.split(","): field = elt sort_dir = 1 if elt.startswith("-"): field = field[1:] sort_dir = -1 sort_spec.append((field, sort_dir)) cursor_kwargs["sort"] = sort_spec if getattr(params, "page_offset", False): cursor_kwargs["skip"] = params.page_offset return cursor_kwargs
def test_filtering_on_relationships(self, mapper): """Test the nested properties with special names like "structures", "references" etc. are applied to the relationships field. """ from optimade.filtertransformers.mongo import MongoTransformer t = MongoTransformer(mapper=mapper("StructureMapper")) p = LarkParser(version=self.version, variant=self.variant) assert t.transform(p.parse('references.id HAS "dummy/2019"')) == { "relationships.references.data.id": { "$in": ["dummy/2019"] } } assert t.transform( p.parse('structures.id HAS ANY "dummy/2019", "dijkstra1968"')) == { "relationships.structures.data.id": { "$in": ["dummy/2019", "dijkstra1968"] } } assert t.transform( p.parse('structures.id HAS ALL "dummy/2019", "dijkstra1968"')) == { "relationships.structures.data.id": { "$all": ["dummy/2019", "dijkstra1968"] } } assert t.transform(p.parse('structures.id HAS ONLY "dummy/2019"')) == { "$and": [ { "relationships.structures.data": { "$not": { "$elemMatch": { "id": { "$nin": ["dummy/2019"] } } } } }, { "relationships.structures.data.0": { "$exists": True } }, ] } assert t.transform( p.parse( 'structures.id HAS ONLY "dummy/2019" AND structures.id HAS "dummy/2019"' )) == { "$and": [ { "$and": [ { "relationships.structures.data": { "$not": { "$elemMatch": { "id": { "$nin": ["dummy/2019"] } } } } }, { "relationships.structures.data.0": { "$exists": True } }, ] }, { "relationships.structures.data.id": { "$in": ["dummy/2019"] } }, ] } with pytest.raises( NotImplementedError, match= 'Cannot filter relationships by field "doi", only "id" is supported.', ): assert t.transform( p.parse( 'references.doi HAS ONLY "10.123/12345" AND structures.id HAS "dummy/2019"' )) == { "$and": [ { "$and": [ { "relationships.references.data": { "$not": { "$elemMatch": { "doi": { "$nin": ["10.123/12345"] } } } } }, { "relationships.references.data.0": { "$exists": True } }, ] }, { "relationships.structures.data.id": { "$in": ["dummy/2019"] } }, ] }
def set_up(self): from optimade.filtertransformers.mongo import MongoTransformer p = LarkParser(version=self.version, variant=self.variant) t = MongoTransformer() self.transform = lambda inp: t.transform(p.parse(inp))
class TestParserV1_0_0: version = (1, 0, 0) variant = "default" @pytest.fixture(autouse=True) def set_up(self): self.parser = LarkParser(version=self.version, variant=self.variant) def parse(self, inp): return self.parser.parse(inp) def test_empty(self): assert isinstance(self.parse(" "), Tree) def test_property_names(self): assert isinstance(self.parse("band_gap = 1"), Tree) assert isinstance(self.parse("cell_length_a = 1"), Tree) assert isinstance(self.parse("cell_volume = 1"), Tree) with pytest.raises(ParserError): self.parse("0_kvak IS KNOWN") # starts with a number with pytest.raises(ParserError): self.parse('"foo bar" IS KNOWN') # contains space; contains quotes with pytest.raises(ParserError): self.parse("BadLuck IS KNOWN") # contains upper-case letters # database-provider-specific prefixes assert isinstance(self.parse("_exmpl_formula_sum = 1"), Tree) assert isinstance(self.parse("_exmpl_band_gap = 1"), Tree) # Nested property names assert isinstance(self.parse("identifier1.identifierd2 = 42"), Tree) def test_string_values(self): assert isinstance(self.parse('author="Sąžininga Žąsis"'), Tree) assert isinstance( self.parse('field = "!#$%&\'() * +, -./:; <= > ? @[] ^ `{|}~ % "'), Tree ) def test_number_values(self): assert isinstance(self.parse("a = 12345"), Tree) assert isinstance(self.parse("b = +12"), Tree) assert isinstance(self.parse("c = -34"), Tree) assert isinstance(self.parse("d = 1.2"), Tree) assert isinstance(self.parse("e = .2E7"), Tree) assert isinstance(self.parse("f = -.2E+7"), Tree) assert isinstance(self.parse("g = +10.01E-10"), Tree) assert isinstance(self.parse("h = 6.03e23"), Tree) assert isinstance(self.parse("i = .1E1"), Tree) assert isinstance(self.parse("j = -.1e1"), Tree) assert isinstance(self.parse("k = 1.e-12"), Tree) assert isinstance(self.parse("l = -.1e-12"), Tree) assert isinstance(self.parse("m = 1000000000.E1000000000"), Tree) with pytest.raises(ParserError): self.parse("number=1.234D12") with pytest.raises(ParserError): self.parse("number=.e1") with pytest.raises(ParserError): self.parse("number= -.E1") with pytest.raises(ParserError): self.parse("number=+.E2") with pytest.raises(ParserError): self.parse("number=1.23E+++") with pytest.raises(ParserError): self.parse("number=+-123") with pytest.raises(ParserError): self.parse("number=0.0.1") def test_operators(self): # Basic boolean operations assert isinstance( self.parse( 'NOT ( chemical_formula_hill = "Al" AND chemical_formula_anonymous = "A" OR ' 'chemical_formula_anonymous = "H2O" AND NOT chemical_formula_hill = "Ti" )' ), Tree, ) # Numeric and String comparisons assert isinstance(self.parse("nelements > 3"), Tree) assert isinstance( self.parse( 'chemical_formula_hill = "H2O" AND chemical_formula_anonymous != "AB"' ), Tree, ) assert isinstance( self.parse( "_exmpl_aax <= +.1e8 OR nelements >= 10 AND " 'NOT ( _exmpl_x != "Some string" OR NOT _exmpl_a = 7)' ), Tree, ) assert isinstance(self.parse('_exmpl_spacegroup="P2"'), Tree) assert isinstance(self.parse("_exmpl_cell_volume<100.0"), Tree) assert isinstance( self.parse("_exmpl_bandgap > 5.0 AND _exmpl_molecular_weight < 350"), Tree ) assert isinstance( self.parse('_exmpl_melting_point<300 AND nelements=4 AND elements="Si,O2"'), Tree, ) assert isinstance(self.parse("_exmpl_some_string_property = 42"), Tree) assert isinstance(self.parse("5 < _exmpl_a"), Tree) # OPTIONAL assert isinstance( self.parse("((NOT (_exmpl_a>_exmpl_b)) AND _exmpl_x>0)"), Tree ) assert isinstance(self.parse("5 < 7"), Tree) def test_id(self): assert isinstance(self.parse('id="example/1"'), Tree) assert isinstance(self.parse('"example/1" = id'), Tree) assert isinstance(self.parse('id="test/2" OR "example/1" = id'), Tree) def test_string_operations(self): # Substring comparisons assert isinstance( self.parse( 'chemical_formula_anonymous CONTAINS "C2" AND ' 'chemical_formula_anonymous STARTS WITH "A2"' ), Tree, ) assert isinstance( self.parse( 'chemical_formula_anonymous STARTS "B2" AND ' 'chemical_formula_anonymous ENDS WITH "D2"' ), Tree, ) def test_list_properties(self): # Comparisons of list properties assert isinstance(self.parse("list HAS < 3"), Tree) assert isinstance(self.parse("list HAS ALL < 3, > 3"), Tree) assert isinstance(self.parse("list:list HAS >=2:<=5"), Tree) assert isinstance( self.parse( 'elements HAS "H" AND elements HAS ALL "H","He","Ga","Ta" AND elements HAS ' 'ONLY "H","He","Ga","Ta" AND elements HAS ANY "H", "He", "Ga", "Ta"' ), Tree, ) # OPTIONAL: assert isinstance(self.parse('elements HAS ONLY "H","He","Ga","Ta"'), Tree) assert isinstance(self.parse('elements HAS ALL "H","He","Ga","Ta"'), Tree) assert isinstance(self.parse('elements HAS ANY "H","He","Ga","Ta"'), Tree) assert isinstance( self.parse( 'elements:_exmpl_element_counts HAS "H":6 AND ' 'elements:_exmpl_element_counts HAS ALL "H":6,"He":7 AND ' 'elements:_exmpl_element_counts HAS ONLY "H":6 AND ' 'elements:_exmpl_element_counts HAS ANY "H":6,"He":7 AND ' 'elements:_exmpl_element_counts HAS ONLY "H":6,"He":7' ), Tree, ) assert isinstance( self.parse( "_exmpl_element_counts HAS < 3 AND " "_exmpl_element_counts HAS ANY > 3, = 6, 4, != 8" ), Tree, ) assert isinstance( self.parse( "elements:_exmpl_element_counts:_exmpl_element_weights " 'HAS ANY > 3:"He":>55.3 , = 6:>"Ti":<37.6 , 8:<"Ga":0' ), Tree, ) def test_properties(self): # Filtering on Properties with unknown value assert isinstance( self.parse( "chemical_formula_hill IS KNOWN AND " "NOT chemical_formula_anonymous IS UNKNOWN" ), Tree, ) def test_precedence(self): assert isinstance(self.parse('NOT a > b OR c = 100 AND f = "C2 H6"'), Tree) assert isinstance( self.parse('(NOT (a > b)) OR ( (c = 100) AND (f = "C2 H6") )'), Tree ) assert isinstance(self.parse("a >= 0 AND NOT b < c OR c = 0"), Tree) assert isinstance(self.parse("((a >= 0) AND (NOT (b < c))) OR (c = 0)"), Tree) def test_special_cases(self): assert isinstance(self.parse("te < st"), Tree) assert isinstance(self.parse('spacegroup="P2"'), Tree) assert isinstance(self.parse("_cod_cell_volume<100.0"), Tree) assert isinstance( self.parse("_mp_bandgap > 5.0 AND _cod_molecular_weight < 350"), Tree ) assert isinstance( self.parse('_cod_melting_point<300 AND nelements=4 AND elements="Si,O2"'), Tree, ) assert isinstance(self.parse("key=value"), Tree) assert isinstance(self.parse('author=" someone "'), Tree) assert isinstance(self.parse('author=" som\neone "'), Tree) assert isinstance( self.parse( "number=0.ANDnumber=.0ANDnumber=0.0ANDnumber=+0AND_n_u_m_b_e_r_=-0AND" "number=0e1ANDnumber=0e-1ANDnumber=0e+1" ), Tree, ) assert isinstance( self.parse("NOTice=val"), Tree ) # property (ice) != property (val) assert isinstance( self.parse('NOTice="val"'), Tree ) # property (ice) != value ("val") assert isinstance( self.parse('"NOTice"=val'), Tree ) # value ("NOTice") = property (val) with pytest.raises(ParserError): self.parse("NOTICE=val") # not valid property or value (NOTICE) with pytest.raises(ParserError): self.parse('"NOTICE"=Val') # not valid property (Val) with pytest.raises(ParserError): self.parse("NOTICE=val") # not valid property or value (NOTICE) def test_parser_version(self): assert self.parser.version == self.version assert self.parser.variant == self.variant def test_repr(self): assert repr(self.parser) is not None self.parser.parse('key="value"') assert repr(self.parser) is not None
def setUp(self): p = LarkParser(version=self.version, variant=self.variant) t = MongoTransformer() self.transform = lambda inp: t.transform(p.parse(inp))
class Lark2Django: def __init__(self): self.opers = { "=": self.eq, ">": self.gt, ">=": self.ge, "<": self.lt, "<=": self.le, "!=": self.ne, "OR": self.or_, "AND": self.and_, "NOT": self.not_, } self.parser = LarkParser(version=(0, 9, 7)) def parse_raw_q(self, raw_query): return self.parser.parse(raw_query) def eq(self, a, b): return Q(**{a: b}) def gt(self, a, b): return Q(**{a + "__gt": b}) def ge(self, a, b): return Q(**{a + "__gte": b}) def lt(self, a, b): return Q(**{a + "__lt": b}) def le(self, a, b): return Q(**{a + "__lte": b}) def ne(self, a, b): return ~Q(**{a: b}) def not_(self, a): return ~a def and_(self, a, b): return operator.and_(a, b) def or_(self, a, b): return operator.or_(a, b) def evaluate(self, parse_Tree): if isinstance(parse_Tree, Tree): children = parse_Tree.children if len(children) == 1: return self.evaluate(children[0]) elif len(children) == 2: op_fn = self.evaluate(children[0]) return op_fn(self.evaluate(children[1])) elif len(children) == 3: if parse_Tree.data == "comparison": db_prop = self.evaluate(children[0]) op_fn = self.evaluate(children[1]) if db_prop in django_db_keys.keys(): return op_fn( django_db_keys[db_prop], self.evaluate(children[2]) ) else: raise DjangoQueryError( "Unknown property is queried : " + (db_prop) ) else: op_fn = self.evaluate(children[1]) return op_fn(self.evaluate(children[0]), self.evaluate(children[2])) else: raise DjangoQueryError("Not compatible format. Tree has >3 children") elif isinstance(parse_Tree, Token): if parse_Tree.type == "VALUE": return parse_Tree.value elif parse_Tree.type in ["NOT", "CONJUNCTION", "OPERATOR"]: return self.opers[parse_Tree.value] else: raise DjangoQueryError("Not a Lark Tree or Token")
class EntryCollection(ABC): """Backend-agnostic base class for querying collections of [`EntryResource`][optimade.models.entries.EntryResource]s.""" def __init__( self, collection, resource_cls: EntryResource, resource_mapper: BaseResourceMapper, transformer: Transformer, ): """Initialize the collection for the given parameters. Parameters: collection: The backend-specific collection. resource_cls (EntryResource): The `EntryResource` model that is stored by the collection. resource_mapper (BaseResourceMapper): A resource mapper object that handles aliases and format changes between deserialization and response. transformer (Transformer): The Lark `Transformer` used to interpret the filter. """ self.collection = collection self.parser = LarkParser() self.resource_cls = resource_cls self.resource_mapper = resource_mapper self.transformer = transformer self.provider_prefix = CONFIG.provider.prefix self.provider_fields = CONFIG.provider_fields.get( resource_mapper.ENDPOINT, []) @abstractmethod def __len__(self) -> int: """ Returns the total number of entries in the collection. """ @abstractmethod def count(self, **kwargs) -> int: """Returns the number of entries matching the query specified by the keyword arguments. Parameters: kwargs (dict): Query parameters as keyword arguments. """ @abstractmethod def find( self, params: EntryListingQueryParams ) -> Tuple[List[EntryResource], int, bool, set]: """ Fetches results and indicates if more data is available. Also gives the total number of data available in the absence of `page_limit`. See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams] for more information. Parameters: params (EntryListingQueryParams): entry listing URL query params Returns: (`results`, `data_returned`, `more_data_available`, `fields`). """ @property def all_fields(self) -> set: """Get the set of all fields handled in this collection, from attribute fields in the schema, provider fields and top-level OPTIMADE fields. Returns: All fields handled in this collection. """ # All OPTIMADE fields fields = self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy() fields |= self.get_attribute_fields() # All provider-specific fields fields |= { f"_{self.provider_prefix}_{field_name}" for field_name in self.provider_fields } return fields def get_attribute_fields(self) -> set: """Get the set of attribute fields from the schema of the resource class, resolving references along the way. Returns: Property names. """ schema = self.resource_cls.schema() attributes = schema["properties"]["attributes"] if "allOf" in attributes: allOf = attributes.pop("allOf") for dict_ in allOf: attributes.update(dict_) if "$ref" in attributes: path = attributes["$ref"].split("/")[1:] attributes = schema.copy() while path: next_key = path.pop(0) attributes = attributes[next_key] return set(attributes["properties"].keys()) def handle_query_params( self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]) -> dict: """Parse and interpret the backend-agnostic query parameter models into a dictionary that can be used by the specific backend. Note: Currently this method returns the pymongo interpretation of the parameters, which will need modification for modified for other backends. Parameters: params (Union[EntryListingQueryParams, SingleEntryQueryParams]): The initialized query parameter model from the server. Raises: Forbidden: If too large of a page limit is provided. BadRequest: If an invalid request is made, e.g., with incorrect fields or response format. Returns: A dictionary representation of the query parameters, ready to be used by pymongo. """ cursor_kwargs = {} if getattr(params, "filter", False): tree = self.parser.parse(params.filter) cursor_kwargs["filter"] = self.transformer.transform(tree) else: cursor_kwargs["filter"] = {} if (getattr(params, "response_format", False) and params.response_format != "json"): raise BadRequest( detail= f"Response format {params.response_format} is not supported, please use response_format='json'" ) if getattr(params, "page_limit", False): limit = params.page_limit if limit > CONFIG.page_limit_max: raise Forbidden( detail= f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}", ) cursor_kwargs["limit"] = limit else: cursor_kwargs["limit"] = CONFIG.page_limit cursor_kwargs["fields"] = self.all_fields cursor_kwargs["projection"] = [ self.resource_mapper.alias_for(f) for f in self.all_fields ] if getattr(params, "sort", False): cursor_kwargs["sort"] = self.parse_sort_params(params.sort) if getattr(params, "page_offset", False): cursor_kwargs["skip"] = params.page_offset return cursor_kwargs def parse_sort_params(self, sort_params) -> List[Tuple[str, int]]: """Handles any sort parameters passed to the collection, resolving aliases and dealing with any invalid fields. Raises: BadRequest: if an invalid sort is requested. Returns: A list of tuples containing the aliased field name and sort direction encoded as 1 (ascending) or -1 (descending). """ sort_spec = [] for field in sort_params.split(","): sort_dir = 1 if field.startswith("-"): field = field[1:] sort_dir = -1 aliased_field = self.resource_mapper.alias_for(field) sort_spec.append((aliased_field, sort_dir)) unknown_fields = [ field for field, _ in sort_spec if self.resource_mapper.alias_of(field) not in self.all_fields ] if unknown_fields: error_detail = "Unable to sort on unknown field{} '{}'".format( "s" if len(unknown_fields) > 1 else "", "', '".join(unknown_fields), ) # If all unknown fields are "other" provider-specific, then only provide a warning if all((re.match(r"_[a-z_0-9]+_[a-z_0-9]*", field) and not field.startswith(f"_{self.provider_prefix}_")) for field in unknown_fields): warnings.warn(error_detail, FieldValueNotRecognized) # Otherwise, if all fields are unknown, or some fields are unknown and do not # have other provider prefixes, then return 400: Bad Request else: raise BadRequest(detail=error_detail) # If at least one valid field has been provided for sorting, then use that sort_spec = tuple((field, sort_dir) for field, sort_dir in sort_spec if field not in unknown_fields) return sort_spec
class MongoCollection(EntryCollection): def __init__( self, collection: Union[pymongo.collection.Collection, mongomock.collection.Collection], resource_cls: EntryResource, resource_mapper: BaseResourceMapper, ): super().__init__(collection, resource_cls, resource_mapper) self.transformer = MongoTransformer(mapper=resource_mapper) self.provider_prefix = CONFIG.provider.prefix self.provider_fields = CONFIG.provider_fields.get( resource_mapper.ENDPOINT, []) self.parser = LarkParser( version=(0, 10, 1), variant="default" ) # The MongoTransformer only supports v0.10.1 as the latest grammar # check aliases do not clash with mongo operators self._check_aliases(self.resource_mapper.all_aliases()) self._check_aliases(self.resource_mapper.all_length_aliases()) def __len__(self): return self.collection.estimated_document_count() def __contains__(self, entry): return self.collection.count_documents(entry.dict()) > 0 def count(self, **kwargs): for k in list(kwargs.keys()): if k not in ("filter", "skip", "limit", "hint", "maxTimeMS"): del kwargs[k] if "filter" not in kwargs: # "filter" is needed for count_documents() kwargs["filter"] = {} return self.collection.count_documents(**kwargs) def find( self, params: Union[EntryListingQueryParams, SingleEntryQueryParams] ) -> Tuple[List[EntryResource], int, bool, set]: criteria = self._parse_params(params) all_fields = criteria.pop("fields") if getattr(params, "response_fields", False): fields = set(params.response_fields.split(",")) fields |= self.resource_mapper.get_required_fields() else: fields = all_fields.copy() results = [] for doc in self.collection.find(**criteria): results.append( self.resource_cls(**self.resource_mapper.map_back(doc))) nresults_now = len(results) if isinstance(params, EntryListingQueryParams): criteria_nolimit = criteria.copy() criteria_nolimit.pop("limit", None) data_returned = self.count(**criteria_nolimit) more_data_available = nresults_now < data_returned else: # SingleEntryQueryParams, e.g., /structures/{entry_id} data_returned = nresults_now more_data_available = False if nresults_now > 1: raise HTTPException( status_code=404, detail= f"Instead of a single entry, {nresults_now} entries were found", ) results = results[0] if results else None return results, data_returned, more_data_available, all_fields - fields def _parse_params( self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]) -> dict: cursor_kwargs = {} if getattr(params, "filter", False): tree = self.parser.parse(params.filter) cursor_kwargs["filter"] = self.transformer.transform(tree) else: cursor_kwargs["filter"] = {} if (getattr(params, "response_format", False) and params.response_format != "json"): raise HTTPException(status_code=400, detail="Only 'json' response_format supported") if getattr(params, "page_limit", False): limit = params.page_limit if limit > CONFIG.page_limit_max: raise HTTPException( status_code=403, # Forbidden detail= f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}", ) cursor_kwargs["limit"] = limit else: cursor_kwargs["limit"] = CONFIG.page_limit # All OPTIMADE fields fields = self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy() fields |= self.get_attribute_fields() # All provider-specific fields fields |= { f"_{self.provider_prefix}_{field_name}" for field_name in self.provider_fields } cursor_kwargs["fields"] = fields cursor_kwargs["projection"] = [ self.resource_mapper.alias_for(f) for f in fields ] if getattr(params, "sort", False): sort_spec = [] for elt in params.sort.split(","): field = elt sort_dir = 1 if elt.startswith("-"): field = field[1:] sort_dir = -1 sort_spec.append((field, sort_dir)) cursor_kwargs["sort"] = sort_spec if getattr(params, "page_offset", False): cursor_kwargs["skip"] = params.page_offset return cursor_kwargs def _check_aliases(self, aliases): """ Check that aliases do not clash with mongo keywords. """ if any(alias[0].startswith("$") or alias[1].startswith("$") for alias in aliases): raise RuntimeError( f"Cannot define an alias starting with a '$': {aliases}")
class AiidaCollection: """Collection of AiiDA entities""" CAST_MAPPING = { "string": "t", "float": "f", "integer": "i", "boolean": "b", "date-time": "d", } def __init__( self, collection: orm.entities.Collection, resource_cls: EntryResource, resource_mapper: ResourceMapper, ): self.collection = collection self.parser = LarkParser() self.resource_cls = resource_cls self.resource_mapper = resource_mapper self.transformer = AiidaTransformerV0_10_1() self.provider = CONFIG.provider["prefix"] self.provider_fields = CONFIG.provider_fields[resource_mapper.ENDPOINT] self.page_limit = CONFIG.page_limit self.db_page_limit = CONFIG.db_page_limit self.parser = LarkParser(version=(0, 10, 0)) # "Cache" self._data_available: int = None self._data_returned: int = None self._filter_fields: set = None self._latest_filter: dict = None def get_attribute_fields(self) -> set: schema = self.resource_cls.schema() attributes = schema["properties"]["attributes"] if "allOf" in attributes: allOf = attributes.pop("allOf") for dict_ in allOf: attributes.update(dict_) if "$ref" in attributes: path = attributes["$ref"].split("/")[1:] attributes = schema.copy() while path: next_key = path.pop(0) attributes = attributes[next_key] return set(attributes["properties"].keys()) @staticmethod def _find(backend: orm.implementation.Backend, entity_type: orm.Entity, **kwargs) -> orm.QueryBuilder: for key in kwargs: if key not in { "filters", "order_by", "limit", "project", "offset" }: raise ValueError( f"You supplied key {key}. _find() only takes the keys: " '"filters", "order_by", "limit", "project", "offset"') filters = kwargs.get("filters", {}) order_by = kwargs.get("order_by", None) order_by = { entity_type: order_by } if order_by else { entity_type: { "id": "asc" } } limit = kwargs.get("limit", None) offset = kwargs.get("offset", None) project = kwargs.get("project", []) query = orm.QueryBuilder(backend=backend, limit=limit, offset=offset) query.append(entity_type, project=project, filters=filters) query.order_by(order_by) return query def _find_all(self, backend: orm.implementation.Backend, **kwargs) -> orm.QueryBuilder: query = self._find(backend, self.collection.entity_type, **kwargs) res = query.all() del query return res def count(self, backend: orm.implementation.Backend, **kwargs): # pylint: disable=arguments-differ query = self._find(backend, self.collection.entity_type, **kwargs) res = query.count() del query return res @property def data_available(self) -> int: if self._data_available is None: raise CausationError( "data_available MUST be set before it can be retrieved.") return self._data_available def set_data_available(self, backend: orm.implementation.Backend): """Set _data_available if it has not yet been set""" if not self._data_available: self._data_available = self.count(backend) @property def data_returned(self) -> int: if self._data_returned is None: raise CausationError( "data_returned MUST be set before it can be retrieved.") return self._data_returned def set_data_returned(self, backend: orm.implementation.Backend, **criteria): """Set _data_returned if it has not yet been set or new filter does not equal latest filter. NB! Nested lists in filters are not accounted for. """ if self._data_returned is None or ( self._latest_filter is not None and criteria.get("filters", {}) != self._latest_filter): for key in ["limit", "offset"]: if key in list(criteria.keys()): del criteria[key] self._latest_filter = criteria.get("filters", {}) self._data_returned = self.count(backend, **criteria) def find( # pylint: disable=arguments-differ self, backend: orm.implementation.Backend, params: Union[EntryListingQueryParams, SingleEntryQueryParams], ) -> Tuple[List[EntryResource], NonnegativeInt, bool, NonnegativeInt, set]: self.set_data_available(backend) criteria = self._parse_params(params) all_fields = criteria.pop("fields") if getattr(params, "response_fields", False): fields = set(params.response_fields.split(",")) else: fields = all_fields.copy() if criteria.get("filters", {}) and self._get_extras_filter_fields(): self._check_and_calculate_entities(backend) self.set_data_returned(backend, **criteria) entities = self._find_all(backend, **criteria) results = [] for entity in entities: results.append( self.resource_cls(**self.resource_mapper.map_back( dict(zip(criteria["project"], entity))))) if isinstance(params, EntryListingQueryParams): criteria_no_limit = criteria.copy() criteria_no_limit.pop("limit", None) more_data_available = len(results) < self.count( backend, **criteria_no_limit) else: more_data_available = False if len(results) > 1: raise HTTPException( status_code=404, detail= f"Instead of a single entry, {len(results)} entries were found", ) if isinstance(params, SingleEntryQueryParams): results = results[0] if results else None return ( results, self.data_returned, more_data_available, self.data_available, all_fields - fields, ) def _alias_filter(self, filters: Any) -> Union[dict, list]: """Get aliased field names in nested filter query. I.e. turn OPTiMaDe field names into AiiDA field names """ if isinstance(filters, dict): res = {} for key, value in filters.items(): new_value = value if isinstance(value, (dict, list)): new_value = self._alias_filter(value) aliased_key = self.resource_mapper.alias_for(key) res[aliased_key] = new_value self._filter_fields.add(aliased_key) elif isinstance(filters, list): res = [] for item in filters: new_value = item if isinstance(item, (dict, list)): new_value = self._alias_filter(item) res.append(new_value) else: raise NotImplementedError( "_alias_filter can only handle dict and list objects") return res def _parse_params(self, params: EntryListingQueryParams) -> dict: """Parse query parameters and transform them into AiiDA QueryBuilder concepts""" cursor_kwargs = {} # filter if getattr(params, "filter", False): aiida_filter = self.transformer.transform( self.parser.parse(params.filter)) self._filter_fields = set() cursor_kwargs["filters"] = self._alias_filter(aiida_filter) # response_format if (getattr(params, "response_format", False) and params.response_format != "json"): raise HTTPException(status_code=400, detail="Only 'json' response_format supported") # page_limit if getattr(params, "page_limit", False): limit = self.page_limit if params.page_limit != self.page_limit: limit = params.page_limit if limit > self.db_page_limit: raise HTTPException( status_code=403, detail= f"Max allowed page_limit is {self.db_page_limit}, you requested {limit}", ) if limit == 0: limit = self.page_limit cursor_kwargs["limit"] = limit # response_fields # All OPTiMaDe fields fields = {"id", "type"} fields |= self.get_attribute_fields() # All provider-specific fields fields |= {self.provider + _ for _ in self.provider_fields} cursor_kwargs["fields"] = fields cursor_kwargs["project"] = list( {self.resource_mapper.alias_for(f) for f in fields}) # sort # NOTE: sorting only works for extras fields for the nodes already with calculated extras. # To calculate all extras, make a single filter query using any extra field. if getattr(params, "sort", False): sort_spec = [] for entity_property in params.sort.split(","): field = entity_property sort_direction = "asc" if entity_property.startswith("-"): field = field[1:] sort_direction = "desc" aliased_field = self.resource_mapper.alias_for(field) _, properties = retrieve_queryable_properties( self.resource_cls.schema(), {"id", "type", "attributes"}) field_type = properties[field].get( "format", properties[field].get("type", "")) if field_type == "array": raise TypeError( "Cannot sort on a field with a list value type") sort_spec.append({ aliased_field: { "order": sort_direction, "cast": self.CAST_MAPPING[field_type], } }) cursor_kwargs["order_by"] = sort_spec # page_offset if getattr(params, "page_offset", False): cursor_kwargs["offset"] = params.page_offset return cursor_kwargs def _get_extras_filter_fields(self) -> set: return { field[len(self.resource_mapper.PROJECT_PREFIX):] for field in self._filter_fields if field.startswith(self.resource_mapper.PROJECT_PREFIX) } def _check_and_calculate_entities(self, backend: orm.implementation.Backend): """Check all entities have OPTiMaDe extras, else calculate them For a bit of optimization, we only care about a field if it has specifically been queried for using "filter". """ extras_keys = [ key for key in self.resource_mapper.PROJECT_PREFIX.split(".") if key ] filter_fields = [{ "!has_key": field for field in self._get_extras_filter_fields() }] necessary_entities_qb = orm.QueryBuilder().append( self.collection.entity_type, filters={ "or": [ { extras_keys[0]: { "!has_key": extras_keys[1] } }, { ".".join(extras_keys): { "or": filter_fields } }, ] }, project="id", ) if necessary_entities_qb.count() > 0: # Necessary entities for the OPTiMaDe query exist with unknown OPTiMaDe fields. necessary_entity_ids = [ pk[0] for pk in necessary_entities_qb.iterall() ] # Create the missing OPTiMaDe fields: # All OPTiMaDe fields fields = {"id", "type"} fields |= self.get_attribute_fields() # All provider-specific fields fields |= {self.provider + _ for _ in self.provider_fields} fields = list({self.resource_mapper.alias_for(f) for f in fields}) entities = self._find_all( backend, filters={"id": { "in": necessary_entity_ids }}, project=fields) for entity in entities: self.resource_cls( **self.resource_mapper.map_back(dict(zip(fields, entity))))
def test_parser_version(self): v = (0, 9, 5) p = LarkParser(version=v) self.assertIsInstance(p.parse(self.test_filters[0]), Tree) self.assertEqual(p.version, v)
class DjangoTransformer: """Filter transformer for implementations using Django. !!! warning "Warning" This transformer is deprecated as it only supports the 0.9.7 grammar version, and works different to other filter transformers in this package. """ def __init__(self): self.opers = { "=": self.eq, ">": self.gt, ">=": self.ge, "<": self.lt, "<=": self.le, "!=": self.ne, "OR": self.or_, "AND": self.and_, "NOT": self.not_, } self.parser = LarkParser(version=(0, 9, 7)) def parse_raw_q(self, raw_query): return self.parser.parse(raw_query) def eq(self, a, b): return Q(**{a: b}) def gt(self, a, b): return Q(**{a + "__gt": b}) def ge(self, a, b): return Q(**{a + "__gte": b}) def lt(self, a, b): return Q(**{a + "__lt": b}) def le(self, a, b): return Q(**{a + "__lte": b}) def ne(self, a, b): return ~Q(**{a: b}) def not_(self, a): return ~a def and_(self, a, b): return operator.and_(a, b) def or_(self, a, b): return operator.or_(a, b) def evaluate(self, parse_Tree): if isinstance(parse_Tree, Tree): children = parse_Tree.children if len(children) == 1: return self.evaluate(children[0]) elif len(children) == 2: op_fn = self.evaluate(children[0]) return op_fn(self.evaluate(children[1])) elif len(children) == 3: if parse_Tree.data == "comparison": db_prop = self.evaluate(children[0]) op_fn = self.evaluate(children[1]) if db_prop in django_db_keys.keys(): return op_fn(django_db_keys[db_prop], self.evaluate(children[2])) else: raise DjangoQueryError( "Unknown property is queried : " + (db_prop)) else: op_fn = self.evaluate(children[1]) return op_fn(self.evaluate(children[0]), self.evaluate(children[2])) else: raise DjangoQueryError( "Not compatible format. Tree has >3 children") elif isinstance(parse_Tree, Token): if parse_Tree.type == "VALUE": return parse_Tree.value elif parse_Tree.type in ["NOT", "CONJUNCTION", "OPERATOR"]: return self.opers[parse_Tree.value] else: raise DjangoQueryError("Not a Lark Tree or Token")
class EntryCollection(ABC): """Backend-agnostic base class for querying collections of [`EntryResource`][optimade.models.entries.EntryResource]s.""" def __init__( self, resource_cls: EntryResource, resource_mapper: BaseResourceMapper, transformer: Transformer, ): """Initialize the collection for the given parameters. Parameters: resource_cls (EntryResource): The `EntryResource` model that is stored by the collection. resource_mapper (BaseResourceMapper): A resource mapper object that handles aliases and format changes between deserialization and response. transformer (Transformer): The Lark `Transformer` used to interpret the filter. """ self.parser = LarkParser() self.resource_cls = resource_cls self.resource_mapper = resource_mapper self.transformer = transformer self.provider_prefix = CONFIG.provider.prefix self.provider_fields = [ field if isinstance(field, str) else field["name"] for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, []) ] self._all_fields: Set[str] = None @abstractmethod def __len__(self) -> int: """Returns the total number of entries in the collection.""" @abstractmethod def insert(self, data: List[EntryResource]) -> None: """Add the given entries to the underlying database. Arguments: data: The entry resource objects to add to the database. """ @abstractmethod def count(self, **kwargs) -> int: """Returns the number of entries matching the query specified by the keyword arguments. Parameters: kwargs (dict): Query parameters as keyword arguments. """ def find( self, params: Union[EntryListingQueryParams, SingleEntryQueryParams] ) -> Tuple[Union[List[EntryResource], EntryResource, None], int, bool, Set[str], Set[str]]: """ Fetches results and indicates if more data is available. Also gives the total number of data available in the absence of `page_limit`. See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams] for more information. Parameters: params: Entry listing URL query params. Returns: A tuple of various relevant values: (`results`, `data_returned`, `more_data_available`, `exclude_fields`, `include_fields`). """ criteria = self.handle_query_params(params) single_entry = isinstance(params, SingleEntryQueryParams) response_fields = criteria.pop("fields") results, data_returned, more_data_available = self._run_db_query( criteria, single_entry) if single_entry: results = results[0] if results else None if data_returned > 1: raise NotFound( detail= f"Instead of a single entry, {data_returned} entries were found", ) exclude_fields = self.all_fields - response_fields include_fields = (response_fields - self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS) bad_optimade_fields = set() bad_provider_fields = set() for field in include_fields: if field not in self.resource_mapper.ALL_ATTRIBUTES: if field.startswith("_"): if any( field.startswith(f"_{prefix}_") for prefix in self.resource_mapper.SUPPORTED_PREFIXES): bad_provider_fields.add(field) else: bad_optimade_fields.add(field) if bad_provider_fields: warnings.warn( message= f"Unrecognised field(s) for this provider requested in `response_fields`: {bad_provider_fields}.", category=UnknownProviderProperty, ) if bad_optimade_fields: raise BadRequest( detail= f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}." ) if results: results = self.resource_mapper.deserialize(results) return ( results, data_returned, more_data_available, exclude_fields, include_fields, ) @abstractmethod def _run_db_query(self, criteria: Dict[str, Any], single_entry: bool = False ) -> Tuple[List[Dict[str, Any]], int, bool]: """Run the query on the backend and collect the results. Arguments: criteria: A dictionary representation of the query parameters. single_entry: Whether or not the caller is expecting a single entry response. Returns: The list of entries from the database (without any re-mapping), the total number of entries matching the query and a boolean for whether or not there is more data available. """ @property def all_fields(self) -> Set[str]: """Get the set of all fields handled in this collection, from attribute fields in the schema, provider fields and top-level OPTIMADE fields. The set of all fields are lazily created and then cached. This means the set is created the first time the property is requested and then cached. Returns: All fields handled in this collection. """ if not self._all_fields: # All OPTIMADE fields self._all_fields = ( self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy()) self._all_fields |= self.get_attribute_fields() # All provider-specific fields self._all_fields |= { f"_{self.provider_prefix}_{field_name}" for field_name in self.provider_fields } return self._all_fields def get_attribute_fields(self) -> Set[str]: """Get the set of attribute fields Return only the _first-level_ attribute fields from the schema of the resource class, resolving references along the way if needed. Note: It is not needed to take care of other special OpenAPI schema keys than `allOf`, since only `allOf` will be found in this context. Other special keys can be found in [the Swagger documentation](https://swagger.io/docs/specification/data-models/oneof-anyof-allof-not/). Returns: Property names. """ schema = self.resource_cls.schema() attributes = schema["properties"]["attributes"] if "allOf" in attributes: allOf = attributes.pop("allOf") for dict_ in allOf: attributes.update(dict_) if "$ref" in attributes: path = attributes["$ref"].split("/")[1:] attributes = schema.copy() while path: next_key = path.pop(0) attributes = attributes[next_key] return set(attributes["properties"].keys()) def handle_query_params( self, params: Union[EntryListingQueryParams, SingleEntryQueryParams] ) -> Dict[str, Any]: """Parse and interpret the backend-agnostic query parameter models into a dictionary that can be used by the specific backend. Note: Currently this method returns the pymongo interpretation of the parameters, which will need modification for modified for other backends. Parameters: params (Union[EntryListingQueryParams, SingleEntryQueryParams]): The initialized query parameter model from the server. Raises: Forbidden: If too large of a page limit is provided. BadRequest: If an invalid request is made, e.g., with incorrect fields or response format. Returns: A dictionary representation of the query parameters. """ cursor_kwargs = {} # filter if getattr(params, "filter", False): cursor_kwargs["filter"] = self.transformer.transform( self.parser.parse(params.filter)) else: cursor_kwargs["filter"] = {} # response_format if (getattr(params, "response_format", False) and params.response_format != "json"): raise BadRequest( detail= f"Response format {params.response_format} is not supported, please use response_format='json'" ) # page_limit if getattr(params, "page_limit", False): limit = params.page_limit if limit > CONFIG.page_limit_max: raise Forbidden( detail= f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}", ) cursor_kwargs["limit"] = limit else: cursor_kwargs["limit"] = CONFIG.page_limit # response_fields cursor_kwargs["projection"] = { f"{self.resource_mapper.get_backend_field(f)}": True for f in self.all_fields } if "_id" not in cursor_kwargs["projection"]: cursor_kwargs["projection"]["_id"] = False if getattr(params, "response_fields", False): response_fields = set(params.response_fields.split(",")) response_fields |= self.resource_mapper.get_required_fields() else: response_fields = self.all_fields.copy() cursor_kwargs["fields"] = response_fields # sort if getattr(params, "sort", False): cursor_kwargs["sort"] = self.parse_sort_params(params.sort) # page_offset if getattr(params, "page_offset", False): cursor_kwargs["skip"] = params.page_offset return cursor_kwargs def parse_sort_params(self, sort_params: str) -> Tuple[Tuple[str, int]]: """Handles any sort parameters passed to the collection, resolving aliases and dealing with any invalid fields. Raises: BadRequest: if an invalid sort is requested. Returns: A tuple of tuples containing the aliased field name and sort direction encoded as 1 (ascending) or -1 (descending). """ sort_spec = [] for field in sort_params.split(","): sort_dir = 1 if field.startswith("-"): field = field[1:] sort_dir = -1 aliased_field = self.resource_mapper.get_backend_field(field) sort_spec.append((aliased_field, sort_dir)) unknown_fields = [ field for field, _ in sort_spec if self.resource_mapper.get_optimade_field(field) not in self.all_fields ] if unknown_fields: error_detail = "Unable to sort on unknown field{} '{}'".format( "s" if len(unknown_fields) > 1 else "", "', '".join(unknown_fields), ) # If all unknown fields are "other" provider-specific, then only provide a warning if all((re.match(r"_[a-z_0-9]+_[a-z_0-9]*", field) and not field.startswith(f"_{self.provider_prefix}_")) for field in unknown_fields): warnings.warn(error_detail, FieldValueNotRecognized) # Otherwise, if all fields are unknown, or some fields are unknown and do not # have other provider prefixes, then return 400: Bad Request else: raise BadRequest(detail=error_detail) # If at least one valid field has been provided for sorting, then use that sort_spec = tuple((field, sort_dir) for field, sort_dir in sort_spec if field not in unknown_fields) return sort_spec
def test_suspected_timestamp_fields(self, mapper): import datetime import bson.tz_util from optimade.filtertransformers.mongo import MongoTransformer from optimade.server.warnings import TimestampNotRFCCompliant example_RFC3339_date = "2019-06-08T04:13:37Z" example_RFC3339_date_2 = "2019-06-08T04:13:37" example_non_RFC3339_date = "2019-06-08T04:13:37.123Z" expected_datetime = datetime.datetime( year=2019, month=6, day=8, hour=4, minute=13, second=37, microsecond=0, tzinfo=bson.tz_util.utc, ) assert self.transform(f'last_modified > "{example_RFC3339_date}"') == { "last_modified": { "$gt": expected_datetime } } assert self.transform( f'last_modified > "{example_RFC3339_date_2}"') == { "last_modified": { "$gt": expected_datetime } } non_rfc_datetime = expected_datetime.replace(microsecond=123000) with pytest.warns(TimestampNotRFCCompliant): assert self.transform( f'last_modified > "{example_non_RFC3339_date}"') == { "last_modified": { "$gt": non_rfc_datetime } } class MyMapper(mapper("StructureMapper")): ALIASES = (("last_modified", "ctime"), ) transformer = MongoTransformer(mapper=MyMapper) parser = LarkParser(version=self.version, variant=self.variant) assert transformer.transform( parser.parse(f'last_modified > "{example_RFC3339_date}"')) == { "ctime": { "$gt": expected_datetime } } assert transformer.transform( parser.parse(f'last_modified > "{example_RFC3339_date_2}"')) == { "ctime": { "$gt": expected_datetime } }
def test_parser_version(self): v = (0, 9, 5) p = LarkParser(version=v) assert isinstance(p.parse(self.test_filters[0]), Tree) assert p.version == v
def test_aliases(self, mapper): """Test that valid aliases are allowed, but do not affect r-values.""" from optimade.filtertransformers.mongo import MongoTransformer class MyStructureMapper(mapper("StructureMapper")): ALIASES = ( ("elements", "my_elements"), ("A", "D"), ("property_a", "D"), ("B", "E"), ("C", "F"), ("_exmpl_nested_field", "nested_field"), ) PROVIDER_FIELDS = ("D", "E", "F", "nested_field") mapper = MyStructureMapper t = MongoTransformer(mapper=mapper) p = LarkParser(version=self.version, variant=self.variant) assert mapper.get_backend_field("elements") == "my_elements" test_filter = 'elements HAS "A"' assert t.transform(p.parse(test_filter)) == { "my_elements": { "$in": ["A"] } } test_filter = 'elements HAS ANY "A","B","C" AND elements HAS "D"' assert t.transform(p.parse(test_filter)) == { "$and": [ { "my_elements": { "$in": ["A", "B", "C"] } }, { "my_elements": { "$in": ["D"] } }, ] } test_filter = 'elements = "A"' assert t.transform(p.parse(test_filter)) == { "my_elements": { "$eq": "A" } } test_filter = 'property_a HAS "B"' assert t.transform(p.parse(test_filter)) == {"D": {"$in": ["B"]}} test_filter = "_exmpl_nested_field.sub_property > 1234.5" assert t.transform(p.parse(test_filter)) == { "nested_field.sub_property": { "$gt": 1234.5 } } test_filter = "_exmpl_nested_field.sub_property.x IS UNKNOWN" assert t.transform(p.parse(test_filter)) == { "$or": [ { "nested_field.sub_property.x": { "$exists": False } }, { "nested_field.sub_property.x": { "$eq": None } }, ] }
class TestTransformer: @pytest.fixture(autouse=True) def set_up(self): from optimade.filtertransformers.elasticsearch import Transformer, Quantity self.parser = LarkParser(version=(0, 10, 0), variant="elastic") nelements = Quantity(name="nelements") elements_only = Quantity(name="elements_only") elements_ratios = Quantity(name="elements_ratios") elements_ratios.nested_quantity = elements_ratios elements = Quantity( name="elements", length_quantity=nelements, has_only_quantity=elements_only, nested_quantity=elements_ratios, ) dimension_types = Quantity(name="dimension_types") dimension_types.length_quantity = dimension_types quantities = [ nelements, elements_only, elements_ratios, elements, dimension_types, Quantity(name="chemical_formula_reduced"), ] self.transformer = Transformer(quantities=quantities) def test_parse_n_transform(self): queries = [ ("nelements > 1", 4), ("nelements >= 2", 4), ("nelements > 2", 1), ("nelements < 4", 4), ("nelements < 3", 3), ("nelements <= 3", 4), ("nelements != 2", 1), ("1 < nelements", 4), ('elements HAS "H"', 4), ('elements HAS ALL "H", "O"', 4), ('elements HAS ALL "H", "C"', 1), ('elements HAS ANY "H", "C"', 4), ('elements HAS ANY "C"', 1), ('elements HAS ONLY "C"', 0), ('elements HAS ONLY "H", "O"', 3), ('elements:elements_ratios HAS "H":>0.66', 2), ('elements:elements_ratios HAS ALL "O":>0.33', 3), ('elements:elements_ratios HAS ALL "O":>0.33,"O":<0.34', 2), ("elements IS KNOWN", 4), ("elements IS UNKNOWN", 0), ('chemical_formula_reduced = "H2O"', 2), ('chemical_formula_reduced CONTAINS "H2"', 3), ('chemical_formula_reduced CONTAINS "H"', 4), ('chemical_formula_reduced CONTAINS "C"', 1), ('chemical_formula_reduced STARTS "H2"', 3), ('chemical_formula_reduced STARTS WITH "H2"', 3), ('chemical_formula_reduced ENDS WITH "C"', 1), ('chemical_formula_reduced ENDS "C"', 1), ("LENGTH elements = 2", 3), ("LENGTH elements = 3", 1), ("LENGTH dimension_types = 0", 3), ("LENGTH dimension_types = 1", 1), ("nelements = 2 AND LENGTH dimension_types = 1", 1), ("nelements = 3 AND LENGTH dimension_types = 1", 0), ("nelements = 3 OR LENGTH dimension_types = 1", 2), ("nelements > 1 OR LENGTH dimension_types = 1 AND nelements = 2", 4), ("(nelements > 1 OR LENGTH dimension_types = 1) AND nelements = 2", 3), ("NOT LENGTH dimension_types = 1", 3), ] for query, _ in queries: tree = self.parser.parse(query) result = self.transformer.transform(tree) assert result is not None
f = "list HAS < 3, > 4" # -> error f = "list HAS ALL < 3, > 4" # multiple lists f = "list1:list2 HAS < 3 : > 4" f = "list1:list2 HAS ALL < 3 : > 4" f = "list1:list2 HAS < 3 : > 4, < 2 : > 5" # -> error f = "list1:list2 HAS ALL < 3 : > 4, < 2 : > 5" f = "list1:list2 HAS ALL < 3, < 2 : > 4, > 5" # -> error # f = 'list1:list2 HAS < 3, > 4' # -> error # f = 'list1:list2 HAS ALL < 3, > 4' # -> error f = 'elements:elements_ratios HAS ALL "Al":>0.3333, "Al":<0.3334' f = 'elements:elements_ratios HAS ALL "Al":>0.3333 AND elements_ratio<0.3334' f = 'elements:elements_ratios HAS ALL "Al" : >0.3333, <0.3334' # -> error f = "list1:list2 HAS ALL < 3 : > 4, < 2 : > 5 : > 4, < 2 : > 5" # valid but wrong f = "ghf.flk<gh" # valid but wrong # f = '' tree = p.parse(f) print(tree) print(tree.pretty()) t.transform(tree)
def test_aliased_length_operator(): """Test LENGTH operator alias""" from optimade.server.mappers import StructureMapper class MyMapper(StructureMapper): """Test mapper with LENGTH_ALIASES""" ALIASES = (("elements", "my_elements"), ("nelements", "nelem")) LENGTH_ALIASES = ( ("chemsys", "nelements"), ("cartesian_site_positions", "nsites"), ("elements", "nelements"), ) PROVIDER_FIELDS = ("chemsys", ) transformer = AiidaTransformer(mapper=MyMapper()) parser = LarkParser(version=VERSION, variant=VARIANT) assert transformer.transform( parser.parse("cartesian_site_positions LENGTH <= 3")) == { "nsites": { "<=": 3 } } assert transformer.transform( parser.parse("cartesian_site_positions LENGTH < 3")) == { "nsites": { "<": 3 } } assert transformer.transform( parser.parse("cartesian_site_positions LENGTH 3")) == ({ "nsites": 3 }) assert transformer.transform( parser.parse("cartesian_site_positions LENGTH 3")) == ({ "nsites": 3 }) assert transformer.transform( parser.parse("cartesian_site_positions LENGTH >= 10")) == { "nsites": { ">=": 10 } } assert transformer.transform( parser.parse("structure_features LENGTH > 10")) == ({ "structure_features": { "longer": 10 } }) assert transformer.transform(parser.parse("nsites LENGTH > 10")) == ({ "nsites": { "longer": 10 } }) assert transformer.transform(parser.parse("elements LENGTH 3")) == { "nelem": 3 } assert transformer.transform(parser.parse('elements HAS "Ag"')) == ({ "my_elements": { "contains": ["Ag"] } }) assert transformer.transform(parser.parse("chemsys LENGTH 3")) == { "nelem": 3 }