示例#1
0
def run():
    """
    如果文件不存在,则创建
    :return:
    """
    if not os.path.exists('./res'):
        os.makedirs('res')
    config = get_config()
    if not os.path.exists(config['url']) or not os.path.exists(
            config['title'] or not os.path.exists(config['content'])):
        data_load(config)
    if not os.path.exists(config['content_clean']):
        data_clean_content(config)
    if not os.path.exists(config['content_filter']):
        filter_stop_word(config)
    if not os.path.exists(config['content_stemming']):
        stemming(config)
    if not os.path.exists(config['term_list']):
        create_term_list(config)

    documents = get_content(config)
    tf_documents = get_tf(documents)
    if not os.path.exists(config['idf']):
        create_idf(config, documents)
    idf_documents = get_idf(config)
    if not os.path.exists(config['tf_idf']):
        create_tf_idf(config, tf_documents, idf_documents, documents)
示例#2
0
 def __init__(self,
              args,
              lm,
              recalculate_props=False):  #, directory='searcher'):
     self.lm = lm
     self.args = args
     #self.holo_directory = 'searcher'
     self.config = data_utils_new.get_config(lm)
     self.args.vocab_size = len(self.config.encode) + 1
     self.pred_logs = []
     self.gen_logs = []
     self.pred_rewards = []
     self.gen_rewards = []
     self.expr_pos = []
     if args.interface_model:
         data = torch.load(self.args.interface_model)
         print('load from %s' % (self.args.interface_model))
         args_model = data['args']
         args_model.device = args.device
         device = torch.device("cuda:0")
         self.pred_model = torch_models.PredModel(args_model,
                                                  self.config).to(device)
         self.gen_model = torch_models.GenModel(args_model,
                                                self.config).to(device)
         self.pred_model.load_state_dict(data['models']['pred'])
         self.gen_model.load_state_dict(data['models']['gen'])
         self.pred_model = self.pred_model.to(args.device)
         self.gen_model = self.gen_model.to(args.device)
     else:
         if self.args.interface_pred_model != '':
             args_pred = torch.load(self.args.interface_pred_model)['args']
             args_pred.device = args.device
             args_pred.cat = False
             self.pred_model = torch_models.PredModel(
                 args_pred, self.config).to(args.device)
             self.pred_model.load(self.args.interface_pred_model)
         else:
             self.pred_model = torch_models.PredModel(args, self.config).to(
                 args.device)
         if self.args.interface_gen_model != '':
             args_gen = torch.load(self.args.interface_gen_model)['args']
             args_gen.device = args.device
             self.gen_model = torch_models.GenModel(
                 args_gen, self.config).to(args.device)
             self.gen_model.load(self.args.interface_gen_model)
         else:
             self.gen_model = torch_models.GenModel(args, self.config).to(
                 args.device)
     self.bsi = gen_model_beam_search_rl.BeamSearchInterface(
         self.args, self.gen_model, self.gen_logs, self.gen_rewards)
示例#3
0
 def __init__(self, args, lm, recalculate_props=False):#, directory='searcher'):
     self.lm = lm
     self.args = args
     #self.holo_directory = 'searcher'
     self.config = data_utils_new.get_config(lm)
     self.args.vocab_size = len(self.config.encode)+1
     loc = 'cpu' if args.cpu else 'cuda:0'
     if args.interface_model:
         print ('load from %s' % (args.interface_model))
         data = torch.load(self.args.interface_model, map_location=loc)
         args_model = data['args']
         args_model.device = args.device
         args_model.cpu = args.cpu
         self.pred_model = torch_models.PredModel(args_model, self.config).to(args.device)
         self.gen_model = torch_models.GenModel(args_model, self.config).to(args.device)
         self.pred_model.load_state_dict(data['models']['pred'])
         self.gen_model.load_state_dict(data['models']['gen'])
     else:
         if self.args.interface_pred_model != '':
             args_pred = torch.load(self.args.interface_pred_model, map_location=loc)['args']
             args_pred.device = args.device
             args_pred.cpu = args.cpu
             args_pred.cat = False
             args_pred.max_len = args.max_len
             self.pred_model = torch_models.PredModel(args_pred, self.config).to(args.device)
             self.pred_model.load(self.args.interface_pred_model)
         else:
             self.pred_model = torch_models.PredModel(args, self.config).to(args.device)
         if self.args.interface_gen_model != '':
             args_gen = torch.load(self.args.interface_gen_model, map_location=loc)['args']
             args_gen.device = args.device
             args_gen.cpu = args.cpu
             args_gen.max_len = args.max_len
             self.gen_model = torch_models.GenModel(args_gen, self.config).to(args.device)
             self.gen_model.load(self.args.interface_gen_model)
         else:
             self.gen_model = torch_models.GenModel(args, self.config).to(args.device)
     self.bsi = gen_model_beam_search_lm.BeamSearchInterface(self.args, self.gen_model)
示例#4
0
    将清洗后的文件过滤 停用词表
    :param config: config information
    :return:
    """
    stop_word = get_stop_word()
    with open(config['content_filter'], 'w', encoding='UTF-8') as f:
        for line in open(config['content_clean'], encoding='UTF-8'):
            if line == '\n':
                f.write('\n')
            else:
                line = line.strip().split()
                line = [word for word in line if word not in stop_word]
                line = ' '.join(line)
                if line:
                    f.write(line + '\n')


if __name__ == '__main__':
    config = get_config()
    if not os.path.exists(config['url']) or not os.path.exists(
            config['title'] or not os.path.exists(config['content'])):
        data_load(config)
    if not os.path.exists(config['content_clean']):
        data_clean_content(config)
    if not os.path.exists(config['content_filter']):
        filter_stop_word(config)
    if not os.path.exists(config['content_stemming']):
        stemming(config)
    if not os.path.exists(config['term_list']):
        create_term_list(config)
示例#5
0
import log
import params
import constructor
import data_utils
import interface_lm
import time
import random
import torch
sys.setrecursionlimit(10000)

args = params.get_args()
_logger = log.get_logger(__name__, args)
_logger.info(data_utils.print_args(args))

lm = data_utils.load_language_model(new=True, iset=args.iset)
config = data_utils.get_config(lm)
_logger.info('load lm')

try:
    os.mkdir(args.expr_path)
except:
    pass
fl = os.listdir(args.expr_path)
for s in fl:
    if s.find('finish_%d'%(args.prover_id)) >= 0:
        os.remove(os.path.join(args.expr_path, s))

_logger.info('remove old logs')

if args.evaluate == 'none':
    interface = interface_lm.LMInterface(args, lm)
示例#6
0
 def __init__(self, args, lm, recalculate_props=True, directory='searcher'):
     self.lm = lm
     self.config = data_utils_new.get_config(lm)
     self.args = args