def __init__(self,model_file,Task=None,Searcher=None, Updater=None, beam_width=8,logger=None,cmd_args={},**conf): """ 初始化 如果不设置,则读取已有模型。如果设置,就是学习新模型 """ if logger==None : logger=logging.getLogger(__name__) console=logging.StreamHandler() console.setLevel(logging.INFO) logger.addHandler(console) logger.setLevel(logging.INFO) self.result_logger=logger self.beam_width=beam_width#:搜索宽度 self.conf=conf if model_file!=None: file=gzip.open(model_file,"rb") self.task=Task(model=pickle.load(file),logger=logger) file.close() else : # new model to train self.paras=Parameters(Updater) #self.paras=Parameters(Ada_Grad) self.task=Task(logger=logger,paras=self.paras) if hasattr(self.task,'init'): self.task.init() self.searcher=Searcher(self.task,beam_width) self.step=0
def isan(**args): orginal_args = args ns = argparse.Namespace() ns.logfile = '/dev/null' for k, v in args.items(): setattr(ns, k, v) args = ns info_color = '34' instream = sys.stdin if args.input == None else open(args.input, 'r') outstream = sys.stdout if args.output == None else open( args.output, 'a' if args.append else 'w') rec = Recorder() logger = logging.getLogger('s' + str(random.random())) console = logging.StreamHandler() logfile = logging.FileHandler(args.logfile, 'w') logfile.setLevel(logging.DEBUG) logfile.addFilter(ContextFilter()) recstream = logging.StreamHandler(rec) console.setLevel(logging.INFO) logger.addHandler(console) logger.addHandler(logfile) logger.addHandler(recstream) if hasattr(args, 'log_handlers'): for handler in args.log_handlers: #handler.addFilter(ContextFilter()) logger.addHandler(handler) logger.setLevel(logging.DEBUG) if args.model_module: mod, _, cls = args.model_module.rpartition('.') Model = getattr(__import__(mod, globals(), locals(), [cls], 0), cls) if args.task: mod, _, cls = args.task.rpartition('.') Task = getattr(__import__(mod, globals(), locals(), [cls], 0), cls) if args.decoder: mod, _, cls = args.decoder.rpartition('.') Decoder = getattr(__import__(mod, globals(), locals(), [cls], 0), cls) if args.updater: mod, _, cls = args.updater.rpartition('.') Updater = getattr(__import__(mod, globals(), locals(), [cls], 0), cls) name_model = Model.name if hasattr(Model, 'name') else '给定学习算法' name_decoder = Decoder.name if hasattr(Decoder, 'name') else '给定解码算法' name_task = Task.name if hasattr(Task, 'name') else '给定任务算法' name_updater = Updater.name if hasattr(Updater, 'name') else '某参数更新算法' logger.info("""模型: %s 解码器: %s 搜索宽度: %s 任务: %s""" % ( make_color(name_model, info_color), make_color(name_decoder, info_color), make_color(args.beam_width, info_color), make_color(name_task, info_color), )) if args.train or args.append_model: """如果指定了训练集,就训练模型""" logger.info( """参数更新算法 : %(updater)s batch size : %(bs)s""" % { 'bs': make_color(args.batch_size, info_color), 'updater': make_color(name_updater, info_color), }) random.seed(args.seed) model = Model(None, (lambda **x: Task(cmd_args=args, **x)), Decoder, beam_width=int(args.beam_width), Updater=Updater, logger=logger, cmd_args=args) if args.train: logger.info('随机数种子: %s' % (make_color(str(args.seed)))) logger.info( "由训练语料库%s迭代%s次,训练%s模型保存在%s。" % (make_color(' '.join(args.train)), make_color( args.iteration), name_task, make_color(args.model_file))) if args.dev_file: logger.info("开发集使用%s" % (make_color(' '.join(args.dev_file)))) model.train(args.train, int(args.iteration), peek=args.peek, batch_size=args.batch_size, dev_files=args.dev_file) model.save(args.model_file) if args.append_model: ### append multiple models task = Task(cmd_args=args, paras=Parameters(Updater)) for m in args.append_model: print(m) task.add_model(pickle.load(gzip.open(m, 'rb'))) pickle.dump(task.dump_weights(), gzip.open(args.model_file, 'wb')) if args.train and not args.test_file: del logger del model return list(rec) if not args.train: print("使用模型文件%s进行%s" % (make_color(args.model_file), name_task), file=sys.stderr) #print(args.model_file) model = Model( args.model_file, (lambda **x: Task(cmd_args=args, **x)), Searcher=Decoder, beam_width=int(args.beam_width), logger=logger, cmd_args=args, ) """如果指定了测试集,就测试模型""" if args.test_file: print("使用已经过%s的文件%s作为测试集" % (name_task, make_color(args.test_file)), file=sys.stderr) model.test(args.test_file) return list(rec) if not args.test_file and not args.append_model and not args.train: threshold = args.threshold print("以 %s 作为输入,以 %s 作为输出" % (make_color('标准输入流'), make_color('标准输出流')), file=sys.stderr) if threshold: print("输出分数差距在 %s 之内的候选词" % (make_color(threshold)), file=sys.stderr) for line in instream: line = line.strip() line = model.task.codec.decode(line) raw = line.get('raw', '') Y = line.get('Y_a', None) if threshold: print(model.task.codec.encode_candidates( model(raw, Y, threshold=threshold)), file=outstream) else: print(model.task.codec.encode(model(raw, Y)), file=outstream) return list(rec)
class Model(object): """感知器模型 """ name="感知器" #: 模型的名字 def __init__(self,model_file,Task=None,Searcher=None, Updater=None, beam_width=8,logger=None,cmd_args={},**conf): """ 初始化 如果不设置,则读取已有模型。如果设置,就是学习新模型 """ if logger==None : logger=logging.getLogger(__name__) console=logging.StreamHandler() console.setLevel(logging.INFO) logger.addHandler(console) logger.setLevel(logging.INFO) self.result_logger=logger self.beam_width=beam_width#:搜索宽度 self.conf=conf if model_file!=None: file=gzip.open(model_file,"rb") self.task=Task(model=pickle.load(file),logger=logger) file.close() else : # new model to train self.paras=Parameters(Updater) #self.paras=Parameters(Ada_Grad) self.task=Task(logger=logger,paras=self.paras) if hasattr(self.task,'init'): self.task.init() self.searcher=Searcher(self.task,beam_width) self.step=0 def __del__(self): del self.searcher def test(self,test_file): """ 测试 """ eval=self.task.Eval() for line in open(test_file): arg=self.task.codec.decode(line.strip()) raw=arg.get('raw') Y=arg.get('Y_a',None) y=arg.get('y',None) hat_y=self(raw) eval(y,hat_y) if hasattr(eval,'get_result'): self.result_logger.info(eval.get_result()) else : eval.print_result()#打印评测结果 return eval def develop(self,dev_file): """ @brief 预测开发集 """ self.paras.final(self.step) eval=self.task.Eval() for line in open(dev_file): arg=self.task.codec.decode(line.strip()) if not arg:continue raw=arg.get('raw') y=arg.get('y',None) hat_y=self(raw) eval(y,hat_y) if hasattr(eval,'get_result'): self.result_logger.info(eval.get_result()) else : eval.print_result()#打印评测结果 self.paras.un_final() if hasattr(eval,'get_scaler'): return eval.get_scaler() def save(self,model_file=None): """ 保存模型 """ if model_file==None : model_file=self.model_file if model_file==None : return if model_file=='/dev/null' : return #self.task.average_weights(self.step) self.paras.final(self.step) file=gzip.open(model_file,'wb') data=self.task.dump_weights() pickle.dump(data,file) file.close() def search(self,raw,Y=None): """ 搜索 """ self.task.set_raw(raw,Y) #self.searcher.set_raw(raw) return self.searcher.search() def __call__(self,raw,Y=None,threshold=0): """ 解码,读入生句子,返回词的数组 """ rst_moves=self.search(raw,Y) hat_y=self.task.moves_to_result(rst_moves,raw) if threshold==0 : return hat_y else: margins=self.searcher.cal_margins() return self.task.gen_candidates(margins,threshold) def _learn_sentence(self,arg): """ 学习,根据生句子和标准分词结果 """ raw=arg.get('raw') self.raw=raw y=arg.get('y',None) Y_a=arg.get('Y_a',None) #self.logger.debug('get training example') #self.logger.debug("raw: %s"%raw) #self.logger.debug("y: %s"%y) #self.logger.debug("Y_a: %s"%Y_a) #学习步数加一 self.step+=1 #set oracle, get standard actions if hasattr(self.task,'set_oracle'): std_moves=self.task.set_oracle(raw,y) #self.logger.debug(std_moves) #get result actions #self.searcher.set_step(self.step) rst_moves=self.search(raw,Y_a)#得到解码后动作 #update if not self.task.check(std_moves,rst_moves):#check self.update(std_moves,rst_moves)#update #clean oracle if hasattr(self.task,'remove_oracle'): self.task.remove_oracle() hat_y=self.task.moves_to_result(rst_moves,raw)#得到解码后结果 return y,hat_y def update(self,std_moves,rst_moves): #self.task.cal_delta(std_moves,rst_moves,self.step) self.task.cal_delta(std_moves,rst_moves) if self.step%self.batch_size==0 : self.paras.update(self.step) def train(self,training_file, iteration=5,peek=-1, dev_files=None,keep_data=True,batch_size=1): """ 训练 """ if iteration<=0 and peek <=0 : peek=5 self.batch_size=batch_size if type(training_file)==str:training_file=[training_file] #random.seed(123) if keep_data : training_data=[] for t_file in training_file : for line in open(t_file):#迭代每个句子 rtn=self.task.codec.decode(line.strip())#得到标准输出 if not rtn:continue training_data.append(rtn) random.shuffle(training_data) def gen_data(): if keep_data : perc=0 print(perc,end='%\r') #random.shuffle(training_data) for i,e in enumerate(training_data) : p=int(i*100/len(training_data)) if p != perc : print("%i"%(p),end='%\r',file=sys.stderr) perc=p yield e else : for t_file in training_file: for line in open(t_file):#迭代每个句子 rtn=self.task.codec.decode(line.strip())#得到标准输出 if not rtn:continue yield rtn it=0 best_it=None best_scaler=None while True : if it == iteration : break self.result_logger.info("训练集第 \033[33;01m%i\033[1;m 次迭代"%(it+1)) eval=self.task.Eval()#: 测试用的对象 for rtn in gen_data(): if rtn is None : continue y,hat_y=self._learn_sentence(rtn)#根据(输入,输出)学习参数,顺便得到解码结果 eval(y,hat_y)#根据解码结果和标准输出,评价效果 if hasattr(eval,'get_result'): self.result_logger.info(eval.get_result()) else : eval.print_result()#打印评测结果 if hasattr(self.task,'report'): self.task.report() if dev_files: #self.result_logger.info("使用开发集 %s 评价当前模型效果"%(dev_file)) for dev_id,dev_file in enumerate(dev_files) : scaler=self.develop(dev_file) if dev_id==0 : if best_scaler==None or (scaler and best_scaler<scaler) : best_it=it best_scaler=scaler it+=1 if peek>=0 and it-best_it>peek : break def __del__(self): self.task.__del__() del self.task