class DataPreprocessing(object): def __init__(self): self.conf = Conf() self.xml_path = self.conf.getConfig("path", "xml_path") self.index_name = self.conf.getConfig("search", "index_name") self.doc_type = self.conf.getConfig("search", "doc_type") # 读取设定 self.tokenizer = RegexpTokenizer(r"\w+") self.lem = WordNetLemmatizer() self.stemmer = PorterStemmer() self.stopwords = set(stopwords.words("english")) self.es = Elasticsearch() self.fields = self.conf.getImportant() # self.mapping = self.conf.getMapping() # es的index和doc_type相当于mysql的db和table # 如果要创建的index已存在,则删除原有index if self.es.indices.exists(index=self.index_name): self.es.indices.delete(index=self.index_name) # 创建index self.es.indices.create(index=self.index_name) # self.es.indices.put_mapping(index=self.index_name, doc_type=self.doc_type, body=self.mapping) print("created index:" + self.index_name) def xml2json(self, xmlpath): # 将xml数据转化为dict with open(xmlpath, "r") as xmlf: xml_str = xmlf.read() dict_str = xmltodict.parse(xml_str) # json_str = json.dumps(dict_str) return dict_str def cleanData(self, doc): raw_tokens = self.tokenizer.tokenize(doc.lower()) lem_tokens = [self.stemmer.stem(token) for token in raw_tokens] lem_tokens = [ token for token in lem_tokens if not token.isdigit() and len(token) > 1 ] lem_tokens_without_stopword = filter( lambda i: i not in self.stopwords, lem_tokens ) return " ".join(list(lem_tokens_without_stopword)) def clean(self, json_data): if json_data["brief_title"]: json_data["brief_title"] = self.cleanData(json_data["brief_title"]) if json_data["official_title"]: json_data["official_title"] = self.cleanData(json_data["official_title"]) if json_data["brief_summary"]: json_data["brief_summary"] = self.cleanData(json_data["brief_summary"]) if json_data["detailed_description"]: json_data["detailed_description"] = self.cleanData( json_data["detailed_description"] ) if json_data["eligibility"]["criteria"]["textblock"]: json_data["eligibility"]["criteria"]["textblock"] = self.cleanData( json_data["eligibility"]["criteria"]["textblock"] ) return json_data def oswalk(self): count = 0 # 遍历xml_path中所有文件夹下的所有文件 for os_set in os.walk(self.xml_path, topdown=True): for filename in os_set[2]: try: filepath = os.path.join(os_set[0], filename) json_data = self.xml2json(filepath) cleaned_json_data = {} default_input_json = { "id_info": "NCT00000000", "brief_title": "", "official_title": "", "brief_summary": "", "detailed_description": "", "intervention": {"intervention_type": "", "intervention_name": ""}, "eligibility": { "criteria": {"textblock": ""}, "gender": "All", "minimum_age": "6 Months", "maximum_age": "100 Years", "healthy_volunteers": "No", }, "keyword": [], "intervention_browse": [], "condition": [], } # 将important.txt中设定好的字段从dict中提取出来,填充到要存进es的dict中 for field in self.fields: if field in json_data["clinical_study"]: if len(self.fields[field]) > 1 and not isinstance( json_data["clinical_study"][field], str ): cleaned_json_data[field] = json_data["clinical_study"][ field ][self.fields[field]] else: cleaned_json_data[field] = json_data["clinical_study"][ field ] else: cleaned_json_data[field] = default_input_json[field] # if len(self.fields[field]) > 1 and not isinstance( # default_input_json[field], str # ): # cleaned_json_data[field] = default_input_json[field][ # self.fields[field] # ] # else: # cleaned_json_data[field] = default_input_json[field] # 处理年龄 # print(default_input_json) # print(cleaned_json_data) if "eligibility" in cleaned_json_data: if "criteria" not in cleaned_json_data["eligibility"]: cleaned_json_data["eligibility"]["criteria"] = {"textblock": ""} for k in default_input_json["eligibility"]: if k not in cleaned_json_data["eligibility"]: cleaned_json_data["eligibility"][k] = default_input_json["eligibility"][k] cleaned_json_data["eligibility"] = NormalAge( cleaned_json_data["eligibility"] ) cleaned_json_data = self.clean(cleaned_json_data) # ---------------------------------- # print(cleaned_json_data) # return # ---------------------------------- # 插入数据 self.es.index( index=self.index_name, body=cleaned_json_data, doc_type=self.doc_type, ) count += 1 if count % 1000 == 0: print("Already finished:" + str(count)) except KeyboardInterrupt: # 处理ctrl+C中断程序的情况 print("Interrupted") try: sys.exit(0) except SystemExit: os._exit(0) except Exception as e: print(cleaned_json_data) print(e) with open("errorxml.txt", "a") as f: f.write(str(filepath) + "\n") print("Error in ", str(filename))
class Query(object): def __init__(self): self.conf = Conf() self.query_xml_path = self.conf.getConfig("path", "query_xml_path") self.index_name = self.conf.getConfig("search", "index_name") self.doc_type = self.conf.getConfig("search", "doc_type") self.meshDict = self.getPickles( self.conf.getConfig("path", "dict_pickle_path")) self.es = Elasticsearch(timeout=30, max_retries=10, retry_on_timeout=True) # 设定es的超时时限为30秒,默认为10秒 # 最大重试次数为10次 # 防止因数据量太大导致的超时 self.fields = self.conf.getImportant() self.extracted = [] self.tokenizer = RegexpTokenizer(r"\w+") self.lem = WordNetLemmatizer() self.stemmer = PorterStemmer() self.stopwords = set(stopwords.words("english")) def getPickles(self, pickle_path): with open(pickle_path, "rb") as pf: data = pickle.load(pf) return data def xml2json(self, xmlpath): with open(xmlpath, "r") as xmlf: xml_str = xmlf.read() dict_str = xmltodict.parse(xml_str) # json_str = json.dumps(dict_str) return dict_str def extract_query(self): # 处理查询字段 query_xml_data = self.xml2json(self.query_xml_path)["topics"]["topic"] for q in query_xml_data: new_query = { "id": q["@number"], "disease": q["disease"], "gene": q["gene"], "age": int(q["demographic"].split("-")[0]) * 365, "gender": q["demographic"].split(" ")[-1], "other": q["other"], } self.extracted.append(new_query) with open("query.json", "w") as f: f.write(json.dumps(self.extracted, indent=4)) def cleanData(self, doc): raw_tokens = self.tokenizer.tokenize(doc.lower()) lem_tokens = [self.stemmer.stem(token) for token in raw_tokens] lem_tokens = [ token for token in lem_tokens if not token.isdigit() and len(token) > 1 ] lem_tokens_without_stopword = filter(lambda i: i not in self.stopwords, lem_tokens) return list(lem_tokens_without_stopword) def query(self, single_query): gender_lst = ["male", "female"] must_not_gender = gender_lst[abs( gender_lst.index(single_query["gender"]) - 1)] # 性别分为male,female和All三种,得到不用的一种 query_keywords = single_query["disease"].lower().split(" ") relevence = single_query["disease"].lower().split(" ") for qk in query_keywords: # qk = qk.lower() if qk in self.meshDict and qk not in [ "cancer", "adenocarcinoma", "carcinoma", ]: relevence += self.meshDict[qk] if "mesh_numbers" in relevence: relevence.remove("mesh_numbers") relevence = list(set(self.cleanData(" ".join(relevence)))) print(single_query["gene"].replace("(", " ").replace(")", " ").replace(",", "")) # for rl in relevence: # if rl in ["mesh_numbers", "cancers", "non", "carcinomas", "tumors", "neoplasms", "pseudocysts", "cysts", "vipomas"]: # # print(rl) # relevence.remove(rl) relevence_str = " ".join(relevence) # print(relevence_str) # query_body = { # "query": { # "multi_match": { # "query": (single_query["disease"] + ' ' + single_query["gene"].replace("(", " ").replace(")", " ").replace(",", "")).lower(), # "type": "cross_fields", # "fields": [ # "brief_title", # "brief_summary", # "detailed_description", # "official_title", # "keyword", # "condition", # "eligibility.criteria.textblock", # ], # } # }, # "size": 1000, # } # p5: 0.3586 # p10:0.3138 # p15:0.2704 # with age: p5: 0.3586 p10:0.3172 p15:0.2805 # with gender: p5: 0.3655 p10:0.3241 p15:0.2920 query_body = { "query": { "multi_match": { "query": (single_query["disease"] + ' ' + single_query["gene"].replace("(", " ").replace( ")", " ").replace(",", "")).lower(), "type": "cross_fields", "fields": [ "brief_title", "brief_summary", "detailed_description", "official_title", "keyword", "condition", "eligibility.criteria.textblock", ], } }, "size": 1000, } # query_body = { # "query": { # "multi_match": { # "query": (single_query["gene"].replace("(", " ").replace(")", " ").replace(",", "")).lower(), # "type": "cross_fields", # "fields": [ # "brief_title", # "brief_summary", # "detailed_description", # "official_title", # "keyword", # "condition", # "eligibility.criteria.textblock", # ], # } # }, # "size": 1000, # } # query_standard = (single_query["gene"].replace("(", " ").replace(")", " ").replace(",", "")).lower() # query_body = { # "query": { # "bool": { # "should": [ # {"match": {"brief_title": {"query": query_standard, "boost": 2}}}, # {"match": {"official_title": {"query": query_standard, "boost": 2}}}, # {"match": {"brief_summary": {"query": query_standard, "boost": 1}}}, # {"match": {"detailed_description": {"query": query_standard, "boost": 1}}}, # {"match": {"eligibility.criteria.textblock": {"query": query_standard, "boost": 5}}}, # {"match": {"keyword": {"query": query_standard, "boost": 6}}}, # {"match": {"condition": {"query": query_standard, "boost": 3}}}, # ], # "must_not": [{"term": {"gender": must_not_gender}}], # }, # }, # "size": 1500, # } # 这里的querybody需要再认真设计下,不同的查询方式对最终结果的MAP和P@10影响很大 query_result = self.es.search(index=self.index_name, doc_type=self.doc_type, body=query_body)["hits"]["hits"] # 获得查询结果 # print(query_result) # score_max = query_result[0]["_score"] rank = 1 with open("trec_eval/eval/r40.txt", "a") as f: try: for qr in query_result: # 过滤年龄不符合的情况 if "eligibility" in qr["_source"]: qr_eli = qr["_source"]["eligibility"] if float(qr_eli["minimum_age"]) > single_query["age"] or\ single_query["age"] > float(qr_eli["maximum_age"]): continue if qr_eli["gender"].lower().strip() not in [ single_query["gender"].lower(), 'all', 'All' ]: print(qr_eli["gender"].lower()) print(single_query["gender"].lower()) continue # 按照要求格式写文件 f.write("{} Q0 {} {} {} certRI\n".format( single_query["id"], qr["_source"]["id_info"], rank, round(qr["_score"], 4), )) rank += 1 if rank > 1000: break except ValueError as _: print(qr["_source"]["eligibility"]) except KeyError as ke: print(ke) print(qr["_source"]) print("Relative docs:{}".format(rank - 1)) def run(self): self.extract_query() for single_query in self.extracted: print(single_query) self.query(single_query)