コード例 #1
0
    def test_mongo_special_id(self, mapper):

        from optimade.filtertransformers.mongo import MongoTransformer
        from bson import ObjectId

        class MyMapper(mapper("StructureMapper")):
            ALIASES = (("immutable_id", "_id"), )

        transformer = MongoTransformer(mapper=MyMapper())
        parser = LarkParser(version=self.version, variant=self.variant)

        assert transformer.transform(
            parser.parse('immutable_id = "5cfb441f053b174410700d02"')) == {
                "_id": {
                    "$eq": ObjectId("5cfb441f053b174410700d02")
                }
            }

        assert transformer.transform(
            parser.parse('immutable_id != "5cfb441f053b174410700d02"')) == {
                "_id": {
                    "$ne": ObjectId("5cfb441f053b174410700d02")
                }
            }

        for op in ("CONTAINS", "STARTS WITH", "ENDS WITH", "HAS"):
            with pytest.raises(
                    BadRequest,
                    match=
                    r".*not supported for query on field 'immutable_id', can only test for equality.*",
            ):
                transformer.transform(
                    parser.parse(f'immutable_id {op} "abcdef"'))
コード例 #2
0
class TestParserV0_9_5:
    @pytest.fixture(autouse=True)
    def set_up(self):
        self.test_filters = []
        for fn in sorted(glob(os.path.join(testfile_dir, "*.inp"))):
            with open(fn) as f:
                self.test_filters.append(f.read().strip())
        self.parser = LarkParser(version=(0, 9, 5))

    def test_inputs(self):
        for tf in self.test_filters:
            if tf == "filter=number=0.0.1":
                with pytest.raises(ParserError):
                    self.parser.parse(tf)
            else:
                tree = self.parser.parse(tf)
                assert isinstance(tree, Tree)

    def test_parser_version(self):
        v = (0, 9, 5)
        p = LarkParser(version=v)
        assert isinstance(p.parse(self.test_filters[0]), Tree)
        assert p.version == v

    def test_repr(self):
        assert repr(self.parser) is not None
        self.parser.parse(self.test_filters[0])
        assert repr(self.parser) is not None
コード例 #3
0
class ParserTestV0_9_5(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.test_filters = []
        for fn in sorted(glob(os.path.join(testfile_dir, "*.inp"))):
            with open(fn) as f:
                cls.test_filters.append(f.read().strip())

    def setUp(self):
        self.parser = LarkParser(version=(0, 9, 5))

    def test_inputs(self):
        for tf in self.test_filters:
            if tf == "filter=number=0.0.1":
                self.assertRaises(ParserError, self.parser.parse, tf)
            else:
                tree = self.parser.parse(tf)
                self.assertTrue(tree, Tree)

    def test_parser_version(self):
        v = (0, 9, 5)
        p = LarkParser(version=v)
        self.assertIsInstance(p.parse(self.test_filters[0]), Tree)
        self.assertEqual(p.version, v)

    def test_repr(self):
        self.assertIsNotNone(repr(self.parser))
        self.parser.parse(self.test_filters[0])
        self.assertIsNotNone(repr(self.parser))
コード例 #4
0
    def test_aliased_length_operator(self, mapper):
        from optimade.filtertransformers.mongo import MongoTransformer

        class MyMapper(mapper("StructureMapper")):
            ALIASES = (("elements", "my_elements"), ("nelements", "nelem"))
            LENGTH_ALIASES = (
                ("chemsys", "nelements"),
                ("cartesian_site_positions", "nsites"),
                ("elements", "nelements"),
            )
            PROVIDER_FIELDS = ("chemsys",)

        transformer = MongoTransformer(mapper=MyMapper())
        parser = LarkParser(version=self.version, variant=self.variant)

        assert transformer.transform(
            parser.parse("cartesian_site_positions LENGTH <= 3")
        ) == {"nsites": {"$lte": 3}}
        assert transformer.transform(
            parser.parse("cartesian_site_positions LENGTH < 3")
        ) == {"nsites": {"$lt": 3}}
        assert transformer.transform(
            parser.parse("cartesian_site_positions LENGTH 3")
        ) == {"nsites": 3}
        assert transformer.transform(
            parser.parse("cartesian_site_positions LENGTH 3")
        ) == {"nsites": 3}
        assert transformer.transform(
            parser.parse("cartesian_site_positions LENGTH >= 10")
        ) == {"nsites": {"$gte": 10}}

        assert transformer.transform(
            parser.parse("structure_features LENGTH > 10")
        ) == {"structure_features.11": {"$exists": True}}

        assert transformer.transform(parser.parse("nsites LENGTH > 10")) == {
            "nsites.11": {"$exists": True}
        }

        assert transformer.transform(parser.parse("elements LENGTH 3")) == {"nelem": 3}

        assert transformer.transform(parser.parse('elements HAS "Ag"')) == {
            "my_elements": {"$in": ["Ag"]}
        }

        assert transformer.transform(parser.parse("chemsys LENGTH 3")) == {"nelem": 3}
コード例 #5
0
    def test_list_length_aliases(self, mapper):
        from optimade.filtertransformers.mongo import MongoTransformer

        transformer = MongoTransformer(mapper=mapper("StructureMapper")())
        parser = LarkParser(version=self.version, variant=self.variant)

        assert transformer.transform(parser.parse("elements LENGTH 3")) == {
            "nelements": 3
        }

        assert transformer.transform(
            parser.parse('elements HAS "Li" AND elements LENGTH = 3')
        ) == {"$and": [{"elements": {"$in": ["Li"]}}, {"nelements": 3}]}

        assert transformer.transform(parser.parse("elements LENGTH > 3")) == {
            "nelements": {"$gt": 3}
        }
        assert transformer.transform(parser.parse("elements LENGTH < 3")) == {
            "nelements": {"$lt": 3}
        }
        assert transformer.transform(parser.parse("elements LENGTH = 3")) == {
            "nelements": 3
        }
        assert transformer.transform(
            parser.parse("cartesian_site_positions LENGTH <= 3")
        ) == {"nsites": {"$lte": 3}}
        assert transformer.transform(
            parser.parse("cartesian_site_positions LENGTH >= 3")
        ) == {"nsites": {"$gte": 3}}
コード例 #6
0
    def test_other_provider_fields(self, mapper):
        """Test that fields from other providers generate
        queries that treat the value of the field as `null`.

        """
        from optimade.filtertransformers.mongo import MongoTransformer

        t = MongoTransformer(mapper=mapper("StructureMapper"))
        p = LarkParser(version=self.version, variant=self.variant)
        assert t.transform(p.parse("_other_provider_field > 1")) == {
            "_other_provider_field": {
                "$gt": 1
            }
        }
コード例 #7
0
    def __init__(self):
        p = LarkParser(version=(1, 0, 0), variant="default")
        t = MongoTransformer()
        self.transform = lambda inp: t.transform(p.parse(inp))

        client = MongoClient('mongodb://{}:{}@{}:{}/?authSource={}'.format(
            "admin", "admin", "localhost", "27017", "admin"))
        db = client["MaterialsDB"]
        self.cl = db["Data.Calculation.StaticCalculation"]

        self.lu = Lower2Upper()

        self.data = info1
        self.info = info2
コード例 #8
0
    def test_list_length_aliases(self):
        from optimade.server.mappers import StructureMapper

        transformer = MongoTransformer(mapper=StructureMapper())
        parser = LarkParser(version=self.version, variant=self.variant)

        self.assertEqual(
            transformer.transform(parser.parse("elements LENGTH 3")),
            {"nelements": 3})

        self.assertEqual(
            transformer.transform(
                parser.parse('elements HAS "Li" AND elements LENGTH = 3')),
            {"$and": [{
                "elements": {
                    "$in": ["Li"]
                }
            }, {
                "nelements": 3
            }]},
        )

        self.assertEqual(
            transformer.transform(parser.parse("elements LENGTH > 3")),
            {"nelements": {
                "$gt": 3
            }},
        )
        self.assertEqual(
            transformer.transform(parser.parse("elements LENGTH < 3")),
            {"nelements": {
                "$lt": 3
            }},
        )
        self.assertEqual(
            transformer.transform(parser.parse("elements LENGTH = 3")),
            {"nelements": 3})
        self.assertEqual(
            transformer.transform(
                parser.parse("cartesian_site_positions LENGTH <= 3")),
            {"nsites": {
                "$lte": 3
            }},
        )
        self.assertEqual(
            transformer.transform(
                parser.parse("cartesian_site_positions LENGTH >= 3")),
            {"nsites": {
                "$gte": 3
            }},
        )
コード例 #9
0
ファイル: test_aiida.py プロジェクト: csadorf/aiida-optimade
def test_list_length_aliases():
    """Check LENGTH aliases for lists"""
    from optimade.server.mappers import StructureMapper

    transformer = AiidaTransformer(mapper=StructureMapper())
    parser = LarkParser(version=VERSION, variant=VARIANT)

    assert transformer.transform(parser.parse("elements LENGTH 3")) == {
        "nelements": 3
    }

    assert transformer.transform(
        parser.parse('elements HAS "Li" AND elements LENGTH = 3')) == {
            "and": [{
                "elements": {
                    "contains": ["Li"]
                }
            }, {
                "nelements": 3
            }]
        }

    assert transformer.transform(parser.parse("elements LENGTH > 3")) == ({
        "nelements": {
            ">": 3
        }
    })
    assert transformer.transform(parser.parse("elements LENGTH < 3")) == ({
        "nelements": {
            "<": 3
        }
    })
    assert transformer.transform(parser.parse("elements LENGTH = 3")) == {
        "nelements": 3
    }
    assert transformer.transform(
        parser.parse("cartesian_site_positions LENGTH <= 3")) == {
            "nsites": {
                "<=": 3
            }
        }
    assert transformer.transform(
        parser.parse("cartesian_site_positions LENGTH >= 3")) == {
            "nsites": {
                ">=": 3
            }
        }
コード例 #10
0
class BaseTestFilterParser(abc.ABC):
    """Base class for parsing different versions of the grammar using `LarkParser`."""

    version: Tuple[int, int, int]
    variant: str = "default"

    @pytest.fixture(autouse=True)
    def set_up(self):
        self.parser = LarkParser(version=self.version, variant=self.variant)

    def test_repr(self):
        assert repr(self.parser) is not None
        self.parse("band_gap = 1")
        assert repr(self.parser) is not None

    def parse(self, inp):
        return self.parser.parse(inp)

    def test_parser_version(self):
        assert self.parser.version == self.version
        assert self.parser.variant == self.variant
コード例 #11
0
    def test_filtering_on_relationships(self, mapper):
        """Test the nested properties with special names
        like "structures", "references" etc. are applied
        to the relationships field.

        """
        from optimade.filtertransformers.mongo import MongoTransformer

        t = MongoTransformer(mapper=mapper("StructureMapper"))
        p = LarkParser(version=self.version, variant=self.variant)

        assert t.transform(p.parse('references.id HAS "dummy/2019"')) == {
            "relationships.references.data.id": {
                "$in": ["dummy/2019"]
            }
        }

        assert t.transform(
            p.parse('structures.id HAS ANY "dummy/2019", "dijkstra1968"')) == {
                "relationships.structures.data.id": {
                    "$in": ["dummy/2019", "dijkstra1968"]
                }
            }

        assert t.transform(
            p.parse('structures.id HAS ALL "dummy/2019", "dijkstra1968"')) == {
                "relationships.structures.data.id": {
                    "$all": ["dummy/2019", "dijkstra1968"]
                }
            }

        assert t.transform(p.parse('structures.id HAS ONLY "dummy/2019"')) == {
            "$and": [
                {
                    "relationships.structures.data": {
                        "$size": 1
                    }
                },
                {
                    "relationships.structures.data.id": {
                        "$all": ["dummy/2019"]
                    }
                },
            ]
        }

        assert t.transform(
            p.parse(
                'structures.id HAS ONLY "dummy/2019" AND structures.id HAS "dummy/2019"'
            )) == {
                "$and": [
                    {
                        "$and": [
                            {
                                "relationships.structures.data": {
                                    "$size": 1,
                                }
                            },
                            {
                                "relationships.structures.data.id": {
                                    "$all": ["dummy/2019"]
                                }
                            },
                        ]
                    },
                    {
                        "relationships.structures.data.id": {
                            "$in": ["dummy/2019"]
                        }
                    },
                ],
            }
コード例 #12
0
class MongoCollection(EntryCollection):
    def __init__(
        self,
        collection: Union[pymongo.collection.Collection,
                          mongomock.collection.Collection],
        resource_cls: EntryResource,
        resource_mapper: ResourceMapper,
    ):
        super().__init__(collection, resource_cls, resource_mapper)
        self.transformer = MongoTransformer()

        self.provider = CONFIG.provider["prefix"]
        self.provider_fields = CONFIG.provider_fields.get(
            resource_mapper.ENDPOINT, [])
        self.parser = LarkParser(
            version=(0, 10, 1), variant="default"
        )  # The MongoTransformer only supports v0.10.1 as the latest grammar

    def __len__(self):
        return self.collection.estimated_document_count()

    def __contains__(self, entry):
        return self.collection.count_documents(entry.dict()) > 0

    def count(self, **kwargs):
        for k in list(kwargs.keys()):
            if k not in ("filter", "skip", "limit", "hint", "maxTimeMS"):
                del kwargs[k]
        if "filter" not in kwargs:  # "filter" is needed for count_documents()
            kwargs["filter"] = {}
        return self.collection.count_documents(**kwargs)

    def find(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Tuple[List[EntryResource], NonnegativeInt, bool, set]:
        criteria = self._parse_params(params)

        all_fields = criteria.pop("fields")
        if getattr(params, "response_fields", False):
            fields = set(params.response_fields.split(","))
        else:
            fields = all_fields.copy()

        results = []
        for doc in self.collection.find(**criteria):
            results.append(
                self.resource_cls(**self.resource_mapper.map_back(doc)))

        nresults_now = len(results)
        if isinstance(params, EntryListingQueryParams):
            criteria_nolimit = criteria.copy()
            criteria_nolimit.pop("limit", None)
            data_returned = self.count(**criteria_nolimit)
            more_data_available = nresults_now < data_returned
        else:
            # SingleEntryQueryParams, e.g., /structures/{entry_id}
            data_returned = nresults_now
            more_data_available = False
            if nresults_now > 1:
                raise HTTPException(
                    status_code=404,
                    detail=
                    f"Instead of a single entry, {nresults_now} entries were found",
                )
            results = results[0] if results else None

        return results, data_returned, more_data_available, all_fields - fields

    def _alias_filter(self, filter_: dict) -> dict:
        res = {}
        for key, value in filter_.items():
            if key in ["$and", "$or"]:
                res[key] = [self._alias_filter(item) for item in value]
            else:
                new_value = value
                if isinstance(value, dict):
                    new_value = self._alias_filter(value)
                res[self.resource_mapper.alias_for(key)] = new_value
        return res

    def _parse_params(
        self, params: Union[EntryListingQueryParams,
                            SingleEntryQueryParams]) -> dict:
        cursor_kwargs = {}

        if getattr(params, "filter", False):
            tree = self.parser.parse(params.filter)
            mongo_filter = self.transformer.transform(tree)
            cursor_kwargs["filter"] = self._alias_filter(mongo_filter)
        else:
            cursor_kwargs["filter"] = {}

        if (getattr(params, "response_format", False)
                and params.response_format != "json"):
            raise HTTPException(status_code=400,
                                detail="Only 'json' response_format supported")

        if getattr(params, "page_limit", False):
            limit = params.page_limit
            if limit > CONFIG.page_limit_max:
                raise HTTPException(
                    status_code=403,  # Forbidden
                    detail=
                    f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
                )
            cursor_kwargs["limit"] = limit
        else:
            cursor_kwargs["limit"] = CONFIG.page_limit

        # All OPTiMaDe fields
        fields = self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy()
        fields |= self.get_attribute_fields()
        # All provider-specific fields
        fields |= {self.provider + _ for _ in self.provider_fields}
        cursor_kwargs["fields"] = fields
        cursor_kwargs["projection"] = [
            self.resource_mapper.alias_for(f) for f in fields
        ]

        if getattr(params, "sort", False):
            sort_spec = []
            for elt in params.sort.split(","):
                field = elt
                sort_dir = 1
                if elt.startswith("-"):
                    field = field[1:]
                    sort_dir = -1
                sort_spec.append((field, sort_dir))
            cursor_kwargs["sort"] = sort_spec

        if getattr(params, "page_offset", False):
            cursor_kwargs["skip"] = params.page_offset

        return cursor_kwargs
コード例 #13
0
    def test_filtering_on_relationships(self, mapper):
        """Test the nested properties with special names
        like "structures", "references" etc. are applied
        to the relationships field.

        """
        from optimade.filtertransformers.mongo import MongoTransformer

        t = MongoTransformer(mapper=mapper("StructureMapper"))
        p = LarkParser(version=self.version, variant=self.variant)

        assert t.transform(p.parse('references.id HAS "dummy/2019"')) == {
            "relationships.references.data.id": {
                "$in": ["dummy/2019"]
            }
        }

        assert t.transform(
            p.parse('structures.id HAS ANY "dummy/2019", "dijkstra1968"')) == {
                "relationships.structures.data.id": {
                    "$in": ["dummy/2019", "dijkstra1968"]
                }
            }

        assert t.transform(
            p.parse('structures.id HAS ALL "dummy/2019", "dijkstra1968"')) == {
                "relationships.structures.data.id": {
                    "$all": ["dummy/2019", "dijkstra1968"]
                }
            }

        assert t.transform(p.parse('structures.id HAS ONLY "dummy/2019"')) == {
            "$and": [
                {
                    "relationships.structures.data": {
                        "$not": {
                            "$elemMatch": {
                                "id": {
                                    "$nin": ["dummy/2019"]
                                }
                            }
                        }
                    }
                },
                {
                    "relationships.structures.data.0": {
                        "$exists": True
                    }
                },
            ]
        }

        assert t.transform(
            p.parse(
                'structures.id HAS ONLY "dummy/2019" AND structures.id HAS "dummy/2019"'
            )) == {
                "$and": [
                    {
                        "$and": [
                            {
                                "relationships.structures.data": {
                                    "$not": {
                                        "$elemMatch": {
                                            "id": {
                                                "$nin": ["dummy/2019"]
                                            }
                                        }
                                    }
                                }
                            },
                            {
                                "relationships.structures.data.0": {
                                    "$exists": True
                                }
                            },
                        ]
                    },
                    {
                        "relationships.structures.data.id": {
                            "$in": ["dummy/2019"]
                        }
                    },
                ]
            }

        with pytest.raises(
                NotImplementedError,
                match=
                'Cannot filter relationships by field "doi", only "id" is supported.',
        ):
            assert t.transform(
                p.parse(
                    'references.doi HAS ONLY "10.123/12345" AND structures.id HAS "dummy/2019"'
                )) == {
                    "$and": [
                        {
                            "$and": [
                                {
                                    "relationships.references.data": {
                                        "$not": {
                                            "$elemMatch": {
                                                "doi": {
                                                    "$nin": ["10.123/12345"]
                                                }
                                            }
                                        }
                                    }
                                },
                                {
                                    "relationships.references.data.0": {
                                        "$exists": True
                                    }
                                },
                            ]
                        },
                        {
                            "relationships.structures.data.id": {
                                "$in": ["dummy/2019"]
                            }
                        },
                    ]
                }
コード例 #14
0
    def set_up(self):
        from optimade.filtertransformers.mongo import MongoTransformer

        p = LarkParser(version=self.version, variant=self.variant)
        t = MongoTransformer()
        self.transform = lambda inp: t.transform(p.parse(inp))
コード例 #15
0
class TestParserV1_0_0:
    version = (1, 0, 0)
    variant = "default"

    @pytest.fixture(autouse=True)
    def set_up(self):
        self.parser = LarkParser(version=self.version, variant=self.variant)

    def parse(self, inp):
        return self.parser.parse(inp)

    def test_empty(self):
        assert isinstance(self.parse(" "), Tree)

    def test_property_names(self):
        assert isinstance(self.parse("band_gap = 1"), Tree)
        assert isinstance(self.parse("cell_length_a = 1"), Tree)
        assert isinstance(self.parse("cell_volume = 1"), Tree)

        with pytest.raises(ParserError):
            self.parse("0_kvak IS KNOWN")  # starts with a number

        with pytest.raises(ParserError):
            self.parse('"foo bar" IS KNOWN')  # contains space; contains quotes

        with pytest.raises(ParserError):
            self.parse("BadLuck IS KNOWN")  # contains upper-case letters

        # database-provider-specific prefixes
        assert isinstance(self.parse("_exmpl_formula_sum = 1"), Tree)
        assert isinstance(self.parse("_exmpl_band_gap = 1"), Tree)

        # Nested property names
        assert isinstance(self.parse("identifier1.identifierd2 = 42"), Tree)

    def test_string_values(self):
        assert isinstance(self.parse('author="Sąžininga Žąsis"'), Tree)
        assert isinstance(
            self.parse('field = "!#$%&\'() * +, -./:; <= > ? @[] ^ `{|}~ % "'), Tree
        )

    def test_number_values(self):
        assert isinstance(self.parse("a = 12345"), Tree)
        assert isinstance(self.parse("b = +12"), Tree)
        assert isinstance(self.parse("c = -34"), Tree)
        assert isinstance(self.parse("d = 1.2"), Tree)
        assert isinstance(self.parse("e = .2E7"), Tree)
        assert isinstance(self.parse("f = -.2E+7"), Tree)
        assert isinstance(self.parse("g = +10.01E-10"), Tree)
        assert isinstance(self.parse("h = 6.03e23"), Tree)
        assert isinstance(self.parse("i = .1E1"), Tree)
        assert isinstance(self.parse("j = -.1e1"), Tree)
        assert isinstance(self.parse("k = 1.e-12"), Tree)
        assert isinstance(self.parse("l = -.1e-12"), Tree)
        assert isinstance(self.parse("m = 1000000000.E1000000000"), Tree)

        with pytest.raises(ParserError):
            self.parse("number=1.234D12")
        with pytest.raises(ParserError):
            self.parse("number=.e1")
        with pytest.raises(ParserError):
            self.parse("number= -.E1")
        with pytest.raises(ParserError):
            self.parse("number=+.E2")
        with pytest.raises(ParserError):
            self.parse("number=1.23E+++")
        with pytest.raises(ParserError):
            self.parse("number=+-123")
        with pytest.raises(ParserError):
            self.parse("number=0.0.1")

    def test_operators(self):
        # Basic boolean operations
        assert isinstance(
            self.parse(
                'NOT ( chemical_formula_hill = "Al" AND chemical_formula_anonymous = "A" OR '
                'chemical_formula_anonymous = "H2O" AND NOT chemical_formula_hill = "Ti" )'
            ),
            Tree,
        )

        # Numeric and String comparisons
        assert isinstance(self.parse("nelements > 3"), Tree)
        assert isinstance(
            self.parse(
                'chemical_formula_hill = "H2O" AND chemical_formula_anonymous != "AB"'
            ),
            Tree,
        )
        assert isinstance(
            self.parse(
                "_exmpl_aax <= +.1e8 OR nelements >= 10 AND "
                'NOT ( _exmpl_x != "Some string" OR NOT _exmpl_a = 7)'
            ),
            Tree,
        )
        assert isinstance(self.parse('_exmpl_spacegroup="P2"'), Tree)
        assert isinstance(self.parse("_exmpl_cell_volume<100.0"), Tree)
        assert isinstance(
            self.parse("_exmpl_bandgap > 5.0 AND _exmpl_molecular_weight < 350"), Tree
        )
        assert isinstance(
            self.parse('_exmpl_melting_point<300 AND nelements=4 AND elements="Si,O2"'),
            Tree,
        )
        assert isinstance(self.parse("_exmpl_some_string_property = 42"), Tree)
        assert isinstance(self.parse("5 < _exmpl_a"), Tree)

        # OPTIONAL
        assert isinstance(
            self.parse("((NOT (_exmpl_a>_exmpl_b)) AND _exmpl_x>0)"), Tree
        )
        assert isinstance(self.parse("5 < 7"), Tree)

    def test_id(self):
        assert isinstance(self.parse('id="example/1"'), Tree)
        assert isinstance(self.parse('"example/1" = id'), Tree)
        assert isinstance(self.parse('id="test/2" OR "example/1" = id'), Tree)

    def test_string_operations(self):
        #  Substring comparisons
        assert isinstance(
            self.parse(
                'chemical_formula_anonymous CONTAINS "C2" AND '
                'chemical_formula_anonymous STARTS WITH "A2"'
            ),
            Tree,
        )
        assert isinstance(
            self.parse(
                'chemical_formula_anonymous STARTS "B2" AND '
                'chemical_formula_anonymous ENDS WITH "D2"'
            ),
            Tree,
        )

    def test_list_properties(self):
        # Comparisons of list properties
        assert isinstance(self.parse("list HAS < 3"), Tree)
        assert isinstance(self.parse("list HAS ALL < 3, > 3"), Tree)
        assert isinstance(self.parse("list:list HAS >=2:<=5"), Tree)
        assert isinstance(
            self.parse(
                'elements HAS "H" AND elements HAS ALL "H","He","Ga","Ta" AND elements HAS '
                'ONLY "H","He","Ga","Ta" AND elements HAS ANY "H", "He", "Ga", "Ta"'
            ),
            Tree,
        )

        # OPTIONAL:
        assert isinstance(self.parse('elements HAS ONLY "H","He","Ga","Ta"'), Tree)
        assert isinstance(self.parse('elements HAS ALL "H","He","Ga","Ta"'), Tree)
        assert isinstance(self.parse('elements HAS ANY "H","He","Ga","Ta"'), Tree)
        assert isinstance(
            self.parse(
                'elements:_exmpl_element_counts HAS "H":6 AND '
                'elements:_exmpl_element_counts HAS ALL "H":6,"He":7 AND '
                'elements:_exmpl_element_counts HAS ONLY "H":6 AND '
                'elements:_exmpl_element_counts HAS ANY "H":6,"He":7 AND '
                'elements:_exmpl_element_counts HAS ONLY "H":6,"He":7'
            ),
            Tree,
        )
        assert isinstance(
            self.parse(
                "_exmpl_element_counts HAS < 3 AND "
                "_exmpl_element_counts HAS ANY > 3, = 6, 4, != 8"
            ),
            Tree,
        )
        assert isinstance(
            self.parse(
                "elements:_exmpl_element_counts:_exmpl_element_weights "
                'HAS ANY > 3:"He":>55.3 , = 6:>"Ti":<37.6 , 8:<"Ga":0'
            ),
            Tree,
        )

    def test_properties(self):
        #  Filtering on Properties with unknown value
        assert isinstance(
            self.parse(
                "chemical_formula_hill IS KNOWN AND "
                "NOT chemical_formula_anonymous IS UNKNOWN"
            ),
            Tree,
        )

    def test_precedence(self):
        assert isinstance(self.parse('NOT a > b OR c = 100 AND f = "C2 H6"'), Tree)
        assert isinstance(
            self.parse('(NOT (a > b)) OR ( (c = 100) AND (f = "C2 H6") )'), Tree
        )
        assert isinstance(self.parse("a >= 0 AND NOT b < c OR c = 0"), Tree)
        assert isinstance(self.parse("((a >= 0) AND (NOT (b < c))) OR (c = 0)"), Tree)

    def test_special_cases(self):
        assert isinstance(self.parse("te < st"), Tree)
        assert isinstance(self.parse('spacegroup="P2"'), Tree)
        assert isinstance(self.parse("_cod_cell_volume<100.0"), Tree)
        assert isinstance(
            self.parse("_mp_bandgap > 5.0 AND _cod_molecular_weight < 350"), Tree
        )
        assert isinstance(
            self.parse('_cod_melting_point<300 AND nelements=4 AND elements="Si,O2"'),
            Tree,
        )
        assert isinstance(self.parse("key=value"), Tree)
        assert isinstance(self.parse('author=" someone "'), Tree)
        assert isinstance(self.parse('author=" som\neone "'), Tree)
        assert isinstance(
            self.parse(
                "number=0.ANDnumber=.0ANDnumber=0.0ANDnumber=+0AND_n_u_m_b_e_r_=-0AND"
                "number=0e1ANDnumber=0e-1ANDnumber=0e+1"
            ),
            Tree,
        )

        assert isinstance(
            self.parse("NOTice=val"), Tree
        )  # property (ice) != property (val)
        assert isinstance(
            self.parse('NOTice="val"'), Tree
        )  # property (ice) != value ("val")
        assert isinstance(
            self.parse('"NOTice"=val'), Tree
        )  # value ("NOTice") = property (val)

        with pytest.raises(ParserError):
            self.parse("NOTICE=val")  # not valid property or value (NOTICE)
        with pytest.raises(ParserError):
            self.parse('"NOTICE"=Val')  # not valid property (Val)
        with pytest.raises(ParserError):
            self.parse("NOTICE=val")  # not valid property or value (NOTICE)

    def test_parser_version(self):
        assert self.parser.version == self.version
        assert self.parser.variant == self.variant

    def test_repr(self):
        assert repr(self.parser) is not None
        self.parser.parse('key="value"')
        assert repr(self.parser) is not None
コード例 #16
0
 def setUp(self):
     p = LarkParser(version=self.version, variant=self.variant)
     t = MongoTransformer()
     self.transform = lambda inp: t.transform(p.parse(inp))
コード例 #17
0
class Lark2Django:
    def __init__(self):
        self.opers = {
            "=": self.eq,
            ">": self.gt,
            ">=": self.ge,
            "<": self.lt,
            "<=": self.le,
            "!=": self.ne,
            "OR": self.or_,
            "AND": self.and_,
            "NOT": self.not_,
        }
        self.parser = LarkParser(version=(0, 9, 7))

    def parse_raw_q(self, raw_query):
        return self.parser.parse(raw_query)

    def eq(self, a, b):
        return Q(**{a: b})

    def gt(self, a, b):
        return Q(**{a + "__gt": b})

    def ge(self, a, b):
        return Q(**{a + "__gte": b})

    def lt(self, a, b):
        return Q(**{a + "__lt": b})

    def le(self, a, b):
        return Q(**{a + "__lte": b})

    def ne(self, a, b):
        return ~Q(**{a: b})

    def not_(self, a):
        return ~a

    def and_(self, a, b):
        return operator.and_(a, b)

    def or_(self, a, b):
        return operator.or_(a, b)

    def evaluate(self, parse_Tree):
        if isinstance(parse_Tree, Tree):
            children = parse_Tree.children
            if len(children) == 1:
                return self.evaluate(children[0])
            elif len(children) == 2:
                op_fn = self.evaluate(children[0])
                return op_fn(self.evaluate(children[1]))
            elif len(children) == 3:
                if parse_Tree.data == "comparison":
                    db_prop = self.evaluate(children[0])
                    op_fn = self.evaluate(children[1])

                    if db_prop in django_db_keys.keys():
                        return op_fn(
                            django_db_keys[db_prop], self.evaluate(children[2])
                        )
                    else:
                        raise DjangoQueryError(
                            "Unknown property is queried : " + (db_prop)
                        )

                else:
                    op_fn = self.evaluate(children[1])
                    return op_fn(self.evaluate(children[0]), self.evaluate(children[2]))
            else:
                raise DjangoQueryError("Not compatible format. Tree has >3 children")

        elif isinstance(parse_Tree, Token):
            if parse_Tree.type == "VALUE":
                return parse_Tree.value
            elif parse_Tree.type in ["NOT", "CONJUNCTION", "OPERATOR"]:
                return self.opers[parse_Tree.value]
        else:
            raise DjangoQueryError("Not a Lark Tree or Token")
コード例 #18
0
class EntryCollection(ABC):
    """Backend-agnostic base class for querying collections of
    [`EntryResource`][optimade.models.entries.EntryResource]s."""
    def __init__(
        self,
        collection,
        resource_cls: EntryResource,
        resource_mapper: BaseResourceMapper,
        transformer: Transformer,
    ):
        """Initialize the collection for the given parameters.

        Parameters:
            collection: The backend-specific collection.
            resource_cls (EntryResource): The `EntryResource` model
                that is stored by the collection.
            resource_mapper (BaseResourceMapper): A resource mapper
                object that handles aliases and format changes between
                deserialization and response.
            transformer (Transformer): The Lark `Transformer` used to
                interpret the filter.

        """
        self.collection = collection
        self.parser = LarkParser()
        self.resource_cls = resource_cls
        self.resource_mapper = resource_mapper
        self.transformer = transformer

        self.provider_prefix = CONFIG.provider.prefix
        self.provider_fields = CONFIG.provider_fields.get(
            resource_mapper.ENDPOINT, [])

    @abstractmethod
    def __len__(self) -> int:
        """ Returns the total number of entries in the collection. """

    @abstractmethod
    def count(self, **kwargs) -> int:
        """Returns the number of entries matching the query specified
        by the keyword arguments.

        Parameters:
            kwargs (dict): Query parameters as keyword arguments.

        """

    @abstractmethod
    def find(
        self, params: EntryListingQueryParams
    ) -> Tuple[List[EntryResource], int, bool, set]:
        """
        Fetches results and indicates if more data is available.

        Also gives the total number of data available in the absence of `page_limit`.
        See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams]
        for more information.

        Parameters:
            params (EntryListingQueryParams): entry listing URL query params

        Returns:
            (`results`, `data_returned`, `more_data_available`, `fields`).

        """

    @property
    def all_fields(self) -> set:
        """Get the set of all fields handled in this collection,
        from attribute fields in the schema, provider fields and top-level OPTIMADE fields.

        Returns:
            All fields handled in this collection.

        """
        # All OPTIMADE fields
        fields = self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy()
        fields |= self.get_attribute_fields()
        # All provider-specific fields
        fields |= {
            f"_{self.provider_prefix}_{field_name}"
            for field_name in self.provider_fields
        }

        return fields

    def get_attribute_fields(self) -> set:
        """Get the set of attribute fields from the schema of the
        resource class, resolving references along the way.

        Returns:
            Property names.

        """
        schema = self.resource_cls.schema()
        attributes = schema["properties"]["attributes"]
        if "allOf" in attributes:
            allOf = attributes.pop("allOf")
            for dict_ in allOf:
                attributes.update(dict_)
        if "$ref" in attributes:
            path = attributes["$ref"].split("/")[1:]
            attributes = schema.copy()
            while path:
                next_key = path.pop(0)
                attributes = attributes[next_key]
        return set(attributes["properties"].keys())

    def handle_query_params(
        self, params: Union[EntryListingQueryParams,
                            SingleEntryQueryParams]) -> dict:
        """Parse and interpret the backend-agnostic query parameter models into a dictionary
        that can be used by the specific backend.

        Note:
            Currently this method returns the pymongo interpretation of the parameters,
            which will need modification for modified for other backends.

        Parameters:
            params (Union[EntryListingQueryParams, SingleEntryQueryParams]): The initialized query parameter model from the server.

        Raises:
            Forbidden: If too large of a page limit is provided.
            BadRequest: If an invalid request is made, e.g., with incorrect fields
                or response format.

        Returns:
            A dictionary representation of the query parameters, ready to be used by pymongo.

        """
        cursor_kwargs = {}

        if getattr(params, "filter", False):
            tree = self.parser.parse(params.filter)
            cursor_kwargs["filter"] = self.transformer.transform(tree)
        else:
            cursor_kwargs["filter"] = {}

        if (getattr(params, "response_format", False)
                and params.response_format != "json"):
            raise BadRequest(
                detail=
                f"Response format {params.response_format} is not supported, please use response_format='json'"
            )

        if getattr(params, "page_limit", False):
            limit = params.page_limit
            if limit > CONFIG.page_limit_max:
                raise Forbidden(
                    detail=
                    f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
                )
            cursor_kwargs["limit"] = limit
        else:
            cursor_kwargs["limit"] = CONFIG.page_limit

        cursor_kwargs["fields"] = self.all_fields
        cursor_kwargs["projection"] = [
            self.resource_mapper.alias_for(f) for f in self.all_fields
        ]

        if getattr(params, "sort", False):
            cursor_kwargs["sort"] = self.parse_sort_params(params.sort)

        if getattr(params, "page_offset", False):
            cursor_kwargs["skip"] = params.page_offset

        return cursor_kwargs

    def parse_sort_params(self, sort_params) -> List[Tuple[str, int]]:
        """Handles any sort parameters passed to the collection,
        resolving aliases and dealing with any invalid fields.

        Raises:
            BadRequest: if an invalid sort is requested.

        Returns:
            A list of tuples containing the aliased field name and
            sort direction encoded as 1 (ascending) or -1 (descending).

        """
        sort_spec = []
        for field in sort_params.split(","):
            sort_dir = 1
            if field.startswith("-"):
                field = field[1:]
                sort_dir = -1
            aliased_field = self.resource_mapper.alias_for(field)
            sort_spec.append((aliased_field, sort_dir))

        unknown_fields = [
            field for field, _ in sort_spec
            if self.resource_mapper.alias_of(field) not in self.all_fields
        ]

        if unknown_fields:
            error_detail = "Unable to sort on unknown field{} '{}'".format(
                "s" if len(unknown_fields) > 1 else "",
                "', '".join(unknown_fields),
            )

            # If all unknown fields are "other" provider-specific, then only provide a warning
            if all((re.match(r"_[a-z_0-9]+_[a-z_0-9]*", field)
                    and not field.startswith(f"_{self.provider_prefix}_"))
                   for field in unknown_fields):
                warnings.warn(error_detail, FieldValueNotRecognized)

            # Otherwise, if all fields are unknown, or some fields are unknown and do not
            # have other provider prefixes, then return 400: Bad Request
            else:
                raise BadRequest(detail=error_detail)

        # If at least one valid field has been provided for sorting, then use that
        sort_spec = tuple((field, sort_dir) for field, sort_dir in sort_spec
                          if field not in unknown_fields)

        return sort_spec
コード例 #19
0
class MongoCollection(EntryCollection):
    def __init__(
        self,
        collection: Union[pymongo.collection.Collection,
                          mongomock.collection.Collection],
        resource_cls: EntryResource,
        resource_mapper: BaseResourceMapper,
    ):
        super().__init__(collection, resource_cls, resource_mapper)
        self.transformer = MongoTransformer(mapper=resource_mapper)

        self.provider_prefix = CONFIG.provider.prefix
        self.provider_fields = CONFIG.provider_fields.get(
            resource_mapper.ENDPOINT, [])
        self.parser = LarkParser(
            version=(0, 10, 1), variant="default"
        )  # The MongoTransformer only supports v0.10.1 as the latest grammar

        # check aliases do not clash with mongo operators
        self._check_aliases(self.resource_mapper.all_aliases())
        self._check_aliases(self.resource_mapper.all_length_aliases())

    def __len__(self):
        return self.collection.estimated_document_count()

    def __contains__(self, entry):
        return self.collection.count_documents(entry.dict()) > 0

    def count(self, **kwargs):
        for k in list(kwargs.keys()):
            if k not in ("filter", "skip", "limit", "hint", "maxTimeMS"):
                del kwargs[k]
        if "filter" not in kwargs:  # "filter" is needed for count_documents()
            kwargs["filter"] = {}
        return self.collection.count_documents(**kwargs)

    def find(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Tuple[List[EntryResource], int, bool, set]:
        criteria = self._parse_params(params)

        all_fields = criteria.pop("fields")
        if getattr(params, "response_fields", False):
            fields = set(params.response_fields.split(","))
            fields |= self.resource_mapper.get_required_fields()
        else:
            fields = all_fields.copy()

        results = []
        for doc in self.collection.find(**criteria):
            results.append(
                self.resource_cls(**self.resource_mapper.map_back(doc)))

        nresults_now = len(results)
        if isinstance(params, EntryListingQueryParams):
            criteria_nolimit = criteria.copy()
            criteria_nolimit.pop("limit", None)
            data_returned = self.count(**criteria_nolimit)
            more_data_available = nresults_now < data_returned
        else:
            # SingleEntryQueryParams, e.g., /structures/{entry_id}
            data_returned = nresults_now
            more_data_available = False
            if nresults_now > 1:
                raise HTTPException(
                    status_code=404,
                    detail=
                    f"Instead of a single entry, {nresults_now} entries were found",
                )
            results = results[0] if results else None

        return results, data_returned, more_data_available, all_fields - fields

    def _parse_params(
        self, params: Union[EntryListingQueryParams,
                            SingleEntryQueryParams]) -> dict:
        cursor_kwargs = {}

        if getattr(params, "filter", False):
            tree = self.parser.parse(params.filter)
            cursor_kwargs["filter"] = self.transformer.transform(tree)
        else:
            cursor_kwargs["filter"] = {}

        if (getattr(params, "response_format", False)
                and params.response_format != "json"):
            raise HTTPException(status_code=400,
                                detail="Only 'json' response_format supported")

        if getattr(params, "page_limit", False):
            limit = params.page_limit
            if limit > CONFIG.page_limit_max:
                raise HTTPException(
                    status_code=403,  # Forbidden
                    detail=
                    f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
                )
            cursor_kwargs["limit"] = limit
        else:
            cursor_kwargs["limit"] = CONFIG.page_limit

        # All OPTIMADE fields
        fields = self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy()
        fields |= self.get_attribute_fields()
        # All provider-specific fields
        fields |= {
            f"_{self.provider_prefix}_{field_name}"
            for field_name in self.provider_fields
        }
        cursor_kwargs["fields"] = fields
        cursor_kwargs["projection"] = [
            self.resource_mapper.alias_for(f) for f in fields
        ]

        if getattr(params, "sort", False):
            sort_spec = []
            for elt in params.sort.split(","):
                field = elt
                sort_dir = 1
                if elt.startswith("-"):
                    field = field[1:]
                    sort_dir = -1
                sort_spec.append((field, sort_dir))
            cursor_kwargs["sort"] = sort_spec

        if getattr(params, "page_offset", False):
            cursor_kwargs["skip"] = params.page_offset

        return cursor_kwargs

    def _check_aliases(self, aliases):
        """ Check that aliases do not clash with mongo keywords. """
        if any(alias[0].startswith("$") or alias[1].startswith("$")
               for alias in aliases):
            raise RuntimeError(
                f"Cannot define an alias starting with a '$': {aliases}")
コード例 #20
0
class AiidaCollection:
    """Collection of AiiDA entities"""

    CAST_MAPPING = {
        "string": "t",
        "float": "f",
        "integer": "i",
        "boolean": "b",
        "date-time": "d",
    }

    def __init__(
        self,
        collection: orm.entities.Collection,
        resource_cls: EntryResource,
        resource_mapper: ResourceMapper,
    ):
        self.collection = collection
        self.parser = LarkParser()
        self.resource_cls = resource_cls
        self.resource_mapper = resource_mapper

        self.transformer = AiidaTransformerV0_10_1()
        self.provider = CONFIG.provider["prefix"]
        self.provider_fields = CONFIG.provider_fields[resource_mapper.ENDPOINT]
        self.page_limit = CONFIG.page_limit
        self.db_page_limit = CONFIG.db_page_limit
        self.parser = LarkParser(version=(0, 10, 0))

        # "Cache"
        self._data_available: int = None
        self._data_returned: int = None
        self._filter_fields: set = None
        self._latest_filter: dict = None

    def get_attribute_fields(self) -> set:
        schema = self.resource_cls.schema()
        attributes = schema["properties"]["attributes"]
        if "allOf" in attributes:
            allOf = attributes.pop("allOf")
            for dict_ in allOf:
                attributes.update(dict_)
        if "$ref" in attributes:
            path = attributes["$ref"].split("/")[1:]
            attributes = schema.copy()
            while path:
                next_key = path.pop(0)
                attributes = attributes[next_key]
        return set(attributes["properties"].keys())

    @staticmethod
    def _find(backend: orm.implementation.Backend, entity_type: orm.Entity,
              **kwargs) -> orm.QueryBuilder:
        for key in kwargs:
            if key not in {
                    "filters", "order_by", "limit", "project", "offset"
            }:
                raise ValueError(
                    f"You supplied key {key}. _find() only takes the keys: "
                    '"filters", "order_by", "limit", "project", "offset"')

        filters = kwargs.get("filters", {})
        order_by = kwargs.get("order_by", None)
        order_by = {
            entity_type: order_by
        } if order_by else {
            entity_type: {
                "id": "asc"
            }
        }
        limit = kwargs.get("limit", None)
        offset = kwargs.get("offset", None)
        project = kwargs.get("project", [])

        query = orm.QueryBuilder(backend=backend, limit=limit, offset=offset)
        query.append(entity_type, project=project, filters=filters)
        query.order_by(order_by)

        return query

    def _find_all(self, backend: orm.implementation.Backend,
                  **kwargs) -> orm.QueryBuilder:
        query = self._find(backend, self.collection.entity_type, **kwargs)
        res = query.all()
        del query
        return res

    def count(self, backend: orm.implementation.Backend, **kwargs):  # pylint: disable=arguments-differ
        query = self._find(backend, self.collection.entity_type, **kwargs)
        res = query.count()
        del query
        return res

    @property
    def data_available(self) -> int:
        if self._data_available is None:
            raise CausationError(
                "data_available MUST be set before it can be retrieved.")
        return self._data_available

    def set_data_available(self, backend: orm.implementation.Backend):
        """Set _data_available if it has not yet been set"""
        if not self._data_available:
            self._data_available = self.count(backend)

    @property
    def data_returned(self) -> int:
        if self._data_returned is None:
            raise CausationError(
                "data_returned MUST be set before it can be retrieved.")
        return self._data_returned

    def set_data_returned(self, backend: orm.implementation.Backend,
                          **criteria):
        """Set _data_returned if it has not yet been set or new filter does not equal latest filter.

        NB! Nested lists in filters are not accounted for.
        """
        if self._data_returned is None or (
                self._latest_filter is not None
                and criteria.get("filters", {}) != self._latest_filter):
            for key in ["limit", "offset"]:
                if key in list(criteria.keys()):
                    del criteria[key]
            self._latest_filter = criteria.get("filters", {})
            self._data_returned = self.count(backend, **criteria)

    def find(  # pylint: disable=arguments-differ
        self,
        backend: orm.implementation.Backend,
        params: Union[EntryListingQueryParams, SingleEntryQueryParams],
    ) -> Tuple[List[EntryResource], NonnegativeInt, bool, NonnegativeInt, set]:
        self.set_data_available(backend)
        criteria = self._parse_params(params)

        all_fields = criteria.pop("fields")
        if getattr(params, "response_fields", False):
            fields = set(params.response_fields.split(","))
        else:
            fields = all_fields.copy()

        if criteria.get("filters", {}) and self._get_extras_filter_fields():
            self._check_and_calculate_entities(backend)

        self.set_data_returned(backend, **criteria)

        entities = self._find_all(backend, **criteria)
        results = []
        for entity in entities:
            results.append(
                self.resource_cls(**self.resource_mapper.map_back(
                    dict(zip(criteria["project"], entity)))))

        if isinstance(params, EntryListingQueryParams):
            criteria_no_limit = criteria.copy()
            criteria_no_limit.pop("limit", None)
            more_data_available = len(results) < self.count(
                backend, **criteria_no_limit)
        else:
            more_data_available = False
            if len(results) > 1:
                raise HTTPException(
                    status_code=404,
                    detail=
                    f"Instead of a single entry, {len(results)} entries were found",
                )

        if isinstance(params, SingleEntryQueryParams):
            results = results[0] if results else None

        return (
            results,
            self.data_returned,
            more_data_available,
            self.data_available,
            all_fields - fields,
        )

    def _alias_filter(self, filters: Any) -> Union[dict, list]:
        """Get aliased field names in nested filter query.

        I.e. turn OPTiMaDe field names into AiiDA field names
        """
        if isinstance(filters, dict):
            res = {}
            for key, value in filters.items():
                new_value = value
                if isinstance(value, (dict, list)):
                    new_value = self._alias_filter(value)
                aliased_key = self.resource_mapper.alias_for(key)
                res[aliased_key] = new_value
                self._filter_fields.add(aliased_key)
        elif isinstance(filters, list):
            res = []
            for item in filters:
                new_value = item
                if isinstance(item, (dict, list)):
                    new_value = self._alias_filter(item)
                res.append(new_value)
        else:
            raise NotImplementedError(
                "_alias_filter can only handle dict and list objects")
        return res

    def _parse_params(self, params: EntryListingQueryParams) -> dict:
        """Parse query parameters and transform them into AiiDA QueryBuilder concepts"""
        cursor_kwargs = {}

        # filter
        if getattr(params, "filter", False):
            aiida_filter = self.transformer.transform(
                self.parser.parse(params.filter))
            self._filter_fields = set()
            cursor_kwargs["filters"] = self._alias_filter(aiida_filter)

        # response_format
        if (getattr(params, "response_format", False)
                and params.response_format != "json"):
            raise HTTPException(status_code=400,
                                detail="Only 'json' response_format supported")

        # page_limit
        if getattr(params, "page_limit", False):
            limit = self.page_limit
            if params.page_limit != self.page_limit:
                limit = params.page_limit
            if limit > self.db_page_limit:
                raise HTTPException(
                    status_code=403,
                    detail=
                    f"Max allowed page_limit is {self.db_page_limit}, you requested {limit}",
                )
            if limit == 0:
                limit = self.page_limit
            cursor_kwargs["limit"] = limit

        # response_fields
        # All OPTiMaDe fields
        fields = {"id", "type"}
        fields |= self.get_attribute_fields()
        # All provider-specific fields
        fields |= {self.provider + _ for _ in self.provider_fields}
        cursor_kwargs["fields"] = fields
        cursor_kwargs["project"] = list(
            {self.resource_mapper.alias_for(f)
             for f in fields})

        # sort
        # NOTE: sorting only works for extras fields for the nodes already with calculated extras.
        #       To calculate all extras, make a single filter query using any extra field.
        if getattr(params, "sort", False):
            sort_spec = []
            for entity_property in params.sort.split(","):
                field = entity_property
                sort_direction = "asc"
                if entity_property.startswith("-"):
                    field = field[1:]
                    sort_direction = "desc"
                aliased_field = self.resource_mapper.alias_for(field)

                _, properties = retrieve_queryable_properties(
                    self.resource_cls.schema(), {"id", "type", "attributes"})
                field_type = properties[field].get(
                    "format", properties[field].get("type", ""))
                if field_type == "array":
                    raise TypeError(
                        "Cannot sort on a field with a list value type")

                sort_spec.append({
                    aliased_field: {
                        "order": sort_direction,
                        "cast": self.CAST_MAPPING[field_type],
                    }
                })
            cursor_kwargs["order_by"] = sort_spec

        # page_offset
        if getattr(params, "page_offset", False):
            cursor_kwargs["offset"] = params.page_offset

        return cursor_kwargs

    def _get_extras_filter_fields(self) -> set:
        return {
            field[len(self.resource_mapper.PROJECT_PREFIX):]
            for field in self._filter_fields
            if field.startswith(self.resource_mapper.PROJECT_PREFIX)
        }

    def _check_and_calculate_entities(self,
                                      backend: orm.implementation.Backend):
        """Check all entities have OPTiMaDe extras, else calculate them

        For a bit of optimization, we only care about a field if it has specifically been queried for using "filter".
        """
        extras_keys = [
            key for key in self.resource_mapper.PROJECT_PREFIX.split(".")
            if key
        ]
        filter_fields = [{
            "!has_key": field
            for field in self._get_extras_filter_fields()
        }]
        necessary_entities_qb = orm.QueryBuilder().append(
            self.collection.entity_type,
            filters={
                "or": [
                    {
                        extras_keys[0]: {
                            "!has_key": extras_keys[1]
                        }
                    },
                    {
                        ".".join(extras_keys): {
                            "or": filter_fields
                        }
                    },
                ]
            },
            project="id",
        )

        if necessary_entities_qb.count() > 0:
            # Necessary entities for the OPTiMaDe query exist with unknown OPTiMaDe fields.
            necessary_entity_ids = [
                pk[0] for pk in necessary_entities_qb.iterall()
            ]

            # Create the missing OPTiMaDe fields:
            # All OPTiMaDe fields
            fields = {"id", "type"}
            fields |= self.get_attribute_fields()
            # All provider-specific fields
            fields |= {self.provider + _ for _ in self.provider_fields}
            fields = list({self.resource_mapper.alias_for(f) for f in fields})

            entities = self._find_all(
                backend,
                filters={"id": {
                    "in": necessary_entity_ids
                }},
                project=fields)
            for entity in entities:
                self.resource_cls(
                    **self.resource_mapper.map_back(dict(zip(fields, entity))))
コード例 #21
0
 def test_parser_version(self):
     v = (0, 9, 5)
     p = LarkParser(version=v)
     self.assertIsInstance(p.parse(self.test_filters[0]), Tree)
     self.assertEqual(p.version, v)
コード例 #22
0
class DjangoTransformer:
    """Filter transformer for implementations using Django.

    !!! warning "Warning"
        This transformer is deprecated as it only supports
        the 0.9.7 grammar version, and works different to other
        filter transformers in this package.

    """
    def __init__(self):
        self.opers = {
            "=": self.eq,
            ">": self.gt,
            ">=": self.ge,
            "<": self.lt,
            "<=": self.le,
            "!=": self.ne,
            "OR": self.or_,
            "AND": self.and_,
            "NOT": self.not_,
        }
        self.parser = LarkParser(version=(0, 9, 7))

    def parse_raw_q(self, raw_query):
        return self.parser.parse(raw_query)

    def eq(self, a, b):
        return Q(**{a: b})

    def gt(self, a, b):
        return Q(**{a + "__gt": b})

    def ge(self, a, b):
        return Q(**{a + "__gte": b})

    def lt(self, a, b):
        return Q(**{a + "__lt": b})

    def le(self, a, b):
        return Q(**{a + "__lte": b})

    def ne(self, a, b):
        return ~Q(**{a: b})

    def not_(self, a):
        return ~a

    def and_(self, a, b):
        return operator.and_(a, b)

    def or_(self, a, b):
        return operator.or_(a, b)

    def evaluate(self, parse_Tree):
        if isinstance(parse_Tree, Tree):
            children = parse_Tree.children
            if len(children) == 1:
                return self.evaluate(children[0])
            elif len(children) == 2:
                op_fn = self.evaluate(children[0])
                return op_fn(self.evaluate(children[1]))
            elif len(children) == 3:
                if parse_Tree.data == "comparison":
                    db_prop = self.evaluate(children[0])
                    op_fn = self.evaluate(children[1])

                    if db_prop in django_db_keys.keys():
                        return op_fn(django_db_keys[db_prop],
                                     self.evaluate(children[2]))
                    else:
                        raise DjangoQueryError(
                            "Unknown property is queried : " + (db_prop))

                else:
                    op_fn = self.evaluate(children[1])
                    return op_fn(self.evaluate(children[0]),
                                 self.evaluate(children[2]))
            else:
                raise DjangoQueryError(
                    "Not compatible format. Tree has >3 children")

        elif isinstance(parse_Tree, Token):
            if parse_Tree.type == "VALUE":
                return parse_Tree.value
            elif parse_Tree.type in ["NOT", "CONJUNCTION", "OPERATOR"]:
                return self.opers[parse_Tree.value]
        else:
            raise DjangoQueryError("Not a Lark Tree or Token")
コード例 #23
0
class EntryCollection(ABC):
    """Backend-agnostic base class for querying collections of
    [`EntryResource`][optimade.models.entries.EntryResource]s."""
    def __init__(
        self,
        resource_cls: EntryResource,
        resource_mapper: BaseResourceMapper,
        transformer: Transformer,
    ):
        """Initialize the collection for the given parameters.

        Parameters:
            resource_cls (EntryResource): The `EntryResource` model
                that is stored by the collection.
            resource_mapper (BaseResourceMapper): A resource mapper
                object that handles aliases and format changes between
                deserialization and response.
            transformer (Transformer): The Lark `Transformer` used to
                interpret the filter.

        """
        self.parser = LarkParser()
        self.resource_cls = resource_cls
        self.resource_mapper = resource_mapper
        self.transformer = transformer

        self.provider_prefix = CONFIG.provider.prefix
        self.provider_fields = [
            field if isinstance(field, str) else field["name"] for field in
            CONFIG.provider_fields.get(resource_mapper.ENDPOINT, [])
        ]

        self._all_fields: Set[str] = None

    @abstractmethod
    def __len__(self) -> int:
        """Returns the total number of entries in the collection."""

    @abstractmethod
    def insert(self, data: List[EntryResource]) -> None:
        """Add the given entries to the underlying database.

        Arguments:
            data: The entry resource objects to add to the database.

        """

    @abstractmethod
    def count(self, **kwargs) -> int:
        """Returns the number of entries matching the query specified
        by the keyword arguments.

        Parameters:
            kwargs (dict): Query parameters as keyword arguments.

        """

    def find(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Tuple[Union[List[EntryResource], EntryResource, None], int, bool,
               Set[str], Set[str]]:
        """
        Fetches results and indicates if more data is available.

        Also gives the total number of data available in the absence of `page_limit`.
        See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams]
        for more information.

        Parameters:
            params: Entry listing URL query params.

        Returns:
            A tuple of various relevant values:
            (`results`, `data_returned`, `more_data_available`, `exclude_fields`, `include_fields`).

        """
        criteria = self.handle_query_params(params)
        single_entry = isinstance(params, SingleEntryQueryParams)
        response_fields = criteria.pop("fields")

        results, data_returned, more_data_available = self._run_db_query(
            criteria, single_entry)

        if single_entry:
            results = results[0] if results else None

            if data_returned > 1:
                raise NotFound(
                    detail=
                    f"Instead of a single entry, {data_returned} entries were found",
                )

        exclude_fields = self.all_fields - response_fields
        include_fields = (response_fields -
                          self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS)

        bad_optimade_fields = set()
        bad_provider_fields = set()
        for field in include_fields:
            if field not in self.resource_mapper.ALL_ATTRIBUTES:
                if field.startswith("_"):
                    if any(
                            field.startswith(f"_{prefix}_") for prefix in
                            self.resource_mapper.SUPPORTED_PREFIXES):
                        bad_provider_fields.add(field)
                else:
                    bad_optimade_fields.add(field)

        if bad_provider_fields:
            warnings.warn(
                message=
                f"Unrecognised field(s) for this provider requested in `response_fields`: {bad_provider_fields}.",
                category=UnknownProviderProperty,
            )

        if bad_optimade_fields:
            raise BadRequest(
                detail=
                f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}."
            )

        if results:
            results = self.resource_mapper.deserialize(results)

        return (
            results,
            data_returned,
            more_data_available,
            exclude_fields,
            include_fields,
        )

    @abstractmethod
    def _run_db_query(self,
                      criteria: Dict[str, Any],
                      single_entry: bool = False
                      ) -> Tuple[List[Dict[str, Any]], int, bool]:
        """Run the query on the backend and collect the results.

        Arguments:
            criteria: A dictionary representation of the query parameters.
            single_entry: Whether or not the caller is expecting a single entry response.

        Returns:
            The list of entries from the database (without any re-mapping), the total number of
            entries matching the query and a boolean for whether or not there is more data available.

        """

    @property
    def all_fields(self) -> Set[str]:
        """Get the set of all fields handled in this collection,
        from attribute fields in the schema, provider fields and top-level OPTIMADE fields.

        The set of all fields are lazily created and then cached.
        This means the set is created the first time the property is requested and then cached.

        Returns:
            All fields handled in this collection.

        """
        if not self._all_fields:
            # All OPTIMADE fields
            self._all_fields = (
                self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS.copy())
            self._all_fields |= self.get_attribute_fields()
            # All provider-specific fields
            self._all_fields |= {
                f"_{self.provider_prefix}_{field_name}"
                for field_name in self.provider_fields
            }

        return self._all_fields

    def get_attribute_fields(self) -> Set[str]:
        """Get the set of attribute fields

        Return only the _first-level_ attribute fields from the schema of the resource class,
        resolving references along the way if needed.

        Note:
            It is not needed to take care of other special OpenAPI schema keys than `allOf`,
            since only `allOf` will be found in this context.
            Other special keys can be found in [the Swagger documentation](https://swagger.io/docs/specification/data-models/oneof-anyof-allof-not/).

        Returns:
            Property names.

        """

        schema = self.resource_cls.schema()
        attributes = schema["properties"]["attributes"]
        if "allOf" in attributes:
            allOf = attributes.pop("allOf")
            for dict_ in allOf:
                attributes.update(dict_)
        if "$ref" in attributes:
            path = attributes["$ref"].split("/")[1:]
            attributes = schema.copy()
            while path:
                next_key = path.pop(0)
                attributes = attributes[next_key]
        return set(attributes["properties"].keys())

    def handle_query_params(
        self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
    ) -> Dict[str, Any]:
        """Parse and interpret the backend-agnostic query parameter models into a dictionary
        that can be used by the specific backend.

        Note:
            Currently this method returns the pymongo interpretation of the parameters,
            which will need modification for modified for other backends.

        Parameters:
            params (Union[EntryListingQueryParams, SingleEntryQueryParams]): The initialized query parameter model from the server.

        Raises:
            Forbidden: If too large of a page limit is provided.
            BadRequest: If an invalid request is made, e.g., with incorrect fields
                or response format.

        Returns:
            A dictionary representation of the query parameters.

        """
        cursor_kwargs = {}

        # filter
        if getattr(params, "filter", False):
            cursor_kwargs["filter"] = self.transformer.transform(
                self.parser.parse(params.filter))
        else:
            cursor_kwargs["filter"] = {}

        # response_format
        if (getattr(params, "response_format", False)
                and params.response_format != "json"):
            raise BadRequest(
                detail=
                f"Response format {params.response_format} is not supported, please use response_format='json'"
            )

        # page_limit
        if getattr(params, "page_limit", False):
            limit = params.page_limit
            if limit > CONFIG.page_limit_max:
                raise Forbidden(
                    detail=
                    f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}",
                )
            cursor_kwargs["limit"] = limit
        else:
            cursor_kwargs["limit"] = CONFIG.page_limit

        # response_fields
        cursor_kwargs["projection"] = {
            f"{self.resource_mapper.get_backend_field(f)}": True
            for f in self.all_fields
        }

        if "_id" not in cursor_kwargs["projection"]:
            cursor_kwargs["projection"]["_id"] = False

        if getattr(params, "response_fields", False):
            response_fields = set(params.response_fields.split(","))
            response_fields |= self.resource_mapper.get_required_fields()
        else:
            response_fields = self.all_fields.copy()

        cursor_kwargs["fields"] = response_fields

        # sort
        if getattr(params, "sort", False):
            cursor_kwargs["sort"] = self.parse_sort_params(params.sort)

        # page_offset
        if getattr(params, "page_offset", False):
            cursor_kwargs["skip"] = params.page_offset

        return cursor_kwargs

    def parse_sort_params(self, sort_params: str) -> Tuple[Tuple[str, int]]:
        """Handles any sort parameters passed to the collection,
        resolving aliases and dealing with any invalid fields.

        Raises:
            BadRequest: if an invalid sort is requested.

        Returns:
            A tuple of tuples containing the aliased field name and
            sort direction encoded as 1 (ascending) or -1 (descending).

        """
        sort_spec = []
        for field in sort_params.split(","):
            sort_dir = 1
            if field.startswith("-"):
                field = field[1:]
                sort_dir = -1
            aliased_field = self.resource_mapper.get_backend_field(field)
            sort_spec.append((aliased_field, sort_dir))

        unknown_fields = [
            field for field, _ in sort_spec
            if self.resource_mapper.get_optimade_field(field) not in
            self.all_fields
        ]

        if unknown_fields:
            error_detail = "Unable to sort on unknown field{} '{}'".format(
                "s" if len(unknown_fields) > 1 else "",
                "', '".join(unknown_fields),
            )

            # If all unknown fields are "other" provider-specific, then only provide a warning
            if all((re.match(r"_[a-z_0-9]+_[a-z_0-9]*", field)
                    and not field.startswith(f"_{self.provider_prefix}_"))
                   for field in unknown_fields):
                warnings.warn(error_detail, FieldValueNotRecognized)

            # Otherwise, if all fields are unknown, or some fields are unknown and do not
            # have other provider prefixes, then return 400: Bad Request
            else:
                raise BadRequest(detail=error_detail)

        # If at least one valid field has been provided for sorting, then use that
        sort_spec = tuple((field, sort_dir) for field, sort_dir in sort_spec
                          if field not in unknown_fields)

        return sort_spec
コード例 #24
0
    def test_suspected_timestamp_fields(self, mapper):
        import datetime
        import bson.tz_util
        from optimade.filtertransformers.mongo import MongoTransformer
        from optimade.server.warnings import TimestampNotRFCCompliant

        example_RFC3339_date = "2019-06-08T04:13:37Z"
        example_RFC3339_date_2 = "2019-06-08T04:13:37"
        example_non_RFC3339_date = "2019-06-08T04:13:37.123Z"

        expected_datetime = datetime.datetime(
            year=2019,
            month=6,
            day=8,
            hour=4,
            minute=13,
            second=37,
            microsecond=0,
            tzinfo=bson.tz_util.utc,
        )

        assert self.transform(f'last_modified > "{example_RFC3339_date}"') == {
            "last_modified": {
                "$gt": expected_datetime
            }
        }
        assert self.transform(
            f'last_modified > "{example_RFC3339_date_2}"') == {
                "last_modified": {
                    "$gt": expected_datetime
                }
            }

        non_rfc_datetime = expected_datetime.replace(microsecond=123000)

        with pytest.warns(TimestampNotRFCCompliant):
            assert self.transform(
                f'last_modified > "{example_non_RFC3339_date}"') == {
                    "last_modified": {
                        "$gt": non_rfc_datetime
                    }
                }

        class MyMapper(mapper("StructureMapper")):
            ALIASES = (("last_modified", "ctime"), )

        transformer = MongoTransformer(mapper=MyMapper)
        parser = LarkParser(version=self.version, variant=self.variant)

        assert transformer.transform(
            parser.parse(f'last_modified > "{example_RFC3339_date}"')) == {
                "ctime": {
                    "$gt": expected_datetime
                }
            }
        assert transformer.transform(
            parser.parse(f'last_modified > "{example_RFC3339_date_2}"')) == {
                "ctime": {
                    "$gt": expected_datetime
                }
            }
コード例 #25
0
 def test_parser_version(self):
     v = (0, 9, 5)
     p = LarkParser(version=v)
     assert isinstance(p.parse(self.test_filters[0]), Tree)
     assert p.version == v
コード例 #26
0
    def test_aliases(self, mapper):
        """Test that valid aliases are allowed, but do not affect r-values."""
        from optimade.filtertransformers.mongo import MongoTransformer

        class MyStructureMapper(mapper("StructureMapper")):
            ALIASES = (
                ("elements", "my_elements"),
                ("A", "D"),
                ("property_a", "D"),
                ("B", "E"),
                ("C", "F"),
                ("_exmpl_nested_field", "nested_field"),
            )

            PROVIDER_FIELDS = ("D", "E", "F", "nested_field")

        mapper = MyStructureMapper
        t = MongoTransformer(mapper=mapper)
        p = LarkParser(version=self.version, variant=self.variant)

        assert mapper.get_backend_field("elements") == "my_elements"
        test_filter = 'elements HAS "A"'
        assert t.transform(p.parse(test_filter)) == {
            "my_elements": {
                "$in": ["A"]
            }
        }
        test_filter = 'elements HAS ANY "A","B","C" AND elements HAS "D"'
        assert t.transform(p.parse(test_filter)) == {
            "$and": [
                {
                    "my_elements": {
                        "$in": ["A", "B", "C"]
                    }
                },
                {
                    "my_elements": {
                        "$in": ["D"]
                    }
                },
            ]
        }
        test_filter = 'elements = "A"'
        assert t.transform(p.parse(test_filter)) == {
            "my_elements": {
                "$eq": "A"
            }
        }
        test_filter = 'property_a HAS "B"'
        assert t.transform(p.parse(test_filter)) == {"D": {"$in": ["B"]}}

        test_filter = "_exmpl_nested_field.sub_property > 1234.5"
        assert t.transform(p.parse(test_filter)) == {
            "nested_field.sub_property": {
                "$gt": 1234.5
            }
        }

        test_filter = "_exmpl_nested_field.sub_property.x IS UNKNOWN"
        assert t.transform(p.parse(test_filter)) == {
            "$or": [
                {
                    "nested_field.sub_property.x": {
                        "$exists": False
                    }
                },
                {
                    "nested_field.sub_property.x": {
                        "$eq": None
                    }
                },
            ]
        }
コード例 #27
0
class TestTransformer:
    @pytest.fixture(autouse=True)
    def set_up(self):
        from optimade.filtertransformers.elasticsearch import Transformer, Quantity

        self.parser = LarkParser(version=(0, 10, 0), variant="elastic")

        nelements = Quantity(name="nelements")
        elements_only = Quantity(name="elements_only")
        elements_ratios = Quantity(name="elements_ratios")
        elements_ratios.nested_quantity = elements_ratios
        elements = Quantity(
            name="elements",
            length_quantity=nelements,
            has_only_quantity=elements_only,
            nested_quantity=elements_ratios,
        )
        dimension_types = Quantity(name="dimension_types")
        dimension_types.length_quantity = dimension_types

        quantities = [
            nelements,
            elements_only,
            elements_ratios,
            elements,
            dimension_types,
            Quantity(name="chemical_formula_reduced"),
        ]

        self.transformer = Transformer(quantities=quantities)

    def test_parse_n_transform(self):
        queries = [
            ("nelements > 1", 4),
            ("nelements >= 2", 4),
            ("nelements > 2", 1),
            ("nelements < 4", 4),
            ("nelements < 3", 3),
            ("nelements <= 3", 4),
            ("nelements != 2", 1),
            ("1 < nelements", 4),
            ('elements HAS "H"', 4),
            ('elements HAS ALL "H", "O"', 4),
            ('elements HAS ALL "H", "C"', 1),
            ('elements HAS ANY "H", "C"', 4),
            ('elements HAS ANY "C"', 1),
            ('elements HAS ONLY "C"', 0),
            ('elements HAS ONLY "H", "O"', 3),
            ('elements:elements_ratios HAS "H":>0.66', 2),
            ('elements:elements_ratios HAS ALL "O":>0.33', 3),
            ('elements:elements_ratios HAS ALL "O":>0.33,"O":<0.34', 2),
            ("elements IS KNOWN", 4),
            ("elements IS UNKNOWN", 0),
            ('chemical_formula_reduced = "H2O"', 2),
            ('chemical_formula_reduced CONTAINS "H2"', 3),
            ('chemical_formula_reduced CONTAINS "H"', 4),
            ('chemical_formula_reduced CONTAINS "C"', 1),
            ('chemical_formula_reduced STARTS "H2"', 3),
            ('chemical_formula_reduced STARTS WITH "H2"', 3),
            ('chemical_formula_reduced ENDS WITH "C"', 1),
            ('chemical_formula_reduced ENDS "C"', 1),
            ("LENGTH elements = 2", 3),
            ("LENGTH elements = 3", 1),
            ("LENGTH dimension_types = 0", 3),
            ("LENGTH dimension_types = 1", 1),
            ("nelements = 2 AND LENGTH dimension_types = 1", 1),
            ("nelements = 3 AND LENGTH dimension_types = 1", 0),
            ("nelements = 3 OR LENGTH dimension_types = 1", 2),
            ("nelements > 1 OR LENGTH dimension_types = 1 AND nelements = 2",
             4),
            ("(nelements > 1 OR LENGTH dimension_types = 1) AND nelements = 2",
             3),
            ("NOT LENGTH dimension_types = 1", 3),
        ]

        for query, _ in queries:
            tree = self.parser.parse(query)
            result = self.transformer.transform(tree)
            assert result is not None
コード例 #28
0
    f = "list HAS < 3, > 4"  # -> error
    f = "list HAS ALL < 3, > 4"

    # multiple lists

    f = "list1:list2 HAS < 3 : > 4"
    f = "list1:list2 HAS ALL < 3 : > 4"

    f = "list1:list2 HAS < 3 : > 4, < 2 : > 5"  # -> error
    f = "list1:list2 HAS ALL < 3 : > 4, < 2 : > 5"
    f = "list1:list2 HAS ALL < 3, < 2 : > 4, > 5"  # -> error

    # f = 'list1:list2 HAS < 3, > 4'  # -> error
    # f = 'list1:list2 HAS ALL < 3, > 4'  # -> error

    f = 'elements:elements_ratios HAS ALL "Al":>0.3333, "Al":<0.3334'
    f = 'elements:elements_ratios HAS ALL "Al":>0.3333 AND elements_ratio<0.3334'
    f = 'elements:elements_ratios HAS ALL "Al" : >0.3333, <0.3334'  # -> error

    f = "list1:list2 HAS ALL < 3 : > 4, < 2 : > 5 : > 4, < 2 : > 5"  # valid but wrong
    f = "ghf.flk<gh"  # valid but wrong

    # f = ''

    tree = p.parse(f)
    print(tree)
    print(tree.pretty())

    t.transform(tree)
コード例 #29
0
ファイル: test_aiida.py プロジェクト: csadorf/aiida-optimade
def test_aliased_length_operator():
    """Test LENGTH operator alias"""
    from optimade.server.mappers import StructureMapper

    class MyMapper(StructureMapper):
        """Test mapper with LENGTH_ALIASES"""

        ALIASES = (("elements", "my_elements"), ("nelements", "nelem"))
        LENGTH_ALIASES = (
            ("chemsys", "nelements"),
            ("cartesian_site_positions", "nsites"),
            ("elements", "nelements"),
        )
        PROVIDER_FIELDS = ("chemsys", )

    transformer = AiidaTransformer(mapper=MyMapper())
    parser = LarkParser(version=VERSION, variant=VARIANT)

    assert transformer.transform(
        parser.parse("cartesian_site_positions LENGTH <= 3")) == {
            "nsites": {
                "<=": 3
            }
        }
    assert transformer.transform(
        parser.parse("cartesian_site_positions LENGTH < 3")) == {
            "nsites": {
                "<": 3
            }
        }
    assert transformer.transform(
        parser.parse("cartesian_site_positions LENGTH 3")) == ({
            "nsites": 3
        })
    assert transformer.transform(
        parser.parse("cartesian_site_positions LENGTH 3")) == ({
            "nsites": 3
        })
    assert transformer.transform(
        parser.parse("cartesian_site_positions LENGTH >= 10")) == {
            "nsites": {
                ">=": 10
            }
        }
    assert transformer.transform(
        parser.parse("structure_features LENGTH > 10")) == ({
            "structure_features": {
                "longer": 10
            }
        })
    assert transformer.transform(parser.parse("nsites LENGTH > 10")) == ({
        "nsites": {
            "longer": 10
        }
    })
    assert transformer.transform(parser.parse("elements LENGTH 3")) == {
        "nelem": 3
    }
    assert transformer.transform(parser.parse('elements HAS "Ag"')) == ({
        "my_elements": {
            "contains": ["Ag"]
        }
    })
    assert transformer.transform(parser.parse("chemsys LENGTH 3")) == {
        "nelem": 3
    }