Exemplo n.º 1
0
    def from_config(cls, config: Config, schema: Dict[str, Type], **kwargs):
        args = config._asdict()
        train_filename = args.pop("train_filename")
        test_filename = args.pop("test_filename")
        eval_filename = args.pop("eval_filename")

        train_file = (SafeFileWrapper(get_absolute_path(train_filename),
                                      encoding="utf-8",
                                      errors="replace")
                      if train_filename else None)
        test_file = (SafeFileWrapper(get_absolute_path(test_filename),
                                     encoding="utf-8",
                                     errors="replace")
                     if test_filename else None)
        eval_file = (SafeFileWrapper(get_absolute_path(eval_filename),
                                     encoding="utf-8",
                                     errors="replace")
                     if eval_filename else None)
        return cls(
            train_file=train_file,
            test_file=test_file,
            eval_file=eval_file,
            schema=schema,
            **args,
            **kwargs,
        )
Exemplo n.º 2
0
    def load_vocab(self,
                   vocab_file,
                   vocab_size,
                   lowercase_tokens: bool = False):
        """
        Loads items into a set from a file containing one item per line.
        Items are added to the set from top of the file to bottom.
        So, the items in the file should be ordered by a preference (if any), e.g.,
        it makes sense to order tokens in descending order of frequency in corpus.

        Args:
            vocab_file (str): vocab file to load
            vocab_size (int): maximum tokens to load, will only load the first n if
                the actual vocab size is larger than this parameter
            lowercase_tokens (bool): if the tokens should be lowercased
        """
        vocab: Set[str] = set()
        vocab_file = get_absolute_path(vocab_file)
        if PathManager.isfile(vocab_file):
            with PathManager.open(vocab_file, "r") as f:
                for i, line in enumerate(f):
                    if vocab_size > 0 and len(vocab) == vocab_size:
                        print(f"Read {i+1} items from {vocab_file} "
                              f"to load vocab of size {vocab_size}. "
                              f"Skipping rest of the file")
                        break
                    line = line.strip()
                    vocab.add(line.lower() if lowercase_tokens else line)
        elif not vocab_file:
            print(
                f"{vocab_file} doesn't exist. Cannot load vocabulary from it")
        return vocab
Exemplo n.º 3
0
def process_squad_tsv(fname, ignore_impossible, max_character_length,
                      min_overlap):
    if not fname:
        print(f"Empty file name!")
        return

    field_names = ["doc", "question", "answers", "answer_starts", "has_answer"]
    tsv_file = SafeFileWrapper(get_absolute_path(fname),
                               encoding="utf-8",
                               errors="replace")
    tsv = TSV(
        tsv_file,
        field_names=field_names,
        delimiter="\t",
        quoted=False,
        drop_incomplete_rows=True,
    )

    for id, row in enumerate(tsv):
        doc, question, answers, answer_starts, has_answer = (
            row[f] for f in field_names)
        answers = json.loads(answers)
        answer_starts = json.loads(answer_starts)
        for piece_dict in _split_document(
                id,
                doc,
                question,
                answers,
                answer_starts,
                has_answer == "True",
                ignore_impossible,
                max_character_length,
                min_overlap,
        ):
            yield piece_dict
Exemplo n.º 4
0
 def DISABLED_test_load_all_configs(self):
     """
         Try an load all the json files in pytext to make sure we didn't
         break the config API.
     """
     print()
     exclude_json_path = {*[get_absolute_path(p) for p in EXCLUDE_JSON]}
     exclude_json_dir = {*[get_absolute_path(p) for p in EXCLUDE_DIRS]}
     for filename in glob.iglob("./**/*.json", recursive=True):
         filepath = get_absolute_path(filename)
         if filepath in exclude_json_path:
             continue
         if any(filepath.startswith(prefix) for prefix in exclude_json_dir):
             continue
         print("--- loading:", filepath)
         with open(filepath) as file:
             config_json = json.load(file)
             config = parse_config(config_json)
             self.assertIsNotNone(config)
Exemplo n.º 5
0
 def __init__(
     self,
     path: str,
     column_names: Optional[List[str]] = None,
     delimiter: str = "\t",
     **kwargs,
 ):
     self.path = get_absolute_path(path)
     self.column_names = column_names
     self.delimiter = delimiter
     super().__init__(self, **kwargs)
Exemplo n.º 6
0
    def process_squad_tsv(self, fname):
        if not fname:
            print("Empty file name!")
            return

        field_names = [
            "doc", "question", "answers", "answer_starts", "has_answer"
        ]
        tsv_file = SafeFileWrapper(get_absolute_path(fname),
                                   encoding="utf-8",
                                   errors="replace")
        tsv = TSV(
            tsv_file,
            field_names=field_names,
            delimiter=self.delimiter,
            quoted=self.quoted,
            drop_incomplete_rows=True,
        )

        for id, row in enumerate(tsv):
            parts = (row[f] for f in field_names)
            doc, question, answers, answer_starts, has_answer = parts
            try:
                # if we have paraphrases for question
                question = json.loads(question)
                if isinstance(question, list):
                    question = choice(question)
            except ValueError:
                pass
            answers = json.loads(answers)
            answer_starts = json.loads(answer_starts)

            if has_answer != "True":
                answers = []
                answer_starts = []

            for piece_dict in _split_document(
                    id,
                    doc,
                    question,
                    answers,
                    answer_starts,
                    has_answer == "True",
                    self.ignore_impossible,
                    self.max_character_length,
                    self.min_overlap,
            ):
                yield piece_dict
Exemplo n.º 7
0
    def __init__(
        self,
        embeddings_path: str = None,
        lowercase_tokens: bool = True,
        skip_header: bool = True,
        delimiter: str = " ",
    ) -> None:
        self.lowercase_tokens = lowercase_tokens
        if embeddings_path:
            embeddings_path = get_absolute_path(embeddings_path)
            if PathManager.isdir(embeddings_path):
                serialized_embed_path = os.path.join(
                    embeddings_path, PackageFileName.SERIALIZED_EMBED)
                raw_embeddings_path = os.path.join(embeddings_path,
                                                   PackageFileName.RAW_EMBED)
            elif PathManager.isfile(embeddings_path):
                serialized_embed_path = ""
                raw_embeddings_path = embeddings_path
            else:
                raise FileNotFoundError(
                    f"{embeddings_path} not found. Can't load pretrained embeddings."
                )

            if PathManager.isfile(serialized_embed_path):
                try:
                    self.load_cached_embeddings(serialized_embed_path)
                except Exception:
                    print(
                        "Failed to load cached embeddings, loading the raw file."
                    )
                    self.load_pretrained_embeddings(
                        raw_embeddings_path,
                        lowercase_tokens=lowercase_tokens,
                        skip_header=skip_header,
                        delimiter=delimiter,
                    )
            else:
                self.load_pretrained_embeddings(
                    raw_embeddings_path,
                    lowercase_tokens=lowercase_tokens,
                    skip_header=skip_header,
                    delimiter=delimiter,
                )
        else:
            self.embed_vocab = []  # type: List[str]
            self.stoi = {}  # type: Dict[str, int]
            self.embedding_vectors = None  # type: torch.Tensor
Exemplo n.º 8
0
    def read_from_file(
        self, file_name: str, columns_to_use: Union[Dict[str, int], List[str]]
    ) -> Generator[Dict, None, None]:
        """
        Read data from csv file. Input file format is required to be
        tab-separated columns

        Args:
            file_name (str): csv file name
            columns_to_use (Union[Dict[str, int], List[str]]): either a list of
                column names or a dict of column name -> column index in the file
        """
        file_name = get_absolute_path(file_name)
        print("reading data from {}".format(file_name))
        if isinstance(columns_to_use, list):
            columns_to_use = {
                name: idx
                for name, idx in zip(columns_to_use, range(len(
                    columns_to_use)))
            }

        with PathManager.open(file_name,
                              "r",
                              encoding="utf-8",
                              errors="replace") as f_handle:
            csv_reader = csv.reader(f_handle,
                                    delimiter="\t",
                                    quoting=csv.QUOTE_NONE)
            i = 0
            while True:
                i += 1
                try:
                    row = next(csv_reader)
                except csv.Error:
                    print("ignoring line {}".format(i))
                    continue
                except StopIteration:
                    break

                yield {
                    name: row[index] if index < len(row) else ""
                    for name, index in columns_to_use.items()
                }
Exemplo n.º 9
0
    def process_squad_tsv(self, fname):
        # Process SQUAD TSV for KD
        if not fname:
            print("Empty file name!")
            return
        field_names = [
            "id1",
            "doc",
            "question",
            "answers",
            "answer_starts",
            "has_answer",
            "id2",
            "start_logits",
            "end_logits",
            "has_answer_logits",
            "pad_mask",
            "segment_labels",
        ]
        tsv_file = SafeFileWrapper(get_absolute_path(fname),
                                   encoding="utf-8",
                                   errors="replace")
        tsv = TSV(
            tsv_file,
            field_names=field_names,
            delimiter=self.delimiter,
            quoted=self.quoted,
            drop_incomplete_rows=True,
        )

        for id, row in enumerate(tsv):
            parts = (row[f] for f in field_names)
            # All model output for KD are dumped using json serialization.
            (
                id1,
                doc,
                question,
                answers,
                answer_starts,
                has_answer,
                id2,
                start_logits,
                end_logits,
                has_answer_logits,
                pad_mask,
                segment_labels,
            ) = (json.loads(s) for s in parts)
            if isinstance(question, list):
                # if we have paraphrases for question
                question = choice(question)
            for piece_dict in _split_document(
                    id,
                    doc,
                    question,
                    answers,
                    answer_starts,
                    has_answer == "True",
                    self.ignore_impossible,
                    self.max_character_length,
                    self.min_overlap,
            ):
                piece_dict.update({
                    "start_logits": start_logits,
                    "end_logits": end_logits,
                    "has_answer_logits": has_answer_logits,
                    "pad_mask": pad_mask,
                    "segment_labels": segment_labels,
                })
                yield piece_dict
Exemplo n.º 10
0
 def __init__(self, path: str, **kwargs):
     self.path = get_absolute_path(path)
     super().__init__(self, **kwargs)