예제 #1
0
def search_resource(eos_store):
    resource = Resource(
        eos_store,
        SearchDoc,
        query_operators=[
            FormulaQuery(),
            MinMaxQuery(),
            SymmetryQuery(),
            ThermoEnergySearchQuery(),
            IsStableQuery(),
            SearchBandGapQuery(),
            BulkModulusQuery(),
            ShearModulusQuery(),
            PoissonQuery(),
            DielectricQuery(),
            PiezoelectricQuery(),
            SurfaceMinMaxQuery(),
            SearchTaskIDsQuery(),
            HasPropsQuery(),
            DeprecationQuery(),
            PaginationQuery(),
            SparseFieldsQuery(SearchDoc, default_fields=["task_id"]),
        ],
        tags=["Search"],
    )

    return resource
예제 #2
0
def xas_resource(xas_store):
    resource = Resource(
        xas_store,
        XASDoc,
        query_operators=[
            FormulaQuery(),
            XASQuery(),
            PaginationQuery(),
            SparseFieldsQuery(
                XASDoc,
                default_fields=[
                    "xas_id",
                    "task_id",
                    "edge",
                    "absorbing_element",
                    "formula_pretty",
                    "spectrum_type",
                    "last_updated",
                ],
            ),
        ],
        tags=["XAS"],
    )

    return resource
예제 #3
0
파일: resources.py 프로젝트: jmmshn/api
    def custom_charge_density_endpoint_prep(self):

        self.s3 = s3_store
        model = ChgcarDataDoc
        model_name = model.__name__
        key_name = "task_id"

        field_input = SparseFieldsQuery(
            model, [key_name, self.s3.last_updated_field]).query

        async def get_chgcar_data(
                material_id:
            str = Path(
                ...,
                alias=key_name,
                title=
                f"The Material ID ({key_name}) associated with the {model_name}",
            ),
                fields: STORE_PARAMS = Depends(field_input),
        ):
            f"""
            Get's a document by the primary key in the store

            Args:
                material_id: The Materials Project ID ({key_name}) of a single {model_name}

            Returns:
                a single {model_name} document
            """

            self.s3.connect()

            chgcar_key = self.s3.query_one(
                criteria={
                    key_name: material_id
                },
                properties=[key_name],
            ).get(key_name, None)

            if chgcar_key is None:
                raise HTTPException(
                    status_code=404,
                    detail=
                    f"Charge density data with {key_name} = {chgcar_key} not found",
                )

            item = self.s3.query_one({key_name: chgcar_key},
                                     properties=fields["properties"])
            response = item

            return response

        self.router.get(
            f"/{{{key_name}}}/",
            response_description=f"Get an {model_name} by {key_name}",
            response_model=model,
            response_model_exclude_unset=True,
            tags=self.tags,
        )(get_chgcar_data)
예제 #4
0
    def custom_dos_endpoint_prep(self):

        self.s3 = s3_store
        model = DOSObjectReturn
        model_name = model.__name__
        key_name = self.s3.key

        field_input = SparseFieldsQuery(
            model, [self.s3.key, self.s3.last_updated_field]
        ).query

        async def get_object(
            key: str = Query(
                ..., alias=key_name, title=f"The {key_name} of the {model_name} to get",
            ),
            fields: STORE_PARAMS = Depends(field_input),
        ):
            f"""
                    Get's a document by the primary key in the store

                    Args:
                        {key_name}: the id of a single {model_name}

                    Returns:
                        a single {model_name} document
                    """

            self.store.connect()

            self.s3.connect()

            dos_entry = self.store.query_one(
                criteria={self.store.key: key}, properties=["total.task_id"]
            )

            dos_task = dos_entry.get("total", None).get("task_id", None)

            if dos_task is None:
                raise HTTPException(
                    status_code=404,
                    detail=f"DOS with {self.store.key} = {key} not found",
                )

            item = self.s3.query_one(
                {"task_id": dos_task}, properties=fields["properties"]
            )
            response = item

            return response

        self.router.get(
            "/object/",
            response_description=f"Get an {model_name} by {key_name}",
            response_model=model,
            response_model_exclude_unset=True,
            tags=self.tags,
        )(get_object)
예제 #5
0
파일: resources.py 프로젝트: jmmshn/api
def dois_resource(dois_store):
    resource = Resource(
        dois_store,
        DOIDoc,
        query_operators=[
            PaginationQuery(),
            SparseFieldsQuery(DOIDoc, default_fields=["task_id", "doi"]),
        ],
        tags=["DOIs"],
        enable_default_search=False,
    )

    return resource
예제 #6
0
def wulff_resource(wulff_store):
    resource = Resource(
        wulff_store,
        WulffDoc,
        query_operators=[
            PaginationQuery(),
            SparseFieldsQuery(WulffDoc, default_fields=["task_id"]),
        ],
        tags=["Surface Properties"],
        enable_default_search=False,
    )

    return resource
예제 #7
0
파일: resources.py 프로젝트: jmmshn/api
def similarity_resource(similarity_store):
    resource = Resource(
        similarity_store,
        SimilarityDoc,
        query_operators=[
            PaginationQuery(),
            SparseFieldsQuery(SimilarityDoc, default_fields=["task_id"]),
        ],
        tags=["Similarity"],
        enable_default_search=False,
    )

    return resource
예제 #8
0
파일: resources.py 프로젝트: jmmshn/api
def fermi_resource(fermi_store):
    resource = Resource(
        fermi_store,
        FermiDoc,
        query_operators=[
            PaginationQuery(),
            SparseFieldsQuery(FermiDoc,
                              default_fields=["task_id", "last_updated"]),
        ],
        tags=["Electronic Structure"],
    )

    return resource
예제 #9
0
def piezo_resource(piezo_store):
    resource = Resource(
        piezo_store,
        PiezoDoc,
        query_operators=[
            PiezoelectricQuery(),
            PaginationQuery(),
            SparseFieldsQuery(PiezoDoc,
                              default_fields=["task_id", "last_updated"]),
        ],
        tags=["Piezoelectric"],
    )

    return resource
예제 #10
0
def eos_resource(eos_store):
    resource = Resource(
        eos_store,
        EOSDoc,
        query_operators=[
            EnergyVolumeQuery(),
            SortQuery(),
            PaginationQuery(),
            SparseFieldsQuery(EOSDoc, default_fields=["task_id"]),
        ],
        tags=["EOS"],
    )

    return resource
예제 #11
0
def dielectric_resource(dielectric_store):
    resource = Resource(
        dielectric_store,
        DielectricDoc,
        query_operators=[
            DielectricQuery(),
            PaginationQuery(),
            SparseFieldsQuery(DielectricDoc,
                              default_fields=["task_id", "last_updated"]),
        ],
        tags=["Dielectric"],
    )

    return resource
예제 #12
0
def surface_props_resource(surface_prop_store):
    resource = Resource(
        surface_prop_store,
        SurfacePropDoc,
        query_operators=[
            SurfaceMinMaxQuery(),
            ReconstructedQuery(),
            PaginationQuery(),
            SparseFieldsQuery(SurfacePropDoc, default_fields=["task_id"]),
        ],
        tags=["Surface Properties"],
    )

    return resource
예제 #13
0
def magnetism_resource(magnetism_store):
    resource = Resource(
        magnetism_store,
        MagnetismDoc,
        query_operators=[
            MagneticQuery(),
            PaginationQuery(),
            SparseFieldsQuery(MagnetismDoc,
                              default_fields=["task_id", "last_updated"]),
        ],
        tags=["Magnetism"],
    )

    return resource
예제 #14
0
def substrates_resource(substrates_store):
    resource = Resource(
        substrates_store,
        SubstratesDoc,
        query_operators=[
            SubstrateStructureQuery(),
            EnergyAreaQuery(),
            PaginationQuery(),
            SparseFieldsQuery(SubstratesDoc,
                              default_fields=["film_id", "sub_id"]),
        ],
        tags=["Substrates"],
        enable_get_by_key=False,
    )

    return resource
예제 #15
0
def task_resource(task_store):
    resource = Resource(
        task_store,
        TaskDoc,
        query_operators=[
            FormulaQuery(),
            PaginationQuery(),
            SparseFieldsQuery(
                TaskDoc,
                default_fields=["task_id", "formula_pretty", "last_updated"],
            ),
        ],
        tags=["Tasks"],
    )

    return resource
예제 #16
0
def gb_resource(gb_store):
    resource = Resource(
        gb_store,
        GBDoc,
        query_operators=[
            GBTaskIDQuery(),
            GBEnergyQuery(),
            GBStructureQuery(),
            PaginationQuery(),
            SparseFieldsQuery(GBDoc,
                              default_fields=["task_id", "last_updated"]),
        ],
        tags=["Grain Boundaries"],
        enable_get_by_key=False,
    )

    return resource
예제 #17
0
def thermo_resource(thermo_store):
    resource = Resource(
        thermo_store,
        ThermoDoc,
        query_operators=[
            VersionQuery(),
            ThermoChemicalQuery(),
            IsStableQuery(),
            ThermoEnergyQuery(),
            PaginationQuery(),
            SparseFieldsQuery(ThermoDoc,
                              default_fields=["task_id", "last_updated"]),
        ],
        tags=["Thermo"],
    )

    return resource
예제 #18
0
파일: resources.py 프로젝트: jmmshn/api
def molecules_resource(molecules_store):
    resource = Resource(
        molecules_store,
        MoleculesDoc,
        query_operators=[
            MoleculeBaseQuery(),
            MoleculeElementsQuery(),
            MoleculeFormulaQuery(),
            SearchTaskIDsQuery(),
            SortQuery(),
            PaginationQuery(),
            SparseFieldsQuery(MoleculesDoc, default_fields=["task_id"]),
        ],
        tags=["Molecules"],
    )

    return resource
예제 #19
0
def elasticity_resource(elasticity_store):
    resource = Resource(
        elasticity_store,
        ElasticityDoc,
        query_operators=[
            ChemsysQuery(),
            BulkModulusQuery(),
            ShearModulusQuery(),
            PoissonQuery(),
            PaginationQuery(),
            SparseFieldsQuery(
                ElasticityDoc,
                default_fields=["task_id", "pretty_formula"],
            ),
        ],
        tags=["Elasticity"],
    )

    return resource
예제 #20
0
def insertion_electrodes_resource(insertion_electrodes_store):
    resource = Resource(
        insertion_electrodes_store,
        InsertionElectrodeDoc,
        query_operators=[
            VoltageStepQuery(),
            InsertionVoltageStepQuery(),
            InsertionElectrodeQuery(),
            SortQuery(),
            PaginationQuery(),
            SparseFieldsQuery(
                InsertionElectrodeDoc,
                default_fields=["task_id", "last_updated"],
            ),
        ],
        tags=["Electrodes"],
    )

    return resource
예제 #21
0
def bs_resource(bs_store, s3_store):
    def custom_bs_endpoint_prep(self):

        self.s3 = s3_store
        model = BSObjectReturn
        model_name = model.__name__
        key_name = self.s3.key

        field_input = SparseFieldsQuery(
            model, [self.s3.key, self.s3.last_updated_field]).query

        async def get_object(
                key: str = Query(
                    ...,
                    alias=key_name,
                    title=f"The {key_name} of the {model_name} to get",
                ),
                path_type: BSPathType = Query(
                    ...,
                    title=
                    "The k-path convention type for the band structure object",
                ),
                fields: STORE_PARAMS = Depends(field_input),
        ):
            f"""
                    Get's a document by the primary key in the store

                    Args:
                        {key_name}: the id of a single {model_name}

                    Returns:
                        a single {model_name} document
                    """

            self.store.connect()

            self.s3.connect()

            bs_entry = self.store.query_one(
                criteria={self.store.key: key},
                properties=[f"{str(path_type.name)}.task_id"],
            )

            bs_task = bs_entry.get(str(path_type.name)).get("task_id", None)

            if bs_task is None:
                raise HTTPException(
                    status_code=404,
                    detail=
                    f"Band structure with {self.store.key} = {key} not found",
                )

            item = self.s3.query_one({"task_id": bs_task},
                                     properties=fields["properties"])
            response = item

            return response

        self.router.get(
            "/object/",
            response_description=f"Get an {model_name} by {key_name}",
            response_model=model,
            response_model_exclude_unset=True,
            tags=self.tags,
        )(get_object)

    resource = Resource(
        bs_store,
        BSDoc,
        query_operators=[
            BSDataQuery(),
            FormulaQuery(),
            MinMaxQuery(),
            PaginationQuery(),
            SparseFieldsQuery(BSDoc,
                              default_fields=["task_id", "last_updated"]),
        ],
        tags=["Electronic Structure"],
        custom_endpoint_funcs=[custom_bs_endpoint_prep],
    )

    return resource
예제 #22
0
파일: resource.py 프로젝트: jmmshn/api
    def __init__(
        self,
        store: Store,
        model: Union[BaseModel, str],
        tags: Optional[List[str]] = None,
        query_operators: Optional[List[QueryOperator]] = None,
        route_class: Type[APIRoute] = None,
        key_fields: List[str] = None,
        custom_endpoint_funcs: List[Callable] = None,
        enable_get_by_key: bool = True,
        enable_default_search: bool = True,
    ):
        """
        Args:
            store: The Maggma Store to get data from
            model: the pydantic model to apply to the documents from the Store
                This can be a string with a full python path to a model or
                an actuall pydantic Model if this is being instantied in python
                code. Serializing this via Monty will autoconvert the pydantic model
                into a python path string
            tags: list of tags for the Endpoint
            query_operators: operators for the query language
            route_class: Custom APIRoute class to define post-processing or custom validation
                of response data
            key_fields: List of fields to always project. Default uses SparseFieldsQuery
                to allow user's to define these on-the-fly.
            custom_endpoint_funcs: Custom endpoint preparation functions to be used
            enable_get_by_key: Enable default key route for endpoint.
            enable_default_search: Enable default endpoint search behavior.
        """
        self.store = store
        self.tags = tags or []
        self.key_fields = key_fields
        self.versioned = False
        self.cep = custom_endpoint_funcs
        self.enable_get_by_key = enable_get_by_key
        self.enable_default_search = enable_default_search

        if isinstance(model, str):
            module_path = ".".join(model.split(".")[:-1])
            class_name = model.split(".")[-1]
            class_model = dynamic_import(module_path, class_name)
            assert issubclass(
                class_model,
                BaseModel), "The resource model has to be a PyDantic Model"
            self.model = class_model
        elif isinstance(model, type) and issubclass(model, BaseModel):
            self.model = model
        else:
            raise ValueError("The resource model has to be a PyDantic Model")

        self.query_operators = (
            query_operators if query_operators is not None else [
                PaginationQuery(),
                SparseFieldsQuery(
                    self.model,
                    default_fields=[
                        self.store.key, self.store.last_updated_field
                    ],
                ),
            ])

        if any(
                isinstance(qop_entry, VersionQuery)
                for qop_entry in self.query_operators):
            self.versioned = True

        if route_class is not None:
            self.router = APIRouter(route_class=route_class)
        else:
            self.router = APIRouter()
        self.response_model = Response[self.model]  # type: ignore
        self.prepare_endpoint()
예제 #23
0
파일: resource.py 프로젝트: jmmshn/api
    def build_get_by_key(self):
        key_name = self.store.key
        model_name = self.model.__name__

        if self.key_fields is None:
            field_input = SparseFieldsQuery(
                self.model,
                [self.store.key, self.store.last_updated_field]).query
        else:

            def field_input():
                return {"properties": self.key_fields}

        if not self.versioned:

            async def get_by_key(
                    key: str = Path(
                        ...,
                        alias=key_name,
                        title=f"The {key_name} of the {model_name} to get",
                    ),
                    fields: STORE_PARAMS = Depends(field_input),
            ):
                f"""
                Get's a document by the primary key in the store

                Args:
                    {key_name}: the id of a single {model_name}

                Returns:
                    a single {model_name} document
                """
                self.store.connect()

                crit = {self.store.key: key}

                if model_name == "MaterialsCoreDoc":
                    crit.update({"_sbxn": "core"})
                elif model_name == "TaskDoc":
                    crit.update({"sbxn": "core"})
                elif model_name == "ThermoDoc":
                    crit.update({"_sbxn": "core"})

                item = self.store.query_one(criteria=crit,
                                            properties=fields["properties"])

                if item is None:
                    raise HTTPException(
                        status_code=404,
                        detail=f"Item with {self.store.key} = {key} not found",
                    )

                response = {"data": [item]}

                return response

            self.router.get(
                f"/{{{key_name}}}/",
                response_description=f"Get an {model_name} by {key_name}",
                response_model=self.response_model,
                response_model_exclude_unset=True,
                tags=self.tags,
            )(get_by_key)

        else:

            async def get_by_key_versioned(
                key: str = Path(
                    ...,
                    alias=key_name,
                    title=f"The {key_name} of the {model_name} to get",
                ),
                fields: STORE_PARAMS = Depends(field_input),
                version: str = Query(
                    None,
                    description=
                    "Database version to query on formatted as YYYY.MM.DD",
                ),
            ):
                f"""
                Get's a document by the primary key in the store

                Args:
                    {key_name}: the id of a single {model_name}

                Returns:
                    a single {model_name} document
                """

                if version is not None:
                    version = version.replace(".", "_")
                else:
                    version = os.environ.get("DB_VERSION")

                prefix = self.store.collection_name.split("_")[0]
                self.store.collection_name = f"{prefix}_{version}"

                self.store.connect(force_reset=True)

                crit = {self.store.key: key}

                if model_name == "MaterialsCoreDoc":
                    crit.update({"_sbxn": "core"})
                elif model_name == "TaskDoc":
                    crit.update({"sbxn": "core"})
                elif model_name == "ThermoDoc":
                    crit.update({"_sbxn": "core"})

                item = self.store.query_one(criteria=crit,
                                            properties=fields["properties"])

                if item is None:
                    raise HTTPException(
                        status_code=404,
                        detail=f"Item with {self.store.key} = {key} not found",
                    )

                response = {"data": [item]}

                return response

            self.router.get(
                f"/{{{key_name}}}/",
                response_description=f"Get an {model_name} by {key_name}",
                response_model=self.response_model,
                response_model_exclude_unset=True,
                tags=self.tags,
            )(get_by_key_versioned)
예제 #24
0
def materials_resource(materials_store):
    def custom_version_prep(self):
        model_name = self.model.__name__

        async def get_versions():
            f"""
            Obtains the database versions for the data in {model_name}

            Returns:
                A list of database versions one can use to query on
            """

            try:
                conn = MongoClient(self.store.host, self.store.port)
                db = conn[self.store.database]
                if self.core.username != "":
                    db.authenticate(self.username, self.password)

            except AttributeError:
                conn = MongoClient(self.store.uri)
                db = conn[self.store.database]

            col_names = db.list_collection_names()

            d = [
                name.replace("_", ".")[15:]
                for name in col_names
                if "materials" in name
                if name != "materials.core"
            ]

            response = {"data": d}

            return response

        self.router.get(
            "/versions/",
            response_model_exclude_unset=True,
            response_description=f"Get versions of {model_name}",
            tags=self.tags,
        )(get_versions)

    def custom_findstructure_prep(self):
        model_name = self.model.__name__

        async def find_structure(
            structure: Structure = Body(
                ...,
                title="Pymatgen structure object to query with",
            ),
            ltol: float = Query(
                0.2,
                title="Fractional length tolerance. Default is 0.2.",
            ),
            stol: float = Query(
                0.3,
                title="Site tolerance. Defined as the fraction of the average free \
                    length per atom := ( V / Nsites ) ** (1/3). Default is 0.3.",
            ),
            angle_tol: float = Query(
                5,
                title="Angle tolerance in degrees. Default is 5 degrees.",
            ),
            limit: int = Query(
                1,
                title="Maximum number of matches to show. Defaults to 1, only showing the best match.",
            ),
        ):
            """
            Obtains material structures that match a given input structure within some tolerance.

            Returns:
                A list of Material IDs for materials with matched structures alongside the associated RMS values
            """

            try:
                s = PS.from_dict(structure.dict())
            except Exception:
                raise HTTPException(
                    status_code=404,
                    detail="Body cannot be converted to a pymatgen structure object.",
                )

            m = StructureMatcher(
                ltol=ltol,
                stol=stol,
                angle_tol=angle_tol,
                primitive_cell=True,
                scale=True,
                attempt_supercell=False,
                comparator=ElementComparator(),
            )

            crit = {"composition_reduced": dict(s.composition.to_reduced_dict)}

            self.store.connect()

            matches = []

            for r in self.store.query(
                criteria=crit, properties=["structure", "task_id"]
            ):

                s2 = PS.from_dict(r["structure"])
                matched = m.fit(s, s2)

                if matched:
                    rms = m.get_rms_dist(s, s2)

                    matches.append(
                        {
                            "task_id": r["task_id"],
                            "normalized_rms_displacement": rms[0],
                            "max_distance_paired_sites": rms[1],
                        }
                    )

            response = {
                "data": sorted(
                    matches[:limit],
                    key=lambda x: (
                        x["normalized_rms_displacement"],
                        x["max_distance_paired_sites"],
                    ),
                )
            }

            return response

        self.router.post(
            "/find_structure/",
            response_model_exclude_unset=True,
            response_description=f"Get matching structures using data from {model_name}",
            tags=self.tags,
        )(find_structure)

    resource = Resource(
        materials_store,
        MaterialsCoreDoc,
        query_operators=[
            VersionQuery(),
            FormulaQuery(),
            MultiTaskIDQuery(),
            SymmetryQuery(),
            DeprecationQuery(),
            MinMaxQuery(),
            PaginationQuery(),
            SparseFieldsQuery(
                MaterialsCoreDoc,
                default_fields=["task_id", "formula_pretty", "last_updated"],
            ),
        ],
        tags=["Materials"],
        custom_endpoint_funcs=[custom_version_prep, custom_findstructure_prep],
    )

    return resource
예제 #25
0
파일: resources.py 프로젝트: jmmshn/api
def materials_resource(materials_store):
    def custom_version_prep(self):
        model_name = self.model.__name__

        async def get_versions():
            f"""
            Obtains the database versions for the data in {model_name}

            Returns:
                A list of database versions one can use to query on
            """

            try:
                conn = MongoClient(self.store.host, self.store.port)
                db = conn[self.store.database]
                if self.core.username != "":
                    db.authenticate(self.username, self.password)

            except AttributeError:
                conn = MongoClient(self.store.uri)
                db = conn[self.store.database]

            col_names = db.list_collection_names()

            d = [
                name.replace("_", ".")[15:] for name in col_names
                if "materials" in name if name != "materials.core"
            ]

            response = {"data": d}

            return response

        self.router.get(
            "/versions/",
            response_model_exclude_unset=True,
            response_description=f"Get versions of {model_name}",
            tags=self.tags,
        )(get_versions)

    def custom_findstructure_prep(self):
        model_name = self.model.__name__

        async def find_structure(
            structure: Structure = Body(
                ...,
                title="Pymatgen structure object to query with",
            ),
            ltol: float = Query(
                0.2,
                title="Fractional length tolerance. Default is 0.2.",
            ),
            stol: float = Query(
                0.3,
                title=
                "Site tolerance. Defined as the fraction of the average free \
                    length per atom := ( V / Nsites ) ** (1/3). Default is 0.3.",
            ),
            angle_tol: float = Query(
                5,
                title="Angle tolerance in degrees. Default is 5 degrees.",
            ),
            limit: int = Query(
                1,
                title=
                "Maximum number of matches to show. Defaults to 1, only showing the best match.",
            ),
        ):
            """
            Obtains material structures that match a given input structure within some tolerance.

            Returns:
                A list of Material IDs for materials with matched structures alongside the associated RMS values
            """

            try:
                s = PS.from_dict(structure.dict())
            except Exception:
                raise HTTPException(
                    status_code=404,
                    detail=
                    "Body cannot be converted to a pymatgen structure object.",
                )

            m = StructureMatcher(
                ltol=ltol,
                stol=stol,
                angle_tol=angle_tol,
                primitive_cell=True,
                scale=True,
                attempt_supercell=False,
                comparator=ElementComparator(),
            )

            crit = {"composition_reduced": dict(s.composition.to_reduced_dict)}

            self.store.connect()

            matches = []

            for r in self.store.query(criteria=crit,
                                      properties=["structure", "task_id"]):

                s2 = PS.from_dict(r["structure"])
                matched = m.fit(s, s2)

                if matched:
                    rms = m.get_rms_dist(s, s2)

                    matches.append({
                        "task_id": r["task_id"],
                        "normalized_rms_displacement": rms[0],
                        "max_distance_paired_sites": rms[1],
                    })

            response = {
                "data":
                sorted(
                    matches[:limit],
                    key=lambda x: (
                        x["normalized_rms_displacement"],
                        x["max_distance_paired_sites"],
                    ),
                )
            }

            return response

        self.router.post(
            "/find_structure/",
            response_model_exclude_unset=True,
            response_description=
            f"Get matching structures using data from {model_name}",
            tags=self.tags,
        )(find_structure)

    def custom_autocomplete_prep(self):
        async def formula_autocomplete(
            text: str = Query(
                ...,
                description="Text to run against formula autocomplete",
            ),
            limit: int = Query(
                10,
                description="Maximum number of matches to show. Defaults to 10",
            ),
        ):

            comp = Composition(text)

            ind_str = []

            if len(comp) == 1:
                d = comp.get_integer_formula_and_factor()

                s = d[0] + str(int(d[1])) if d[1] != 1 else d[0]

                ind_str.append(s)
            else:

                comp_red = comp.reduced_composition.items()

                for (i, j) in comp_red:

                    if j != 1:
                        ind_str.append(i.name + str(int(j)))
                    else:
                        ind_str.append(i.name)

            final_terms = ["".join(entry) for entry in permutations(ind_str)]

            pipeline = [
                {
                    "$search": {
                        "index": "formula_autocomplete",
                        "autocomplete": {
                            "path": "formula_pretty",
                            "query": final_terms,
                            "tokenOrder": "any",
                        },
                    }
                },
                {
                    "$group": {
                        "_id": "$formula_pretty",
                    }
                },
                {
                    "$project": {
                        "score": {
                            "$strLenCP": "$_id"
                        }
                    }
                },
                {
                    "$sort": {
                        "score": 1
                    }
                },
                {
                    "$limit": limit
                },
            ]

            self.store.connect()

            data = list(
                self.store._collection.aggregate(pipeline, allowDiskUse=True))

            response = {"data": data}

            return response

        self.router.get(
            "/formula_autocomplete/",
            response_model_exclude_unset=True,
            response_description="Get autocomplete results for a formula",
            tags=self.tags,
        )(formula_autocomplete)

    resource = Resource(
        materials_store,
        MaterialsCoreDoc,
        query_operators=[
            VersionQuery(),
            FormulaQuery(),
            MultiTaskIDQuery(),
            SymmetryQuery(),
            DeprecationQuery(),
            MinMaxQuery(),
            SortQuery(),
            PaginationQuery(),
            SparseFieldsQuery(
                MaterialsCoreDoc,
                default_fields=["task_id", "formula_pretty", "last_updated"],
            ),
        ],
        tags=["Materials"],
        custom_endpoint_funcs=[
            custom_version_prep,
            custom_findstructure_prep,
            custom_autocomplete_prep,
        ],
    )

    return resource
예제 #26
0
def synth_resource(synth_store):
    def custom_synth_prep(self):
        async def query_synth_text(
            keywords: str = Query(
                ...,
                description=
                "Comma delimited string keywords to search synthesis description text with",
            ),
            skip: int = Query(
                0, description="Number of entries to skip in the search"),
            limit: int = Query(
                100,
                description=
                "Max number of entries to return in a single query. Limited to 100",
            ),
        ):

            pipeline = [
                {
                    "$search": {
                        "index": "synth_descriptions",
                        "regex": {
                            "query":
                            [word + ".*" for word in keywords.split(",")],
                            "path": "text",
                            "allowAnalyzedField": True,
                        },
                    }
                },
                {
                    "$project": {
                        "_id": 0,
                        "doi": 1,
                        "formula": 1,
                        "text": 1,
                        "search_score": {
                            "$meta": "searchScore"
                        },
                    }
                },
                {
                    "$sort": {
                        "search_score": -1
                    }
                },
                {
                    "$skip": skip
                },
                {
                    "$limit": limit
                },
            ]

            self.store.connect()

            data = list(
                self.store._collection.aggregate(pipeline, allowDiskUse=True))

            response = {"data": data}

            return response

        self.router.get(
            "/text_search/",
            response_model=self.response_model,
            response_model_exclude_unset=True,
            response_description=
            "Find synthesis description documents through text search.",
            tags=self.tags,
        )(query_synth_text)

    resource = Resource(
        synth_store,
        SynthesisDoc,
        query_operators=[
            SynthFormulaQuery(),
            SortQuery(),
            PaginationQuery(),
            SparseFieldsQuery(SynthesisDoc, default_fields=["formula", "doi"]),
        ],
        tags=["Synthesis"],
        custom_endpoint_funcs=[custom_synth_prep],
        enable_default_search=True,
        enable_get_by_key=False,
    )

    return resource