def cleanup_data(self): mysqlDb = MysqlOpt() treamentMysql = mysqlDb.select('treatment', 'id,prescriptionId') for treamentUnit in treamentMysql: tid = int(treamentUnit[0]) pid = treamentUnit[1] priscriptionMysql = mysqlDb.select('prescription', 'name', 'id = ' + str(pid)) if (len(priscriptionMysql) == 0): mysqlDb.delete('treatment', 'id=' + str(tid))
def store_symptom2tonguezhi(self, filter=False, filterSet=set([])): FileUtil.print_string(u"开始存储症状和舌质的关系...", True) mysqlDb = MysqlOpt() neoDb = Neo4jOpt() treamentMysql = mysqlDb.select('treatment', 'id,symptomIds,tongueZhiId') for treamentUnit in treamentMysql: syptomeIds = treamentUnit[1].split(",") tid = int(treamentUnit[0]) tongueZhiId = treamentUnit[2] if (filter == True and tid not in filterSet): continue for syptomeId in syptomeIds: tongueZhiNodes = neoDb.selectNodeElementsFromDB( self.TONGUEZHITAG, condition=[], properties={self.NODEID: int(tongueZhiId)}) syptomeNodes = neoDb.selectNodeElementsFromDB( self.SYMPTOMTAG, condition=[], properties={self.NODEID: int(syptomeId)}) if (len(tongueZhiNodes) > 0 and len(syptomeNodes) > 0): tongueZhiNode = tongueZhiNodes[0] syptomeNode = syptomeNodes[0] # print(tongueZhiNode['name']) # print(syptomeNode['name']) # print("输出症状%s和舌质%s的关系:%d" % ( # syptomeNode['name'].encode('utf8'), tongueZhiNode['name'].encode('utf8'), tid)) # 如果节点之间已经存在关系了,则权重加1,否则创建关系 relations = neoDb.selectRelationshipsFromDB( syptomeNode, self.SYMPTOM2TONGUEZHITAG, tongueZhiNode) if (len(relations) > 0): # print("更新症状%s和舌质%s的关系" % (syptomeNode['name'].encode('utf8'), tongueZhiNode['name'].encode('utf8'))) relation = relations[0] tids = relation[self.RELATIONTID] tids.append(tid) ntids = list(set(tids)) nweight = len(ntids) neoDb.updateKeyInRelationship(relation, properties={ self.RELATIONTID: ntids, self.RELATIONWEIGHT: nweight }) else: tids = [tid] weight = 1 neoDb.createRelationship(self.SYMPTOM2TONGUEZHITAG, syptomeNode, tongueZhiNode, propertyDic={ self.RELATIONTID: tids, self.RELATIONWEIGHT: weight }) mysqlDb.close()
def get_pulse_id_set(self, treatmentIdSet): mysqlDb = MysqlOpt() treamentMysql = mysqlDb.select('treatment', 'id,pulseId') idSet = set([]) for treamentUnit in treamentMysql: tid = int(treamentUnit[0]) pid = treamentUnit[1] if (tid not in treatmentIdSet): continue idSet.add(pid) return idSet
def get_tonguetai_id_set(self, treatmentIdSet): mysqlDb = MysqlOpt() treamentMysql = mysqlDb.select('treatment', 'id,tongueTaiId') idSet = set([]) for treamentUnit in treamentMysql: tid = int(treamentUnit[0]) ttid = treamentUnit[1] if (tid not in treatmentIdSet): continue idSet.add(ttid) return idSet
def store_medicine_weight(self, filter=False, filterSet=set([])): FileUtil.print_string(u"开始存储每个中药节点的出现的次数...", True) mysqlDb = MysqlOpt() neoDb = Neo4jOpt() treamentMysql = mysqlDb.select('treatment', 'id,prescriptionId') for treamentUnit in treamentMysql: tid = int(treamentUnit[0]) # print tid pid = int(treamentUnit[1]) if (filter == True and tid not in filterSet): continue priscriptionMysql = mysqlDb.select('prescription', 'name', 'id = ' + str(pid)) if (len(priscriptionMysql) > 0): medicineIds = priscriptionMysql[0][0].split(",") else: continue # 针对name末尾有逗号的情况进行处理 lenth = len(medicineIds) if (medicineIds[lenth - 1] == ''): medicineIds.pop() for medicineId in medicineIds: # print "mid"+str(medicineId) medicineNodes = neoDb.selectNodeElementsFromDB( self.MEDICINETAG, condition=[], properties={self.NODEID: int(medicineId)}) if (len(medicineNodes) > 0): medicineNode = medicineNodes[0] tids = [] for i in medicineNode[self.NODETID]: tids.append(int(i)) tids.append(tid) ntids = list(set(tids)) nweight = len(ntids) neoDb.updateKeyInNode(medicineNode, properties={ self.NODETID: ntids, self.NODEWEIGHT: nweight }) mysqlDb.close()
def get_medicine_id_set(self, treatmentIdSet): mysqlDb = MysqlOpt() treamentMysql = mysqlDb.select('treatment', 'id,prescriptionId') idSet = set([]) for treamentUnit in treamentMysql: tid = int(treamentUnit[0]) pid = treamentUnit[1] if (tid not in treatmentIdSet): continue priscriptionMysql = mysqlDb.select('prescription', 'name', 'id = ' + str(pid)) if (len(priscriptionMysql) == 0): continue medicineIds = priscriptionMysql[0][0].split(",") # 针对name末尾有逗号的情况进行处理 lenth = len(medicineIds) if (medicineIds[lenth - 1] == ''): medicineIds.pop() medicineSet = set([int(x) for x in medicineIds]) idSet = idSet | medicineSet return idSet
def get_symptom_id_set(self, treatmentIdSet): mysqlDb = MysqlOpt() treamentMysql = mysqlDb.select('treatment', 'id,symptomIds') idSet = set([]) for treamentUnit in treamentMysql: tid = int(treamentUnit[0]) sids = treamentUnit[1] if (tid not in treatmentIdSet): continue symptomIds = sids.split(",") symptomSet = set([int(x) for x in symptomIds]) idSet = idSet | symptomSet return idSet
def store_pulse(self, filter=False, filterSet=set([])): mysqlDb = MysqlOpt() neoDb = Neo4jOpt() cql = "MATCH (n:" + self.PULSETAG + ") DETACH DELETE n" neoDb.graph.data(cql) FileUtil.print_string(u"开始存储脉搏节点...", True) mysqlDatas = mysqlDb.select('pulse', 'id,name') for unit in mysqlDatas: id = unit[0] name = unit[1] if (filter == True and id not in filterSet): continue neoDb.createNode([self.PULSETAG], { self.NODEID: id, self.NODENAME: name }) mysqlDb.close()
def store_medicine(self, filter=False, filterSet=set([])): mysqlDb = MysqlOpt() neoDb = Neo4jOpt() # 先删除数据库中存在的medicine节点以及和其有关联的关系 cql = "MATCH (n:" + self.MEDICINETAG + ") DETACH DELETE n" neoDb.graph.data(cql) FileUtil.print_string(u"开始存储中药节点...", True) mysqlDatas = mysqlDb.select('medicine', 'id,name') for unit in mysqlDatas: id = unit[0] name = unit[1] if (filter == True and (id not in filterSet)): continue # print("输出id为%d的药物:%s"%(id,name.encode('utf8'))) defaultTid = [] defaultWeight = 0 neoDb.createNode( [self.MEDICINETAG], { self.NODEID: id, self.NODENAME: name, self.NODEWEIGHT: defaultWeight, self.NODETID: defaultTid }) mysqlDb.close()
def store_tonguezhi2medicine(self, filter=False, filterSet=set([])): FileUtil.print_string(u"开始存储舌质和中药的关系...", True) mysqlDb = MysqlOpt() neoDb = Neo4jOpt() treamentMysql = mysqlDb.select('treatment', 'id,tongueZhiId,prescriptionId') for treamentUnit in treamentMysql: tongueZhiId = treamentUnit[1] tid = int(treamentUnit[0]) if (filter == True and tid not in filterSet): continue priscriptionMysql = mysqlDb.select('prescription', 'name', 'id = ' + str(treamentUnit[2])) if (len(priscriptionMysql) > 0): medicineIds = priscriptionMysql[0][0].split(",") else: continue # 针对name末尾有逗号的情况进行处理 lenth = len(medicineIds) if (medicineIds[lenth - 1] == ''): medicineIds.pop() for medicineId in medicineIds: medicineNodes = neoDb.selectNodeElementsFromDB( self.MEDICINETAG, condition=[], properties={self.NODEID: int(medicineId)}) tongueZhiNodes = neoDb.selectNodeElementsFromDB( self.TONGUEZHITAG, condition=[], properties={self.NODEID: int(tongueZhiId)}) if (len(medicineNodes) > 0 and len(tongueZhiNodes) > 0): medicineNode = medicineNodes[0] tongueZhiNode = tongueZhiNodes[0] # print(medicineNode['name']) # print(tongueZhiNode['name']) # print("输出舌质%s和药物%s的关系:%d"%(tongueZhiNode['name'].encode('utf8'),medicineNode['name'].encode('utf8'),tid)) # 如果节点之间已经存在关系了,则权重加1,否则创建关系 relations = neoDb.selectRelationshipsFromDB( tongueZhiNode, self.TONGUEZHI2MEDICINETAG, medicineNode) if (len(relations) > 0): # print("更新舌质%s和药物%s的关系" % (tongueZhiNode['name'].encode('utf8'), medicineNode['name'].encode('utf8'))) relation = relations[0] tids = relation[self.RELATIONTID] tids.append(tid) ntids = list(set(tids)) nweight = len(ntids) neoDb.updateKeyInRelationship(relation, properties={ self.RELATIONTID: ntids, self.RELATIONWEIGHT: nweight }) else: tids = [tid] weight = 1 neoDb.createRelationship(self.TONGUEZHI2MEDICINETAG, tongueZhiNode, medicineNode, propertyDic={ self.RELATIONTID: tids, self.RELATIONWEIGHT: weight }) mysqlDb.close()
def cross_validation(self): preFile = 'pre.txt' randomState = 51 mysqlOpt = MysqlOpt() treamentIds = mysqlOpt.select('treatment', 'id') tidList = [] for tid in treamentIds: tidList.append(tid[0]) kf = KFold(len(tidList), n_folds=self.CRVNUMFOLDER, shuffle=True, random_state=randomState) numFold = 0 for trainIndex, testIndex in kf: numFold += 1 FileUtil.print_string(str(numFold) + 'th validation:', True) trainList = [] for i in trainIndex: trainList.append(tidList[i]) self.createGraph(filter=True, treamentIdfilterSet=set(trainList)) yTrue = [] yPre = [] FileUtil.add_string(preFile, str(numFold) + 'th validation:') for i in testIndex: print str(tidList[i]) perTreament = mysqlOpt.select( 'treatment', 'prescriptionId,symptomIds,tongueZhiId,tongueTaiId,pulseId', 'id=' + str(tidList[i])) pids = perTreament[0][0] symptomIds = perTreament[0][1] tZhiId = perTreament[0][2] tTaiId = perTreament[0][3] pulseId = perTreament[0][4] perPrecription = mysqlOpt.select('prescription', 'name', 'id = ' + str(pids)) mids = perPrecription[0][0] medicines = mysqlOpt.select('medicine', 'name', 'id IN (' + mids + ')') trueMedicineList = [] for m in medicines: trueMedicineList.append(m[0]) yTrue.append(trueMedicineList) symptoms = mysqlOpt.select('symptom', 'name', 'id IN (' + symptomIds + ')') symptomList = [] for s in symptoms: symptomList.append(s[0]) tongueZhiSql = mysqlOpt.select('tongueZhi', 'name', 'id = ' + str(tZhiId)) if (len(tongueZhiSql) > 0): tongueZhi = tongueZhiSql[0][0] else: tongueZhi = '' tongueTaiSql = mysqlOpt.select('tongueTai', 'name', 'id = ' + str(tTaiId)) if (len(tongueTaiSql) > 0): tongueTai = tongueTaiSql[0][0] else: tongueTai = '' pulseSql = mysqlOpt.select('pulse', 'name', 'id = ' + str(pulseId)) if (len(pulseSql)): pulse = pulseSql[0][0] else: pulse = '' preMedicineList = self.input_multiple_syptomes( symptomList, tongueZhi, tongueTai, pulse) preString = str(tidList[i]) + ' ' for m in preMedicineList: preString = preString + m + ',' FileUtil.add_string(preFile, preString) yPre.append(preMedicineList) precision, recall = self.classification_evaluate(yTrue, yPre) FileUtil.print_string('precision:' + str(precision), True) FileUtil.print_string('recall:' + str(recall), True) F1 = 2 * precision * recall / (precision + recall) FileUtil.print_string('F1:' + str(F1), True)