コード例 #1
0
 def store_symptom2tonguezhi(self, filter=False, filterSet=set([])):
     FileUtil.print_string(u"开始存储症状和舌质的关系...", True)
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     treamentMysql = mysqlDb.select('treatment',
                                    'id,symptomIds,tongueZhiId')
     for treamentUnit in treamentMysql:
         syptomeIds = treamentUnit[1].split(",")
         tid = int(treamentUnit[0])
         tongueZhiId = treamentUnit[2]
         if (filter == True and tid not in filterSet):
             continue
         for syptomeId in syptomeIds:
             tongueZhiNodes = neoDb.selectNodeElementsFromDB(
                 self.TONGUEZHITAG,
                 condition=[],
                 properties={self.NODEID: int(tongueZhiId)})
             syptomeNodes = neoDb.selectNodeElementsFromDB(
                 self.SYMPTOMTAG,
                 condition=[],
                 properties={self.NODEID: int(syptomeId)})
             if (len(tongueZhiNodes) > 0 and len(syptomeNodes) > 0):
                 tongueZhiNode = tongueZhiNodes[0]
                 syptomeNode = syptomeNodes[0]
                 # print(tongueZhiNode['name'])
                 # print(syptomeNode['name'])
                 # print("输出症状%s和舌质%s的关系:%d" % (
                 # syptomeNode['name'].encode('utf8'), tongueZhiNode['name'].encode('utf8'), tid))
                 # 如果节点之间已经存在关系了,则权重加1,否则创建关系
                 relations = neoDb.selectRelationshipsFromDB(
                     syptomeNode, self.SYMPTOM2TONGUEZHITAG, tongueZhiNode)
                 if (len(relations) > 0):
                     # print("更新症状%s和舌质%s的关系" % (syptomeNode['name'].encode('utf8'), tongueZhiNode['name'].encode('utf8')))
                     relation = relations[0]
                     tids = relation[self.RELATIONTID]
                     tids.append(tid)
                     ntids = list(set(tids))
                     nweight = len(ntids)
                     neoDb.updateKeyInRelationship(relation,
                                                   properties={
                                                       self.RELATIONTID:
                                                       ntids,
                                                       self.RELATIONWEIGHT:
                                                       nweight
                                                   })
                 else:
                     tids = [tid]
                     weight = 1
                     neoDb.createRelationship(self.SYMPTOM2TONGUEZHITAG,
                                              syptomeNode,
                                              tongueZhiNode,
                                              propertyDic={
                                                  self.RELATIONTID: tids,
                                                  self.RELATIONWEIGHT:
                                                  weight
                                              })
     mysqlDb.close()
コード例 #2
0
 def get_pulse_id_set(self, treatmentIdSet):
     mysqlDb = MysqlOpt()
     treamentMysql = mysqlDb.select('treatment', 'id,pulseId')
     idSet = set([])
     for treamentUnit in treamentMysql:
         tid = int(treamentUnit[0])
         pid = treamentUnit[1]
         if (tid not in treatmentIdSet):
             continue
         idSet.add(pid)
     return idSet
コード例 #3
0
 def get_tonguetai_id_set(self, treatmentIdSet):
     mysqlDb = MysqlOpt()
     treamentMysql = mysqlDb.select('treatment', 'id,tongueTaiId')
     idSet = set([])
     for treamentUnit in treamentMysql:
         tid = int(treamentUnit[0])
         ttid = treamentUnit[1]
         if (tid not in treatmentIdSet):
             continue
         idSet.add(ttid)
     return idSet
コード例 #4
0
 def get_symptom_id_set(self, treatmentIdSet):
     mysqlDb = MysqlOpt()
     treamentMysql = mysqlDb.select('treatment', 'id,symptomIds')
     idSet = set([])
     for treamentUnit in treamentMysql:
         tid = int(treamentUnit[0])
         sids = treamentUnit[1]
         if (tid not in treatmentIdSet):
             continue
         symptomIds = sids.split(",")
         symptomSet = set([int(x) for x in symptomIds])
         idSet = idSet | symptomSet
     return idSet
コード例 #5
0
 def store_pulse(self, filter=False, filterSet=set([])):
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     cql = "MATCH (n:" + self.PULSETAG + ") DETACH DELETE n"
     neoDb.graph.data(cql)
     FileUtil.print_string(u"开始存储脉搏节点...", True)
     mysqlDatas = mysqlDb.select('pulse', 'id,name')
     for unit in mysqlDatas:
         id = unit[0]
         name = unit[1]
         if (filter == True and id not in filterSet):
             continue
         neoDb.createNode([self.PULSETAG], {
             self.NODEID: id,
             self.NODENAME: name
         })
     mysqlDb.close()
コード例 #6
0
 def get_medicine_id_set(self, treatmentIdSet):
     mysqlDb = MysqlOpt()
     treamentMysql = mysqlDb.select('treatment', 'id,prescriptionId')
     idSet = set([])
     for treamentUnit in treamentMysql:
         tid = int(treamentUnit[0])
         pid = treamentUnit[1]
         if (tid not in treatmentIdSet):
             continue
         priscriptionMysql = mysqlDb.select('prescription', 'name',
                                            'id = ' + str(pid))
         if (len(priscriptionMysql) == 0):
             continue
         medicineIds = priscriptionMysql[0][0].split(",")
         # 针对name末尾有逗号的情况进行处理
         lenth = len(medicineIds)
         if (medicineIds[lenth - 1] == ''):
             medicineIds.pop()
         medicineSet = set([int(x) for x in medicineIds])
         idSet = idSet | medicineSet
     return idSet
コード例 #7
0
 def store_medicine(self, filter=False, filterSet=set([])):
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     # 先删除数据库中存在的medicine节点以及和其有关联的关系
     cql = "MATCH (n:" + self.MEDICINETAG + ") DETACH DELETE n"
     neoDb.graph.data(cql)
     FileUtil.print_string(u"开始存储中药节点...", True)
     mysqlDatas = mysqlDb.select('medicine', 'id,name')
     for unit in mysqlDatas:
         id = unit[0]
         name = unit[1]
         if (filter == True and (id not in filterSet)):
             continue
         # print("输出id为%d的药物:%s"%(id,name.encode('utf8')))
         defaultTid = []
         defaultWeight = 0
         neoDb.createNode(
             [self.MEDICINETAG], {
                 self.NODEID: id,
                 self.NODENAME: name,
                 self.NODEWEIGHT: defaultWeight,
                 self.NODETID: defaultTid
             })
     mysqlDb.close()
コード例 #8
0
 def cleanup_data(self):
     mysqlDb = MysqlOpt()
     treamentMysql = mysqlDb.select('treatment', 'id,prescriptionId')
     for treamentUnit in treamentMysql:
         tid = int(treamentUnit[0])
         pid = treamentUnit[1]
         priscriptionMysql = mysqlDb.select('prescription', 'name',
                                            'id = ' + str(pid))
         if (len(priscriptionMysql) == 0):
             mysqlDb.delete('treatment', 'id=' + str(tid))
コード例 #9
0
 def store_medicine_weight(self, filter=False, filterSet=set([])):
     FileUtil.print_string(u"开始存储每个中药节点的出现的次数...", True)
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     treamentMysql = mysqlDb.select('treatment', 'id,prescriptionId')
     for treamentUnit in treamentMysql:
         tid = int(treamentUnit[0])
         # print tid
         pid = int(treamentUnit[1])
         if (filter == True and tid not in filterSet):
             continue
         priscriptionMysql = mysqlDb.select('prescription', 'name',
                                            'id = ' + str(pid))
         if (len(priscriptionMysql) > 0):
             medicineIds = priscriptionMysql[0][0].split(",")
         else:
             continue
         # 针对name末尾有逗号的情况进行处理
         lenth = len(medicineIds)
         if (medicineIds[lenth - 1] == ''):
             medicineIds.pop()
         for medicineId in medicineIds:
             # print "mid"+str(medicineId)
             medicineNodes = neoDb.selectNodeElementsFromDB(
                 self.MEDICINETAG,
                 condition=[],
                 properties={self.NODEID: int(medicineId)})
             if (len(medicineNodes) > 0):
                 medicineNode = medicineNodes[0]
                 tids = []
                 for i in medicineNode[self.NODETID]:
                     tids.append(int(i))
                 tids.append(tid)
                 ntids = list(set(tids))
                 nweight = len(ntids)
                 neoDb.updateKeyInNode(medicineNode,
                                       properties={
                                           self.NODETID: ntids,
                                           self.NODEWEIGHT: nweight
                                       })
     mysqlDb.close()
コード例 #10
0
 def store_tonguezhi2medicine(self, filter=False, filterSet=set([])):
     FileUtil.print_string(u"开始存储舌质和中药的关系...", True)
     mysqlDb = MysqlOpt()
     neoDb = Neo4jOpt()
     treamentMysql = mysqlDb.select('treatment',
                                    'id,tongueZhiId,prescriptionId')
     for treamentUnit in treamentMysql:
         tongueZhiId = treamentUnit[1]
         tid = int(treamentUnit[0])
         if (filter == True and tid not in filterSet):
             continue
         priscriptionMysql = mysqlDb.select('prescription', 'name',
                                            'id = ' + str(treamentUnit[2]))
         if (len(priscriptionMysql) > 0):
             medicineIds = priscriptionMysql[0][0].split(",")
         else:
             continue
         # 针对name末尾有逗号的情况进行处理
         lenth = len(medicineIds)
         if (medicineIds[lenth - 1] == ''):
             medicineIds.pop()
         for medicineId in medicineIds:
             medicineNodes = neoDb.selectNodeElementsFromDB(
                 self.MEDICINETAG,
                 condition=[],
                 properties={self.NODEID: int(medicineId)})
             tongueZhiNodes = neoDb.selectNodeElementsFromDB(
                 self.TONGUEZHITAG,
                 condition=[],
                 properties={self.NODEID: int(tongueZhiId)})
             if (len(medicineNodes) > 0 and len(tongueZhiNodes) > 0):
                 medicineNode = medicineNodes[0]
                 tongueZhiNode = tongueZhiNodes[0]
                 # print(medicineNode['name'])
                 # print(tongueZhiNode['name'])
                 # print("输出舌质%s和药物%s的关系:%d"%(tongueZhiNode['name'].encode('utf8'),medicineNode['name'].encode('utf8'),tid))
                 # 如果节点之间已经存在关系了,则权重加1,否则创建关系
                 relations = neoDb.selectRelationshipsFromDB(
                     tongueZhiNode, self.TONGUEZHI2MEDICINETAG,
                     medicineNode)
                 if (len(relations) > 0):
                     # print("更新舌质%s和药物%s的关系" % (tongueZhiNode['name'].encode('utf8'), medicineNode['name'].encode('utf8')))
                     relation = relations[0]
                     tids = relation[self.RELATIONTID]
                     tids.append(tid)
                     ntids = list(set(tids))
                     nweight = len(ntids)
                     neoDb.updateKeyInRelationship(relation,
                                                   properties={
                                                       self.RELATIONTID:
                                                       ntids,
                                                       self.RELATIONWEIGHT:
                                                       nweight
                                                   })
                 else:
                     tids = [tid]
                     weight = 1
                     neoDb.createRelationship(self.TONGUEZHI2MEDICINETAG,
                                              tongueZhiNode,
                                              medicineNode,
                                              propertyDic={
                                                  self.RELATIONTID: tids,
                                                  self.RELATIONWEIGHT:
                                                  weight
                                              })
     mysqlDb.close()
コード例 #11
0
ファイル: naive_bayes_dm.py プロジェクト: chenhehong/TCM_KG
 def cross_validation(self):
     preFile = 'pre.txt'
     randomState = 51
     mysqlOpt = MysqlOpt()
     treamentIds = mysqlOpt.select('treatment', 'id')
     tidList = []
     for tid in treamentIds:
         tidList.append(tid[0])
     kf = KFold(len(tidList),
                n_folds=self.CRVNUMFOLDER,
                shuffle=True,
                random_state=randomState)
     numFold = 0
     for trainIndex, testIndex in kf:
         numFold += 1
         FileUtil.print_string(str(numFold) + 'th validation:', True)
         trainList = []
         for i in trainIndex:
             trainList.append(tidList[i])
         self.createGraph(filter=True, treamentIdfilterSet=set(trainList))
         yTrue = []
         yPre = []
         FileUtil.add_string(preFile, str(numFold) + 'th validation:')
         for i in testIndex:
             print str(tidList[i])
             perTreament = mysqlOpt.select(
                 'treatment',
                 'prescriptionId,symptomIds,tongueZhiId,tongueTaiId,pulseId',
                 'id=' + str(tidList[i]))
             pids = perTreament[0][0]
             symptomIds = perTreament[0][1]
             tZhiId = perTreament[0][2]
             tTaiId = perTreament[0][3]
             pulseId = perTreament[0][4]
             perPrecription = mysqlOpt.select('prescription', 'name',
                                              'id = ' + str(pids))
             mids = perPrecription[0][0]
             medicines = mysqlOpt.select('medicine', 'name',
                                         'id IN (' + mids + ')')
             trueMedicineList = []
             for m in medicines:
                 trueMedicineList.append(m[0])
             yTrue.append(trueMedicineList)
             symptoms = mysqlOpt.select('symptom', 'name',
                                        'id IN (' + symptomIds + ')')
             symptomList = []
             for s in symptoms:
                 symptomList.append(s[0])
             tongueZhiSql = mysqlOpt.select('tongueZhi', 'name',
                                            'id = ' + str(tZhiId))
             if (len(tongueZhiSql) > 0):
                 tongueZhi = tongueZhiSql[0][0]
             else:
                 tongueZhi = ''
             tongueTaiSql = mysqlOpt.select('tongueTai', 'name',
                                            'id = ' + str(tTaiId))
             if (len(tongueTaiSql) > 0):
                 tongueTai = tongueTaiSql[0][0]
             else:
                 tongueTai = ''
             pulseSql = mysqlOpt.select('pulse', 'name',
                                        'id = ' + str(pulseId))
             if (len(pulseSql)):
                 pulse = pulseSql[0][0]
             else:
                 pulse = ''
             preMedicineList = self.input_multiple_syptomes(
                 symptomList, tongueZhi, tongueTai, pulse)
             preString = str(tidList[i]) + ' '
             for m in preMedicineList:
                 preString = preString + m + ','
             FileUtil.add_string(preFile, preString)
             yPre.append(preMedicineList)
         precision, recall = self.classification_evaluate(yTrue, yPre)
         FileUtil.print_string('precision:' + str(precision), True)
         FileUtil.print_string('recall:' + str(recall), True)
         F1 = 2 * precision * recall / (precision + recall)
         FileUtil.print_string('F1:' + str(F1), True)