def run_sent_ret(config): train, dev = load_paper_dataset() with open('data/preprocessed_data/edocs.bin', 'rb') as rb: edocs = pickle.load(rb) with open(config['doc_ret_model'], 'rb') as rb: dmodel = pickle.load(rb) t2jnum = titles_to_jsonl_num() try: with open(config['sent_ret_model'], 'rb') as rb: model = pickle.load(rb) # 加载模型参数 except BaseException: try: selected = load_selected(config['sent_ret_line']) # 加载采样数据 except BaseException: docs = doc_ret(train, edocs, model=dmodel) selected = select_lines(docs, t2jnum, train, config['sent_ret_line']) model = sent_ret_model() X, y = model.process_train(selected, train) # 训练模型 model.fit(X, y) with open(config['sent_ret_model'], 'wb') as wb: pickle.dump(model, wb) docs = doc_ret(dev, edocs, model=dmodel) # 进行文档检索 lines = load_doc_lines(docs, t2jnum) evidence = sent_ret(dev, docs, lines, best=config['n_best'], model=model) # 进行句子检索 line_hits(dev, evidence) # 评估结果
def run_evi_ret(config): train, dev = load_paper_dataset(config['train_input'], dev=config['dev_input']) for split, data in [('train', train), ('dev', dev)]: if split == 'train': out_file = config['train_output'] if split == 'dev': out_file = config['dev_output'] docs, evidence = evi_ret(data, n_docs=config['n_docs'], n_sents=config['n_sents']) pred = to_fever_format(data, docs, evidence) with open(out_file, 'w') as w: for example in pred: w.write(json.dumps(example, cls=NpEncoder) + '\n')
def run_doc_ret(config): train, dev = load_paper_dataset() if os.path.exists(config['doc_ret_model']): with open(config['doc_ret_model'], 'rb') as rb: model = pickle.load(rb) else: if os.path.exists(config['doc_ret_docs']): selected = load_selected(config['doc_ret_docs']) else: selected = sample_docs(train, config['doc_ret_docs']) # 建立模型 model = doc_ret_model() # 对训练数据进行预处理 X, y = model.process_train(selected, train) # 训练模型 model.fit(X, y) # 存储训练好的模型 with open(config['doc_ret_model'], 'wb') as wb: pickle.dump(model, wb) if os.path.exists('data/preprocessed_data/edocs.bin'): with open('data/preprocessed_data/edocs.bin', 'rb') as rb: edocs = pickle.load(rb) else: t2jnum = titles_to_jsonl_num() edocs = title_edict(t2jnum) with open('data/preprocessed_data/edocs.bin', 'wb') as wb: pickle.dump(edocs, wb) print(len(model.f2v)) # 使用训练好的模型对验证集进行文档检索 docs = doc_ret(dev, edocs, best=config['n_best'], model=model) # 对检索结果进行评估 title_hits(dev, docs)
with open(fname) as f: for line in tqdm(f): fields = line.rstrip("\n").split("\t") cid = int(fields[0]) yn = int(fields[1]) t = fields[2] p = fields[3] s = int(fields[4]) if cid not in selected: selected[cid] = dict() selected[cid][yn] = [t, p, s] return selected if __name__ == "__main__": train, dev = load_paper_dataset() # train, dev = load_split_trainset(9999) try: with open("data/doc_ir_model.bin", "rb") as rb: model = pickle.load(rb) except: try: selected = load_selected() except: selected = select_docs(train) model = doc_ir_model() rdocs = dict() for example in tqdm(train): cid = example["id"] if cid in selected: claim = example["claim"]
parser.add_argument("--dev_output") parser.add_argument("--test_input") parser.add_argument("--test_output") parser.add_argument("--n_docs", type=int, default=5, help="how many documents to retrieve") parser.add_argument("--n_sents", type=int, default=5, help="how many setences to retrieve") args = parser.parse_args() print(args) train, dev, test = load_paper_dataset(train=args.train_input, dev=args.dev_input, test=args.test_input) # train, dev = load_split_trainset(9999) for split, data in [("train", train), ("dev", dev), ("test", test)]: if split == "train": out_file = args.train_output elif split == "dev": out_file = args.dev_output elif split == "test": out_file = args.test_output if os.path.exists(out_file): print("file {} exists. skipping ir...".format(out_file)) continue docs, evidence = get_evidence(data, n_docs=args.n_docs,