class TrainAll(object): def __init__(self, cfg_file_name): self.features = [] self.labels = [] self.cnt = 0 self.preprocess = Preprocess(cfg_file_name) def load_data_path(self, path, level=0): if os.path.isdir(path): for file in os.listdir(path): t_path = os.path.join(path, file) self.load_data_path(t_path, level) elif os.path.isfile(path): fp = open(path, 'r') for line in fp: line = line.strip('\n') fields = line.split('\t') try: cat_name = fields[0].decode('u8') name = fields[1].decode('u8') brand = fields[2].decode('u8') price = fields[3] if len(fields) == 5: ori_cat = fields[4] #cat_name, name, brand, price = line.split('\t') #cat_name, name, feats = line.strip('\n').split('\t') cat_lst = json.loads(cat_name) if len(cat_lst) < level + 1: continue labels = u'$'.join(cat_lst[:level + 1]) except Exception, e: print >> sys.stderr, "Error:", line print >> sys.stderr, e continue self.cnt += 1 if self.cnt % 10000 == 0: print >> sys.stderr, "load %d samples..." % self.cnt if cat_lst[0] in [u'服饰鞋帽', u'个人护理', u'母婴用品$']: continue features = self.preprocess.process(name, cat_name, brand, price, level=level) self.features.append(features) self.labels.append(labels) fp.close()
class EvalNode(Process): def __init__(self, cfg_file_name, test_size_prob, data_dir, model_dir, model_name, level, dump_file=False, has_feat=False, max_num=0, min_num=0): Process.__init__(self) self.preprocess = Preprocess(cfg_file_name) self.test_size_prob = test_size_prob self.data_dir = data_dir self.model_dir = model_dir self.model_name = model_name self.level = level self.cnt = 0 self.train_features = [] self.train_labels = [] self.test_features = [] self.test_labels = [] self.test_lines = [] self.train_lines = [] self.dump_file = dump_file self.has_feat = has_feat self.load_label = True if max_num != 0 and min_num != 0: self.sample_adjust = SampleAdjust(min_num, max_num, seed=20) #self.cate_trans = {u'运动户外$女装':[u'服装配饰'], u'运动户外$男装':[u'服装配饰'], u'运动户外$运动女鞋':[u'鞋'], u'运动户外$运动男鞋':[u'鞋'], u'运动户外$户外鞋':[u'鞋'], u'母婴用品$童装':[u'服装配饰'], u'母婴用品$宝宝洗护':[u'个护化妆']} def cate_lst_transform(self, cate_lst): key = u'$'.join(cate_lst[:2]) if key in self.cate_trans: return self.cate_trans[cate_lst] else: return cate_lst # level训练分类器类目的level def load_data_path(self, path, level=0, test_size_prob=0.0, has_feat=False): if os.path.isdir(path): for file in os.listdir(path): t_path = os.path.join(path, file) self.load_data_path(t_path, level, test_size_prob, has_feat) elif os.path.isfile(path): if has_feat and not path.endswith("feature"): return fp = open(path, 'r') for line in fp: line = line.strip('\n') fields = line.split('\t') try: cat_name = fields[0].decode('u8') name = fields[1].decode('u8') if len(fields) > 3: brand = fields[2].decode('u8') price = fields[3] else: brand = u'' price = 0 if has_feat: features = json.loads(fields[-1]) if not isinstance(features, dict): print >> sys.stderr, "Error:", line continue #cat_name, name, brand, price = line.split('\t') #cat_name, name, feats = line.strip('\n').split('\t') cat_lst = json.loads(cat_name) ####################### # 类别替换 #cat_lst = self.cate_lst_transform(cat_lst) ###################### if len(cat_lst) < level + 1: continue labels = u'$'.join(cat_lst[:level + 1]) except Exception, e: print >> sys.stderr, "Error:", line print >> sys.stderr, e traceback.print_exc() continue self.cnt += 1 if self.cnt % 10000 == 0: print >> sys.stderr, "load %d samples..." % self.cnt if cat_lst[0] in [u'服饰鞋帽', u'个人护理', u'母婴用品$']: continue if test_size_prob >= 1.0: is_train = False elif test_size_prob <= 0.0: is_train = True else: if random.random() < test_size_prob: is_train = False else: is_train = True if is_train: if not has_feat: if self.load_label: features = self.preprocess.process(name, cat_name, brand, price, level=level) else: features = self.preprocess.process( name, u"[]", brand, price) if len(features) == 0: continue # 是否平衡训练样本的数量 if hasattr(self, 'sample_adjust'): self.sample_adjust.add_sample(features, labels) else: self.train_features.append(features) self.train_labels.append(labels) # 是否将训练集输出到文件 if self.dump_file: self.train_lines.append(line) else: if not has_feat: features = self.preprocess.process(name) self.test_features.append(features) self.test_labels.append(labels) #self.test_names.append(name) self.test_lines.append(line) # print json.dumps(features, ensure_ascii=False).encode('u8') fp.close()
class ClassifierRun(): def __init__(self, cfg_file_name): self.config = ConfigParser.ConfigParser() self.cur_dir = os.path.dirname(os.path.abspath(cfg_file_name)) self.cfg_parser(cfg_file_name) self.preprocess = Preprocess(cfg_file_name) self.cnt = 0 self.train_features = [] self.train_labels = [] self.test_features = [] self.test_labels = [] self.test_names = [] self._train = Train(self.space, self.params) self._predict = Predict() self._rule = Rule() self._tree = ClassTreePedict('./resource/cate_id.cfg', './model') def cfg_parser(self, cfg_file_name): self.config.read(cfg_file_name) section = 'model' if self.config.has_option(section, 'model_file'): self.model_path = self.config.get(section, 'model_file') else: self.model_path = './model/testmodel' if self.config.has_option(section, 'model_dir'): self.model_dir = self.config.get(section, 'model_dir') else: self.model_dir = './model' if self.config.has_option(section, 'vec_space') and self.config.get(section, 'vec_space') == 'topic': self.space = 'topic' else: self.space = 'word' if self.space == 'topic': if self.config.has_section('topic_param'): self.params = dict(self.config.items('topic_param')) elif self.space == 'word': if self.config.has_section('word_param'): self.params = dict(self.config.items('word_param')) section = 'evaluation' self.test_size_prob = 1.0 if self.config.has_option(section, 'test_size'): self.test_size_prob = self.config.getfloat(section, 'test_size') if self.config.has_option(section, 'random_state'): seed = self.config.getint(section, 'random_state') random.seed(seed) self.level = 0 section = 'default' if self.config.has_option(section, 'level'): self.level = self.config.getint(section, 'level') if self.config.has_option(section, 'cate_id_file'): self.cate_id_file = self.config.get(section, 'cate_id_file') else: self.cate_id_file = "resource/cate_id.cfg" logging.info('[Done] config parsing') logging.info('use %s space, params=%s' %(self.space, json.dumps(self.params) )) def train(self): self._train.train(self.train_features, self.train_labels) self._train.dump_model(self.model_path) def test(self): if self.model_path.endswith('rule'): self._rule.load_rule(self.model_path) is_rule = True else: self._predict.load_model(self.model_path) is_rule = False print len(self.test_features) for (features, label, name) in zip(self.test_features, self.test_labels, self.test_names): if is_rule: result = self._rule.predict(name, 0) else: result = self._predict.predict(features) print result.encode('u8'),'\t', label.encode('u8'),'\t', name.encode('u8'), '\t', json.dumps(features,'\t', ensure_ascii=False).encode('u8') def testone(self, name, cat_name, brand, price): tree = ClassTreePedict(self.cate_id_file, model_dir) features = self.preprocess.process(name, cat_name, brand, price, level=0) features = json.loads('{"Eden": 1, "Botkier": 1, "Satchel": 1, "马毛": 1, "女士": 1, "柏柯尔": 1, "拼接": 1, "手提包": 1, "Small": 1 }') result = tree.predict(name, features, indexclass=u"root") print result.encode('u8'), name.encode('u8'), json.dumps(features, ensure_ascii=False).encode('u8') # map_cfg 类目和ID的映射文件,model_dir 存放模型文件目录,data_file 数据文件 def predict(self, map_cfg, model_dir, data_file_name): tree = ClassTreePedict(map_cfg, model_dir) data_file = open(data_file_name, 'r') for line in data_file: line = line.strip() try: old_cate, cid_cate, name, brand, price = line.decode('u8').split(u'\t') except Exception,e : print >> sys.stderr, "Error:", line print >> sys.stderr, e sys.exit() cat_name = json.dumps(cid_cate.split(',')) price = float(price) features = self.preprocess.process(name, cat_name, brand, price, level=0) #result = tree.predict(name, features, indexclass=u"root") indexclass = u'root' result = tree.predict(name, features, indexclass, price, cat_name) print "%s\t%s\t%s" %(result.encode('u8'), old_cate.encode('u8'), name.encode('u8')) data_file.close()
class Classifier(Module): def __init__(self, context): super(Classifier, self).__init__(context) logging.info("Classifier module init start") cur_dir = os.path.dirname(os.path.abspath(__file__)) or os.getcwd() model_dir = cur_dir + "/model" #model_dir = "/home/fenghua.huang/classifier_v2/model" map_file = cur_dir + "/resource/cate_id.cfg" cfg_file = cur_dir + "/classifier.cfg" cid_field_file = cur_dir + "/resource/fields.cfg" self.cid_fields = dict() #部分cid的字段要特别处理 self.resource_process(model_dir, cfg_file, map_file, cid_field_file) self.classifier = ClassTreePedict(map_file, model_dir) self.preprocess = Preprocess(cfg_file) self.start_node = self.get_start_node(cfg_file) def get_start_node(self, cfg_file): # 开始分类的节点,默认是根节点 u'root' # 如使用类目映射,可在 u'root'节点前增加一个节点,使用规则跳到相应节点 start_node = u'root' self.config = ConfigParser.ConfigParser() self.config.read(cfg_file) section = 'default' if self.config.has_option(section, 'start_node'): start_node = self.config.get(section, 'start_node') start_node = start_node.decode('u8') return start_node def classify(self, cid, name, pid, brand, price): features = self.preprocess.process(name, pid, brand, price) #print json.dumps(features, ensure_ascii=False).encode('u8') result = self.classifier.predict(name, features, self.start_node, price, pid) return result def load_field_file(self, field_file): ''' 读取配置文件(cid对应的需要合并的字段) ''' with open(field_file, 'r') as rf: for line in rf: line = line.strip() if not line or line.startswith('#'): continue try: cid, ori_field, dest_field = line.split('#') fields_lst = self.cid_fields.setdefault(cid, []) fields_lst.append((ori_field, dest_field)) except: logging.error("wrong field config line: %s" % line) logging.info("load file %s done" % field_file) def chg_cid_fields(self, item_base): ''' 用于部分字段的调整 主要用于,对于某个cid,需要将subtitle的值合并到name中 ''' cid = item_base['cid'] if cid in self.cid_fields: for (ori_field, dest_field) in self.cid_fields[cid]: if ori_field in item_base and dest_field in item_base: if isinstance(item_base[ori_field], unicode) and isinstance( item_base[dest_field], unicode): item_base[dest_field] += ' ' item_base[dest_field] += item_base[ori_field] if isinstance(item_base[ori_field], unicode) and isinstance( item_base[dest_field], list): item_base[dest_field].append(item_base[ori_field]) if isinstance(item_base[ori_field], list) and isinstance( item_base[dest_field], list): item_base[dest_field].extend(item_base[ori_field]) if isinstance(item_base[ori_field], list) and isinstance( item_base[dest_field], unicode): item_base[dest_field] += ' ' item_base[dest_field] += ' '.join(item_base[ori_field]) def __call__(self, item_base, item_profile): try: self.chg_cid_fields(item_base) cid = item_base["cid"] name = item_base.get("name", None) pid = item_base.get("pid", None) brand = u" ".join(item_base.get("brand", "")) cat = [] cat.extend(pid) cat_str = json.dumps(cat) price = item_base.get("price", 0) result = self.classify(cid, name, cat_str, brand, price) result = result.split(u'$') item_profile["category_name_new"] = [] item_profile["category_name_new"].extend(result) item_profile["category_id_new"] = [] ids = [] for i in range(len(result)): key = '$'.join(result[:i + 1]) if key in self.classifier.mapclass: ids.append(int(self.classifier.mapclass[key])) else: logging.error("find category id error, key: %s" % key.encode('u8')) item_profile["category_id_new"] = ids #item_profile["category_id_new"].extend(map(lambda x:self.classifier.mapclass[x], result.split("$"))) except Exception as e: logging.error(traceback.print_exc()) logging.error("category_name: %s", e) return {"status": 0} def resource_process(self, model_dir, cfg_file, map_file, cid_field_file): self.add_resource_file(model_dir) self.add_resource_file(cfg_file) self.add_resource_file(map_file) self.load_field_file(cid_field_file) self.add_resource_file(cid_field_file) def test(self): cid = u"Cjianyi" # name = u"迪士尼(Disney)米妮K金镶钻手机链读卡器 (短)" # name = u"佳能(Canon)CL-97彩色墨盒 (适用佳能E568)" # name = u"东芝(TOSHIBA) 32A150C 32英寸 高清液晶电视(黑色)" # name = u"东芝(TOSHIBA) 55X1000C 55寸 全高清3D LED液晶电视" # name = u"东芝(TOSHIBA)42寸液晶电视42A3000C" name = u"三星(SAMSUNG)M55 黑色墨盒(适用SF-350)" brand = u"" #cate = u"" category = json.dumps([u"0"]) price = -1 logging.info("start to test") print self.classify(cid, name, category, brand, price) '''
class Core_Process(object): """ Func: training and predict process. """ def __init__(self, model_name='root', opt='train'): self.model_name = model_name self.preprocess = Preprocess(CONFIG_FILE) self.d_model_map = self.preprocess.d_models if opt == 'predict': self.predict_obj = Predict() if self.d_model_map.get(self.model_name, None): self.predict_obj.load_model(MODEL_PATH + 'model/' + self.model_name + '.model') else: self.predict_obj.load_model(MODEL_PATH + 'model/' + 'root.model') print "\nNote: using the default model--root.model to predict.\n" self.train_features = [] self.train_labels = [] self.predict_features = [] self.predict_labels = [] self.predict_data_id = [] self.predict_result = [] def load_data_path(self, data_path): """ Input: data_path <string>: the input file path. Output: None """ print data_path fp = open(data_path, 'r') for json_line in fp.readlines(): d_line = json.loads(json_line) data_id = d_line['id'] desc_text = ' '.join(d_line['description'].replace('.', ' ').split()) labels = d_line['label'] features = self.preprocess.process(title='', content=desc_text, model_name=self.model_name) self.train_features.append(features) self.train_labels.append(labels) self.predict_data_id.append(data_id) fp.close() if len(self.train_features) == len(self.train_labels): pass #print '=========', len(self.train_features), len(self.train_labels) else: print 'ERROR: len(train_features) != len(train_labels)' def train_all(self, train_data_dir, model_name='root'): """ train model with all training dataset, use model 'root' by default """ self.load_data_path(train_data_dir) print >> sys.stderr, "train the model", train_data_dir space = 'word' #space = 'topic' # There are some problems ? _train = Train(space, {}) _train.train(self.train_features, self.train_labels) if not os.path.exists(os.path.join(MODEL_PATH, 'model')): os.makedirs(os.path.join(MODEL_PATH, 'model')) if not os.path.exists(os.path.join(MODEL_PATH, 'report')): os.makedirs(os.path.join(MODEL_PATH, 'report')) if not os.path.exists(os.path.join(MODEL_PATH, 'feature')): os.makedirs(os.path.join(MODEL_PATH, 'feature')) model_path = MODEL_PATH + 'model/' + model_name + ".model" print >> sys.stderr, "dump the model", model_path _train.dump_model(model_path) feature_file = os.path.join(MODEL_PATH, 'feature/' + model_name + ".feature") #输出选择的特征及系数 ffile = open(feature_file, 'w') feature_coef = _train.get_feature_coef() print "----------len featrue coef:", len(feature_coef) feature_len = 0 for cate in feature_coef: print "-------------", cate print >> ffile, "%s" % (cate.encode('u8')) features = sorted(feature_coef[cate].items(), key=lambda x: x[1], reverse=True) feature_len = len(features) for f_item in features: print >> ffile, "\t%s\t%f" % (f_item[0].encode('u8'), f_item[1]) ffile.close() print >> sys.stderr, "%d features has been selected!" % feature_len def evaluation(self, predict_file): """ Func: evaluation of batch data. Input: predict_file <string>: input file path. Output: precision <float>: the precision of the prediction . """ d_eval = {'corr': 0} all_cnt = 0 precision = 0.0 self.load_data_path(predict_file) self.predict_features = self.train_features self.predict_labels = self.train_labels all_cnt = len(self.predict_labels) for features, label in zip(self.predict_features, self.predict_labels): result = self.predict_obj.predict(features) if result == label: d_eval['corr'] += 1 self.predict_result.append(result) if all_cnt == 0: print 'ERROR: all_cnt of predict_file: 0 !' else: precision = d_eval['corr'] * 1.0 / all_cnt print '========== all_cnt: ', all_cnt print '========== precision: ', precision return precision def run(self, opt, file_path): """ opt: to determine train or predict file_path: traning data. """ if opt == 'train': for mod_name, values in self.d_model_map.items(): self.train_all(file_path, mod_name) elif opt == 'predict': predict_file = file_path result = self.evaluation(predict_file) report_file = os.path.join(MODEL_PATH, 'report/' + self.model_name + ".report") rfile = open(report_file, 'a') rfile.write(str(file_path + ' precision: ') + str(result) + '\n') rfile.close() with open(report_file + '.rep', 'w') as rf: for tid, res in zip(self.predict_data_id, self.predict_result): rf.write(tid + '\t' + res + '\n') else: print 'Nothing to do, please input train or predict.' def predict_one(self, desc_text): """ Func: predict single data. Input: desc_text <string>: description text of the single data. Output: result <string>: the label of the input text. """ features = self.preprocess.process(title='', content=desc_text, model_name=self.model_name) result = self.predict_obj.predict(features) return str(result)
class Classifier(Module): def __init__(self, context): super(Classifier, self).__init__(context) logging.info("Classifier module init start") cur_dir = os.path.dirname(os.path.abspath(__file__)) or os.getcwd() model_dir = cur_dir + "/model" #model_dir = "/home/fenghua.huang/classifier_v2/model" map_file = cur_dir + "/resource/cate_id.cfg" cfg_file = cur_dir + "/classifier.cfg" cid_field_file = cur_dir + "/resource/fields.cfg" self.cid_fields = dict() #部分cid的字段要特别处理 self.resource_process(model_dir, cfg_file, map_file, cid_field_file) self.classifier = ClassTreePedict(map_file, model_dir) self.preprocess = Preprocess(cfg_file) self.start_node= self.get_start_node(cfg_file) def get_start_node(self, cfg_file): # 开始分类的节点,默认是根节点 u'root' # 如使用类目映射,可在 u'root'节点前增加一个节点,使用规则跳到相应节点 start_node = u'root' self.config = ConfigParser.ConfigParser() self.config.read(cfg_file) section = 'default' if self.config.has_option(section, 'start_node'): start_node = self.config.get(section, 'start_node') start_node = start_node.decode('u8') return start_node def classify(self, cid, name, pid, brand, price): features = self.preprocess.process(name, pid, brand, price) #print json.dumps(features, ensure_ascii=False).encode('u8') index,result = self.classifier.predict(name, features, self.start_node, price, pid) return index,result def load_field_file(self, field_file): ''' 读取配置文件(cid对应的需要合并的字段) ''' with open(field_file, 'r') as rf: for line in rf: line = line.strip() if not line or line.startswith('#'): continue try: cid, ori_field, dest_field = line.split('#') fields_lst = self.cid_fields.setdefault(cid, []) fields_lst.append((ori_field, dest_field)) except: logging.error("wrong field config line: %s" %line) logging.info("load file %s done" %field_file) def chg_cid_fields(self, item_base): ''' 用于部分字段的调整 主要用于,对于某个cid,需要将subtitle的值合并到name中 ''' cid = item_base['cid'] if cid in self.cid_fields: for (ori_field, dest_field) in self.cid_fields[cid]: if ori_field in item_base and dest_field in item_base: if isinstance(item_base[ori_field], unicode) and isinstance(item_base[dest_field], unicode): item_base[dest_field] += ' ' item_base[dest_field] += item_base[ori_field] if isinstance(item_base[ori_field], unicode) and isinstance(item_base[dest_field], list): item_base[dest_field].append(item_base[ori_field]) if isinstance(item_base[ori_field], list) and isinstance(item_base[dest_field], list): item_base[dest_field].extend(item_base[ori_field]) if isinstance(item_base[ori_field], list) and isinstance(item_base[dest_field], unicode): item_base[dest_field] += ' ' item_base[dest_field] += ' '.join(item_base[ori_field]) def __call__(self,item_base,item_profile): try: self.chg_cid_fields(item_base) cid = item_base["cid"] name=item_base.get("name",None) pid=item_base.get("pid",None) brand=u" ".join(item_base.get("brand", "")) cat=[] cat.extend(pid) cat_str=json.dumps(cat) price=item_base.get("price",0) result=self.classify(cid, name, cat_str, brand, price) result = result.split(u'$') item_profile["category_name_new"]=[] item_profile["category_name_new"].extend(result) item_profile["category_id_new"]=[] ids = [] for i in range(len(result)): key = '$'.join(result[:i+1]) if key in self.classifier.mapclass: ids.append(int(self.classifier.mapclass[key])) else: logging.error("find category id error, key: %s" %key.encode('u8')) item_profile["category_id_new"] = ids #item_profile["category_id_new"].extend(map(lambda x:self.classifier.mapclass[x], result.split("$"))) except Exception as e: logging.error(traceback.print_exc()) logging.error("category_name: %s", e) return {"status":0} def resource_process(self, model_dir, cfg_file, map_file, cid_field_file): self.add_resource_file(model_dir) self.add_resource_file(cfg_file) self.add_resource_file(map_file) self.load_field_file(cid_field_file) self.add_resource_file(cid_field_file) def test(self): ''' cid = u"Cjianyi" name = u"46°宋河老窖500ml热销爆款,底价出击!" brand = u"" #cate = u"" category = json.dumps([u"0"]) price = 135.0 logging.info("start to test") print self.classify(cid, name, category, brand, price) ''' split_str= "$$$" file_name = "/home/python/wch/classifyData/testErrorData3.txt" with open(file_name, 'rb') as f: data = f.readlines() for line in data: prod = line.strip().split("$$") #print prod[0],prod[1],prod[2],prod[3] # print words[0], words[1], words[2], words[3] # ret = be.run_brand(prod[0],prod[1],prod[2],prod[3]) name = prod[0] #price = prod[3] category = json.dumps([u"0"]) index,predictCate = self.classify('', name,category, '', price=-1) if predictCate: print index,predictCate.encode('u8')+split_str else: f = open("/home/python/wch/classifyData/error_output_0509.txt",'a') errors = prod[0]+split_str print>>f, errors f.close() '''