Example #1
0
    def normalize(self,
                  text: str,
                  verbose: bool = False,
                  punct_pre_process: bool = False,
                  punct_post_process: bool = False) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
            e.g. 12 kg -> twelve kilograms

        Args:
            text: string that may include semiotic classes
            verbose: whether to print intermediate meta information
            punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
            punct_post_process: whether to normalize punctuation

        Returns: spoken form
        """
        original_text = text
        if punct_pre_process:
            text = pre_process(text)
        text = text.strip()
        if not text:
            if verbose:
                print(text)
            return text
        text = pynini.escape(text)
        tagged_lattice = self.find_tags(text)
        tagged_text = self.select_tag(tagged_lattice)
        if verbose:
            print(tagged_text)
        self.parser(tagged_text)
        tokens = self.parser.parse()
        tags_reordered = self.generate_permutations(tokens)
        for tagged_text in tags_reordered:
            tagged_text = pynini.escape(tagged_text)

            verbalizer_lattice = self.find_verbalizer(tagged_text)
            if verbalizer_lattice.num_states() == 0:
                continue
            output = self.select_verbalizer(verbalizer_lattice)
            if punct_post_process:
                # do post-processing based on Moses detokenizer
                if self.processor:
                    output = self.processor.moses_detokenizer.detokenize(
                        [output], unescape=False)
                    output = post_process_punct(input=original_text,
                                                normalized_text=output)
                else:
                    print(
                        "NEMO_NLP collection is not available: skipping punctuation post_processing"
                    )
            return output
        raise ValueError()
Example #2
0
    def normalize(
        self,
        text: str,
        n_tagged: int,
        punct_post_process: bool = True,
        verbose: bool = False,
    ) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
            e.g. 12 kg -> twelve kilograms

        Args:
            text: string that may include semiotic classes
            n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
            punct_post_process: whether to normalize punctuation
            verbose: whether to print intermediate meta information

        Returns:
            normalized text options (usually there are multiple ways of normalizing a given semiotic class)
        """

        assert (
            len(text.split()) < 500
        ), "Your input is too long. Please split up the input into sentences, or strings with fewer than 500 words"
        original_text = text
        text = pre_process(text)  # to handle []

        text = text.strip()
        if not text:
            if verbose:
                print(text)
            return text
        text = pynini.escape(text)

        if self.lm:
            if self.lang not in ["en"]:
                raise ValueError(f"{self.lang} is not supported in LM mode")

            if self.lang == "en":
                try:
                    lattice = rewrite.rewrite_lattice(
                        text, self.tagger.fst_no_digits)
                except pynini.lib.rewrite.Error:
                    lattice = rewrite.rewrite_lattice(text, self.tagger.fst)
                lattice = rewrite.lattice_to_nshortest(lattice, n_tagged)
                tagged_texts = [(x[1], float(x[2]))
                                for x in lattice.paths().items()]
                tagged_texts.sort(key=lambda x: x[1])
                tagged_texts, weights = list(zip(*tagged_texts))
        else:
            if n_tagged == -1:
                if self.lang == "en":
                    try:
                        tagged_texts = rewrite.rewrites(
                            text, self.tagger.fst_no_digits)
                    except pynini.lib.rewrite.Error:
                        tagged_texts = rewrite.rewrites(text, self.tagger.fst)
                else:
                    tagged_texts = rewrite.rewrites(text, self.tagger.fst)
            else:
                if self.lang == "en":
                    try:
                        tagged_texts = rewrite.top_rewrites(
                            text,
                            self.tagger.fst_no_digits,
                            nshortest=n_tagged)
                    except pynini.lib.rewrite.Error:
                        tagged_texts = rewrite.top_rewrites(text,
                                                            self.tagger.fst,
                                                            nshortest=n_tagged)
                else:
                    tagged_texts = rewrite.top_rewrites(text,
                                                        self.tagger.fst,
                                                        nshortest=n_tagged)

        # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
        if self.lang == "en":
            normalized_texts = tagged_texts
        else:
            normalized_texts = []
            for tagged_text in tagged_texts:
                self._verbalize(tagged_text, normalized_texts, verbose=verbose)

        if len(normalized_texts) == 0:
            raise ValueError()

        if punct_post_process:
            # do post-processing based on Moses detokenizer
            if self.processor:
                normalized_texts = [
                    self.processor.detokenize([t]) for t in normalized_texts
                ]
                normalized_texts = [
                    post_process_punct(input=original_text, normalized_text=t)
                    for t in normalized_texts
                ]

        if self.lm:
            return normalized_texts, weights

        normalized_texts = set(normalized_texts)
        return normalized_texts
Example #3
0
    def normalize(self,
                  text: str,
                  verbose: bool = False,
                  punct_pre_process: bool = False,
                  punct_post_process: bool = False) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
            e.g. 12 kg -> twelve kilograms

        Args:
            text: string that may include semiotic classes
            verbose: whether to print intermediate meta information
            punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
            punct_post_process: whether to normalize punctuation

        Returns: spoken form
        """
        assert (
            len(text.split()) < 500
        ), "Your input is too long. Please split up the input into sentences, or strings with fewer than 500 words"

        original_text = text
        if punct_pre_process:
            text = pre_process(text)
        text = text.strip()
        if not text:
            if verbose:
                print(text)
            return text
        text = pynini.escape(text)
        tagged_lattice = self.find_tags(text)
        tagged_text = self.select_tag(tagged_lattice)
        if verbose:
            print(tagged_text)
        self.parser(tagged_text)
        tokens = self.parser.parse()
        split_tokens = self._split_tokens_to_reduce_number_of_permutations(
            tokens)
        output = ""
        for s in split_tokens:
            tags_reordered = self.generate_permutations(s)
            verbalizer_lattice = None
            for tagged_text in tags_reordered:
                tagged_text = pynini.escape(tagged_text)

                verbalizer_lattice = self.find_verbalizer(tagged_text)
                if verbalizer_lattice.num_states() != 0:
                    break
            if verbalizer_lattice is None:
                raise ValueError(
                    f"No permutations were generated from tokens {s}")
            output += ' ' + self.select_verbalizer(verbalizer_lattice)
        output = SPACE_DUP.sub(' ', output[1:])

        if self.lang == "en" and hasattr(self, 'post_processor'):
            output = self.post_process(output)

        if punct_post_process:
            # do post-processing based on Moses detokenizer
            if self.processor:
                output = self.processor.moses_detokenizer.detokenize(
                    [output], unescape=False)
                output = post_process_punct(input=original_text,
                                            normalized_text=output)
            else:
                print(
                    "NEMO_NLP collection is not available: skipping punctuation post_processing"
                )

        return output
Example #4
0
    def normalize(
        self,
        text: str,
        n_tagged: int,
        punct_post_process: bool = True,
        verbose: bool = False,
    ) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
            e.g. 12 kg -> twelve kilograms

        Args:
            text: string that may include semiotic classes
            n_tagged: number of tagged options to consider, -1 - to get all possible tagged options
            punct_post_process: whether to normalize punctuation
            verbose: whether to print intermediate meta information

        Returns:
            normalized text options (usually there are multiple ways of normalizing a given semiotic class)
        """
        original_text = text

        if self.lang == "en":
            text = pre_process(text)
        text = text.strip()
        if not text:
            if verbose:
                print(text)
            return text
        text = pynini.escape(text)

        if n_tagged == -1:
            if self.lang == "en":
                try:
                    tagged_texts = rewrite.rewrites(text,
                                                    self.tagger.fst_no_digits)
                except pynini.lib.rewrite.Error:
                    tagged_texts = rewrite.rewrites(text, self.tagger.fst)
            else:
                tagged_texts = rewrite.rewrites(text, self.tagger.fst)
        else:
            if self.lang == "en":
                try:
                    tagged_texts = rewrite.top_rewrites(
                        text, self.tagger.fst_no_digits, nshortest=n_tagged)
                except pynini.lib.rewrite.Error:
                    tagged_texts = rewrite.top_rewrites(text,
                                                        self.tagger.fst,
                                                        nshortest=n_tagged)
            else:
                tagged_texts = rewrite.top_rewrites(text,
                                                    self.tagger.fst,
                                                    nshortest=n_tagged)

        # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between
        if self.lang == "en":
            normalized_texts = tagged_texts
        else:
            normalized_texts = []
            for tagged_text in tagged_texts:
                self._verbalize(tagged_text, normalized_texts, verbose=verbose)

        if len(normalized_texts) == 0:
            raise ValueError()

        if punct_post_process:
            # do post-processing based on Moses detokenizer
            if self.processor:
                normalized_texts = [
                    self.processor.detokenize([t]) for t in normalized_texts
                ]
                normalized_texts = [
                    post_process_punct(input=original_text, normalized_text=t)
                    for t in normalized_texts
                ]

        normalized_texts = set(normalized_texts)
        return normalized_texts
Example #5
0
    def _infer(self,
               sents: List[str],
               inst_directions: List[str],
               processed=False):
        """
        Main function for Inference

        If the 'joint' mode is used, "sents" will include both spoken and written forms on each input sentence,
        and "inst_directions" will include both constants.INST_BACKWARD and constants.INST_FORWARD

        Args:
            sents: A list of input texts.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance \
                (i.e., constants.INST_BACKWARD for ITN or constants.INST_FORWARD for TN).
            processed: Set to True when used with TextNormalizationTestDataset, the data is already tokenized with moses,
                repetitive moses tokenization could lead to the number of tokens and class span mismatch

        Returns:
            tag_preds: A list of lists where the inner list contains the tag predictions from the tagger for each word in the input text.
            output_spans: A list of lists where each list contains the decoded semiotic spans from the decoder for an input text.
            final_outputs: A list of str where each str is the final output text for an input text.
        """
        original_sents = [s for s in sents]
        # Separate into words
        if not processed:
            sents = [input_preprocessing(x, lang=self.lang) for x in sents]
            sents = [self.decoder.processor.tokenize(x).split() for x in sents]
        else:
            sents = [x.split() for x in sents]

        # Tagging
        # span_ends included, returns index wrt to words in input without auxiliary words
        tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(
            sents, inst_directions)
        output_spans = self.decoder._infer(sents, nb_spans, span_starts,
                                           span_ends, inst_directions)

        # Prepare final outputs
        final_outputs = []
        for ix, (sent, tags) in enumerate(zip(sents, tag_preds)):
            try:
                cur_words, jx, span_idx = [], 0, 0
                cur_spans = output_spans[ix]
                while jx < len(sent):
                    tag, word = tags[jx], sent[jx]
                    if constants.SAME_TAG in tag:
                        cur_words.append(word)
                        jx += 1
                    else:
                        jx += 1
                        cur_words.append(cur_spans[span_idx])
                        span_idx += 1
                        while jx < len(sent) and tags[
                                jx] == constants.I_PREFIX + constants.TRANSFORM_TAG:
                            jx += 1

                if processed:
                    # for Class-based evaluation, don't apply Moses detokenization
                    cur_output_str = " ".join(cur_words)
                else:
                    # detokenize the output with Moses and fix punctuation marks to match the input
                    # for interactive inference or inference from a file
                    cur_output_str = self.decoder.processor.detokenize(
                        cur_words)
                    cur_output_str = post_process_punct(
                        input=original_sents[ix],
                        normalized_text=cur_output_str)
                final_outputs.append(cur_output_str)
            except IndexError:
                logging.warning(
                    f"Input sent is too long and will be skipped - {original_sents[ix]}"
                )
                final_outputs.append(original_sents[ix])
        return tag_preds, output_spans, final_outputs