class DockerCmdBackwardsCompatibility: max_version = PartialVersion("2.13") field = ("execution", "docker_cmd") @classmethod def prepare_for_save(cls, call: APICall, fields: dict): if call.requested_endpoint_version >= cls.max_version: return docker_cmd = nested_get(fields, cls.field) if docker_cmd is not None: image, _, arguments = docker_cmd.partition(" ") nested_set(fields, ("container", "image"), value=image) nested_set(fields, ("container", "arguments"), value=arguments) nested_delete(fields, cls.field) @classmethod def unprepare_from_saved(cls, call: APICall, tasks_data: Union[Sequence[dict], dict]): if call.requested_endpoint_version >= cls.max_version: return if isinstance(tasks_data, dict): tasks_data = [tasks_data] for task in tasks_data: container = task.get("container") if not container or not container.get("image"): continue docker_cmd = " ".join( filter(None, map(container.get, ("image", "arguments")))) if docker_cmd: nested_set(task, cls.field, docker_cmd)
def make_version_number( version: PartialVersion) -> Union[None, float, str]: """ Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception """ if version is None: return None if self.requested_endpoint_version < PartialVersion("2.1"): return float(str(version)) return str(version)
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]): if isinstance(tasks_data, dict): tasks_data = [tasks_data] conform_output_tags(call, tasks_data) for data in tasks_data: need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9") params_unprepare_from_saved( fields=data, copy_to_legacy=need_legacy_params, ) artifacts_unprepare_from_saved(fields=data)
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]): """ For old clients both tags and system tags are returned in 'tags' field """ if call.requested_endpoint_version >= PartialVersion("2.3"): return if isinstance(documents, dict): documents = [documents] for doc in documents: system_tags = doc.get("system_tags") if system_tags: doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
def conform_tags( call: APICall, tags: Sequence, system_tags: Sequence, validate=False ) -> Tuple[Sequence, Sequence]: """ Make sure that 'tags' from the old SDK clients are correctly split into 'tags' and 'system_tags' Make sure that there are no duplicate tags """ if validate: validate_tags(tags, system_tags) if call.requested_endpoint_version < PartialVersion("2.3"): tags, system_tags = _upgrade_tags(call, tags, system_tags) return tags, system_tags
class ModelsBackwardsCompatibility: max_version = PartialVersion("2.13") mode_to_fields = { TaskModelTypes.input: ("execution", "model"), TaskModelTypes.output: ("output", "model"), } models_field = "models" @classmethod def prepare_for_save(cls, call: APICall, fields: dict): if call.requested_endpoint_version >= cls.max_version: return for mode, field in cls.mode_to_fields.items(): value = nested_get(fields, field) if value is None: continue val = [ dict( name=TaskModelNames[mode], model=value, updated=datetime.utcnow(), ) ] if value else [] nested_set(fields, (cls.models_field, mode), value=val) nested_delete(fields, field) @classmethod def unprepare_from_saved( cls, call: APICall, tasks_data: Union[Sequence[dict], dict] ): if call.requested_endpoint_version >= cls.max_version: return if isinstance(tasks_data, dict): tasks_data = [tasks_data] for task in tasks_data: for mode, field in cls.mode_to_fields.items(): models = nested_get(task, (cls.models_field, mode)) if not models: continue model = models[0] if mode == TaskModelTypes.input else models[-1] if model: nested_set(task, field, model.get("model"))
def __init__( self, name: Text, func: EndpointFunc, min_version: Text = "1.0", required_fields: Sequence[Text] = None, request_data_model: models.Base = None, response_data_model: models.Base = None, validate_schema: bool = False, ): """ Endpoint configuration :param name: full endpoint name :param func: endpoint implementation :param min_version: minimum supported version :param required_fields: required request fields, can not be used with validate_schema :param request_data_model: request jsonschema model, will be validated if validate_schema=False :param response_data_model: response jsonschema model, will be validated if validate_schema=False :param validate_schema: whether request and response schema should be validated """ self.name = name self.min_version = PartialVersion(min_version) self.func = func self.required_fields = required_fields self.request_data_model = request_data_model self.response_data_model = response_data_model service, _, endpoint_name = self.name.partition(".") try: self.endpoint_group = schema.services[service].endpoint_groups[ endpoint_name] except KeyError: raise RuntimeError( f"schema for endpoint {service}.{endpoint_name} not found") if validate_schema: if self.required_fields: raise ValueError( f"endpoint {self.name}: can not use 'required_fields' with 'validate_schema'" ) endpoint = self.endpoint_group.get_for_version(self.min_version) request_schema = endpoint.request_schema response_schema = endpoint.response_schema else: request_schema = None response_schema = None self.request_schema_validator = SchemaValidator(request_schema) self.response_schema_validator = SchemaValidator(response_schema)
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]): """ Make sure that tags are always returned sorted For old clients both tags and system tags are returned in 'tags' field """ if isinstance(documents, dict): documents = [documents] merge_tags = call.requested_endpoint_version < PartialVersion("2.3") for doc in documents: if merge_tags: system_tags = doc.get("system_tags") if system_tags: doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags)) for field in ("system_tags", "tags"): tags = doc.get(field) if tags: doc[field] = sorted(tags)
def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]: """ Parse endpoint version, service and action from request path. """ m = cls._endpoint_exp.match(path) if not m: raise MalformedPathError("Invalid request path %s" % path) endpoint_name = m.group("endpoint_name") version = m.group("endpoint_version") if version is None: # If endpoint is available, use the max version version = cls._max_version else: try: version = PartialVersion(version) except ValueError as e: raise errors.bad_request.RequestPathHasInvalidVersion( version=version, reason=e) if cls._check_max_version and version > cls._max_version: raise InvalidVersionError( f"Invalid API version (max. supported version is {cls._max_version})" ) return version, endpoint_name
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]): """ Unescape special characters in metadata keys """ if isinstance(documents, dict): documents = [documents] old_client = call.requested_endpoint_version <= PartialVersion("2.16") for doc in documents: if old_client and "metadata" in doc: doc["metadata"] = [] continue metadata = doc.get("metadata") if not metadata: continue doc["metadata"] = { ParameterKeyEscaper.unescape(k): v for k, v in metadata.items() }
def __init__( self, endpoint_name, remote_addr=None, endpoint_version: PartialVersion = PartialVersion("1.0"), data=None, batched_data=None, headers=None, files=None, trx=None, host=None, auth_cookie=None, ): super().__init__(data=data, batched_data=batched_data) self._id = database.utils.id() self._files = files # currently dic of key to flask's FileStorage) self._start_ts = time.time() self._end_ts = 0 self._duration = 0 self._endpoint_name = endpoint_name self._remote_addr = remote_addr assert isinstance(endpoint_version, PartialVersion), endpoint_version self._requested_endpoint_version = endpoint_version self._actual_endpoint_version = None self._headers = CaseInsensitiveDict() self._kpis = {} self._log_api = True if headers: self._headers.update(headers) self._result = APICallResult() self._auth = None self._impersonation = None if trx: self.set_header(self._transaction_headers, trx) self._requires_authorization = True self._host = host self._auth_cookie = auth_cookie self._json_flags = {}
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]): if isinstance(tasks_data, dict): tasks_data = [tasks_data] conform_output_tags(call, tasks_data) for data in tasks_data: for path in dict_fields_paths: unescape_dict_field(data, path) ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data) DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data) need_legacy_params = call.requested_endpoint_version < PartialVersion( "2.9") for data in tasks_data: params_unprepare_from_saved( fields=data, copy_to_legacy=need_legacy_params, ) artifacts_unprepare_from_saved(fields=data)
class ServiceRepo(object): _endpoints: MutableMapping[str, List[Endpoint]] = {} """ Registered endpoints, in the format of {endpoint_name: Endpoint)} the list of endpoints is sorted by min_version """ _version_required = config.get("apiserver.version.required") """ If version is required, parsing will fail for endpoint paths that do not contain a valid version """ _check_max_version = config.get("apiserver.version.check_max_version") """If the check is set, parsing will fail for endpoint request with the version that is grater than the current maximum """ _max_version = PartialVersion("2.13") """ Maximum version number (the highest min_version value across all endpoints) """ _endpoint_exp = (re.compile( r"^/?v(?P<endpoint_version>\d+\.?\d+)/(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$" ) if config.get("apiserver.version.required") else re.compile( r"^/?(v(?P<endpoint_version>\d+\.?\d+)/)?(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$" )) """ Endpoint structure expressions. We have two expressions, one with optional version part. Constraints for the first (strict) expression: 1. May start with a leading '/' 2. Followed by a version number (int or float) preceded by a leading 'v' 3. Followed by a '/' 4. Followed by a service name, which must start with an english letter (lower or upper case) or underscore, and followed by any number of alphanumeric or underscore characters 5. Followed by a '.' 6. Followed by an action name, which must start with an english letter (lower or upper case) or underscore, and followed by any number of alphanumeric or underscore characters 7. May end with a leading '/' The second (optional version) expression does not require steps 2 and 3. """ _return_stack = config.get("apiserver.return_stack") """ return stack trace on error """ _return_stack_on_code = parse_return_stack_on_code( config.get("apiserver.return_stack_on_code", {})) """ if 'return_stack' is true and error contains a return code, return stack trace only for these error codes """ _credentials = config["secure.credentials.apiserver"] """ Api Server credentials used for intra-service communication """ _token = None """ Token for internal calls """ @classmethod def _load_from_path( cls, root_module: Path, module_prefix: Optional[str] = None, predicate: Optional[Callable[[Path], bool]] = None, ): log.info(f"Loading services from {str(root_module.absolute())}") sub_module = None for sub_module in root_module.glob("*"): if predicate and not predicate(sub_module): continue if (sub_module.is_file() and sub_module.suffix == ".py" and not sub_module.stem == "__init__"): import_module(".".join( filter( None, (module_prefix, root_module.stem, sub_module.stem)))) if sub_module.is_dir() and not sub_module.stem == "__pycache__": import_module(".".join( filter( None, (module_prefix, root_module.stem, sub_module.stem)))) # leave no trace of the 'sub_module' local del sub_module cls._max_version = max( cls._max_version, max(ep.min_version for ep in cast(Iterable[Endpoint], chain(*cls._endpoints.values()))), ) @classmethod def load(cls, root_module="services"): cls._load_from_path( root_module=Path(__file__).parents[1] / root_module, module_prefix="apiserver", ) @classmethod def register(cls, endpoint: Endpoint): if cls._endpoints.get(endpoint.name): if any(ep.min_version == endpoint.min_version for ep in cls._endpoints[endpoint.name]): raise Exception( f"Trying to register an existing endpoint. name={endpoint.name}, version={endpoint.min_version}" ) else: cls._endpoints[endpoint.name].append(endpoint) else: cls._endpoints[endpoint.name] = [endpoint] cls._endpoints[endpoint.name].sort(key=lambda ep: ep.min_version, reverse=True) @classmethod def endpoint_names(cls): return sorted(cls._endpoints.keys()) @classmethod def endpoints_summary(cls): return { "endpoints": { name: list(map(Endpoint.to_dict, eps)) for name, eps in cls._endpoints.items() }, "models": {}, } @classmethod def max_endpoint_version(cls) -> PartialVersion: return cls._max_version @classmethod def _get_endpoint(cls, name, version) -> Optional[Endpoint]: versions = cls._endpoints.get(name) if not versions: return None try: return next(ep for ep in versions if ep.min_version <= version) except StopIteration: # no appropriate version found return None @classmethod def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]: endpoint = cls._get_endpoint(call.endpoint_name, call.requested_endpoint_version) if endpoint is None: call.log_api = False call.set_error_result( msg=(f"Unable to find endpoint for name {call.endpoint_name} " f"and version {call.requested_endpoint_version}"), code=404, subcode=0, ) return call.actual_endpoint_version = endpoint.min_version call.requires_authorization = endpoint.authorize return endpoint @classmethod def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]: """ Parse endpoint version, service and action from request path. """ m = cls._endpoint_exp.match(path) if not m: raise MalformedPathError("Invalid request path %s" % path) endpoint_name = m.group("endpoint_name") version = m.group("endpoint_version") if version is None: # If endpoint is available, use the max version version = cls._max_version else: try: version = PartialVersion(version) except ValueError as e: raise errors.bad_request.RequestPathHasInvalidVersion( version=version, reason=e) if cls._check_max_version and version > cls._max_version: raise InvalidVersionError( f"Invalid API version (max. supported version is {cls._max_version})" ) return version, endpoint_name @classmethod def _should_return_stack(cls, code: int, subcode: int) -> bool: if not cls._return_stack or code not in cls._return_stack_on_code: return False if subcode is None: # Code in dict, but no subcode. We'll allow it. return True subcode_list = cls._return_stack_on_code.get(code) if subcode_list is None: # if the code is there but we don't have any subcode list, always return stack return True return subcode in subcode_list @classmethod def _get_company(cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False) -> Optional[str]: authorize = endpoint and endpoint.authorize if ignore_error or not authorize: try: return call.identity.company except Exception: return None return call.identity.company @classmethod def handle_call(cls, call: APICall, load_data_callback: Callable = None): try: if call.failed: raise CallFailedError() endpoint = cls._resolve_endpoint_from_call(call) if call.failed: raise CallFailedError() validate_auth(endpoint, call) validate_role(endpoint, call) if validate_impersonation(endpoint, call): # if impersonating, validate role again validate_role(endpoint, call) if load_data_callback: load_data_callback(call) if call.failed: raise CallFailedError() validate_data(call, endpoint) if call.failed: raise CallFailedError() # In case call does not require authorization, parsing the identity.company might raise an exception company = cls._get_company(call, endpoint) ret = endpoint.func(call, company, call.data_model) # allow endpoints to return dict or model (instead of setting them explicitly on the call) if ret is not None: if isinstance(ret, jsonmodels.models.Base): call.result.data_model = ret elif isinstance(ret, dict): call.result.data = ret except APIError as ex: # report stack trace only for gene include_stack = cls._return_stack and cls._should_return_stack( ex.code, ex.subcode) call.set_error_result( code=ex.code, subcode=ex.subcode, msg=str(ex), include_stack=include_stack, ) except CallFailedError: # Do nothing, let 'finally' wrap up pass except Exception as ex: log.exception(ex) call.set_error_result(code=500, subcode=0, msg=str(ex), include_stack=cls._return_stack) finally: content, content_type = call.get_response() call.mark_end() console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms" if call.result.code < 300: log.info(console_msg) else: console_msg = f"{console_msg}, msg={call.result.msg}" if call.result.code < 500: log.warn(console_msg) else: log.error(console_msg) return content, content_type
def parse_version(version): if not re.match(r"^\d+\.\d+$", version): raise ValueError( f"Encountered unrecognized key {version!r} in {self.service_name}.{self.action_name}" ) return PartialVersion(version)