def __init__( self, *args: ServableValidArgs_T, download_path: Optional[Path] = None, script_loader_cls: Type[ FlashServeScriptLoader] = FlashServeScriptLoader, ): try: loc = args[-1] # last element in args is always loc parsed = parse_obj_as(ServableValidArgs_T, tuple(args)) except ValidationError: if args[0].__qualname__ != script_loader_cls.__qualname__: raise parsed = [ script_loader_cls, parse_obj_as(Union[HttpUrl, FilePath], loc) ] if isinstance(parsed[-1], Path): f_path = loc else: f_path = download_file(loc, download_path=download_path) if len(args) == 2 and args[ 0].__qualname__ != script_loader_cls.__qualname__: # if this is a class and path/url... klass = args[0] instance = klass.load_from_checkpoint(f_path) else: # if this is just a path/url klass = script_loader_cls instance = klass(f_path) self.instance = instance
def validate(cls, v, field: ModelField, **kwargs): # this is probably a bad idea... if not field.sub_fields: raise SyntaxError( 'A field id must be provided when using FileType') field_id = get_args(field.sub_fields[0].type_)[0] if isinstance(v, FileType): v.field_id = field_id return v path = None url = None try: path = parse_obj_as(FilePath, v) except ValidationError: pass try: url = parse_obj_as(AnyHttpUrl, v) except ValidationError: pass if path is None and url is None: raise ValueError( f'Could not parse {v} as a local path or a remote url') return FileType(path=path, url=url, field_id=field_id)
def get_template_by_name_and_version( self, name: str, version: str, resource_type: ResourceType, parent_service_name: str = None ) -> Union[ResourceTemplate, UserResourceTemplate]: """ Returns full template for the 'resource_type' template defined by 'template_name' and 'version' For UserResource templates, you also need to pass in 'parent_service_name' as a parameter """ query = self._template_by_name_query( name, resource_type) + f' AND c.version = "{version}"' # If querying for a user resource, we also need to add the parentWorkspaceService (name) to the query if resource_type == ResourceType.UserResource: if parent_service_name: query += f' AND c.parentWorkspaceService = "{parent_service_name}"' else: raise Exception( "When getting a UserResource template, you must pass in a 'parent_service_name'" ) # Execute the query and handle results templates = self.query(query=query) if len(templates) != 1: raise EntityDoesNotExist if resource_type == ResourceType.UserResource: return parse_obj_as(UserResourceTemplate, templates[0]) else: return parse_obj_as(ResourceTemplate, templates[0])
def __read_templates() -> Tuple[Dict[TemplateId, ItemTemplate], Dict[ TemplateId, NodeTemplate], ]: item_templates: List[ItemTemplate] = [] node_templates: List[NodeTemplate] = [] # Read every file from db/items for item_file_path in db_dir.joinpath("items").glob("*"): file_data: List[dict] = ujson.load( item_file_path.open("r", encoding="utf8")) item_templates.extend( pydantic.parse_obj_as( List[ItemTemplate], (item for item in file_data if item["_type"] == "Item"), )) node_templates.extend( pydantic.parse_obj_as( List[NodeTemplate], (item for item in file_data if item["_type"] == "Node"), )) return ( {tpl.id: tpl for tpl in item_templates}, {tpl.id: tpl for tpl in node_templates}, )
def create_template( self, template_input: ResourceTemplateInCreate, resource_type: ResourceType, parent_service_name: str = "" ) -> Union[ResourceTemplate, UserResourceTemplate]: """ creates a template based on the input (workspace and workspace-services template) """ template = { "id": str(uuid.uuid4()), "name": template_input.name, "title": template_input.json_schema["title"], "description": template_input.json_schema["description"], "version": template_input.version, "resourceType": resource_type, "current": template_input.current, "required": template_input.json_schema["required"], "properties": template_input.json_schema["properties"], "customActions": template_input.customActions } if "pipeline" in template_input.json_schema: template["pipeline"] = template_input.json_schema["pipeline"] if resource_type == ResourceType.UserResource: template["parentWorkspaceService"] = parent_service_name template = parse_obj_as(UserResourceTemplate, template) else: template = parse_obj_as(ResourceTemplate, template) self.save_item(template) return template
def parse_params( q_params: dict[str, Any], p_params: dict[str, Any], types: dict[str, Any], ) -> tuple[dict[str, Any], dict[str, Any]]: """Casts query and path parameters to their corresponding types. Args: q_params: A dict containing query parameters and their values. p_params: A dict containing path parameters and their values. types: A dictionary containing the types to cast the query and path parameters to. Returns: A tuple of two dicts which contains the parsed query and path parameters. Raises: ValidationError: Raises pydantic's ValidationError upon encountering an invalid value which cannot be parsed. """ # Parses the parameters using a dictionary comprehension. `k` is the current # key in the original parameters dictionary, and `v` is the value to be casted. # The function fetches a type from `types` corresponding to the current key, # and if there isn't one, uses Any for casting, which effectively doesn't cast parsed_q = {k: parse_obj_as(types.get(k) or Any, v) for k, v in q_params.items()} parsed_p = {k: parse_obj_as(types.get(k) or Any, v) for k, v in p_params.items()} return parsed_q, parsed_p
def deserialize_config(cls, v, values, **kwargs): if v: if isinstance(v, str): return parse_obj_as(List[List[TimeseriesSchema]], json.loads(v)) elif isinstance(v, list): return parse_obj_as(List[List[TimeseriesSchema]], v) else: return v return None
def deserialize_config(cls, v, values, **kwargs): if v: if isinstance(v, str): return parse_obj_as(Union[Dict, List], json.loads(v)) elif isinstance(v, (dict, list)): return parse_obj_as(Union[Dict, List], v) else: return v return None
def validate_cfg(_in, filename: Optional[str] = None) -> Optional[List[UserThingCfg]]: try: if isinstance(_in, list): return parse_obj_as(List[UserThingCfg], _in, type_name=filename) else: return [parse_obj_as(UserThingCfg, _in, type_name=filename)] except ValidationError as e: HABAppError(log).add_exception(e).dump() return None
def from_raw_event(event: Union[dict, List[List[dict]]]) -> List[RawScheduledEvent]: if isinstance(event, dict): return [pydantic.parse_obj_as(RawScheduledEvent, event)] # flatten the event into 1d array event: List[dict] = list(itertools.chain(*event)) events = pydantic.parse_obj_as(List[RawScheduledEvent], event) return events
def parse_request(event_cls: t.Type[Event]) -> Event: if issubclass(event_cls, ServerRequest): event = parse_obj_as(event_cls, request.params) assert request.id is not None event._id = request.id event._client = self return event elif issubclass(event_cls, ServerNotification): return parse_obj_as(event_cls, request.params) else: raise TypeError( "`event_cls` must be a subclass of ServerRequest" " or ServerNotification")
def searc_inc(q: str): pipeline = [{ "$match": { "correct": { "$exists": True }, "question": { "$regex": f".*{q}.*", "$options": "i" }, } }] doc = list(collection.aggregate(pipeline=pipeline)) pipeline_inc = [ { "$match": { "is_correct": True, "question": { "$regex": f".*{q}.*", "$options": "i" }, } }, { "$group": { "_id": { "type": "$type", "question": "$question", "title": "$title" }, "answers": { "$addToSet": { "answer": "$answer", "is_correct": "$is_correct", "id": "$_id", } }, } }, { "$project": { "answers": 1, "type": "$_id.type", "question": "$_id.question", "title": "$_id.title", "_id": 0, } }, ] docs_inc = list(collection_incomplete.aggregate(pipeline=pipeline_inc)) return [*parse_obj_as(list[QA], doc), *parse_obj_as(list[QAINC], docs_inc)]
async def test_list_datasets_entrypoint( async_client: httpx.AsyncClient, pennsieve_subsystem_mock, pennsieve_api_headers: Dict[str, str], ): response = await async_client.get( "v0/datasets", headers=pennsieve_api_headers, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data parse_obj_as(Page[DatasetMetaData], data)
def upload(update: Update, context: CallbackContext): if not update.effective_message: return if update.effective_message.document.file_size > 5242880: return update.effective_message.reply_text( "File size is greater than 5MB.", quote=True, ) if update.effective_message.document.file_name[-3:].lower() not in [ "jpg", "peg", "png", "gif", "mp4", ]: return update.effective_message.reply_text( "File type not supported.", quote=True ) user_file = f'{str(update.effective_message.from_user.id)}.{update.effective_message.document.file_name.rsplit(".", 1)[-1]}' context.bot.get_file(update.effective_message.document.file_id).download( user_file ) if not os.path.exists(user_file): return update.effective_message.reply_text( "Failed to upload. Reason: File not found.", quote=True, ) try: with open(user_file, "rb") as file: with requests.post( "https://telegra.ph/upload", files={"files": file} ) as resp: if resp.ok and isinstance(content := resp.json(), list): model = parse_obj_as(list[UploadSuccess], content) update.effective_message.reply_text( model[0].url, quote=True ) elif isinstance(content, dict): model = parse_obj_as(UploadError, resp.json()) return update.effective_message.reply_text( f"Failed to upload. Reason: {model.error}", quote=True, ) else:
def __validate_metadata__(cls, json_str: str) -> List[Tuple[str, str]]: try: parse_obj_as(Json[List[Tuple[str, str]]], json_str) data = [(str(item[0]), str(item[1])) for item in json.loads(json_str)] except ValidationError: raise InvalidLnurlPayMetadata clean_data = [x for x in data if x[0] in cls.valid_metadata_mime_types] mime_types = [x[0] for x in clean_data] counts = {x: mime_types.count(x) for x in mime_types} if not clean_data or "text/plain" not in mime_types or counts["text/plain"] > 1: raise InvalidLnurlPayMetadata return clean_data
def _handle_response(self, response: Response) -> Event: assert response.id is not None request = self._unanswered_requests.pop(response.id) # FIXME: The errors have meanings. if response.error is not None: __import__("pprint").pprint(response.error) raise RuntimeError("Response error!") event: Event if request.method == "initialize": assert self._state == ClientState.WAITING_FOR_INITIALIZED self._send_notification("initialized") event = Initialized.parse_obj(response.result) self._state = ClientState.NORMAL elif request.method == "shutdown": assert self._state == ClientState.WAITING_FOR_SHUTDOWN event = Shutdown() self._state = ClientState.SHUTDOWN elif request.method == "textDocument/completion": completion_list = None try: completion_list = CompletionList.parse_obj(response.result) except ValidationError: try: completion_list = CompletionList( isIncomplete=False, items=parse_obj_as(t.List[CompletionItem], response.result), ) except ValidationError: assert response.result is None event = Completion(message_id=response.id, completion_list=completion_list) elif request.method == "textDocument/willSaveWaitUntil": event = WillSaveWaitUntilEdits( edits=parse_obj_as(t.List[TextEdit], response.result)) else: raise NotImplementedError((response, request)) return event
def _parse_event(self, result) -> BaseEvent: """ Internal use only Parse event or message from json to BaseEvent :param result: the json :return: BaseEvent """ try: result = parse_obj_as(Events, result) if isinstance(result, Message): # construct message chain # parse quote first if len(result.messageChain) > 2: first_component = result.messageChain[1] if isinstance( first_component, Quote ): # FIXME: add the first two message part back try: if isinstance(result.messageChain[2], At): del result.messageChain[ 2] # delete duplicated at if len(result.messageChain) > 2: if isinstance( result.messageChain[2], Plain ) and result.messageChain[2].text == ' ': del result.messageChain[ 2] # delete space after duplicated at except: self.logger.exception( 'Please open a github issue to report this error' ) except: self.logger.exception('Unhandled exception') return result
def deserialize_message(raw: bytes) -> Union[Message, DescribeMessage]: if raw[0] == ord("{"): message_dict = json.loads(raw.decode("utf-8")) elif raw[0] == ord("Z"): message_dict = json.loads( zstandard.decompress(raw[1:]).decode("utf-8")) else: raise PipelineMessageError( "Unknown bytes string cannot be deserialized") if message_dict["kind"] == Kind.Message: return parse_obj_as(Message, message_dict) elif message_dict["kind"] == Kind.Describe: return parse_obj_as(DescribeMessage, message_dict) else: raise PipelineMessageError("Unknown format")
def settings_singleton( cli_args: list[str] = [], env: dict[str, str] = {}, refresh: bool = False ) -> Settings: """Singleton for settings model. Args: args (list[str], optional): List of all arguments passed to program. Use it like this: `parse_args(sys.argv[1:])`. Args must start with one or two dashes and only contain lower case chars, period and underscores. Defaults to `[]`. env (dict[str, str], optional): Dict with all enviornment variables. Defaults to `{}`. refresh (bool, optional): Should the singleton object settings be refreshed? Defaults to `False`. Returns: Settings: Settings object. """ global _settings if _settings and not refresh: return _settings else: settings_dict = setup_raw_settings(cli_args, env) _settings = parse_obj_as(Settings, settings_dict) _ensure_generic_route_exists(_settings) return _settings
def get_project(self, filters: ProjectFilter) -> Project: sql = f""" SELECT project.uuid AS uuid, project.title AS title, CASE WHEN project.notes IS NULL THEN '' ELSE project.notes END AS notes, project.area AS area FROM TMTask AS project WHERE project.uuid == '{filters.uuid}' AND project.type == {TaskType.project} AND project.trashed == 0 """ try: q = SqliteQuery(connection=self._get_connection(), sql=sql) project: dict = q.execute_for_one() return parse_obj_as(Project, project) except Things3NotFoundException as ex: raise Things3StorageException( f'There is no any project with id {filters.uuid}') except Things3DataBaseException as ex: raise Things3StorageException(ex)
def get_projects(self, filters: ProjectFilter) -> List[Project]: statuses = self._filter_statuses(filters.statuses) sql = f""" SELECT project.uuid AS uuid, project.title AS title, project.area AS area, ( SELECT COUNT(uuid) FROM TMTask AS task WHERE task.project = project.uuid AND task.trashed = 0 AND task.status IN ({statuses}) ) AS tasks FROM TMTask AS project WHERE project.type == {TaskType.project} AND project.trashed == 0 AND project.status IN ({statuses}) AND project.area == '{filters.area}' ORDER BY project.title COLLATE NOCASE """ try: q = SqliteQuery(connection=self._get_connection(), sql=sql) projects: List[dict] = q.execute() return list(map(lambda t: parse_obj_as(Project, t), projects)) except Things3NotFoundException as ex: raise Things3StorageException( f'There are no any projects for area {filters.area}') except Things3DataBaseException as ex: raise Things3StorageException(ex)
def get_areas(self) -> List[Area]: sql = f""" SELECT area.uuid AS uuid, area.title AS title, ( SELECT COUNT(uuid) FROM TMTask AS project WHERE project.area = area.uuid AND project.trashed = 0 AND project.status = {TaskStatus.new} ) AS projects FROM TMArea AS area ORDER BY area.title COLLATE NOCASE """ try: q = SqliteQuery(connection=self._get_connection(), sql=sql) areas: List[dict] = q.execute() return list(map(lambda t: parse_obj_as(Area, t), areas)) except Things3NotFoundException as ex: raise Things3StorageException(f'There are no any Areas') except Things3DataBaseException as ex: raise Things3StorageException(ex)
def serialize_list_obj(serialize_schema: Type[BaseSchema], obj_list: List[ModelType]) -> List[Type[BaseModel]]: item_list = [ jsonable_encoder(item) for item in parse_obj_as(List[serialize_schema], obj_list) ] return item_list
def convert_all_yaml_to_sdf(yaml_schemas: Sequence[Mapping[str, Any]], library_id: str) -> Mapping[str, Any]: """Convert YAML schema library into SDF schema library. Args: yaml_schemas: YAML schemas. library_id: ID of schema collection. Returns: Data in JSON output format. """ sdf_schemas = [] parsed_yaml = parse_obj_as(List[Schema], yaml_schemas) if [p.dict(exclude_none=True) for p in parsed_yaml] != yaml_schemas: raise RuntimeError( "The parsed and raw schemas do not match. The schema might have misordered fields, or there is a bug in this script." ) for yaml_schema in parsed_yaml: out_json = convert_yaml_to_sdf(yaml_schema) sdf_schemas.append(out_json) json_data = merge_schemas(sdf_schemas, library_id) validate_schemas(json_data) return json_data
def search_workspace(self, q: str, limit: int = 10, entry_id: str = None, category: str = None) -> List[Hit]: """Search a specific workspace Args: q (str): What to search for limit (int, optional): Limit results. Defaults to 10. entry_id (str, optional): Search for specific entry id. Defaults to None. category (str, optional): Limit search to specific category. Defaults to None. Returns: List[Hit]: List of hits """ query = {"query": q, "limit": limit} if entry_id: query["entry_id"] = entry_id if category: query["category"] = category res = self._make_request("get", f"/api/archive/search/{self.workspace}", queryParams=query) if res is None: return [] return parse_obj_as(List[Hit], res)
def test_one(): json_str = """\ { "constraints": [ { "constraint_name": "alignment", "instances": ["i0"], "direction": "horizontal", "edge": "bottom" }, { "constraint_name": "generator", "instances": ["i1", "i2"], "style": "cc" }, { "constraint_name": "orientation", "instances": ["i2"], "flip_y": true }, { "constraint_name": "boundary", "subcircuits": ["subcircuit_a"], "aspect_ratio_min": 0.3, "aspect_ratio_max": 0.6 } ] } """ const = json.loads(json_str) placement_constraints = parse_obj_as(ConstraintsPlacement, const) print(placement_constraints.json())
def _get_group_uncached(self, name: Union[str, int], referer: Optional["PermissionGroup"], required: bool ) -> Tuple["PermissionGroup", bool]: group_desc = self.config.get(name) if group_desc is None: if required: if referer: logger.error('Permission group {}:{} not found (required from {}:{})', self.name, name, referer.namespace.name, referer.name) else: logger.error('Permission group {}:{} not found', self.name, name) return NullPermissionGroup(), False try: desc = parse_obj_as(GroupDesc, group_desc) except ValueError: logger.exception('Failed to parse {}:{} ({})', self.name, name, self.path) return NullPermissionGroup(), False # 注入插件预设 if self.name == 'global' and name in default_groups: for pn in plugin_namespaces: if name in pn.config: desc.inherits.append(f'{pn.name}:{name}') self.groups[name] = group = PermissionGroup(self, name) group.populate(desc, referer) return group, True
async def test_pull_file_from_remote( ftpserver: ProcessFTPServer, tmp_path: Path, faker: Faker, mocked_log_publishing_cb: mock.AsyncMock, ): ftp_server = cast(dict, ftpserver.get_login_data()) ftp_server["password"] = ftp_server["passwd"] del ftp_server["passwd"] ftp_server["username"] = ftp_server["user"] del ftp_server["user"] # put some file on the remote fs = fsspec.filesystem("ftp", **ftp_server) TEXT_IN_FILE = faker.text() file_name = faker.file_name() with fs.open(file_name, mode="wt") as fp: fp.write(TEXT_IN_FILE) ftp_server_url_login_data = ftpserver.get_login_data(style="url") src_url = parse_obj_as(AnyUrl, f"{ftp_server_url_login_data}/{file_name}") dst_path = tmp_path / faker.file_name() await pull_file_from_remote(src_url, dst_path, mocked_log_publishing_cb) assert dst_path.exists() assert dst_path.read_text() == TEXT_IN_FILE mocked_log_publishing_cb.assert_called()
def fetch_all(self, db: Session, name: str = '') -> ProductsResponse: """ Retrieve all products records. Args: db (Session): The database session. name (str): Product name to filter. Raises: ItensNotFound: If no item was found. Returns: ProductsResponse: A dict with products records. """ products = db.query(Product).filter( Product.is_deleted == False, func.lower(Product.name).contains( name.lower(), autoescape=True)).order_by(Product.id).all() products = parse_obj_as(List[ProductResponse], products) if len(products) == 0: raise ItensNotFound("No products found") response = ProductsResponse(records=products) return response
async def get_messages_for_graph( self, start_date: datetime.date, finish_date: datetime.date, ) -> dict[datetime.date, int]: """Получить данные для графика кол-ва сообщений. :param start_date: datetime.date :param finish_date: datetime.date :return: dict[datetime.date, int] """ query = """ SELECT date::DATE, COUNT(*) AS messages_count FROM bot_init_message WHERE date BETWEEN :start_date AND :finish_date GROUP BY date::DATE ORDER BY date """ rows = await self._connection.fetch_all(query, { 'start_date': start_date, 'finish_date': finish_date }) return { message_graph_data_item.date: message_graph_data_item.messages_count for message_graph_data_item in parse_obj_as( list[MessageGraphDataItem], rows) }