Example #1
0
def get_valid_session_links(session_id: str, api_router: APIRouter) \
        -> typing.Dict[str, ResourceLink]:
    """Get the valid links for a session"""
    return {
        "GET":
        ResourceLink(href=api_router.url_path_for(get_session_handler.__name__,
                                                  session_id=session_id)),
        "POST":
        ResourceLink(href=api_router.url_path_for(
            session_command_create_handler.__name__, session_id=session_id)),
        "DELETE":
        ResourceLink(href=api_router.url_path_for(
            delete_session_handler.__name__, session_id=session_id)),
    }
Example #2
0
def get_valid_session_links(session_id: IdentifierType,
                            api_router: APIRouter) \
        -> ResourceLinks:
    """Get the valid links for a session"""
    return {
        ResourceLinkKey.self:
        ResourceLink(href=api_router.url_path_for(get_session_handler.__name__,
                                                  sessionId=session_id)),
        ResourceLinkKey.session_command_execute:
        ResourceLink(href=api_router.url_path_for(
            session_command_execute_handler.__name__, sessionId=session_id)),
        ResourceLinkKey.sessions:
        ROOT_RESOURCE,
        ResourceLinkKey.session_by_id:
        SESSIONS_BY_ID_RESOURCE,
    }
Example #3
0
def get_protocol_links(api_router: APIRouter, protocol_id: str) \
        -> ResourceLinks:
    """Get resource links for specific resource path handlers"""
    return {
        ResourceLinkKey.self: ResourceLink(
            href=api_router.url_path_for(get_protocol.__name__,
                                         protocolId=protocol_id)),
        ResourceLinkKey.protocols: ROOT_RESOURCE,
        ResourceLinkKey.protocol_by_id: PROTOCOL_BY_ID_RESOURCE
    }
Example #4
0
def get_session(manager: SessionManager, session_id: str,
                api_router: APIRouter) -> CalibrationSession:
    """Get the session or raise a RobotServerError"""
    found_session = manager.sessions.get(session_id)
    if not found_session:
        # There is no session raise error
        raise RobotServerError(
            status_code=http_status_codes.HTTP_404_NOT_FOUND,
            error=Error(title="No session",
                        detail=f"Cannot find session with id '{session_id}'.",
                        links={
                            "POST":
                            api_router.url_path_for(
                                create_session_handler.__name__)
                        }))
    return found_session
Example #5
0
def standard_json_response(
    task: AsyncResult,
    router: APIRouter,
    get_route: str,
    timeout: float = 0.2,
    api_version: str = config.API_LATEST,
) -> Dict[str, Any]:
    router_path = router.url_path_for(get_route, task_id=task.id)

    result_route = f"{api_version}{router_path}"

    _ = wait_a_sec_and_see_if_we_can_return_some_data(task, timeout=timeout)

    response = dict(task_id=task.task_id,
                    status=task.status,
                    result_route=result_route)

    if task.successful():
        response["data"] = task.result

    return response
Example #6
0
class Resource(MSONable):
    """
    Implements a REST Compatible Resource as a URL endpoint
    This class provides a number of convenience features
    including full pagination, field projection, and the
    MAPI query lanaugage

    - implements custom error handlers to provide MAPI Responses
    - implement standard metadata respomse for class
    - JSON Configuration
    """
    def __init__(
        self,
        store: Store,
        model: Union[BaseModel, str],
        tags: Optional[List[str]] = None,
        query_operators: Optional[List[QueryOperator]] = None,
        route_class: Type[APIRoute] = None,
        key_fields: List[str] = None,
        custom_endpoint_funcs: List[Callable] = None,
        enable_get_by_key: bool = True,
        enable_default_search: bool = True,
    ):
        """
        Args:
            store: The Maggma Store to get data from
            model: the pydantic model to apply to the documents from the Store
                This can be a string with a full python path to a model or
                an actuall pydantic Model if this is being instantied in python
                code. Serializing this via Monty will autoconvert the pydantic model
                into a python path string
            tags: list of tags for the Endpoint
            query_operators: operators for the query language
            route_class: Custom APIRoute class to define post-processing or custom validation
                of response data
            key_fields: List of fields to always project. Default uses SparseFieldsQuery
                to allow user's to define these on-the-fly.
            custom_endpoint_funcs: Custom endpoint preparation functions to be used
            enable_get_by_key: Enable default key route for endpoint.
            enable_default_search: Enable default endpoint search behavior.
        """
        self.store = store
        self.tags = tags or []
        self.key_fields = key_fields
        self.versioned = False
        self.cep = custom_endpoint_funcs
        self.enable_get_by_key = enable_get_by_key
        self.enable_default_search = enable_default_search

        if isinstance(model, str):
            module_path = ".".join(model.split(".")[:-1])
            class_name = model.split(".")[-1]
            class_model = dynamic_import(module_path, class_name)
            assert issubclass(
                class_model,
                BaseModel), "The resource model has to be a PyDantic Model"
            self.model = class_model
        elif isinstance(model, type) and issubclass(model, BaseModel):
            self.model = model
        else:
            raise ValueError("The resource model has to be a PyDantic Model")

        self.query_operators = (
            query_operators if query_operators is not None else [
                PaginationQuery(),
                SparseFieldsQuery(
                    self.model,
                    default_fields=[
                        self.store.key, self.store.last_updated_field
                    ],
                ),
            ])

        if any(
                isinstance(qop_entry, VersionQuery)
                for qop_entry in self.query_operators):
            self.versioned = True

        if route_class is not None:
            self.router = APIRouter(route_class=route_class)
        else:
            self.router = APIRouter()
        self.response_model = Response[self.model]  # type: ignore
        self.prepare_endpoint()

    def prepare_endpoint(self):
        """
        Internal method to prepare the endpoint by setting up default handlers
        for routes
        """

        if self.cep is not None:
            for func in self.cep:
                func(self)

        if self.enable_get_by_key:
            self.build_get_by_key()

        if self.enable_default_search:
            self.set_dynamic_model_search()

    def build_get_by_key(self):
        key_name = self.store.key
        model_name = self.model.__name__

        if self.key_fields is None:
            field_input = SparseFieldsQuery(
                self.model,
                [self.store.key, self.store.last_updated_field]).query
        else:

            def field_input():
                return {"properties": self.key_fields}

        if not self.versioned:

            async def get_by_key(
                    key: str = Path(
                        ...,
                        alias=key_name,
                        title=f"The {key_name} of the {model_name} to get",
                    ),
                    fields: STORE_PARAMS = Depends(field_input),
            ):
                f"""
                Get's a document by the primary key in the store

                Args:
                    {key_name}: the id of a single {model_name}

                Returns:
                    a single {model_name} document
                """
                self.store.connect()

                crit = {self.store.key: key}

                if model_name == "MaterialsCoreDoc":
                    crit.update({"_sbxn": "core"})
                elif model_name == "TaskDoc":
                    crit.update({"sbxn": "core"})
                elif model_name == "ThermoDoc":
                    crit.update({"_sbxn": "core"})

                item = self.store.query_one(criteria=crit,
                                            properties=fields["properties"])

                if item is None:
                    raise HTTPException(
                        status_code=404,
                        detail=f"Item with {self.store.key} = {key} not found",
                    )

                response = {"data": [item]}

                return response

            self.router.get(
                f"/{{{key_name}}}/",
                response_description=f"Get an {model_name} by {key_name}",
                response_model=self.response_model,
                response_model_exclude_unset=True,
                tags=self.tags,
            )(get_by_key)

        else:

            async def get_by_key_versioned(
                key: str = Path(
                    ...,
                    alias=key_name,
                    title=f"The {key_name} of the {model_name} to get",
                ),
                fields: STORE_PARAMS = Depends(field_input),
                version: str = Query(
                    None,
                    description=
                    "Database version to query on formatted as YYYY.MM.DD",
                ),
            ):
                f"""
                Get's a document by the primary key in the store

                Args:
                    {key_name}: the id of a single {model_name}

                Returns:
                    a single {model_name} document
                """

                if version is not None:
                    version = version.replace(".", "_")
                else:
                    version = os.environ.get("DB_VERSION")

                prefix = self.store.collection_name.split("_")[0]
                self.store.collection_name = f"{prefix}_{version}"

                self.store.connect(force_reset=True)

                crit = {self.store.key: key}

                if model_name == "MaterialsCoreDoc":
                    crit.update({"_sbxn": "core"})
                elif model_name == "TaskDoc":
                    crit.update({"sbxn": "core"})
                elif model_name == "ThermoDoc":
                    crit.update({"_sbxn": "core"})

                item = self.store.query_one(criteria=crit,
                                            properties=fields["properties"])

                if item is None:
                    raise HTTPException(
                        status_code=404,
                        detail=f"Item with {self.store.key} = {key} not found",
                    )

                response = {"data": [item]}

                return response

            self.router.get(
                f"/{{{key_name}}}/",
                response_description=f"Get an {model_name} by {key_name}",
                response_model=self.response_model,
                response_model_exclude_unset=True,
                tags=self.tags,
            )(get_by_key_versioned)

    def set_dynamic_model_search(self):

        model_name = self.model.__name__

        async def search(**queries: STORE_PARAMS):

            request: Request = queries.pop("request")  # type: ignore

            query: STORE_PARAMS = merge_queries(list(queries.values()))

            query_params = [
                entry for _, i in enumerate(self.query_operators)
                for entry in signature(i.query).parameters
            ]

            overlap = [
                key for key in request.query_params.keys()
                if key not in query_params
            ]
            if any(overlap):
                raise HTTPException(
                    status_code=404,
                    detail=
                    "Request contains query parameters which cannot be used: {}"
                    .format(", ".join(overlap)),
                )

            if self.versioned:
                if query["criteria"].get("version", None) is not None:
                    version = query["criteria"]["version"].replace(".", "_")
                    query["criteria"].pop("version")

                else:
                    version = os.environ.get("DB_VERSION")

                prefix = self.store.collection_name.split("_")[0]
                self.store.collection_name = f"{prefix}_{version}"

            self.store.connect(force_reset=True)

            if model_name == "MaterialsCoreDoc":
                query["criteria"].update({"_sbxn": "core"})
            elif model_name == "TaskDoc":
                query["criteria"].update({"sbxn": "core"})
            elif model_name == "ThermoDoc":
                query["criteria"].update({"_sbxn": "core"})

            data = list(self.store.query(**query))  # type: ignore
            operator_metas = [
                operator.meta(self.store, query.get("criteria", {}))
                for operator in self.query_operators
            ]
            meta = {k: v for m in operator_metas for k, v in m.items()}

            response = {"data": data, "meta": meta}

            return response

        ann = {
            f"dep{i}": STORE_PARAMS
            for i, _ in enumerate(self.query_operators)
        }
        ann.update({"request": Request})
        attach_signature(
            search,
            annotations=ann,
            defaults={
                f"dep{i}": Depends(dep.query)
                for i, dep in enumerate(self.query_operators)
            },
        )

        self.router.get(
            "/",
            tags=self.tags,
            summary=f"Get {model_name} documents",
            response_model=self.response_model,
            response_description=f"Search for a {model_name}",
            response_model_exclude_unset=True,
        )(search)

        @self.router.get("", include_in_schema=False)
        def redirect_unslashes():
            """
            Redirects unforward slashed url to resource
            url with the forward slash
            """

            url = self.router.url_path_for("/")
            return RedirectResponse(url=url, status_code=301)

    def run(self):  # pragma: no cover
        """
        Runs the Endpoint cluster locally
        This is intended for testing not production
        """
        import uvicorn

        app = FastAPI()
        app.include_router(self.router, prefix="")
        uvicorn.run(app)

    def as_dict(self) -> Dict:
        """
        Special as_dict implemented to convert pydantic models into strings
        """

        d = super().as_dict()  # Ensures sub-classes serialize correctly
        d["model"] = f"{self.model.__module__}.{self.model.__name__}"
        return d
Example #7
0
    return route_models.ProtocolResponseDataModel.create(
        attributes=route_models.ProtocolResponseAttributes(
            protocolFile=route_models.FileAttributes(
                basename=meta.protocol_file.path.name
            ),
            supportFiles=[route_models.FileAttributes(
                basename=s.path.name
            ) for s in meta.support_files],
            lastModifiedAt=meta.last_modified_at,
            createdAt=meta.created_at
        ),
        resource_id=meta.identifier
    )


ROOT_RESOURCE = ResourceLink(href=router.url_path_for(get_protocols.__name__))
PROTOCOL_BY_ID_RESOURCE = ResourceLink(href=PATH_PROTOCOL_ID)


def get_root_links(api_router: APIRouter) -> ResourceLinks:
    """Get resource links for root path handlers"""
    return {
        ResourceLinkKey.self: ROOT_RESOURCE,
        ResourceLinkKey.protocol_by_id: PROTOCOL_BY_ID_RESOURCE,
    }


def get_protocol_links(api_router: APIRouter, protocol_id: str) \
        -> ResourceLinks:
    """Get resource links for specific resource path handlers"""
    return {
Example #8
0
        "privacy": PrivacyKinds
    })


@router.post("/update_user_fullname")
async def update_user_fullname(
        request: Request, session=Depends(get_db)):
    user = session.query(User).filter_by(id=1).first()
    data = await request.form()
    new_fullname = data['fullname']

    # Update database
    user.full_name = new_fullname
    session.commit()

    url = router.url_path_for("profile")
    return RedirectResponse(url=url, status_code=HTTP_302_FOUND)


@router.post("/update_user_email")
async def update_user_email(
        request: Request, session=Depends(get_db)):
    user = session.query(User).filter_by(id=1).first()
    data = await request.form()
    new_email = data['email']

    # Update database
    user.email = new_email
    session.commit()

    url = router.url_path_for("profile")
Example #9
0
    log.info(f"Command completed {command}")

    return CommandResponse(data=ResponseDataModel.create(
        attributes=SessionCommand(
            data=command_result.content.data,
            command=command_result.content.name,
            status=command_result.result.status,
            createdAt=command_result.meta.created_at,
            startedAt=command_result.result.started_at,
            completedAt=command_result.result.completed_at),
        resource_id=command_result.meta.identifier),
                           links=get_valid_session_links(sessionId, router))


ROOT_RESOURCE = ResourceLink(
    href=router.url_path_for(get_sessions_handler.__name__))
SESSIONS_BY_ID_RESOURCE = ResourceLink(href=PATH_SESSION_BY_ID)


def get_valid_session_links(session_id: IdentifierType,
                            api_router: APIRouter) \
        -> ResourceLinks:
    """Get the valid links for a session"""
    return {
        ResourceLinkKey.self:
        ResourceLink(href=api_router.url_path_for(get_session_handler.__name__,
                                                  sessionId=session_id)),
        ResourceLinkKey.session_command_execute:
        ResourceLink(href=api_router.url_path_for(
            session_command_execute_handler.__name__, sessionId=session_id)),
        ResourceLinkKey.sessions:
Example #10
0
class Resource(MSONable, metaclass=ABCMeta):
    """
    Base class for a REST Compatible Resource
    """
    def __init__(
        self,
        model: Type[BaseModel],
    ):
        """
        Args:
            model: the pydantic model this Resource represents
        """
        if not issubclass(model, BaseModel):
            raise ValueError("The resource model has to be a PyDantic Model")

        self.model = api_sanitize(model, allow_dict_msonable=True)
        self.logger = logging.getLogger(type(self).__name__)
        self.logger.addHandler(logging.NullHandler())
        self.router = APIRouter()
        self.prepare_endpoint()
        self.setup_redirect()

    def on_startup(self):
        """
        Callback to perform some work on resource initialization
        """
        pass

    @abstractmethod
    def prepare_endpoint(self):
        """
        Internal method to prepare the endpoint by setting up default handlers
        for routes.
        """
        pass

    def setup_redirect(self):
        @self.router.get("$", include_in_schema=False)
        def redirect_unslashed():
            """
            Redirects unforward slashed url to resource
            url with the forward slash
            """

            url = self.router.url_path_for("/")
            return RedirectResponse(url=url, status_code=301)

    def run(self):  # pragma: no cover
        """
        Runs the Endpoint cluster locally
        This is intended for testing not production
        """
        import uvicorn

        app = FastAPI()
        app.include_router(self.router, prefix="")
        uvicorn.run(app)

    def as_dict(self) -> Dict:
        """
        Special as_dict implemented to convert pydantic models into strings
        """

        d = super().as_dict()  # Ensures sub-classes serialize correctly
        d["model"] = f"{self.model.__module__}.{self.model.__name__}"
        return d

    @classmethod
    def from_dict(cls, d: Dict):

        if isinstance(d["model"], str):
            d["model"] = dynamic_import(d["model"])
        d = {k: MontyDecoder().process_decoded(v) for k, v in d.items()}
        return cls(**d)
Example #11
0
class MSALAuthorization:
    def __init__(
        self,
        client_config: MSALClientConfig,
        return_to_path: str = "/",
        tags: OptStrList = None,
    ):
        self.handler = MSALAuthCodeHandler(client_config=client_config)
        if not tags:
            tags = ["authentication"]
        self.return_to_path = return_to_path
        self.router = APIRouter(prefix=client_config.path_prefix, tags=tags)
        self.router.add_api_route(
            name="_login_route",
            path=client_config.login_path,
            endpoint=self._login_route,
            methods=["GET"],
            include_in_schema=client_config.show_in_docs,
        )

        self.router.add_api_route(
            name="_get_token_route",
            path=client_config.token_path,
            endpoint=self._get_token_route,
            methods=["GET"],
            include_in_schema=client_config.show_in_docs,
        )

        self.router.add_api_route(
            name="_post_token_route",
            path=client_config.token_path,
            endpoint=self._post_token_route,
            methods=["POST"],
            response_model=BearerToken,
            include_in_schema=client_config.show_in_docs,
        )
        self.router.add_api_route(
            client_config.logout_path,
            self._logout_route,
            methods=["GET"],
            include_in_schema=client_config.show_in_docs,
        )

    async def _login_route(
        self,
        request: Request,
        redirect_uri: OptStr = None,
        state: OptStr = None,
        client_id: OptStr = None,
    ) -> RedirectResponse:
        if client_id:
            print(client_id)
        if not redirect_uri:
            redirect_uri = request.url_for("_get_token_route")
        return await self.handler.authorize_redirect(request=request,
                                                     redirec_uri=redirect_uri,
                                                     state=state)

    async def _get_token_route(self, request: Request, code: str,
                               state: Optional[str]) -> RedirectResponse:
        await self.handler.authorize_access_token(request=request,
                                                  code=code,
                                                  state=state)
        return RedirectResponse(url=f"{self.return_to_path}",
                                headers=dict(request.headers.items()))

    async def _post_token_route(
        self, request: Request, code: str = Form(...)) -> BearerToken:
        token: AuthToken = await self.handler.authorize_access_token(
            request=request, code=code)
        return BearerToken(access_token=token.id_token)

    async def _logout_route(
        self, request: Request,
        referer: OptStr = Header(None)) -> RedirectResponse:
        callback_url = referer if referer else str(self.return_to_path)
        return self.handler.logout(request=request, callback_url=callback_url)

    async def get_session_token(self, request: Request) -> Optional[AuthToken]:
        return await self.handler.get_token_from_session(request=request)

    async def check_authenticated_session(self, request: Request) -> bool:
        auth_token: Optional[AuthToken] = await self.get_session_token(request)
        if auth_token and auth_token.id_token:
            token_claims = self.handler.parse_id_token(request=request,
                                                       token=auth_token)
            if token_claims:
                return True
        return False

    @property
    def scheme(self) -> MSALScheme:
        return MSALScheme(
            authorizationUrl=self.router.url_path_for("_login_route"),
            tokenUrl=self.router.url_path_for("_post_token_route"),
            handler=self.handler,
        )