예제 #1
0
    def text_to_instance(self, text: str,label:str) -> Instance:  # type: ignore
        """
        # Parameters

        text : `str`, required.
            The text to process.

        # Returns

        An `Instance` containing the following fields:
            - anchors (`Union[TextField, ListField[TextField]]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each anchor span sampled from `text`. Otherwise, this will be a `TextField` object
                containing the tokenized `text`.
            - positives (`ListField[TextField]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each positive span sampled from `text`. Otherwise this field will not be included
                in the returned `Instance`.
        """
        # Some very minimal preprocessing to remove whitespace, newlines and tabs.
        # We peform it here as it will cover both training and predicting with the model.
        # We DON'T lowercase by default, but rather allow `self._tokenizer` to decide.
        #print(label)
        #print(text)
        text = sanitize(text, lowercase=False)

        fields: Dict[str, Field] = {}
        if self.sample_spans:
            # Choose the anchor/positives at random.
            anchor_text, positive_text = sample_anchor_positive_pairs(
                text=text,
                num_anchors=self._num_anchors,
                num_positives=self._num_positives,
                max_span_len=self._max_span_len,
                min_span_len=self._min_span_len,
                sampling_strategy=self._sampling_strategy,
            )
            anchors: List[Field] = []
            for text in anchor_text:
                tokens = self._tokenizer.tokenize(text)
                anchors.append(TextField(tokens, self._token_indexers))
            fields["anchors"] = ListField(anchors)
        
            positives: List[Field] = []
            for text in positive_text:
                tokens = self._tokenizer.tokenize(text)
                positives.append(TextField(tokens, self._token_indexers))
            fields["positives"] = ListField(positives)
            #ltokens = self._tokenizer.tokenize(label)
            #fields["label"] = TextField(ltokens, self._token_indexers)
            fields["label"] = LabelField(str(label))
        else:
            tokens = self._tokenizer.tokenize(text)
            fields["anchors"] = TextField(tokens, self._token_indexers)
            #ltokens = self._tokenizer.tokenize(label)
            #fields["label"] = TextField(ltoken, self._token_indexers)
            fields["label"] = LabelField(str(label))
        return Instance(fields)
예제 #2
0
    def test_sanitize(self, text: str):
        sanitized_text = sanitize(text)

        # There should be no cases of multiple spaces or tabs
        assert re.search(r"[ ]{2,}", sanitized_text) is None
        assert "\t" not in sanitized_text
        # The beginning and end of the string should be stripped of whitespace
        assert not sanitized_text.startswith(("\n", " "))
        assert not sanitized_text.endswith(("\n", " "))
예제 #3
0
    def test_sanitize(self, text: str, lowercase: bool) -> None:
        sanitized_text = sanitize(text, lowercase=lowercase)

        # There should be no cases of multiple spaces or tabs
        assert re.search(r"[ ]{2,}", sanitized_text) is None
        assert "\t" not in sanitized_text
        # The beginning and end of the string should be stripped of whitespace
        assert not sanitized_text.startswith(("\n", " "))
        assert not sanitized_text.endswith(("\n", " "))
        if lowercase:
            assert all(not char.isupper() for char in sanitized_text)
예제 #4
0
    def test_sanitize(self, text: str, lowercase: bool) -> None:
        sanitized_text = sanitize(text, lowercase=lowercase)

        # There should be no cases of multiple spaces or tabs
        assert re.search(r"[ ]{2,}", sanitized_text) is None
        assert "\t" not in sanitized_text
        # The beginning and end of the string should be stripped of whitespace
        assert not sanitized_text.startswith(("\n", " "))
        assert not sanitized_text.endswith(("\n", " "))
        # Sometimes, hypothesis generates text that cannot be lowercased (like latin characters).
        # We don't particularly care about this, and it breaks this check.
        # Only run if the generated text can be lowercased.
        if lowercase and text.lower().islower():
            assert all(not char.isupper() for char in sanitized_text)
예제 #5
0
def main(
    output_filepath: Union[str, Path],
    segment_sentences: bool = False,
    lowercase: bool = False,
    min_length: Optional[int] = None,
    max_instances: Optional[int] = None,
    pretrained_model_name_or_path: Optional[str] = None,
) -> None:
    """Downloads and lightly preprocesses WikiText-103. If `min_length is not None`, only documents
    with at least this many tokens are retained. If `pretrained_model_name_or_path` is not None, the
    tokenizer will be loaded as `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
    using the HuggingFace Transformers library. Otherwise `str.split()` is used. This argument has
    no effect if `min-length is None`. If `segment_sentences` is provided, individual sentences
    will be returned instead of documents. You must have the `"en_core_web_sm"` spacy model
    installed to segment sentences.
    """
    # Setup the pre-trained tokenizer, if specified
    if min_length is not None:
        if pretrained_model_name_or_path is not None:
            # Import transformers here to prevent ImportError errors if the
            # user doesn't want to use it.
            from transformers import AutoTokenizer

            tokenizer = AutoTokenizer.from_pretrained(
                pretrained_model_name_or_path).tokenize
        else:
            tokenizer = lambda x: x.split()  # noqa
    else:
        tokenizer = None

    # Setup spacy lang object if we are segmenting sentences
    if segment_sentences:
        import spacy

        nlp = spacy.load("en_core_web_sm", disable=["ner"])

    # Download WikiText-103
    r = requests.get(WIKITEXT_103_URL, stream=True)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    partition_filenames = z.namelist()[1:]
    typer.secho(f"{DOWNLOAD} Downloaded WikiText-103", bold=True)

    preprocessed_documents: List[str] = []
    for filename in partition_filenames:
        text = z.open(filename).read().decode("utf-8")

        # Strip out subtitles and split the text into documents
        no_subtitles = re.sub(r"(=\s){2,5}.*(=\s){2,5}", "", text)
        documents = re.split(r"=\s.*\s=", no_subtitles)

        if segment_sentences:
            documents = (sent.text for doc in documents
                         for sent in nlp(doc).sents)  # type: ignore

        with typer.progressbar(documents,
                               length=max_instances,
                               label=typer.style("Preprocessing text",
                                                 bold=True)) as progress:
            for doc in progress:
                doc = sanitize(doc, lowercase=lowercase)
                if not doc:
                    continue

                # Retain documents if the length of their shortest document is
                # equal to or greater than the minimum specified length
                if tokenizer is not None:
                    num_tokens = len(tokenizer(doc))
                    if min_length and num_tokens < min_length:
                        continue

                if max_instances and len(
                        preprocessed_documents) >= max_instances:
                    break
                preprocessed_documents.append(doc)
                progress.update(1)

    _write_output_to_disk(preprocessed_documents, output_filepath)
예제 #6
0
    def text_to_instance(self, text: str) -> Instance:  # type: ignore
        """
        # Parameters

        text : `str`, required.
            The text to process.

        # Returns

        An `Instance` containing the following fields:
            - anchors (`Union[TextField, ListField[TextField]]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each anchor span sampled from `text`. Otherwise, this will be a `TextField` object
                containing the tokenized `text`.
            - positives (`ListField[TextField]`) :
                If `self.sample_spans`, this will be a `ListField[TextField]` object, containing
                each positive span sampled from `text`. Otherwise this field will not be included
                in the returned `Instance`.
        """
        # Some very minimal preprocessing to remove whitespace, newlines and tabs.
        # We peform it here as it will cover both training and predicting with the model.
        # We DON'T lowercase by default, but rather allow `self._tokenizer` to decide.
        text = sanitize(text, lowercase=False)

        difficulty_step = int(self.instance / 49784) + 1
        # difficulty_step = self.instance
        # difficulty_step = -100
        # difficulty_step_sample = int(self.instance / 165944) + 1
        # difficulty_step_sample = int(self.instance / 82972) - 2
        # difficulty_step = int(self.instance / 24 ) + 1
        # difficulty_step_sample = int(self.instance /44) - 2
        # difficulty_step_sample = int(self.instance / 80 ) + 1
        self.instance += 1

        fields: Dict[str, Field] = {}
        if self.sample_spans:
            # print("reading instance is", self.instance)
            # difficulty_step = int(self.instance / 40 ) + 1

            # # print("difficulty step is ",difficulty_step)
            # if difficulty_step > 5 :
            # # if difficulty_step > 2 :
            #     # self._num_anchors = 2
            #     # self._num_anchors = int(difficulty_step /2) + 1
            #     # self._num_anchors = int((difficulty_step - 1)/2) - 1
            #     self._num_anchors = difficulty_step_sample
            #     if self._num_anchors > 3:
            #         # print("over anchor!")
            #         self._num_anchors = 3
            #     # self._num_anchors = random.randint(1, self._num_anchors)
            #     # print("num_anchors", self._num_anchors, self.instance, difficulty_step)
            #     # sample_difficulty = difficulty_step
            #     sample_difficulty = 1
            # else:
            #     sample_difficulty = 1
            # self._num_anchors = difficulty_step_sample
            # if difficulty_step_sample <=0 :
            #     self._num_anchors = 1
            # if self._num_anchors > 3:
            #     # print("over anchor!")
            #     self._num_anchors = 3
            sample_difficulty = 1
            # print("anchor num is", self._num_anchors)

            # fields["text"] = LabelField(len(text), skip_indexing=True)
            # Choose the anchor/positives at random.
            # anchor_text, positive_text = sample_anchor_positive_pairs(
            #     text=text,
            #     num_anchors=self._num_anchors,
            #     num_positives=self._num_positives,
            #     max_span_len=self._max_span_len,
            #     min_span_len=self._min_span_len,
            #     difficulty_step = sample_difficulty,
            #     sampling_strategy=self._sampling_strategy,
            # )
            # # print("anchor_text", anchor_text)
            # # print("positive_text", positive_text)
            # anchors: List[Field] = []
            # for text in anchor_text:
            #     tokens = self._tokenizer.tokenize(text)
            #     anchors.append(TextField(tokens, self._token_indexers))
            # fields["anchors"] = ListField(anchors)
            # positives: List[Field] = []
            # for text in positive_text:
            #     tokens = self._tokenizer.tokenize(text)
            #     positives.append(TextField(tokens, self._token_indexers))
            # fields["positives"] = ListField(positives)
            # fields["difficulty"] = LabelField(difficulty_step, skip_indexing=True)
            anchor_text = sample_anchor_positive_pairs(
                text=text,
                num_anchors=self._num_anchors,
                num_positives=self._num_positives,
                max_span_len=self._max_span_len,
                min_span_len=self._min_span_len,
                difficulty_step=sample_difficulty,
                sampling_strategy=self._sampling_strategy,
            )
            # print("anchor_text", anchor_text)
            # print("positive_text", positive_text)
            anchors: List[Field] = []
            for text in anchor_text:
                tokens = self._tokenizer.tokenize(text)
                anchors.append(TextField(tokens, self._token_indexers))
            print("number of token is", len(tokens))
            fields["anchors"] = ListField(anchors)
            fields["difficulty"] = LabelField(difficulty_step,
                                              skip_indexing=True)
        else:
            # print("no sampling")
            tokens = self._tokenizer.tokenize(text)
            print("number of token is", len(tokens))
            fields["anchors"] = TextField(tokens, self._token_indexers)
            fields["difficulty"] = LabelField(difficulty_step,
                                              skip_indexing=True)
        return Instance(fields)
예제 #7
0
    def __call__(self,
                 inputs: Union[str, List[str]],
                 batch_size: Optional[int] = None) -> torch.Tensor:
        """Returns a numpy array of embeddings, one for each item in `inputs`.

        # Parameters

        inputs : `Union[str, List[str]]`, required
            The input text to embed. Can be a string, list of strings, or a filepath/URL to a text
            file with one input per line.
        batch_size : `int`, optional
            If given, the `inputs` will be batched before embedding.
        """
        if isinstance(inputs, str):
            if Path(inputs).is_file() or url(inputs):
                inputs = Path(cached_path(inputs)).read_text().split("\n")
            else:
                inputs = [inputs]

        if batch_size is None:
            unsort = False
            batch_size = len(inputs)
        else:
            # Sort the inputs by length, maintaining the original indices so we can un-sort
            # before returning the embeddings. This speeds up embedding by minimizing the
            # amount of computation performed on pads. Because this sorting happens before
            # tokenization, it is only a proxy of the true lengths of the inputs to the model.
            # In the future, it would be better to use the built-in bucket sort of AllenNLP,
            # which would lead to an even larger speedup.
            unsort = True
            sorted_indices, inputs = zip(
                *sorted(enumerate(inputs), key=itemgetter(1)))
            unsorted_indices, _ = zip(
                *sorted(enumerate(sorted_indices), key=itemgetter(1)))

        inputs = [{"text": sanitize(input_)} for input_ in inputs]

        embeddings = []
        for i in range(0, len(inputs), batch_size):
            outputs = self._predictor.predict_batch_json(inputs[i:i +
                                                                batch_size])
            outputs = torch.as_tensor(
                # Accumulating the tensors on the GPU would quickly lead to OOM.
                [output[self._output_dict_field] for output in outputs],
                device="cpu",
            )
            embeddings.append(outputs)
        embeddings = torch.cat(embeddings)
        # Make sure to unsort the embeddings if they were sorted.
        if unsort:
            unsorted_indices = torch.as_tensor(unsorted_indices,
                                               dtype=torch.long)
            embeddings = torch.index_select(embeddings,
                                            dim=0,
                                            index=unsorted_indices)
        if self._sphereize:
            if embeddings.size(0) > 1:
                centroid = torch.mean(embeddings, dim=0)
                embeddings -= centroid
                embeddings /= torch.norm(embeddings, dim=1, keepdim=True)
            else:
                warnings.warn(
                    "sphereize==True but only a single input sentence was passed."
                    " Inputs will not be sphereized.")

        return embeddings.numpy()
예제 #8
0
def main(
    openwebtext_path: Union[str, Path],
    output_filepath: Union[str, Path],
    min_length: Optional[int] = None,
    lowercase: bool = True,
    max_documents: Optional[int] = None,
    pretrained_model_name_or_path: Optional[str] = None,
) -> None:
    """Lightly preprocesses an OpenWebText dump obtained from
    https://skylion007.github.io/OpenWebTextCorpus/. If `min-length is not None`, only documents
    with at least this many tokens are retained. If `pretrained_model_name_or_path` is not None,
    the tokenizer will be loaded as `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
    using the HuggingFace Transformers library. Otherwise `str.split()` is used. This argument has
    no effect if `min-length is None`.
    """
    openwebtext_path = Path(openwebtext_path)

    # Setup the pre-trained tokenizer, if specified
    if min_length is not None:
        if pretrained_model_name_or_path is not None:
            # Import transformers here to prevent ImportError errors if the
            # user doesn't want to use it.
            from transformers import AutoTokenizer

            tokenizer = AutoTokenizer.from_pretrained(
                pretrained_model_name_or_path).tokenize
        else:
            tokenizer = lambda x: x.split()  # noqa
    else:
        tokenizer = None

    early_exit = False
    documents = []
    skipped_files = 0
    typer.secho(
        (f'{MINING} Scraping {max_documents or "all"} documents'
         f' {f"with a minimum token length of {min_length}" if min_length else ""}'
         ),
        bold=True,
    )

    with typer.progressbar(length=max_documents
                           or len(list(openwebtext_path.iterdir())),
                           label="Preprocessing text") as progress:
        for i, tar_filepath in enumerate(openwebtext_path.iterdir()):
            # Didn't bother debugging as it only happens for a tiny number (1-2) of tar archives.
            # Instead, catch the error and report to the user at the end how many we skipped.
            untared_filepath = Path(tar_filepath.stem)
            try:
                with tarfile.open(tar_filepath) as f:
                    f.extractall(untared_filepath)
            except (tarfile.ReadError, IsADirectoryError):
                skipped_files += 1
                continue

            for text_filepath in untared_filepath.iterdir():
                text = text_filepath.read_text()
                text = sanitize(text, lowercase=lowercase)
                if not text:
                    continue

                # Retain documents if the length of their shortest document is
                # equal to or greater than the minimum specified length
                if tokenizer is not None:
                    num_tokens = len(tokenizer(text))
                    if min_length and num_tokens < min_length:
                        continue

                documents.append(text)
                if max_documents:
                    progress.update(1)

                if max_documents and len(documents) == max_documents:
                    early_exit = True
                    break

            shutil.rmtree(untared_filepath)
            if max_documents is None:
                progress.update(1)
            if early_exit:
                break

    if skipped_files > 0:
        typer.secho(
            f"{WARNING} {skipped_files} tar files were skipped because they couldn't be extracted.",
            fg=typer.colors.YELLOW,
            bold=True,
        )

    _write_output_to_disk(documents, output_filepath)