Ejemplo n.º 1
0
    async def _execute_call(
        self,
        kwargs: Dict[str, str],
        http_method: str,
        config: common.UrlMethod,
        message: Dict[str, Any],
        request_headers: Mapping[str, str],
        auth: Optional[MutableMapping[str, str]] = None,
    ) -> common.Response:
        try:
            if kwargs is None or config is None:
                raise Exception("This method is unknown! This should not occur!")

            # create message that contains all arguments
            message.update(kwargs)

            if config.properties.validate_sid:
                if "sid" not in message:
                    raise exceptions.BadRequest("this is an agent to server call, it should contain an agent session id")

                elif not self.validate_sid(uuid.UUID(message["sid"])):
                    raise exceptions.BadRequest("the sid %s is not valid." % message["sid"])

            arguments = CallArguments(config.properties, message, request_headers)
            await arguments.process()
            authorize_request(auth, arguments.metadata, arguments.call_args, config)

            # rename arguments if handler requests this
            call_args = arguments.call_args
            if hasattr(config.handler, "__protocol_mapping__"):
                for k, v in config.handler.__protocol_mapping__.items():
                    if v in call_args:
                        call_args[k] = call_args[v]
                        del call_args[v]

            LOGGER.debug(
                "Calling method %s(%s)",
                config.handler,
                ", ".join(["%s='%s'" % (name, common.shorten(str(value))) for name, value in arguments.call_args.items()]),
            )

            result = await config.handler(**arguments.call_args)
            return await arguments.process_return(config, result)
        except pydantic.ValidationError:
            LOGGER.exception(f"The handler {config.handler} caused a validation error in a data model (pydantic).")
            raise exceptions.ServerError("data validation error.")

        except exceptions.BaseHttpException:
            LOGGER.debug("An HTTP Error occurred", exc_info=True)
            raise

        except Exception as e:
            LOGGER.exception("An exception occured during the request.")
            raise exceptions.ServerError(str(e.args))
Ejemplo n.º 2
0
    def _validate_union_return(self, arg_type: Type, value: Any) -> None:
        """Validate a return with a union type
        :see: protocol.common.MethodProperties._validate_function_types
        """
        matching_type = None
        for t in typing_inspect.get_args(arg_type, evaluate=True):
            instanceof_type = t
            if typing_inspect.is_generic_type(t):
                instanceof_type = typing_inspect.get_origin(t)

            if isinstance(value, instanceof_type):
                if matching_type is not None:
                    raise exceptions.ServerError(
                        f"Return type is defined as a union {arg_type} for which multiple "
                        f"types match the provided value {value}"
                    )
                matching_type = t

        if matching_type is None:
            raise exceptions.BadRequest(
                f"Invalid return value, no matching type found in union {arg_type} for value type {type(value)}"
            )

        if typing_inspect.is_generic_type(matching_type):
            self._validate_generic_return(arg_type, matching_type)
Ejemplo n.º 3
0
 def get_default_value(self, arg_name: str, arg_position: int, default_start: int) -> Optional[Any]:
     """
     Get a default value for an argument
     """
     if default_start >= 0 and 0 <= (arg_position - default_start) < len(self._argspec.defaults):
         return self._argspec.defaults[arg_position - default_start]
     else:
         raise exceptions.BadRequest("Field '%s' is required." % arg_name)
Ejemplo n.º 4
0
    def _validate_generic_return(self, arg_type: Type, value: Any) -> None:
        """Validate List or Dict types.

        :note: we return any here because the calling function also returns any.
        """
        if issubclass(typing_inspect.get_origin(arg_type), list):
            if not isinstance(value, list):
                raise exceptions.ServerError(
                    f"Invalid return value, type needs to be a list. Argument type should be {arg_type}"
                )

            el_type = typing_inspect.get_args(arg_type, evaluate=True)[0]
            if el_type is Any:
                return
            for el in value:
                if typing_inspect.is_union_type(el_type):
                    self._validate_union_return(el_type, el)
                elif not isinstance(el, el_type):
                    raise exceptions.ServerError(f"Element {el} of returned list is not of type {el_type}.")

        elif issubclass(typing_inspect.get_origin(arg_type), dict):
            if not isinstance(value, dict):
                raise exceptions.ServerError(
                    f"Invalid return value, type needs to be a dict. Argument type should be {arg_type}"
                )

            el_type = typing_inspect.get_args(arg_type, evaluate=True)[1]
            if el_type is Any:
                return
            for k, v in value.items():
                if not isinstance(k, str):
                    raise exceptions.ServerError("Keys of return dict need to be strings.")

                if typing_inspect.is_union_type(el_type):
                    self._validate_union_return(el_type, v)
                elif not isinstance(v, el_type):
                    raise exceptions.ServerError(f"Element {v} of returned list is not of type {el_type}.")

        else:
            # This should not happen because of MethodProperties validation
            raise exceptions.BadRequest(
                f"Failed to validate generic type {arg_type} of return value, only List and Dict are supported"
            )
Ejemplo n.º 5
0
 async def prepare_paging_metadata(
     self,
     query_identifier: IQuery,
     dtos: List[T],
     db_query: Mapping[str, Tuple[QueryType, object]],
     limit: int,
     database_order: DatabaseOrder,
 ) -> PagingMetadata:
     items_on_next_pages = 0
     items_on_prev_pages = 0
     total = 0
     if dtos:
         paging_borders = self._get_paging_boundaries(dtos, database_order)
         start = paging_borders.start
         first_id = paging_borders.first_id
         end = paging_borders.end
         last_id = paging_borders.last_id
         try:
             paging_counts = await self.counts_provider.count_items_for_paging(
                 query_identifier=query_identifier,
                 database_order=database_order,
                 start=start,
                 first_id=first_id,
                 end=end,
                 last_id=last_id,
                 **db_query,
             )
         except (InvalidFieldNameException, InvalidQueryParameter) as e:
             raise exceptions.BadRequest(
                 f"Invalid query specified: {e.message}")
         total = paging_counts.total
         items_on_prev_pages = paging_counts.before
         items_on_next_pages = paging_counts.after
     metadata = PagingMetadata(
         total=total,
         before=items_on_prev_pages,
         after=items_on_next_pages,
         page_size=limit,
     )
     return metadata
Ejemplo n.º 6
0
    async def process(self) -> None:
        """
        Process the message
        """
        args: List[str] = list(self._argspec.args)

        if "self" in args:
            args.remove("self")

        all_fields = set(self._message.keys())  # Track all processed fields to warn user
        defaults_start: int = -1
        if self._argspec.defaults is not None:
            defaults_start = len(args) - len(self._argspec.defaults)

        call_args = {}

        for i, arg in enumerate(args):
            # get value from headers, defaults or message
            value = self._map_headers(arg)
            if value is None:
                if not self._is_header_param(arg):
                    arg_type = self._argspec.annotations.get(arg)
                    if arg in self._message:
                        value = self._message[arg]
                        all_fields.remove(arg)
                    # Pre-process dict params for GET
                    elif self._properties.operation == "GET" and self._is_dict_or_optional_dict(arg_type):
                        dict_prefix = f"{arg}."
                        dict_with_prefixed_names = {
                            param_name: param_value
                            for param_name, param_value in self._message.items()
                            if param_name.startswith(dict_prefix) and len(param_name) > len(dict_prefix)
                        }
                        value = (
                            await self._get_dict_value_from_message(arg, dict_prefix, dict_with_prefixed_names)
                            if dict_with_prefixed_names
                            else None
                        )

                        for key in dict_with_prefixed_names.keys():
                            all_fields.remove(key)

                    else:  # get default value
                        value = self.get_default_value(arg, i, defaults_start)
                else:  # get default value
                    value = self.get_default_value(arg, i, defaults_start)

            call_args[arg] = value

        # validate types
        call_args = self._properties.validate_arguments(call_args)

        for arg, value in call_args.items():
            # run getters
            value = await self._run_getters(arg, value)

            self._call_args[arg] = value

        # discard session handling data
        if self._properties.agent_server and "sid" in all_fields:
            all_fields.remove("sid")

        if len(all_fields) > 0 and self._argspec.varkw is None:
            raise exceptions.BadRequest(
                "request contains fields %s that are not declared in method and no kwargs argument is provided." % all_fields
            )

        self._processed = True