예제 #1
0
파일: Search.py 프로젝트: tymrail/certRI
class IRSearch(object):
    def __init__(self):
        self.conf = Conf()
        self.xml_path = self.conf.getConfig('path', 'xml_path')
        self.index_name = self.conf.getConfig('search', 'index_name')
        self.doc_type = self.conf.getConfig('search', 'doc_type')
        self.es = Elasticsearch()
        self.search_body = {}
        self.search_type_support = ['match_all', 'term', 'terms',
                                    'match', 'multi_match', 'bool', 'range', 'prefix', 'wildcard']
        self.search_type_model = self.conf.getSeachModel()
        self.conn = ES('127.0.0.1:9200')
        self.search_result = None
        
        self.conn.default_indices = [self.index_name]

    def makeQuery(self, searchtype, searchfield, keyword, is_sort=False, is_aggs=False, is_multi_match=False, use_bool=""):
        if searchtype not in self.search_type_support:
            print('Ops, your search type is not supported')
            print('Supported search types:\n')
            print(self.search_type_support)
            return
        self.search_body = self.search_type_model[searchtype]
        if is_multi_match:
            self.search_body["query"][searchtype] = {
                "query": keyword,
                "fields": searchfield
            }
        elif use_bool:
            self.search_body["query"][searchtype][use_bool] = [{
                "term": {
                    searchfield: keyword
                }
            }]
        else:
            self.search_body["query"][searchtype][searchfield] = keyword

        print(self.search_body)
        return self.search_body

    # I don't know what I am doing because I'm an idiot.
    def Query(self, searchtype, searchfield, keyword, is_sort=False, is_aggs=False, is_multi_match=False, use_bool=""):
        query_body = self.makeQuery(
            searchtype, searchfield, keyword, is_sort, is_aggs, is_multi_match, use_bool)
        result = self.es.search(index=self.index_name,
                                doc_type=self.doc_type, body=query_body)
        return result

    def querySingle(self, searchfield, keyword):
        q = TermQuery(searchfield, keyword)
        self.search_result = self.conn.search(query=q)
예제 #2
0
파일: test.py 프로젝트: tymrail/certRI
class Test(object):
    def __init__(self):
        self.conf = Conf()
        self.xml_path = self.conf.getConfig("path", "xml_path")
        self.index_name = self.conf.getConfig("search", "index_name")
        self.doc_type = self.conf.getConfig("search", "doc_type")
        self.es = Elasticsearch(timeout=30, max_retries=10, retry_on_timeout=True)
        self.search_body = {}
        self.search_type_support = [
            "match_all",
            "term",
            "terms",
            "match",
            "multi_match",
            "bool",
            "range",
            "prefix",
            "wildcard",
        ]
        self.search_type_model = self.conf.getSeachModel()

    def getCount(self):
        print(self.es.count(index=self.index_name, doc_type=self.doc_type))

    def searchSingle(self):
        res = self.es.search(
            index=self.index_name,
            doc_type=self.doc_type,
            body={
                "query": {"match": {"id_info": "NCT02065063"}},
                "size": 10000,
            },
        )

        # body={"query": {"match": {"detailed_description": "carcinoma"}}},
        # body={"query": {"match": {"id_info": "NCT00001431"}}},

        for r in res["hits"]["hits"]:
            print(r["_source"])
            with open("carcinoma", 'a') as f:
                f.write("{}\n".format(r["_source"]["id_info"]))

    def getPickles(self, pickle_path):
        with open(pickle_path, 'rb') as pf:
            data = pickle.load(pf)
            # pprint.pprint(data)
            return data
예제 #3
0
class DataPreprocessing(object):
    def __init__(self):
        self.conf = Conf()
        self.xml_path = self.conf.getConfig("path", "xml_path")
        self.index_name = self.conf.getConfig("search", "index_name")
        self.doc_type = self.conf.getConfig("search", "doc_type")
        # 读取设定

        self.tokenizer = RegexpTokenizer(r"\w+")
        self.lem = WordNetLemmatizer()
        self.stemmer = PorterStemmer()
        self.stopwords = set(stopwords.words("english"))

        self.es = Elasticsearch()
        self.fields = self.conf.getImportant()
        # self.mapping = self.conf.getMapping()

        # es的index和doc_type相当于mysql的db和table
        # 如果要创建的index已存在,则删除原有index
        if self.es.indices.exists(index=self.index_name):
            self.es.indices.delete(index=self.index_name)

        # 创建index
        self.es.indices.create(index=self.index_name)
        # self.es.indices.put_mapping(index=self.index_name, doc_type=self.doc_type, body=self.mapping)
        print("created index:" + self.index_name)

    def xml2json(self, xmlpath):
        # 将xml数据转化为dict
        with open(xmlpath, "r") as xmlf:
            xml_str = xmlf.read()
            dict_str = xmltodict.parse(xml_str)
            # json_str = json.dumps(dict_str)
            return dict_str

    def cleanData(self, doc):
        raw_tokens = self.tokenizer.tokenize(doc.lower())
        lem_tokens = [self.stemmer.stem(token) for token in raw_tokens]
        lem_tokens = [
            token for token in lem_tokens if not token.isdigit() and len(token) > 1
        ]
        lem_tokens_without_stopword = filter(
            lambda i: i not in self.stopwords, lem_tokens
        )
        return " ".join(list(lem_tokens_without_stopword))

    def clean(self, json_data):
        if json_data["brief_title"]:
            json_data["brief_title"] = self.cleanData(json_data["brief_title"])
        if json_data["official_title"]:
            json_data["official_title"] = self.cleanData(json_data["official_title"])
        if json_data["brief_summary"]:
            json_data["brief_summary"] = self.cleanData(json_data["brief_summary"])
        if json_data["detailed_description"]:
            json_data["detailed_description"] = self.cleanData(
                json_data["detailed_description"]
            )
        if json_data["eligibility"]["criteria"]["textblock"]:
            json_data["eligibility"]["criteria"]["textblock"] = self.cleanData(
                json_data["eligibility"]["criteria"]["textblock"]
            )
        return json_data

    def oswalk(self):
        count = 0

        # 遍历xml_path中所有文件夹下的所有文件
        for os_set in os.walk(self.xml_path, topdown=True):
            for filename in os_set[2]:
                try:
                    filepath = os.path.join(os_set[0], filename)
                    json_data = self.xml2json(filepath)

                    cleaned_json_data = {}

                    default_input_json = {
                        "id_info": "NCT00000000",
                        "brief_title": "",
                        "official_title": "",
                        "brief_summary": "",
                        "detailed_description": "",
                        "intervention": {"intervention_type": "", "intervention_name": ""},
                        "eligibility": {
                            "criteria": {"textblock": ""},
                            "gender": "All",
                            "minimum_age": "6 Months",
                            "maximum_age": "100 Years",
                            "healthy_volunteers": "No",
                        },
                        "keyword": [],
                        "intervention_browse": [],
                        "condition": [],
                    }

                    # 将important.txt中设定好的字段从dict中提取出来,填充到要存进es的dict中
                    for field in self.fields:
                        if field in json_data["clinical_study"]:
                            if len(self.fields[field]) > 1 and not isinstance(
                                json_data["clinical_study"][field], str
                            ):
                                cleaned_json_data[field] = json_data["clinical_study"][
                                    field
                                ][self.fields[field]]
                            else:
                                cleaned_json_data[field] = json_data["clinical_study"][
                                    field
                                ]
                        else:
                            cleaned_json_data[field] = default_input_json[field]
                            # if len(self.fields[field]) > 1 and not isinstance(
                            #     default_input_json[field], str
                            # ):
                            #     cleaned_json_data[field] = default_input_json[field][
                            #         self.fields[field]
                            #     ]
                            # else:
                            #     cleaned_json_data[field] = default_input_json[field]

                    # 处理年龄
                    # print(default_input_json)
                    # print(cleaned_json_data)
                    if "eligibility" in cleaned_json_data:
                        if "criteria" not in cleaned_json_data["eligibility"]:
                            cleaned_json_data["eligibility"]["criteria"] = {"textblock": ""}

                        for k in default_input_json["eligibility"]:
                            if k not in cleaned_json_data["eligibility"]:
                                cleaned_json_data["eligibility"][k] = default_input_json["eligibility"][k]

                        cleaned_json_data["eligibility"] = NormalAge(
                            cleaned_json_data["eligibility"]
                        )

                    cleaned_json_data = self.clean(cleaned_json_data)

                    # ----------------------------------
                    # print(cleaned_json_data)
                    # return
                    # ----------------------------------

                    # 插入数据
                    self.es.index(
                        index=self.index_name,
                        body=cleaned_json_data,
                        doc_type=self.doc_type,
                    )

                    count += 1
                    if count % 1000 == 0:
                        print("Already finished:" + str(count))
                except KeyboardInterrupt:
                    # 处理ctrl+C中断程序的情况
                    print("Interrupted")
                    try:
                        sys.exit(0)
                    except SystemExit:
                        os._exit(0)
                except Exception as e:
                    print(cleaned_json_data)
                    print(e)
                    with open("errorxml.txt", "a") as f:
                        f.write(str(filepath) + "\n")
                    print("Error in ", str(filename))
예제 #4
0
class Query(object):
    def __init__(self):
        self.conf = Conf()
        self.query_xml_path = self.conf.getConfig("path", "query_xml_path")
        self.index_name = self.conf.getConfig("search", "index_name")
        self.doc_type = self.conf.getConfig("search", "doc_type")
        self.meshDict = self.getPickles(
            self.conf.getConfig("path", "dict_pickle_path"))
        self.es = Elasticsearch(timeout=30,
                                max_retries=10,
                                retry_on_timeout=True)
        # 设定es的超时时限为30秒,默认为10秒
        # 最大重试次数为10次
        # 防止因数据量太大导致的超时
        self.fields = self.conf.getImportant()
        self.extracted = []

        self.tokenizer = RegexpTokenizer(r"\w+")
        self.lem = WordNetLemmatizer()
        self.stemmer = PorterStemmer()
        self.stopwords = set(stopwords.words("english"))

    def getPickles(self, pickle_path):
        with open(pickle_path, "rb") as pf:
            data = pickle.load(pf)
            return data

    def xml2json(self, xmlpath):
        with open(xmlpath, "r") as xmlf:
            xml_str = xmlf.read()
            dict_str = xmltodict.parse(xml_str)
            # json_str = json.dumps(dict_str)
            return dict_str

    def extract_query(self):
        # 处理查询字段
        query_xml_data = self.xml2json(self.query_xml_path)["topics"]["topic"]
        for q in query_xml_data:
            new_query = {
                "id": q["@number"],
                "disease": q["disease"],
                "gene": q["gene"],
                "age": int(q["demographic"].split("-")[0]) * 365,
                "gender": q["demographic"].split(" ")[-1],
                "other": q["other"],
            }
            self.extracted.append(new_query)
        with open("query.json", "w") as f:
            f.write(json.dumps(self.extracted, indent=4))

    def cleanData(self, doc):
        raw_tokens = self.tokenizer.tokenize(doc.lower())
        lem_tokens = [self.stemmer.stem(token) for token in raw_tokens]
        lem_tokens = [
            token for token in lem_tokens
            if not token.isdigit() and len(token) > 1
        ]
        lem_tokens_without_stopword = filter(lambda i: i not in self.stopwords,
                                             lem_tokens)
        return list(lem_tokens_without_stopword)

    def query(self, single_query):
        gender_lst = ["male", "female"]
        must_not_gender = gender_lst[abs(
            gender_lst.index(single_query["gender"]) - 1)]
        # 性别分为male,female和All三种,得到不用的一种

        query_keywords = single_query["disease"].lower().split(" ")
        relevence = single_query["disease"].lower().split(" ")

        for qk in query_keywords:
            # qk = qk.lower()
            if qk in self.meshDict and qk not in [
                    "cancer",
                    "adenocarcinoma",
                    "carcinoma",
            ]:
                relevence += self.meshDict[qk]

        if "mesh_numbers" in relevence:
            relevence.remove("mesh_numbers")
        relevence = list(set(self.cleanData(" ".join(relevence))))

        print(single_query["gene"].replace("(",
                                           " ").replace(")",
                                                        " ").replace(",", ""))

        # for rl in relevence:
        #     if rl in ["mesh_numbers", "cancers", "non", "carcinomas", "tumors", "neoplasms", "pseudocysts", "cysts", "vipomas"]:
        #         # print(rl)
        #         relevence.remove(rl)

        relevence_str = " ".join(relevence)
        # print(relevence_str)

        # query_body = {
        #     "query": {
        #         "multi_match": {
        #             "query": (single_query["disease"] + ' ' + single_query["gene"].replace("(", " ").replace(")", " ").replace(",", "")).lower(),
        #             "type": "cross_fields",
        #             "fields": [
        #                 "brief_title",
        #                 "brief_summary",
        #                 "detailed_description",
        #                 "official_title",
        #                 "keyword",
        #                 "condition",
        #                 "eligibility.criteria.textblock",
        #             ],
        #         }
        #     },
        #     "size": 1000,
        # }
        # p5: 0.3586
        # p10:0.3138
        # p15:0.2704
        # with age: p5: 0.3586 p10:0.3172 p15:0.2805
        # with gender: p5: 0.3655 p10:0.3241 p15:0.2920

        query_body = {
            "query": {
                "multi_match": {
                    "query": (single_query["disease"] + ' ' +
                              single_query["gene"].replace("(", " ").replace(
                                  ")", " ").replace(",", "")).lower(),
                    "type":
                    "cross_fields",
                    "fields": [
                        "brief_title",
                        "brief_summary",
                        "detailed_description",
                        "official_title",
                        "keyword",
                        "condition",
                        "eligibility.criteria.textblock",
                    ],
                }
            },
            "size": 1000,
        }

        # query_body = {
        #     "query": {
        #         "multi_match": {
        #             "query": (single_query["gene"].replace("(", " ").replace(")", " ").replace(",", "")).lower(),
        #             "type": "cross_fields",
        #             "fields": [
        #                 "brief_title",
        #                 "brief_summary",
        #                 "detailed_description",
        #                 "official_title",
        #                 "keyword",
        #                 "condition",
        #                 "eligibility.criteria.textblock",
        #             ],
        #         }
        #     },
        #     "size": 1000,
        # }

        # query_standard = (single_query["gene"].replace("(", " ").replace(")", " ").replace(",", "")).lower()

        # query_body = {
        #     "query": {
        #         "bool": {
        #             "should": [
        #                 {"match": {"brief_title": {"query": query_standard, "boost": 2}}},
        #                 {"match": {"official_title": {"query": query_standard, "boost": 2}}},
        #                 {"match": {"brief_summary": {"query": query_standard, "boost": 1}}},
        #                 {"match": {"detailed_description": {"query": query_standard, "boost": 1}}},
        #                 {"match": {"eligibility.criteria.textblock": {"query": query_standard, "boost": 5}}},
        #                 {"match": {"keyword": {"query": query_standard, "boost": 6}}},
        #                 {"match": {"condition": {"query": query_standard, "boost": 3}}},
        #             ],
        #             "must_not": [{"term": {"gender": must_not_gender}}],
        #         },
        #     },
        #     "size": 1500,
        # }
        # 这里的querybody需要再认真设计下,不同的查询方式对最终结果的MAP和P@10影响很大

        query_result = self.es.search(index=self.index_name,
                                      doc_type=self.doc_type,
                                      body=query_body)["hits"]["hits"]
        # 获得查询结果

        # print(query_result)
        # score_max = query_result[0]["_score"]
        rank = 1
        with open("trec_eval/eval/r40.txt", "a") as f:
            try:
                for qr in query_result:
                    # 过滤年龄不符合的情况
                    if "eligibility" in qr["_source"]:
                        qr_eli = qr["_source"]["eligibility"]
                        if float(qr_eli["minimum_age"]) > single_query["age"] or\
                            single_query["age"] > float(qr_eli["maximum_age"]):
                            continue
                        if qr_eli["gender"].lower().strip() not in [
                                single_query["gender"].lower(), 'all', 'All'
                        ]:
                            print(qr_eli["gender"].lower())
                            print(single_query["gender"].lower())
                            continue

                    # 按照要求格式写文件
                    f.write("{} Q0 {} {} {} certRI\n".format(
                        single_query["id"],
                        qr["_source"]["id_info"],
                        rank,
                        round(qr["_score"], 4),
                    ))
                    rank += 1

                    if rank > 1000:
                        break

            except ValueError as _:
                print(qr["_source"]["eligibility"])
            except KeyError as ke:
                print(ke)
                print(qr["_source"])

        print("Relative docs:{}".format(rank - 1))

    def run(self):
        self.extract_query()
        for single_query in self.extracted:
            print(single_query)
            self.query(single_query)
예제 #5
0
파일: w2v.py 프로젝트: tymrail/certRI
class W2V(object):
    def __init__(self):
        self.tokenizer = RegexpTokenizer(r'\w+')
        self.lem = WordNetLemmatizer()
        self.stopwords = set(stopwords.words('english'))
        self.dict = None
        self.corpus = None
        self.bm25model = None
        self.docs_list = []
        self.conf = Conf()
        self.xml_path = self.conf.getConfig('path', 'xml_path')

    def cleanData(self, doc):
        raw_tokens = self.tokenizer.tokenize(doc.lower())
        lem_tokens = [self.lem.lemmatize(token) for token in raw_tokens]
        lem_tokens_without_stopword = filter(lambda i: i not in self.stopwords,
                                             lem_tokens)
        return list(lem_tokens_without_stopword)

    def xml2json(self, xmlpath):
        with open(xmlpath, "r") as xmlf:
            xml_str = xmlf.read()
            dict_str = xmltodict.parse(xml_str)
            # json_str = json.dumps(dict_str)
            return dict_str

    def extractUseful(self, dict_str):
        useful_list = []

        if "official_title" in dict_str["clinical_study"]:
            useful_list.append(dict_str["clinical_study"]["official_title"])
        else:
            useful_list.append(dict_str["clinical_study"]["brief_title"])

        if "brief_summary" in dict_str["clinical_study"]:
            useful_list.append(
                dict_str["clinical_study"]["brief_summary"]["textblock"])

        if "detailed_description" in dict_str["clinical_study"]:
            useful_list.append(dict_str["clinical_study"]
                               ["detailed_description"]["textblock"])

        if "eligibility" in dict_str["clinical_study"]:
            useful_list.append(dict_str["clinical_study"]["eligibility"]
                               ["criteria"]["textblock"])

        return ','.join(useful_list)

    def buildModel(self):
        model = word2vec.Word2Vec(sentences=self.docs_list,
                                  min_count=5,
                                  workers=4)
        model.save("models/w2v.model")

    def run(self):
        count = 0

        for root, _, files in os.walk(self.xml_path, topdown=True):
            for filename in files:
                try:
                    file_path = os.path.join(root, filename)
                    json_data = self.xml2json(file_path)

                    useful_str = self.extractUseful(json_data)
                    useful_tokens = self.cleanData(useful_str)

                    self.docs_list.append(useful_tokens)
                except KeyboardInterrupt:
                    # 处理ctrl+C中断程序的情况
                    print('Interrupted')
                    try:
                        sys.exit(0)
                    except SystemExit:
                        os._exit(0)
                except Exception as e:
                    print(e)
                    with open('error_w2v_xml.txt', 'a') as f:
                        f.write(str(file_path) + '\n')
                    print('Error in ', str(filename))

                count += 1
                if count % 2000 == 0:
                    print("Already finished {}".format(count))

        print("Start build model")
        self.buildModel()