Пример #1
0
def process_squad_tsv(fname, ignore_impossible, max_character_length,
                      min_overlap):
    if not fname:
        print(f"Empty file name!")
        return

    field_names = ["doc", "question", "answers", "answer_starts", "has_answer"]
    tsv_file = SafeFileWrapper(get_absolute_path(fname),
                               encoding="utf-8",
                               errors="replace")
    tsv = TSV(
        tsv_file,
        field_names=field_names,
        delimiter="\t",
        quoted=False,
        drop_incomplete_rows=True,
    )

    for id, row in enumerate(tsv):
        doc, question, answers, answer_starts, has_answer = (
            row[f] for f in field_names)
        answers = json.loads(answers)
        answer_starts = json.loads(answer_starts)
        for piece_dict in _split_document(
                id,
                doc,
                question,
                answers,
                answer_starts,
                has_answer == "True",
                ignore_impossible,
                max_character_length,
                min_overlap,
        ):
            yield piece_dict
Пример #2
0
 def __init__(
     self,
     path: str,
     field_names: List[str] = None,
     delimiter: str = "\t",
     batch_size: int = 1,
     is_shuffle: bool = True,
     transforms_dict: Dict[str, List[Transform]] = None,
     batcher=None,
     collate_fn=None,
     chunk_size: int = 1000,
     is_cycle: bool = False,
     length: Optional[int] = None,
     rank: int = 0,
     num_workers: int = 1,
 ):
     field_names = field_names or ["text", "label"]
     self.file = SafeFileWrapper(path, encoding="utf-8", errors="replace")
     tsv_iterator = TSV(self.file,
                        field_names=field_names,
                        delimiter=delimiter)
     super().__init__(
         iterable=tsv_iterator,
         batch_size=batch_size,
         is_shuffle=is_shuffle,
         transforms_dict=transforms_dict,
         batcher=batcher,
         collate_fn=collate_fn,
         chunk_size=chunk_size,
         is_cycle=is_cycle,
         length=length,
         rank=rank,
         num_workers=num_workers,
     )
Пример #3
0
    def process_squad_tsv(self, fname):
        if not fname:
            print("Empty file name!")
            return

        field_names = [
            "doc", "question", "answers", "answer_starts", "has_answer"
        ]
        tsv_file = SafeFileWrapper(get_absolute_path(fname),
                                   encoding="utf-8",
                                   errors="replace")
        tsv = TSV(
            tsv_file,
            field_names=field_names,
            delimiter=self.delimiter,
            quoted=self.quoted,
            drop_incomplete_rows=True,
        )

        for id, row in enumerate(tsv):
            parts = (row[f] for f in field_names)
            doc, question, answers, answer_starts, has_answer = parts
            try:
                # if we have paraphrases for question
                question = json.loads(question)
                if isinstance(question, list):
                    question = choice(question)
            except ValueError:
                pass
            answers = json.loads(answers)
            answer_starts = json.loads(answer_starts)

            if has_answer != "True":
                answers = []
                answer_starts = []

            for piece_dict in _split_document(
                    id,
                    doc,
                    question,
                    answers,
                    answer_starts,
                    has_answer == "True",
                    self.ignore_impossible,
                    self.max_character_length,
                    self.min_overlap,
            ):
                yield piece_dict
Пример #4
0
 def __init__(
     self,
     path: str,
     columns: List[Any] = None,
     column_mapping: Optional[Dict[str, str]] = None,
     delimiter: str = "\t",
     batch_size: Optional[int] = None,
     is_shuffle: bool = True,
     transform: Optional[nn.Module] = None,
     custom_batcher: Optional[Batcher] = None,
     collate_fn: Optional[Callable] = None,
     chunk_size: int = 1000,
     is_cycle: bool = False,
     length: Optional[int] = None,
     rank: int = 0,
     world_size: int = 1,
     *args,
     **kwargs,
 ):
     logger.debug(f"init TsvDataset from: {path}")
     columns = columns or ["text", "label"]
     if column_mapping:
         raise NotImplementedError(
             "column mapping is not supported for tsv yet!")
     self.file = SafeFileWrapper(path, encoding="utf-8", errors="replace")
     tsv_iterator = TSV(self.file, field_names=columns, delimiter=delimiter)
     super().__init__(
         iterable=tsv_iterator,
         batch_size=batch_size,
         is_shuffle=is_shuffle,
         transform=transform,
         custom_batcher=custom_batcher,
         collate_fn=collate_fn,
         chunk_size=chunk_size,
         is_cycle=is_cycle,
         length=length,
         rank=rank,
         world_size=world_size,
     )
Пример #5
0
    def process_squad_tsv(self, fname):
        # Process SQUAD TSV for KD
        if not fname:
            print("Empty file name!")
            return
        field_names = [
            "id1",
            "doc",
            "question",
            "answers",
            "answer_starts",
            "has_answer",
            "id2",
            "start_logits",
            "end_logits",
            "has_answer_logits",
            "pad_mask",
            "segment_labels",
        ]
        tsv_file = SafeFileWrapper(get_absolute_path(fname),
                                   encoding="utf-8",
                                   errors="replace")
        tsv = TSV(
            tsv_file,
            field_names=field_names,
            delimiter=self.delimiter,
            quoted=self.quoted,
            drop_incomplete_rows=True,
        )

        for id, row in enumerate(tsv):
            parts = (row[f] for f in field_names)
            # All model output for KD are dumped using json serialization.
            (
                id1,
                doc,
                question,
                answers,
                answer_starts,
                has_answer,
                id2,
                start_logits,
                end_logits,
                has_answer_logits,
                pad_mask,
                segment_labels,
            ) = (json.loads(s) for s in parts)
            if isinstance(question, list):
                # if we have paraphrases for question
                question = choice(question)
            for piece_dict in _split_document(
                    id,
                    doc,
                    question,
                    answers,
                    answer_starts,
                    has_answer == "True",
                    self.ignore_impossible,
                    self.max_character_length,
                    self.min_overlap,
            ):
                piece_dict.update({
                    "start_logits": start_logits,
                    "end_logits": end_logits,
                    "has_answer_logits": has_answer_logits,
                    "pad_mask": pad_mask,
                    "segment_labels": segment_labels,
                })
                yield piece_dict