def load_data_to(self, ctxs: Dict[object, BiEncoderPassage], date): year = "_" + str(datetime.strptime(date, "%b-%d-%Y").year) + "_" tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base") print(f"Creating bi-encoder dict for {date}...") for file_path in tqdm(self.file_paths): if year in file_path: with open(file_path, 'rb') as f: items = ijson.kvitems(f, '') ocr_text_generators = [] for k, v in items: if date in k: ocr_text_generators.append(self.ocr_text_iter(v)) if len(ocr_text_generators) == 0: continue for gen in ocr_text_generators: for layobj in gen: title, passage, object_id = layobj uid = object_id title = normalize_passage(title) title = title.lower() passage = take_max_model_paragraphs(passage, tokenizer) passage = normalize_passage(passage) ctxs[uid] = BiEncoderPassage(passage, title)
def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): with jsonlines.open(self.file, mode="r") as jsonl_reader: if self.hypotheses: for jline in jsonl_reader: for k in ['positive_ctxs', 'negative_ctxs', 'hard_negative_ctxs']: uid = jline[k][0]['title'] passage = jline[k][0]['text'] if self.normalize: passage = normalize_passage(passage) ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], uid) else: for jline in jsonl_reader: uid = jline['positive_ctxs'][0]['title'][:-1] passage = jline['question'] if self.normalize: passage = normalize_passage(passage) ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], uid)
def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): from datasets import load_dataset hfdataset = load_dataset(self.dataset, split=self.split) for idx, spl in enumerate(hfdataset): uid = str(idx) + '-class' + str(spl['label']) passage = spl['text'] if self.normalize: passage = normalize_passage(passage) ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], uid)
def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base") if self.n_random_papers: print("Random newspaper subset...") scan_names = [] for file_path in tqdm(self.file_paths): with open(file_path, 'rb') as f: items = ijson.kvitems(f, '') for k, v in items: scan_names.append(k) papers = list(set([self.get_paper_name(scan) for scan in scan_names])) papers.sort() print(f"{len(papers)} total papers...") random.seed(789) random_papers = random.sample(papers, self.n_random_papers) print(f"Selected random papers: {random_papers}") print("Creating bi-encoder dict...") for file_path in tqdm(self.file_paths): with open(file_path, 'rb') as f: items = ijson.kvitems(f, '') ocr_text_generators = [] for k, v in items: if self.month_str: if self.month_str in k: if self.n_random_papers: if self.get_paper_name(k) in random_papers: ocr_text_generators.append(self.ocr_text_iter(v)) else: ocr_text_generators.append(self.ocr_text_iter(v)) else: if self.n_random_papers: if self.get_paper_name(k) in random_papers: ocr_text_generators.append(self.ocr_text_iter(v)) else: ocr_text_generators.append(self.ocr_text_iter(v)) if len(ocr_text_generators) == 0: continue for gen in ocr_text_generators: for layobj in gen: title, passage, object_id = layobj uid = object_id if self.normalize: title = normalize_passage(title) title = title.lower() passage = take_max_model_paragraphs(passage, tokenizer) passage = normalize_passage(passage) ctxs[uid] = BiEncoderPassage(passage, title)
def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): with open(self.file_path) as ifile: reader = csv.reader(ifile, delimiter="\t") for row in reader: if row[self.id_col] == "id": continue if self.id_prefix: sample_id = self.id_prefix + str(row[self.id_col]) else: sample_id = row[self.id_col] passage = row[self.text_col] if self.normalize: passage = normalize_passage(passage) ctxs[sample_id] = BiEncoderPassage(passage, row[self.title_col])
def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): for file_path in self.file_paths: with open(file_path, 'rb') as f: items = ijson.kvitems(f, '') ocr_text_generators = [ ((ik['image_file_name'], ik['ocr_text'], ik['object_id']) for ik in v if ik['label']=='article') for k, v in items ] for gen in ocr_text_generators: for layobj in gen: title, passage, object_id = layobj uid = str(object_id) + '_' + title if self.normalize: passage = normalize_passage(passage) ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], title)
def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): super().load_data() logger.info("Reading file %s", self.file) with open(self.file) as ifile: reader = csv.reader(ifile, delimiter="\t") for row in reader: # for row in ifile: # row = row.strip().split("\t") if row[self.id_col] == "id": continue if self.id_prefix: sample_id = self.id_prefix + str(row[self.id_col]) else: sample_id = row[self.id_col] passage = row[self.text_col].strip('"') if self.normalize: passage = normalize_passage(passage) ctxs[sample_id] = BiEncoderPassage(passage, row[self.title_col])
def dump_passages(cfg, encoder, tensorizer): logger.info(f"Pair file from {cfg.pair_file}") passage_pairs = json.load(open(cfg.pair_file))['data'] questions = [pp['question'] for pp in passage_pairs] answers = [pp['answer'][0] for pp in passage_pairs] answers_start = [pp['answer'][1] for pp in passage_pairs] titles = [pp['title'] for pp in passage_pairs] gold_passages = [pp['gold_passage'] for pp in passage_pairs] entail_neg_passages = [pp['entail_neg_passage'] for pp in passage_pairs] neg_titles = [pp['neg_title'] for pp in passage_pairs] topic_neg_passages = [pp['topic_neg_passage'] for pp in passage_pairs] stats = {} logger.info(f"***** Processing gold passages *****") gold_passages = [ (p_id, BiEncoderPassage(title=title, text=passage)) for p_id, (title, passage) in enumerate(zip(titles, gold_passages)) ] data = gen_ctx_vectors(cfg, gold_passages, questions, encoder, tensorizer, True) for p_id, _, _, gold_score in data: if p_id not in stats: stats[p_id] = { 'gold_score': -1e9, 'topic_neg_score': -1e9, 'entail_neg_score': -1e9 } stats[p_id]['gold_score'] = gold_score if gold_score > stats[p_id][ 'gold_score'] else stats[p_id]['gold_score'] # break logger.info(f"***** Processing topic neg passages *****") topic_neg_passages = [ (p_id, BiEncoderPassage(title=title, text=passage)) for p_id, (title, passage) in enumerate(zip(neg_titles, topic_neg_passages)) ] data = gen_ctx_vectors(cfg, topic_neg_passages, questions, encoder, tensorizer, True) for p_id, _, _, topic_neg_score in data: if p_id not in stats: continue stats[p_id][ 'topic_neg_score'] = topic_neg_score if topic_neg_score > stats[ p_id]['topic_neg_score'] else stats[p_id]['topic_neg_score'] # break logger.info(f"***** Processing entail neg passages *****") entail_neg_passages = [ (p_id, BiEncoderPassage(title=title, text=passage)) for p_id, (title, passage) in enumerate(zip(titles, entail_neg_passages)) ] data = gen_ctx_vectors(cfg, entail_neg_passages, questions, encoder, tensorizer, True) for p_id, _, _, entail_neg_score in data: if p_id not in stats: continue stats[p_id][ 'entail_neg_score'] = entail_neg_score if entail_neg_score > stats[ p_id]['entail_neg_score'] else stats[p_id]['entail_neg_score'] # break if not all( all(val > -999 for val in score.values()) for score in stats.values()): import pdb pdb.set_trace() gold_mean = sum([stat['gold_score'] for stat in stats.values()]) / len(stats) topic_mean = sum([stat['topic_neg_score'] for stat in stats.values()]) / len(stats) entail_mean = sum([stat['entail_neg_score'] for stat in stats.values()]) / len(stats) L_topic = sum([ -torch.nn.functional.log_softmax(torch.Tensor( [stat['gold_score'], stat['topic_neg_score']]), dim=-1)[0] for stat in stats.values() ]) / len(stats) L_hard = sum([ -torch.nn.functional.log_softmax(torch.Tensor( [stat['gold_score'], stat['entail_neg_score']]), dim=-1)[0] for stat in stats.values() ]) / len(stats) logger.info( f'gold mean: {gold_mean:.2f}, topic mean: {topic_mean:.2f}, entail mean: {entail_mean:.2f}' ) logger.info( f'topical relevance: {gold_mean - topic_mean:.2f}, fine-grained entailment: {gold_mean - entail_mean:.2f}' ) logger.info(f'L_topic: {L_topic:.4f}, L_hard: {L_hard:.4f}') logger.info( f"Analysis done for {len(passage_pairs)} (processed={len(stats)})")