Пример #1
0
    def load_data_to(self, ctxs: Dict[object, BiEncoderPassage], date):

        year = "_" + str(datetime.strptime(date, "%b-%d-%Y").year) + "_"

        tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")

        print(f"Creating bi-encoder dict for {date}...")
        for file_path in tqdm(self.file_paths):

            if year in file_path:
                with open(file_path, 'rb') as f:
                    items = ijson.kvitems(f, '')
                    ocr_text_generators = []
                    for k, v in items:
                        if date in k:
                            ocr_text_generators.append(self.ocr_text_iter(v))

                if len(ocr_text_generators) == 0:
                    continue

                for gen in ocr_text_generators:
                    for layobj in gen:
                        title, passage, object_id = layobj
                        uid = object_id
                        title = normalize_passage(title)
                        title = title.lower()
                        passage = take_max_model_paragraphs(passage, tokenizer)
                        passage = normalize_passage(passage)
                        ctxs[uid] = BiEncoderPassage(passage, title)
Пример #2
0
    def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]):

        tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")

        if self.n_random_papers:
            print("Random newspaper subset...")
            scan_names = []
            for file_path in tqdm(self.file_paths):
                with open(file_path, 'rb') as f:
                    items = ijson.kvitems(f, '')
                    for k, v in items:
                        scan_names.append(k)
            papers = list(set([self.get_paper_name(scan) for scan in scan_names]))
            papers.sort()
            print(f"{len(papers)} total papers...")

            random.seed(789)
            random_papers = random.sample(papers, self.n_random_papers)
            print(f"Selected random papers: {random_papers}")

        print("Creating bi-encoder dict...")
        for file_path in tqdm(self.file_paths):

            with open(file_path, 'rb') as f:
                items = ijson.kvitems(f, '')
                ocr_text_generators = []
                for k, v in items:
                    if self.month_str:
                        if self.month_str in k:
                            if self.n_random_papers:
                                if self.get_paper_name(k) in random_papers:
                                    ocr_text_generators.append(self.ocr_text_iter(v))
                            else:
                                ocr_text_generators.append(self.ocr_text_iter(v))
                    else:
                        if self.n_random_papers:
                            if self.get_paper_name(k) in random_papers:
                                ocr_text_generators.append(self.ocr_text_iter(v))
                        else:
                            ocr_text_generators.append(self.ocr_text_iter(v))

            if len(ocr_text_generators) == 0:
                continue

            for gen in ocr_text_generators:
                for layobj in gen:
                    title, passage, object_id = layobj
                    uid = object_id
                    if self.normalize:
                        title = normalize_passage(title)
                        title = title.lower()
                        passage = take_max_model_paragraphs(passage, tokenizer)
                        passage = normalize_passage(passage)
                    ctxs[uid] = BiEncoderPassage(passage, title)
Пример #3
0
    def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]):
        with jsonlines.open(self.file, mode="r") as jsonl_reader:

            if self.hypotheses:
                for jline in jsonl_reader:
                    for k in ['positive_ctxs', 'negative_ctxs', 'hard_negative_ctxs']:
                        uid = jline[k][0]['title']
                        passage = jline[k][0]['text']
                        if self.normalize:
                            passage = normalize_passage(passage)  
                        ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], uid)
            else:
                for jline in jsonl_reader:
                    uid = jline['positive_ctxs'][0]['title'][:-1]
                    passage = jline['question']
                    if self.normalize:
                        passage = normalize_passage(passage)  
                    ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], uid)
Пример #4
0
    def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]):

        from datasets import load_dataset
        hfdataset = load_dataset(self.dataset, split=self.split)

        for idx, spl in enumerate(hfdataset):
            uid = str(idx) + '-class' + str(spl['label'])
            passage = spl['text']
            if self.normalize:
                passage = normalize_passage(passage)  
            ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], uid)
Пример #5
0
 def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]):
     with open(self.file_path) as ifile:
         reader = csv.reader(ifile, delimiter="\t")
         for row in reader:
             if row[self.id_col] == "id":
                 continue
             if self.id_prefix:
                 sample_id = self.id_prefix + str(row[self.id_col])
             else:
                 sample_id = row[self.id_col]
             passage = row[self.text_col]
             if self.normalize:
                 passage = normalize_passage(passage)
             ctxs[sample_id] = BiEncoderPassage(passage, row[self.title_col])
Пример #6
0
    def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]):
        for file_path in self.file_paths:
            with open(file_path, 'rb') as f:
                items = ijson.kvitems(f, '')
                ocr_text_generators = [
                    ((ik['image_file_name'], ik['ocr_text'], ik['object_id']) 
                        for ik in v if ik['label']=='article')
                    for k, v in items
                ]

            for gen in ocr_text_generators:
                for layobj in gen:
                    title, passage, object_id = layobj
                    uid = str(object_id) + '_' + title 
                    if self.normalize:
                        passage = normalize_passage(passage)
                    ctxs[uid] = BiEncoderPassage(passage[:self.passage_char_max], title)
Пример #7
0
 def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]):
     super().load_data()
     logger.info("Reading file %s", self.file)
     with open(self.file) as ifile:
         reader = csv.reader(ifile, delimiter="\t")
         for row in reader:
             # for row in ifile:
             # row = row.strip().split("\t")
             if row[self.id_col] == "id":
                 continue
             if self.id_prefix:
                 sample_id = self.id_prefix + str(row[self.id_col])
             else:
                 sample_id = row[self.id_col]
             passage = row[self.text_col].strip('"')
             if self.normalize:
                 passage = normalize_passage(passage)
             ctxs[sample_id] = BiEncoderPassage(passage, row[self.title_col])