示例#1
0
class DockerCmdBackwardsCompatibility:
    max_version = PartialVersion("2.13")
    field = ("execution", "docker_cmd")

    @classmethod
    def prepare_for_save(cls, call: APICall, fields: dict):
        if call.requested_endpoint_version >= cls.max_version:
            return

        docker_cmd = nested_get(fields, cls.field)
        if docker_cmd is not None:
            image, _, arguments = docker_cmd.partition(" ")
            nested_set(fields, ("container", "image"), value=image)
            nested_set(fields, ("container", "arguments"), value=arguments)

        nested_delete(fields, cls.field)

    @classmethod
    def unprepare_from_saved(cls, call: APICall,
                             tasks_data: Union[Sequence[dict], dict]):
        if call.requested_endpoint_version >= cls.max_version:
            return

        if isinstance(tasks_data, dict):
            tasks_data = [tasks_data]

        for task in tasks_data:
            container = task.get("container")
            if not container or not container.get("image"):
                continue

            docker_cmd = " ".join(
                filter(None, map(container.get, ("image", "arguments"))))
            if docker_cmd:
                nested_set(task, cls.field, docker_cmd)
示例#2
0
 def make_version_number(
         version: PartialVersion) -> Union[None, float, str]:
     """
     Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
     """
     if version is None:
         return None
     if self.requested_endpoint_version < PartialVersion("2.1"):
         return float(str(version))
     return str(version)
示例#3
0
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
    if isinstance(tasks_data, dict):
        tasks_data = [tasks_data]

    conform_output_tags(call, tasks_data)

    for data in tasks_data:
        need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
        params_unprepare_from_saved(
            fields=data, copy_to_legacy=need_legacy_params,
        )
        artifacts_unprepare_from_saved(fields=data)
示例#4
0
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
    """
    For old clients both tags and system tags are returned in 'tags' field
    """
    if call.requested_endpoint_version >= PartialVersion("2.3"):
        return
    if isinstance(documents, dict):
        documents = [documents]
    for doc in documents:
        system_tags = doc.get("system_tags")
        if system_tags:
            doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
示例#5
0
def conform_tags(
    call: APICall, tags: Sequence, system_tags: Sequence, validate=False
) -> Tuple[Sequence, Sequence]:
    """
    Make sure that 'tags' from the old SDK clients
    are correctly split into 'tags' and 'system_tags'
    Make sure that there are no duplicate tags
    """
    if validate:
        validate_tags(tags, system_tags)
    if call.requested_endpoint_version < PartialVersion("2.3"):
        tags, system_tags = _upgrade_tags(call, tags, system_tags)
    return tags, system_tags
示例#6
0
class ModelsBackwardsCompatibility:
    max_version = PartialVersion("2.13")
    mode_to_fields = {
        TaskModelTypes.input: ("execution", "model"),
        TaskModelTypes.output: ("output", "model"),
    }
    models_field = "models"

    @classmethod
    def prepare_for_save(cls, call: APICall, fields: dict):
        if call.requested_endpoint_version >= cls.max_version:
            return

        for mode, field in cls.mode_to_fields.items():
            value = nested_get(fields, field)
            if value is None:
                continue
            val = [
                dict(
                    name=TaskModelNames[mode],
                    model=value,
                    updated=datetime.utcnow(),
                )
            ] if value else []
            nested_set(fields, (cls.models_field, mode), value=val)

            nested_delete(fields, field)

    @classmethod
    def unprepare_from_saved(
        cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
    ):
        if call.requested_endpoint_version >= cls.max_version:
            return

        if isinstance(tasks_data, dict):
            tasks_data = [tasks_data]

        for task in tasks_data:
            for mode, field in cls.mode_to_fields.items():
                models = nested_get(task, (cls.models_field, mode))
                if not models:
                    continue

                model = models[0] if mode == TaskModelTypes.input else models[-1]
                if model:
                    nested_set(task, field, model.get("model"))
示例#7
0
 def __init__(
     self,
     name: Text,
     func: EndpointFunc,
     min_version: Text = "1.0",
     required_fields: Sequence[Text] = None,
     request_data_model: models.Base = None,
     response_data_model: models.Base = None,
     validate_schema: bool = False,
 ):
     """
     Endpoint configuration
     :param name: full endpoint name
     :param func: endpoint implementation
     :param min_version: minimum supported version
     :param required_fields: required request fields, can not be used with validate_schema
     :param request_data_model: request jsonschema model, will be validated if validate_schema=False
     :param response_data_model: response jsonschema model, will be validated if validate_schema=False
     :param validate_schema: whether request and response schema should be validated
     """
     self.name = name
     self.min_version = PartialVersion(min_version)
     self.func = func
     self.required_fields = required_fields
     self.request_data_model = request_data_model
     self.response_data_model = response_data_model
     service, _, endpoint_name = self.name.partition(".")
     try:
         self.endpoint_group = schema.services[service].endpoint_groups[
             endpoint_name]
     except KeyError:
         raise RuntimeError(
             f"schema for endpoint {service}.{endpoint_name} not found")
     if validate_schema:
         if self.required_fields:
             raise ValueError(
                 f"endpoint {self.name}: can not use 'required_fields' with 'validate_schema'"
             )
         endpoint = self.endpoint_group.get_for_version(self.min_version)
         request_schema = endpoint.request_schema
         response_schema = endpoint.response_schema
     else:
         request_schema = None
         response_schema = None
     self.request_schema_validator = SchemaValidator(request_schema)
     self.response_schema_validator = SchemaValidator(response_schema)
示例#8
0
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
    """
    Make sure that tags are always returned sorted
    For old clients both tags and system tags are returned in 'tags' field
    """
    if isinstance(documents, dict):
        documents = [documents]

    merge_tags = call.requested_endpoint_version < PartialVersion("2.3")
    for doc in documents:
        if merge_tags:
            system_tags = doc.get("system_tags")
            if system_tags:
                doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))

        for field in ("system_tags", "tags"):
            tags = doc.get(field)
            if tags:
                doc[field] = sorted(tags)
示例#9
0
 def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]:
     """ Parse endpoint version, service and action from request path. """
     m = cls._endpoint_exp.match(path)
     if not m:
         raise MalformedPathError("Invalid request path %s" % path)
     endpoint_name = m.group("endpoint_name")
     version = m.group("endpoint_version")
     if version is None:
         # If endpoint is available, use the max version
         version = cls._max_version
     else:
         try:
             version = PartialVersion(version)
         except ValueError as e:
             raise errors.bad_request.RequestPathHasInvalidVersion(
                 version=version, reason=e)
         if cls._check_max_version and version > cls._max_version:
             raise InvalidVersionError(
                 f"Invalid API version (max. supported version is {cls._max_version})"
             )
     return version, endpoint_name
示例#10
0
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
    """
    Unescape special characters in metadata keys
    """
    if isinstance(documents, dict):
        documents = [documents]

    old_client = call.requested_endpoint_version <= PartialVersion("2.16")
    for doc in documents:
        if old_client and "metadata" in doc:
            doc["metadata"] = []
            continue

        metadata = doc.get("metadata")
        if not metadata:
            continue

        doc["metadata"] = {
            ParameterKeyEscaper.unescape(k): v
            for k, v in metadata.items()
        }
示例#11
0
    def __init__(
        self,
        endpoint_name,
        remote_addr=None,
        endpoint_version: PartialVersion = PartialVersion("1.0"),
        data=None,
        batched_data=None,
        headers=None,
        files=None,
        trx=None,
        host=None,
        auth_cookie=None,
    ):
        super().__init__(data=data, batched_data=batched_data)

        self._id = database.utils.id()
        self._files = files  # currently dic of key to flask's FileStorage)
        self._start_ts = time.time()
        self._end_ts = 0
        self._duration = 0
        self._endpoint_name = endpoint_name
        self._remote_addr = remote_addr
        assert isinstance(endpoint_version, PartialVersion), endpoint_version
        self._requested_endpoint_version = endpoint_version
        self._actual_endpoint_version = None
        self._headers = CaseInsensitiveDict()
        self._kpis = {}
        self._log_api = True
        if headers:
            self._headers.update(headers)
        self._result = APICallResult()
        self._auth = None
        self._impersonation = None
        if trx:
            self.set_header(self._transaction_headers, trx)
        self._requires_authorization = True
        self._host = host
        self._auth_cookie = auth_cookie
        self._json_flags = {}
示例#12
0
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict],
                                                          dict]):
    if isinstance(tasks_data, dict):
        tasks_data = [tasks_data]

    conform_output_tags(call, tasks_data)

    for data in tasks_data:
        for path in dict_fields_paths:
            unescape_dict_field(data, path)

    ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
    DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)

    need_legacy_params = call.requested_endpoint_version < PartialVersion(
        "2.9")

    for data in tasks_data:
        params_unprepare_from_saved(
            fields=data,
            copy_to_legacy=need_legacy_params,
        )
        artifacts_unprepare_from_saved(fields=data)
示例#13
0
class ServiceRepo(object):
    _endpoints: MutableMapping[str, List[Endpoint]] = {}
    """ 
    Registered endpoints, in the format of {endpoint_name: Endpoint)}
    the list of endpoints is sorted by min_version
    """

    _version_required = config.get("apiserver.version.required")
    """ If version is required, parsing will fail for endpoint paths that do not contain a valid version """

    _check_max_version = config.get("apiserver.version.check_max_version")
    """If the check is set, parsing will fail for endpoint request with the version that is grater than the current 
    maximum """

    _max_version = PartialVersion("2.13")
    """ Maximum version number (the highest min_version value across all endpoints) """

    _endpoint_exp = (re.compile(
        r"^/?v(?P<endpoint_version>\d+\.?\d+)/(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
    ) if config.get("apiserver.version.required") else re.compile(
        r"^/?(v(?P<endpoint_version>\d+\.?\d+)/)?(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
    ))
    """ 
        Endpoint structure expressions. We have two expressions, one with optional version part.
        Constraints for the first (strict) expression:
        1. May start with a leading '/'
        2. Followed by a version number (int or float) preceded by a leading 'v'
        3. Followed by a '/'
        4. Followed by a service name, which must start with an english letter (lower or upper case) or underscore,
            and followed by any number of alphanumeric or underscore characters
        5. Followed by a '.'  
        6. Followed by an action name, which must start with an english letter (lower or upper case) or underscore,
            and followed by any number of alphanumeric or underscore characters  
        7. May end with a leading '/' 
        
        The second (optional version) expression does not require steps 2 and 3. 
    """

    _return_stack = config.get("apiserver.return_stack")
    """ return stack trace on error """

    _return_stack_on_code = parse_return_stack_on_code(
        config.get("apiserver.return_stack_on_code", {}))
    """ if 'return_stack' is true and error contains a return code, return stack trace only for these error codes """

    _credentials = config["secure.credentials.apiserver"]
    """ Api Server credentials used for intra-service communication """

    _token = None
    """ Token for internal calls """
    @classmethod
    def _load_from_path(
        cls,
        root_module: Path,
        module_prefix: Optional[str] = None,
        predicate: Optional[Callable[[Path], bool]] = None,
    ):
        log.info(f"Loading services from {str(root_module.absolute())}")
        sub_module = None
        for sub_module in root_module.glob("*"):
            if predicate and not predicate(sub_module):
                continue

            if (sub_module.is_file() and sub_module.suffix == ".py"
                    and not sub_module.stem == "__init__"):
                import_module(".".join(
                    filter(
                        None,
                        (module_prefix, root_module.stem, sub_module.stem))))

            if sub_module.is_dir() and not sub_module.stem == "__pycache__":
                import_module(".".join(
                    filter(
                        None,
                        (module_prefix, root_module.stem, sub_module.stem))))

        # leave no trace of the 'sub_module' local
        del sub_module

        cls._max_version = max(
            cls._max_version,
            max(ep.min_version
                for ep in cast(Iterable[Endpoint],
                               chain(*cls._endpoints.values()))),
        )

    @classmethod
    def load(cls, root_module="services"):
        cls._load_from_path(
            root_module=Path(__file__).parents[1] / root_module,
            module_prefix="apiserver",
        )

    @classmethod
    def register(cls, endpoint: Endpoint):
        if cls._endpoints.get(endpoint.name):
            if any(ep.min_version == endpoint.min_version
                   for ep in cls._endpoints[endpoint.name]):
                raise Exception(
                    f"Trying to register an existing endpoint. name={endpoint.name}, version={endpoint.min_version}"
                )
            else:
                cls._endpoints[endpoint.name].append(endpoint)
        else:
            cls._endpoints[endpoint.name] = [endpoint]

        cls._endpoints[endpoint.name].sort(key=lambda ep: ep.min_version,
                                           reverse=True)

    @classmethod
    def endpoint_names(cls):
        return sorted(cls._endpoints.keys())

    @classmethod
    def endpoints_summary(cls):
        return {
            "endpoints": {
                name: list(map(Endpoint.to_dict, eps))
                for name, eps in cls._endpoints.items()
            },
            "models": {},
        }

    @classmethod
    def max_endpoint_version(cls) -> PartialVersion:
        return cls._max_version

    @classmethod
    def _get_endpoint(cls, name, version) -> Optional[Endpoint]:
        versions = cls._endpoints.get(name)
        if not versions:
            return None
        try:
            return next(ep for ep in versions if ep.min_version <= version)
        except StopIteration:
            # no appropriate version found
            return None

    @classmethod
    def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
        endpoint = cls._get_endpoint(call.endpoint_name,
                                     call.requested_endpoint_version)
        if endpoint is None:
            call.log_api = False
            call.set_error_result(
                msg=(f"Unable to find endpoint for name {call.endpoint_name} "
                     f"and version {call.requested_endpoint_version}"),
                code=404,
                subcode=0,
            )
            return

        call.actual_endpoint_version = endpoint.min_version
        call.requires_authorization = endpoint.authorize
        return endpoint

    @classmethod
    def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]:
        """ Parse endpoint version, service and action from request path. """
        m = cls._endpoint_exp.match(path)
        if not m:
            raise MalformedPathError("Invalid request path %s" % path)
        endpoint_name = m.group("endpoint_name")
        version = m.group("endpoint_version")
        if version is None:
            # If endpoint is available, use the max version
            version = cls._max_version
        else:
            try:
                version = PartialVersion(version)
            except ValueError as e:
                raise errors.bad_request.RequestPathHasInvalidVersion(
                    version=version, reason=e)
            if cls._check_max_version and version > cls._max_version:
                raise InvalidVersionError(
                    f"Invalid API version (max. supported version is {cls._max_version})"
                )
        return version, endpoint_name

    @classmethod
    def _should_return_stack(cls, code: int, subcode: int) -> bool:
        if not cls._return_stack or code not in cls._return_stack_on_code:
            return False
        if subcode is None:
            # Code in dict, but no subcode. We'll allow it.
            return True
        subcode_list = cls._return_stack_on_code.get(code)
        if subcode_list is None:
            # if the code is there but we don't have any subcode list, always return stack
            return True
        return subcode in subcode_list

    @classmethod
    def _get_company(cls,
                     call: APICall,
                     endpoint: Endpoint = None,
                     ignore_error: bool = False) -> Optional[str]:
        authorize = endpoint and endpoint.authorize
        if ignore_error or not authorize:
            try:
                return call.identity.company
            except Exception:
                return None
        return call.identity.company

    @classmethod
    def handle_call(cls, call: APICall, load_data_callback: Callable = None):
        try:
            if call.failed:
                raise CallFailedError()

            endpoint = cls._resolve_endpoint_from_call(call)

            if call.failed:
                raise CallFailedError()

            validate_auth(endpoint, call)
            validate_role(endpoint, call)
            if validate_impersonation(endpoint, call):
                # if impersonating, validate role again
                validate_role(endpoint, call)

            if load_data_callback:
                load_data_callback(call)
                if call.failed:
                    raise CallFailedError()

            validate_data(call, endpoint)

            if call.failed:
                raise CallFailedError()

            # In case call does not require authorization, parsing the identity.company might raise an exception
            company = cls._get_company(call, endpoint)

            ret = endpoint.func(call, company, call.data_model)

            # allow endpoints to return dict or model (instead of setting them explicitly on the call)
            if ret is not None:
                if isinstance(ret, jsonmodels.models.Base):
                    call.result.data_model = ret
                elif isinstance(ret, dict):
                    call.result.data = ret

        except APIError as ex:
            # report stack trace only for gene
            include_stack = cls._return_stack and cls._should_return_stack(
                ex.code, ex.subcode)
            call.set_error_result(
                code=ex.code,
                subcode=ex.subcode,
                msg=str(ex),
                include_stack=include_stack,
            )
        except CallFailedError:
            # Do nothing, let 'finally' wrap up
            pass
        except Exception as ex:
            log.exception(ex)
            call.set_error_result(code=500,
                                  subcode=0,
                                  msg=str(ex),
                                  include_stack=cls._return_stack)
        finally:
            content, content_type = call.get_response()
            call.mark_end()

            console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
            if call.result.code < 300:
                log.info(console_msg)
            else:
                console_msg = f"{console_msg}, msg={call.result.msg}"
                if call.result.code < 500:
                    log.warn(console_msg)
                else:
                    log.error(console_msg)

        return content, content_type
示例#14
0
 def parse_version(version):
     if not re.match(r"^\d+\.\d+$", version):
         raise ValueError(
             f"Encountered unrecognized key {version!r} in {self.service_name}.{self.action_name}"
         )
     return PartialVersion(version)