Exemplo n.º 1
0
 def createGraph(self, filter=False, treamentIdfilterSet=set([])):
     FileUtil.print_string(u"开始创建图谱...", True)
     mysql2Neo = Mysql2Neo4j()
     mysql2Neo.store_medicine(
         filter=filter,
         filterSet=mysql2Neo.get_medicine_id_set(treamentIdfilterSet))
     mysql2Neo.store_symptom(
         filter=filter,
         filterSet=mysql2Neo.get_symptom_id_set(treamentIdfilterSet))
     mysql2Neo.store_tongue_tai(
         filter=filter,
         filterSet=mysql2Neo.get_tonguetai_id_set(treamentIdfilterSet))
     mysql2Neo.store_tongue_zhi(
         filter=filter,
         filterSet=mysql2Neo.get_tonguezhi_id_set(treamentIdfilterSet))
     mysql2Neo.store_pulse(
         filter=filter,
         filterSet=mysql2Neo.get_pulse_id_set(treamentIdfilterSet))
     mysql2Neo.store_medicine_weight(filter=filter,
                                     filterSet=treamentIdfilterSet)
     mysql2Neo.store_symptom2medicine(filter=filter,
                                      filterSet=treamentIdfilterSet)
     mysql2Neo.store_tonguetai2medicine(filter=filter,
                                        filterSet=treamentIdfilterSet)
     mysql2Neo.store_tonguezhi2medicine(filter=filter,
                                        filterSet=treamentIdfilterSet)
     mysql2Neo.store_pulse2medicine(filter=filter,
                                    filterSet=treamentIdfilterSet)
Exemplo n.º 2
0
 def store_symptom2tonguezhi(self, filter=False, filterSet=set([])):
     FileUtil.print_string(u"开始存储症状和舌质的关系...", True)
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     treamentMysql = mysqlDb.select('treatment',
                                    'id,symptomIds,tongueZhiId')
     for treamentUnit in treamentMysql:
         syptomeIds = treamentUnit[1].split(",")
         tid = int(treamentUnit[0])
         tongueZhiId = treamentUnit[2]
         if (filter == True and tid not in filterSet):
             continue
         for syptomeId in syptomeIds:
             tongueZhiNodes = neoDb.selectNodeElementsFromDB(
                 self.TONGUEZHITAG,
                 condition=[],
                 properties={self.NODEID: int(tongueZhiId)})
             syptomeNodes = neoDb.selectNodeElementsFromDB(
                 self.SYMPTOMTAG,
                 condition=[],
                 properties={self.NODEID: int(syptomeId)})
             if (len(tongueZhiNodes) > 0 and len(syptomeNodes) > 0):
                 tongueZhiNode = tongueZhiNodes[0]
                 syptomeNode = syptomeNodes[0]
                 # print(tongueZhiNode['name'])
                 # print(syptomeNode['name'])
                 # print("输出症状%s和舌质%s的关系:%d" % (
                 # syptomeNode['name'].encode('utf8'), tongueZhiNode['name'].encode('utf8'), tid))
                 # 如果节点之间已经存在关系了,则权重加1,否则创建关系
                 relations = neoDb.selectRelationshipsFromDB(
                     syptomeNode, self.SYMPTOM2TONGUEZHITAG, tongueZhiNode)
                 if (len(relations) > 0):
                     # print("更新症状%s和舌质%s的关系" % (syptomeNode['name'].encode('utf8'), tongueZhiNode['name'].encode('utf8')))
                     relation = relations[0]
                     tids = relation[self.RELATIONTID]
                     tids.append(tid)
                     ntids = list(set(tids))
                     nweight = len(ntids)
                     neoDb.updateKeyInRelationship(relation,
                                                   properties={
                                                       self.RELATIONTID:
                                                       ntids,
                                                       self.RELATIONWEIGHT:
                                                       nweight
                                                   })
                 else:
                     tids = [tid]
                     weight = 1
                     neoDb.createRelationship(self.SYMPTOM2TONGUEZHITAG,
                                              syptomeNode,
                                              tongueZhiNode,
                                              propertyDic={
                                                  self.RELATIONTID: tids,
                                                  self.RELATIONWEIGHT:
                                                  weight
                                              })
     mysqlDb.close()
Exemplo n.º 3
0
 def store_pulse(self, filter=False, filterSet=set([])):
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     cql = "MATCH (n:" + self.PULSETAG + ") DETACH DELETE n"
     neoDb.graph.data(cql)
     FileUtil.print_string(u"开始存储脉搏节点...", True)
     mysqlDatas = mysqlDb.select('pulse', 'id,name')
     for unit in mysqlDatas:
         id = unit[0]
         name = unit[1]
         if (filter == True and id not in filterSet):
             continue
         neoDb.createNode([self.PULSETAG], {
             self.NODEID: id,
             self.NODENAME: name
         })
     mysqlDb.close()
Exemplo n.º 4
0
 def classification_evaluate(self, yTrue, yPre):
     testNum = len(yTrue)
     totalPrecision = 0
     totalRecall = 0
     for i in range(testNum):
         l = []
         l.append(yTrue[i])
         l.append(yPre[i])
         interList = ListUtil.list_intersect(l)
         perPrecision = round(len(interList) * 1.0 / len(yPre[i]), 5)
         perRecall = round(len(interList) * 1.0 / len(yTrue[i]), 5)
         FileUtil.print_string(
             str(i) + 'th precision:' + str(perPrecision) + ',recall:' +
             str(perRecall), self.LOGFILE, True)
         totalPrecision += perPrecision
         totalRecall += perRecall
     precison = round(totalPrecision / testNum, 5)
     recall = round(totalRecall / testNum, 5)
     return precison, recall
Exemplo n.º 5
0
 def store_medicine_weight(self, filter=False, filterSet=set([])):
     FileUtil.print_string(u"开始存储每个中药节点的出现的次数...", True)
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     treamentMysql = mysqlDb.select('treatment', 'id,prescriptionId')
     for treamentUnit in treamentMysql:
         tid = int(treamentUnit[0])
         # print tid
         pid = int(treamentUnit[1])
         if (filter == True and tid not in filterSet):
             continue
         priscriptionMysql = mysqlDb.select('prescription', 'name',
                                            'id = ' + str(pid))
         if (len(priscriptionMysql) > 0):
             medicineIds = priscriptionMysql[0][0].split(",")
         else:
             continue
         # 针对name末尾有逗号的情况进行处理
         lenth = len(medicineIds)
         if (medicineIds[lenth - 1] == ''):
             medicineIds.pop()
         for medicineId in medicineIds:
             # print "mid"+str(medicineId)
             medicineNodes = neoDb.selectNodeElementsFromDB(
                 self.MEDICINETAG,
                 condition=[],
                 properties={self.NODEID: int(medicineId)})
             if (len(medicineNodes) > 0):
                 medicineNode = medicineNodes[0]
                 tids = []
                 for i in medicineNode[self.NODETID]:
                     tids.append(int(i))
                 tids.append(tid)
                 ntids = list(set(tids))
                 nweight = len(ntids)
                 neoDb.updateKeyInNode(medicineNode,
                                       properties={
                                           self.NODETID: ntids,
                                           self.NODEWEIGHT: nweight
                                       })
     mysqlDb.close()
Exemplo n.º 6
0
 def store_medicine(self, filter=False, filterSet=set([])):
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     # 先删除数据库中存在的medicine节点以及和其有关联的关系
     cql = "MATCH (n:" + self.MEDICINETAG + ") DETACH DELETE n"
     neoDb.graph.data(cql)
     FileUtil.print_string(u"开始存储中药节点...", True)
     mysqlDatas = mysqlDb.select('medicine', 'id,name')
     for unit in mysqlDatas:
         id = unit[0]
         name = unit[1]
         if (filter == True and (id not in filterSet)):
             continue
         # print("输出id为%d的药物:%s"%(id,name.encode('utf8')))
         defaultTid = []
         defaultWeight = 0
         neoDb.createNode(
             [self.MEDICINETAG], {
                 self.NODEID: id,
                 self.NODENAME: name,
                 self.NODEWEIGHT: defaultWeight,
                 self.NODETID: defaultTid
             })
     mysqlDb.close()
Exemplo n.º 7
0
 def store_tonguezhi2medicine(self, filter=False, filterSet=set([])):
     FileUtil.print_string(u"开始存储舌质和中药的关系...", True)
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     treamentMysql = mysqlDb.select('treatment',
                                    'id,tongueZhiId,prescriptionId')
     for treamentUnit in treamentMysql:
         tongueZhiId = treamentUnit[1]
         tid = int(treamentUnit[0])
         if (filter == True and tid not in filterSet):
             continue
         priscriptionMysql = mysqlDb.select('prescription', 'name',
                                            'id = ' + str(treamentUnit[2]))
         if (len(priscriptionMysql) > 0):
             medicineIds = priscriptionMysql[0][0].split(",")
         else:
             continue
         # 针对name末尾有逗号的情况进行处理
         lenth = len(medicineIds)
         if (medicineIds[lenth - 1] == ''):
             medicineIds.pop()
         for medicineId in medicineIds:
             medicineNodes = neoDb.selectNodeElementsFromDB(
                 self.MEDICINETAG,
                 condition=[],
                 properties={self.NODEID: int(medicineId)})
             tongueZhiNodes = neoDb.selectNodeElementsFromDB(
                 self.TONGUEZHITAG,
                 condition=[],
                 properties={self.NODEID: int(tongueZhiId)})
             if (len(medicineNodes) > 0 and len(tongueZhiNodes) > 0):
                 medicineNode = medicineNodes[0]
                 tongueZhiNode = tongueZhiNodes[0]
                 # print(medicineNode['name'])
                 # print(tongueZhiNode['name'])
                 # print("输出舌质%s和药物%s的关系:%d"%(tongueZhiNode['name'].encode('utf8'),medicineNode['name'].encode('utf8'),tid))
                 # 如果节点之间已经存在关系了,则权重加1,否则创建关系
                 relations = neoDb.selectRelationshipsFromDB(
                     tongueZhiNode, self.TONGUEZHI2MEDICINETAG,
                     medicineNode)
                 if (len(relations) > 0):
                     # print("更新舌质%s和药物%s的关系" % (tongueZhiNode['name'].encode('utf8'), medicineNode['name'].encode('utf8')))
                     relation = relations[0]
                     tids = relation[self.RELATIONTID]
                     tids.append(tid)
                     ntids = list(set(tids))
                     nweight = len(ntids)
                     neoDb.updateKeyInRelationship(relation,
                                                   properties={
                                                       self.RELATIONTID:
                                                       ntids,
                                                       self.RELATIONWEIGHT:
                                                       nweight
                                                   })
                 else:
                     tids = [tid]
                     weight = 1
                     neoDb.createRelationship(self.TONGUEZHI2MEDICINETAG,
                                              tongueZhiNode,
                                              medicineNode,
                                              propertyDic={
                                                  self.RELATIONTID: tids,
                                                  self.RELATIONWEIGHT:
                                                  weight
                                              })
     mysqlDb.close()
Exemplo n.º 8
0
 def cross_validation(self, in_file):
     preFile = '../data/pre.txt'
     randomState = 43
     idList = []
     symptomList = []
     medicineList = []
     with codecs.open(in_file, 'r', encoding='utf-8') as f:
         for line in f:
             splits = line.split("||")
             symptoms = splits[1].split(',')
             medicines = splits[2].split(',')
             idList.append(splits[0])
             symptomList.append(symptoms)
             medicineList.append(medicines)
     # 临时用全部的数据构建图数据库
     # trainIndex = [x for x in range(len(idList))]
     # self.createGraph([idList[i] for i in trainIndex],[symptomList[i] for i in trainIndex],[medicineList[i] for i in trainIndex])
     kf = KFold(len(idList),
                n_folds=self.CRVNUMFOLDER,
                shuffle=True,
                random_state=randomState)
     numFold = 0
     for trainIndex, testIndex in kf:
         numFold += 1
         FileUtil.print_string(str(numFold) + 'th validation:', True)
         # self.createGraph(idList[trainIndex],symptomList[trainIndex],medicineList[trainIndex])
         yTrue = []
         yPre = []
         FileUtil.print_string(
             str(numFold) + 'th validation:', self.LOGFILE, True)
         for i in testIndex:
             testId = idList[i]
             print "**********************************************"
             print str(testId)
             preMedicineList = self.predict_medicine(symptomList[i])
             if (len(preMedicineList) == 0):
                 continue
             preString = str(testId) + ' '
             for m in preMedicineList:
                 preString = preString + m + ','
             FileUtil.add_string(preFile, preString)
             yTrue.append(medicineList[i])
             yPre.append(preMedicineList)
         precision, recall = self.classification_evaluate(yTrue, yPre)
         FileUtil.print_string('precision:' + str(precision), self.LOGFILE,
                               True)
         FileUtil.print_string('recall:' + str(recall), self.LOGFILE, True)
         F1 = 2 * precision * recall / (precision + recall)
         FileUtil.print_string('F1:' + str(F1), self.LOGFILE, True)
Exemplo n.º 9
0
 def createGraph(self, idTrainList, symptomTrainList, medicineTrainList):
     FileUtil.print_string(u"开始创建图谱...", self.LOGFILE, True)
     neoDb = Neo4jOpt()
     neoDb.graph.data('MATCH (n) DETACH DELETE n')
     for i in range(len(idTrainList)):
         print("读取id:" + str(idTrainList[i]))
         tid = idTrainList[i]
         for s in symptomTrainList[i]:
             symptomNode = neoDb.selectFirstNodeElementsFromDB(
                 self.SYMPTOMTAG, properties={self.NODENAME: s})
             if (symptomNode == None):
                 tids = [tid]
                 weight = 1
                 neoDb.createNode(
                     [self.SYMPTOMTAG], {
                         self.NODENAME: s,
                         self.NODETID: tids,
                         self.NODEWEIGHT: weight
                     })
             else:
                 tids = symptomNode[self.NODETID]
                 tids.append(tid)
                 ntids = list(set(tids))
                 nweight = len(ntids)
                 neoDb.updateKeyInNode(symptomNode, {
                     self.NODETID: ntids,
                     self.NODEWEIGHT: nweight
                 })
         for m in medicineTrainList[i]:
             medicineNode = neoDb.selectFirstNodeElementsFromDB(
                 self.MEDICINETAG, properties={self.NODENAME: m})
             if (medicineNode == None):
                 tids = [tid]
                 weight = 1
                 neoDb.createNode(
                     [self.MEDICINETAG], {
                         self.NODENAME: m,
                         self.NODETID: tids,
                         self.NODEWEIGHT: weight
                     })
             else:
                 tids = medicineNode[self.NODETID]
                 tids.append(tid)
                 ntids = list(set(tids))
                 nweight = len(ntids)
                 neoDb.updateKeyInNode(medicineNode, {
                     self.NODETID: ntids,
                     self.NODEWEIGHT: nweight
                 })
         symptoms = symptomTrainList[i]
         for j in range(len(symptoms)):
             s = symptoms[j]
             symptomNode = neoDb.selectFirstNodeElementsFromDB(
                 self.SYMPTOMTAG, properties={self.NODENAME: s})
             # 症状->中药
             for m in medicineTrainList[i]:
                 medicineNode = neoDb.selectFirstNodeElementsFromDB(
                     self.MEDICINETAG, properties={self.NODENAME: m})
                 s2mrel = neoDb.selectFirstRelationshipsFromDB(
                     symptomNode, self.SYMPTOM2MEDICINETAG, medicineNode)
                 if (s2mrel == None):
                     tids = [tid]
                     weight = 1
                     neoDb.createRelationship(self.SYMPTOM2MEDICINETAG,
                                              symptomNode,
                                              medicineNode,
                                              propertyDic={
                                                  self.RELATIONTID: tids,
                                                  self.RELATIONWEIGHT:
                                                  weight
                                              })
                 else:
                     tids = s2mrel[self.RELATIONTID]
                     tids.append(tid)
                     ntids = list(set(tids))
                     nweight = len(ntids)
                     neoDb.updateKeyInRelationship(s2mrel,
                                                   properties={
                                                       self.RELATIONTID:
                                                       ntids,
                                                       self.RELATIONWEIGHT:
                                                       nweight
                                                   })
             if (j == len(symptoms) - 1):
                 continue
             # 症状-症状
             for k in range(j + 1, len(symptoms)):
                 s1 = symptoms[k]
                 if (s == s1):
                     continue
                 symptomNode1 = neoDb.selectFirstNodeElementsFromDB(
                     self.SYMPTOMTAG, properties={self.NODENAME: s1})
                 s2srel = neoDb.selectFirstRelationshipsFromDB(
                     symptomNode,
                     self.SYMPTOM2SYMPTOMTAG,
                     symptomNode1,
                     bidirectional=True)
                 if (s2srel == None):
                     tids = [tid]
                     weight = 1
                     neoDb.createRelationship(self.SYMPTOM2SYMPTOMTAG,
                                              symptomNode,
                                              symptomNode1,
                                              propertyDic={
                                                  self.RELATIONTID: tids,
                                                  self.RELATIONWEIGHT:
                                                  weight
                                              })
                 else:
                     tids = s2srel[self.RELATIONTID]
                     tids.append(tid)
                     ntids = list(set(tids))
                     nweight = len(ntids)
                     neoDb.updateKeyInRelationship(s2srel,
                                                   properties={
                                                       self.RELATIONTID:
                                                       ntids,
                                                       self.RELATIONWEIGHT:
                                                       nweight
                                                   })
     self.relationFilter()
Exemplo n.º 10
0
 def cross_validation(self):
     preFile = 'pre.txt'
     randomState = 51
     mysqlOpt = MysqlOpt()
     treamentIds = mysqlOpt.select('treatment', 'id')
     tidList = []
     for tid in treamentIds:
         tidList.append(tid[0])
     kf = KFold(len(tidList),
                n_folds=self.CRVNUMFOLDER,
                shuffle=True,
                random_state=randomState)
     numFold = 0
     for trainIndex, testIndex in kf:
         numFold += 1
         FileUtil.print_string(str(numFold) + 'th validation:', True)
         trainList = []
         for i in trainIndex:
             trainList.append(tidList[i])
         self.createGraph(filter=True, treamentIdfilterSet=set(trainList))
         yTrue = []
         yPre = []
         FileUtil.add_string(preFile, str(numFold) + 'th validation:')
         for i in testIndex:
             print str(tidList[i])
             perTreament = mysqlOpt.select(
                 'treatment',
                 'prescriptionId,symptomIds,tongueZhiId,tongueTaiId,pulseId',
                 'id=' + str(tidList[i]))
             pids = perTreament[0][0]
             symptomIds = perTreament[0][1]
             tZhiId = perTreament[0][2]
             tTaiId = perTreament[0][3]
             pulseId = perTreament[0][4]
             perPrecription = mysqlOpt.select('prescription', 'name',
                                              'id = ' + str(pids))
             mids = perPrecription[0][0]
             medicines = mysqlOpt.select('medicine', 'name',
                                         'id IN (' + mids + ')')
             trueMedicineList = []
             for m in medicines:
                 trueMedicineList.append(m[0])
             yTrue.append(trueMedicineList)
             symptoms = mysqlOpt.select('symptom', 'name',
                                        'id IN (' + symptomIds + ')')
             symptomList = []
             for s in symptoms:
                 symptomList.append(s[0])
             tongueZhiSql = mysqlOpt.select('tongueZhi', 'name',
                                            'id = ' + str(tZhiId))
             if (len(tongueZhiSql) > 0):
                 tongueZhi = tongueZhiSql[0][0]
             else:
                 tongueZhi = ''
             tongueTaiSql = mysqlOpt.select('tongueTai', 'name',
                                            'id = ' + str(tTaiId))
             if (len(tongueTaiSql) > 0):
                 tongueTai = tongueTaiSql[0][0]
             else:
                 tongueTai = ''
             pulseSql = mysqlOpt.select('pulse', 'name',
                                        'id = ' + str(pulseId))
             if (len(pulseSql)):
                 pulse = pulseSql[0][0]
             else:
                 pulse = ''
             preMedicineList = self.input_multiple_syptomes(
                 symptomList, tongueZhi, tongueTai, pulse)
             preString = str(tidList[i]) + ' '
             for m in preMedicineList:
                 preString = preString + m + ','
             FileUtil.add_string(preFile, preString)
             yPre.append(preMedicineList)
         precision, recall = self.classification_evaluate(yTrue, yPre)
         FileUtil.print_string('precision:' + str(precision), True)
         FileUtil.print_string('recall:' + str(recall), True)
         F1 = 2 * precision * recall / (precision + recall)
         FileUtil.print_string('F1:' + str(F1), True)